Merge branch 'fix/chore-fix' into dev/plugin-deploy
commit
7cea6c1713
@ -1,3 +1,3 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
poetry install -C api
|
cd api && poetry install
|
||||||
@ -0,0 +1,6 @@
|
|||||||
|
from werkzeug.exceptions import HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
class FilenameNotExistsError(HTTPException):
|
||||||
|
code = 400
|
||||||
|
description = "The specified filename does not exist."
|
||||||
@ -0,0 +1,58 @@
|
|||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import urllib.parse
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class FileInfo(BaseModel):
|
||||||
|
filename: str
|
||||||
|
extension: str
|
||||||
|
mimetype: str
|
||||||
|
size: int
|
||||||
|
|
||||||
|
|
||||||
|
def guess_file_info_from_response(response: httpx.Response):
|
||||||
|
url = str(response.url)
|
||||||
|
# Try to extract filename from URL
|
||||||
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
|
url_path = parsed_url.path
|
||||||
|
filename = os.path.basename(url_path)
|
||||||
|
|
||||||
|
# If filename couldn't be extracted, use Content-Disposition header
|
||||||
|
if not filename:
|
||||||
|
content_disposition = response.headers.get("Content-Disposition")
|
||||||
|
if content_disposition:
|
||||||
|
filename_match = re.search(r'filename="?(.+)"?', content_disposition)
|
||||||
|
if filename_match:
|
||||||
|
filename = filename_match.group(1)
|
||||||
|
|
||||||
|
# If still no filename, generate a unique one
|
||||||
|
if not filename:
|
||||||
|
unique_name = str(uuid4())
|
||||||
|
filename = f"{unique_name}"
|
||||||
|
|
||||||
|
# Guess MIME type from filename first, then URL
|
||||||
|
mimetype, _ = mimetypes.guess_type(filename)
|
||||||
|
if mimetype is None:
|
||||||
|
mimetype, _ = mimetypes.guess_type(url)
|
||||||
|
if mimetype is None:
|
||||||
|
# If guessing fails, use Content-Type from response headers
|
||||||
|
mimetype = response.headers.get("Content-Type", "application/octet-stream")
|
||||||
|
|
||||||
|
extension = os.path.splitext(filename)[1]
|
||||||
|
|
||||||
|
# Ensure filename has an extension
|
||||||
|
if not extension:
|
||||||
|
extension = mimetypes.guess_extension(mimetype) or ".bin"
|
||||||
|
filename = f"{filename}{extension}"
|
||||||
|
|
||||||
|
return FileInfo(
|
||||||
|
filename=filename,
|
||||||
|
extension=extension,
|
||||||
|
mimetype=mimetype,
|
||||||
|
size=int(response.headers.get("Content-Length", -1)),
|
||||||
|
)
|
||||||
@ -0,0 +1,25 @@
|
|||||||
|
from libs.exception import BaseHTTPException
|
||||||
|
|
||||||
|
|
||||||
|
class FileTooLargeError(BaseHTTPException):
|
||||||
|
error_code = "file_too_large"
|
||||||
|
description = "File size exceeded. {message}"
|
||||||
|
code = 413
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedFileTypeError(BaseHTTPException):
|
||||||
|
error_code = "unsupported_file_type"
|
||||||
|
description = "File type not allowed."
|
||||||
|
code = 415
|
||||||
|
|
||||||
|
|
||||||
|
class TooManyFilesError(BaseHTTPException):
|
||||||
|
error_code = "too_many_files"
|
||||||
|
description = "Only one file is allowed."
|
||||||
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
|
class NoFileUploadedError(BaseHTTPException):
|
||||||
|
error_code = "no_file_uploaded"
|
||||||
|
description = "Please upload your file."
|
||||||
|
code = 400
|
||||||
@ -0,0 +1,71 @@
|
|||||||
|
import urllib.parse
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
|
from flask_restful import Resource, marshal_with, reqparse
|
||||||
|
|
||||||
|
from controllers.common import helpers
|
||||||
|
from core.file import helpers as file_helpers
|
||||||
|
from core.helper import ssrf_proxy
|
||||||
|
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||||
|
from models.account import Account
|
||||||
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteFileInfoApi(Resource):
|
||||||
|
@marshal_with(remote_file_info_fields)
|
||||||
|
def get(self, url):
|
||||||
|
decoded_url = urllib.parse.unquote(url)
|
||||||
|
try:
|
||||||
|
response = ssrf_proxy.head(decoded_url)
|
||||||
|
return {
|
||||||
|
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||||
|
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}, 400
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteFileUploadApi(Resource):
|
||||||
|
@marshal_with(file_fields_with_signed_url)
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("url", type=str, required=True, help="URL is required")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
url = args["url"]
|
||||||
|
|
||||||
|
response = ssrf_proxy.head(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
file_info = helpers.guess_file_info_from_response(response)
|
||||||
|
|
||||||
|
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||||
|
return {"error": "File size exceeded"}, 400
|
||||||
|
|
||||||
|
response = ssrf_proxy.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
content = response.content
|
||||||
|
|
||||||
|
try:
|
||||||
|
user = cast(Account, current_user)
|
||||||
|
upload_file = FileService.upload_file(
|
||||||
|
filename=file_info.filename,
|
||||||
|
content=content,
|
||||||
|
mimetype=file_info.mimetype,
|
||||||
|
user=user,
|
||||||
|
source_url=url,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}, 400
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": upload_file.id,
|
||||||
|
"name": upload_file.name,
|
||||||
|
"size": upload_file.size,
|
||||||
|
"extension": upload_file.extension,
|
||||||
|
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||||
|
"mime_type": upload_file.mime_type,
|
||||||
|
"created_by": upload_file.created_by,
|
||||||
|
"created_at": upload_file.created_at,
|
||||||
|
}, 201
|
||||||
@ -1,56 +0,0 @@
|
|||||||
import urllib.parse
|
|
||||||
|
|
||||||
from flask import request
|
|
||||||
from flask_restful import marshal_with, reqparse
|
|
||||||
|
|
||||||
import services
|
|
||||||
from controllers.web import api
|
|
||||||
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
|
||||||
from controllers.web.wraps import WebApiResource
|
|
||||||
from core.helper import ssrf_proxy
|
|
||||||
from fields.file_fields import file_fields, remote_file_info_fields
|
|
||||||
from services.file_service import FileService
|
|
||||||
|
|
||||||
|
|
||||||
class FileApi(WebApiResource):
|
|
||||||
@marshal_with(file_fields)
|
|
||||||
def post(self, app_model, end_user):
|
|
||||||
# get file from request
|
|
||||||
file = request.files["file"]
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("source", type=str, required=False, location="args")
|
|
||||||
source = parser.parse_args().get("source")
|
|
||||||
|
|
||||||
# check file
|
|
||||||
if "file" not in request.files:
|
|
||||||
raise NoFileUploadedError()
|
|
||||||
|
|
||||||
if len(request.files) > 1:
|
|
||||||
raise TooManyFilesError()
|
|
||||||
try:
|
|
||||||
upload_file = FileService.upload_file(file=file, user=end_user, source=source)
|
|
||||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
|
||||||
raise FileTooLargeError(file_too_large_error.description)
|
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
|
||||||
raise UnsupportedFileTypeError()
|
|
||||||
|
|
||||||
return upload_file, 201
|
|
||||||
|
|
||||||
|
|
||||||
class RemoteFileInfoApi(WebApiResource):
|
|
||||||
@marshal_with(remote_file_info_fields)
|
|
||||||
def get(self, url):
|
|
||||||
decoded_url = urllib.parse.unquote(url)
|
|
||||||
try:
|
|
||||||
response = ssrf_proxy.head(decoded_url)
|
|
||||||
return {
|
|
||||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
|
||||||
"file_length": int(response.headers.get("Content-Length", -1)),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
return {"error": str(e)}, 400
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(FileApi, "/files/upload")
|
|
||||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
|
||||||
@ -0,0 +1,43 @@
|
|||||||
|
from flask import request
|
||||||
|
from flask_restful import marshal_with
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.common.errors import FilenameNotExistsError
|
||||||
|
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
||||||
|
from controllers.web.wraps import WebApiResource
|
||||||
|
from fields.file_fields import file_fields
|
||||||
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
|
class FileApi(WebApiResource):
|
||||||
|
@marshal_with(file_fields)
|
||||||
|
def post(self, app_model, end_user):
|
||||||
|
file = request.files["file"]
|
||||||
|
source = request.form.get("source")
|
||||||
|
|
||||||
|
if "file" not in request.files:
|
||||||
|
raise NoFileUploadedError()
|
||||||
|
|
||||||
|
if len(request.files) > 1:
|
||||||
|
raise TooManyFilesError()
|
||||||
|
|
||||||
|
if not file.filename:
|
||||||
|
raise FilenameNotExistsError
|
||||||
|
|
||||||
|
if source not in ("datasets", None):
|
||||||
|
source = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
upload_file = FileService.upload_file(
|
||||||
|
filename=file.filename,
|
||||||
|
content=file.read(),
|
||||||
|
mimetype=file.mimetype,
|
||||||
|
user=end_user,
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||||
|
raise FileTooLargeError(file_too_large_error.description)
|
||||||
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
|
return upload_file, 201
|
||||||
@ -0,0 +1,69 @@
|
|||||||
|
import urllib.parse
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
|
from flask_restful import marshal_with, reqparse
|
||||||
|
|
||||||
|
from controllers.common import helpers
|
||||||
|
from controllers.web.wraps import WebApiResource
|
||||||
|
from core.file import helpers as file_helpers
|
||||||
|
from core.helper import ssrf_proxy
|
||||||
|
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||||
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteFileInfoApi(WebApiResource):
|
||||||
|
@marshal_with(remote_file_info_fields)
|
||||||
|
def get(self, url):
|
||||||
|
decoded_url = urllib.parse.unquote(url)
|
||||||
|
try:
|
||||||
|
response = ssrf_proxy.head(decoded_url)
|
||||||
|
return {
|
||||||
|
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||||
|
"file_length": int(response.headers.get("Content-Length", -1)),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}, 400
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteFileUploadApi(WebApiResource):
|
||||||
|
@marshal_with(file_fields_with_signed_url)
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("url", type=str, required=True, help="URL is required")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
url = args["url"]
|
||||||
|
|
||||||
|
response = ssrf_proxy.head(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
file_info = helpers.guess_file_info_from_response(response)
|
||||||
|
|
||||||
|
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||||
|
return {"error": "File size exceeded"}, 400
|
||||||
|
|
||||||
|
response = ssrf_proxy.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
content = response.content
|
||||||
|
|
||||||
|
try:
|
||||||
|
upload_file = FileService.upload_file(
|
||||||
|
filename=file_info.filename,
|
||||||
|
content=content,
|
||||||
|
mimetype=file_info.mimetype,
|
||||||
|
user=current_user,
|
||||||
|
source_url=url,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}, 400
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": upload_file.id,
|
||||||
|
"name": upload_file.name,
|
||||||
|
"size": upload_file.size,
|
||||||
|
"extension": upload_file.extension,
|
||||||
|
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||||
|
"mime_type": upload_file.mime_type,
|
||||||
|
"created_by": upload_file.created_by,
|
||||||
|
"created_at": upload_file.created_at,
|
||||||
|
}, 201
|
||||||
@ -1,9 +0,0 @@
|
|||||||
- claude-3-5-sonnet-20241022
|
|
||||||
- claude-3-5-sonnet-20240620
|
|
||||||
- claude-3-haiku-20240307
|
|
||||||
- claude-3-opus-20240229
|
|
||||||
- claude-3-sonnet-20240229
|
|
||||||
- claude-2.1
|
|
||||||
- claude-instant-1.2
|
|
||||||
- claude-2
|
|
||||||
- claude-instant-1
|
|
||||||
@ -1,39 +0,0 @@
|
|||||||
model: claude-3-5-sonnet-20241022
|
|
||||||
label:
|
|
||||||
en_US: claude-3-5-sonnet-20241022
|
|
||||||
model_type: llm
|
|
||||||
features:
|
|
||||||
- agent-thought
|
|
||||||
- vision
|
|
||||||
- tool-call
|
|
||||||
- stream-tool-call
|
|
||||||
model_properties:
|
|
||||||
mode: chat
|
|
||||||
context_size: 200000
|
|
||||||
parameter_rules:
|
|
||||||
- name: temperature
|
|
||||||
use_template: temperature
|
|
||||||
- name: top_p
|
|
||||||
use_template: top_p
|
|
||||||
- name: top_k
|
|
||||||
label:
|
|
||||||
zh_Hans: 取样数量
|
|
||||||
en_US: Top k
|
|
||||||
type: int
|
|
||||||
help:
|
|
||||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
|
||||||
en_US: Only sample from the top K options for each subsequent token.
|
|
||||||
required: false
|
|
||||||
- name: max_tokens
|
|
||||||
use_template: max_tokens
|
|
||||||
required: true
|
|
||||||
default: 8192
|
|
||||||
min: 1
|
|
||||||
max: 8192
|
|
||||||
- name: response_format
|
|
||||||
use_template: response_format
|
|
||||||
pricing:
|
|
||||||
input: '3.00'
|
|
||||||
output: '15.00'
|
|
||||||
unit: '0.000001'
|
|
||||||
currency: USD
|
|
||||||
@ -1,764 +0,0 @@
|
|||||||
import copy
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from collections.abc import Generator, Sequence
|
|
||||||
from typing import Optional, Union, cast
|
|
||||||
|
|
||||||
import tiktoken
|
|
||||||
from openai import AzureOpenAI, Stream
|
|
||||||
from openai.types import Completion
|
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
|
|
||||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
||||||
from core.model_runtime.entities.message_entities import (
|
|
||||||
AssistantPromptMessage,
|
|
||||||
ImagePromptMessageContent,
|
|
||||||
PromptMessage,
|
|
||||||
PromptMessageContentType,
|
|
||||||
PromptMessageFunction,
|
|
||||||
PromptMessageTool,
|
|
||||||
SystemPromptMessage,
|
|
||||||
TextPromptMessageContent,
|
|
||||||
ToolPromptMessage,
|
|
||||||
UserPromptMessage,
|
|
||||||
)
|
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
|
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
||||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
|
||||||
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
|
|
||||||
from core.model_runtime.utils import helper
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
||||||
def _invoke(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
model_parameters: dict,
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
stream: bool = True,
|
|
||||||
user: Optional[str] = None,
|
|
||||||
) -> Union[LLMResult, Generator]:
|
|
||||||
base_model_name = self._get_base_model_name(credentials)
|
|
||||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
|
||||||
|
|
||||||
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
|
||||||
# chat model
|
|
||||||
return self._chat_generate(
|
|
||||||
model=model,
|
|
||||||
credentials=credentials,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
model_parameters=model_parameters,
|
|
||||||
tools=tools,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# text completion model
|
|
||||||
return self._generate(
|
|
||||||
model=model,
|
|
||||||
credentials=credentials,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
model_parameters=model_parameters,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_num_tokens(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
|
||||||
) -> int:
|
|
||||||
base_model_name = self._get_base_model_name(credentials)
|
|
||||||
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
|
||||||
if not model_entity:
|
|
||||||
raise ValueError(f"Base Model Name {base_model_name} is invalid")
|
|
||||||
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
|
|
||||||
|
|
||||||
if model_mode == LLMMode.CHAT.value:
|
|
||||||
# chat model
|
|
||||||
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
|
||||||
else:
|
|
||||||
# text completion model, do not support tool calling
|
|
||||||
content = prompt_messages[0].content
|
|
||||||
assert isinstance(content, str)
|
|
||||||
return self._num_tokens_from_string(credentials, content)
|
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
||||||
if "openai_api_base" not in credentials:
|
|
||||||
raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required")
|
|
||||||
|
|
||||||
if "openai_api_key" not in credentials:
|
|
||||||
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
|
|
||||||
|
|
||||||
if "base_model_name" not in credentials:
|
|
||||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
|
||||||
|
|
||||||
base_model_name = self._get_base_model_name(credentials)
|
|
||||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
|
||||||
|
|
||||||
if not ai_model_entity:
|
|
||||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
|
||||||
|
|
||||||
try:
|
|
||||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
|
||||||
|
|
||||||
if model.startswith("o1"):
|
|
||||||
client.chat.completions.create(
|
|
||||||
messages=[{"role": "user", "content": "ping"}],
|
|
||||||
model=model,
|
|
||||||
temperature=1,
|
|
||||||
max_completion_tokens=20,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
|
||||||
# chat model
|
|
||||||
client.chat.completions.create(
|
|
||||||
messages=[{"role": "user", "content": "ping"}],
|
|
||||||
model=model,
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=20,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# text completion model
|
|
||||||
client.completions.create(
|
|
||||||
prompt="ping",
|
|
||||||
model=model,
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=20,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
except Exception as ex:
|
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
|
||||||
base_model_name = self._get_base_model_name(credentials)
|
|
||||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
|
||||||
return ai_model_entity.entity if ai_model_entity else None
|
|
||||||
|
|
||||||
def _generate(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
model_parameters: dict,
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
stream: bool = True,
|
|
||||||
user: Optional[str] = None,
|
|
||||||
) -> Union[LLMResult, Generator]:
|
|
||||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
|
||||||
|
|
||||||
extra_model_kwargs = {}
|
|
||||||
|
|
||||||
if stop:
|
|
||||||
extra_model_kwargs["stop"] = stop
|
|
||||||
|
|
||||||
if user:
|
|
||||||
extra_model_kwargs["user"] = user
|
|
||||||
|
|
||||||
# text completion model
|
|
||||||
response = client.completions.create(
|
|
||||||
prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
|
||||||
|
|
||||||
def _handle_generate_response(
|
|
||||||
self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage]
|
|
||||||
):
|
|
||||||
assistant_text = response.choices[0].text
|
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
|
|
||||||
|
|
||||||
# calculate num tokens
|
|
||||||
if response.usage:
|
|
||||||
# transform usage
|
|
||||||
prompt_tokens = response.usage.prompt_tokens
|
|
||||||
completion_tokens = response.usage.completion_tokens
|
|
||||||
else:
|
|
||||||
# calculate num tokens
|
|
||||||
content = prompt_messages[0].content
|
|
||||||
assert isinstance(content, str)
|
|
||||||
prompt_tokens = self._num_tokens_from_string(credentials, content)
|
|
||||||
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
# transform response
|
|
||||||
result = LLMResult(
|
|
||||||
model=response.model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
message=assistant_prompt_message,
|
|
||||||
usage=usage,
|
|
||||||
system_fingerprint=response.system_fingerprint,
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _handle_generate_stream_response(
|
|
||||||
self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage]
|
|
||||||
) -> Generator:
|
|
||||||
full_text = ""
|
|
||||||
for chunk in response:
|
|
||||||
if len(chunk.choices) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
delta = chunk.choices[0]
|
|
||||||
|
|
||||||
if delta.finish_reason is None and (delta.text is None or delta.text == ""):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
|
||||||
text = delta.text or ""
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
|
||||||
|
|
||||||
full_text += text
|
|
||||||
|
|
||||||
if delta.finish_reason is not None:
|
|
||||||
# calculate num tokens
|
|
||||||
if chunk.usage:
|
|
||||||
# transform usage
|
|
||||||
prompt_tokens = chunk.usage.prompt_tokens
|
|
||||||
completion_tokens = chunk.usage.completion_tokens
|
|
||||||
else:
|
|
||||||
# calculate num tokens
|
|
||||||
content = prompt_messages[0].content
|
|
||||||
assert isinstance(content, str)
|
|
||||||
prompt_tokens = self._num_tokens_from_string(credentials, content)
|
|
||||||
completion_tokens = self._num_tokens_from_string(credentials, full_text)
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=chunk.model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
system_fingerprint=chunk.system_fingerprint,
|
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=delta.index,
|
|
||||||
message=assistant_prompt_message,
|
|
||||||
finish_reason=delta.finish_reason,
|
|
||||||
usage=usage,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=chunk.model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
system_fingerprint=chunk.system_fingerprint,
|
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=delta.index,
|
|
||||||
message=assistant_prompt_message,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _chat_generate(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
model_parameters: dict,
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
stream: bool = True,
|
|
||||||
user: Optional[str] = None,
|
|
||||||
) -> Union[LLMResult, Generator]:
|
|
||||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
|
||||||
|
|
||||||
response_format = model_parameters.get("response_format")
|
|
||||||
if response_format:
|
|
||||||
if response_format == "json_schema":
|
|
||||||
json_schema = model_parameters.get("json_schema")
|
|
||||||
if not json_schema:
|
|
||||||
raise ValueError("Must define JSON Schema when the response format is json_schema")
|
|
||||||
try:
|
|
||||||
schema = json.loads(json_schema)
|
|
||||||
except:
|
|
||||||
raise ValueError(f"not correct json_schema format: {json_schema}")
|
|
||||||
model_parameters.pop("json_schema")
|
|
||||||
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
|
|
||||||
else:
|
|
||||||
model_parameters["response_format"] = {"type": response_format}
|
|
||||||
|
|
||||||
extra_model_kwargs = {}
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
|
|
||||||
|
|
||||||
if stop:
|
|
||||||
extra_model_kwargs["stop"] = stop
|
|
||||||
|
|
||||||
if user:
|
|
||||||
extra_model_kwargs["user"] = user
|
|
||||||
|
|
||||||
# clear illegal prompt messages
|
|
||||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
|
||||||
|
|
||||||
block_as_stream = False
|
|
||||||
if model.startswith("o1"):
|
|
||||||
if stream:
|
|
||||||
block_as_stream = True
|
|
||||||
stream = False
|
|
||||||
|
|
||||||
if "stream_options" in extra_model_kwargs:
|
|
||||||
del extra_model_kwargs["stream_options"]
|
|
||||||
|
|
||||||
if "stop" in extra_model_kwargs:
|
|
||||||
del extra_model_kwargs["stop"]
|
|
||||||
|
|
||||||
# chat model
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
|
||||||
model=model,
|
|
||||||
stream=stream,
|
|
||||||
**model_parameters,
|
|
||||||
**extra_model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
|
||||||
|
|
||||||
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
|
||||||
|
|
||||||
if block_as_stream:
|
|
||||||
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
|
||||||
|
|
||||||
return block_result
|
|
||||||
|
|
||||||
def _handle_chat_block_as_stream_response(
|
|
||||||
self,
|
|
||||||
block_result: LLMResult,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
) -> Generator[LLMResultChunk, None, None]:
|
|
||||||
"""
|
|
||||||
Handle llm chat response
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: credentials
|
|
||||||
:param response: response
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:param tools: tools for tool calling
|
|
||||||
:param stop: stop words
|
|
||||||
:return: llm response chunk generator
|
|
||||||
"""
|
|
||||||
text = block_result.message.content
|
|
||||||
text = cast(str, text)
|
|
||||||
|
|
||||||
if stop:
|
|
||||||
text = self.enforce_stop_tokens(text, stop)
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=block_result.model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
system_fingerprint=block_result.system_fingerprint,
|
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=0,
|
|
||||||
message=AssistantPromptMessage(content=text),
|
|
||||||
finish_reason="stop",
|
|
||||||
usage=block_result.usage,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
|
||||||
"""
|
|
||||||
Clear illegal prompt messages for OpenAI API
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:return: cleaned prompt messages
|
|
||||||
"""
|
|
||||||
checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"]
|
|
||||||
|
|
||||||
if model in checklist:
|
|
||||||
# count how many user messages are there
|
|
||||||
user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
|
|
||||||
if user_message_count > 1:
|
|
||||||
for prompt_message in prompt_messages:
|
|
||||||
if isinstance(prompt_message, UserPromptMessage):
|
|
||||||
if isinstance(prompt_message.content, list):
|
|
||||||
prompt_message.content = "\n".join(
|
|
||||||
[
|
|
||||||
item.data
|
|
||||||
if item.type == PromptMessageContentType.TEXT
|
|
||||||
else "[IMAGE]"
|
|
||||||
if item.type == PromptMessageContentType.IMAGE
|
|
||||||
else ""
|
|
||||||
for item in prompt_message.content
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if model.startswith("o1"):
|
|
||||||
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
|
|
||||||
if system_message_count > 0:
|
|
||||||
new_prompt_messages = []
|
|
||||||
for prompt_message in prompt_messages:
|
|
||||||
if isinstance(prompt_message, SystemPromptMessage):
|
|
||||||
prompt_message = UserPromptMessage(
|
|
||||||
content=prompt_message.content,
|
|
||||||
name=prompt_message.name,
|
|
||||||
)
|
|
||||||
|
|
||||||
new_prompt_messages.append(prompt_message)
|
|
||||||
prompt_messages = new_prompt_messages
|
|
||||||
|
|
||||||
return prompt_messages
|
|
||||||
|
|
||||||
def _handle_chat_generate_response(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
response: ChatCompletion,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
|
||||||
):
|
|
||||||
assistant_message = response.choices[0].message
|
|
||||||
assistant_message_tool_calls = assistant_message.tool_calls
|
|
||||||
|
|
||||||
# extract tool calls from response
|
|
||||||
tool_calls = []
|
|
||||||
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
|
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
|
|
||||||
|
|
||||||
# calculate num tokens
|
|
||||||
if response.usage:
|
|
||||||
# transform usage
|
|
||||||
prompt_tokens = response.usage.prompt_tokens
|
|
||||||
completion_tokens = response.usage.completion_tokens
|
|
||||||
else:
|
|
||||||
# calculate num tokens
|
|
||||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
|
||||||
completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message])
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
# transform response
|
|
||||||
result = LLMResult(
|
|
||||||
model=response.model or model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
message=assistant_prompt_message,
|
|
||||||
usage=usage,
|
|
||||||
system_fingerprint=response.system_fingerprint,
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _handle_chat_generate_stream_response(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
response: Stream[ChatCompletionChunk],
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
|
||||||
):
|
|
||||||
index = 0
|
|
||||||
full_assistant_content = ""
|
|
||||||
real_model = model
|
|
||||||
system_fingerprint = None
|
|
||||||
completion = ""
|
|
||||||
tool_calls = []
|
|
||||||
for chunk in response:
|
|
||||||
if len(chunk.choices) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
delta = chunk.choices[0]
|
|
||||||
# NOTE: For fix https://github.com/langgenius/dify/issues/5790
|
|
||||||
if delta.delta is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# extract tool calls from response
|
|
||||||
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
|
|
||||||
|
|
||||||
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
|
|
||||||
if delta.finish_reason is None and not delta.delta.content:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
|
|
||||||
|
|
||||||
full_assistant_content += delta.delta.content or ""
|
|
||||||
|
|
||||||
real_model = chunk.model
|
|
||||||
system_fingerprint = chunk.system_fingerprint
|
|
||||||
completion += delta.delta.content or ""
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=real_model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
system_fingerprint=system_fingerprint,
|
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=index,
|
|
||||||
message=assistant_prompt_message,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
index += 1
|
|
||||||
|
|
||||||
# calculate num tokens
|
|
||||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
|
||||||
|
|
||||||
full_assistant_prompt_message = AssistantPromptMessage(content=completion)
|
|
||||||
completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=real_model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
system_fingerprint=system_fingerprint,
|
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _update_tool_calls(
|
|
||||||
tool_calls: list[AssistantPromptMessage.ToolCall],
|
|
||||||
tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]],
|
|
||||||
) -> None:
|
|
||||||
if tool_calls_response:
|
|
||||||
for response_tool_call in tool_calls_response:
|
|
||||||
if isinstance(response_tool_call, ChatCompletionMessageToolCall):
|
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
||||||
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
|
||||||
id=response_tool_call.id, type=response_tool_call.type, function=function
|
|
||||||
)
|
|
||||||
tool_calls.append(tool_call)
|
|
||||||
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
|
|
||||||
index = response_tool_call.index
|
|
||||||
if index < len(tool_calls):
|
|
||||||
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
|
|
||||||
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
|
|
||||||
if response_tool_call.function:
|
|
||||||
tool_calls[index].function.name = (
|
|
||||||
response_tool_call.function.name or tool_calls[index].function.name
|
|
||||||
)
|
|
||||||
tool_calls[index].function.arguments += response_tool_call.function.arguments or ""
|
|
||||||
else:
|
|
||||||
assert response_tool_call.id is not None
|
|
||||||
assert response_tool_call.type is not None
|
|
||||||
assert response_tool_call.function is not None
|
|
||||||
assert response_tool_call.function.name is not None
|
|
||||||
assert response_tool_call.function.arguments is not None
|
|
||||||
|
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
||||||
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
|
|
||||||
)
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
|
||||||
id=response_tool_call.id, type=response_tool_call.type, function=function
|
|
||||||
)
|
|
||||||
tool_calls.append(tool_call)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _convert_prompt_message_to_dict(message: PromptMessage):
|
|
||||||
if isinstance(message, UserPromptMessage):
|
|
||||||
message = cast(UserPromptMessage, message)
|
|
||||||
if isinstance(message.content, str):
|
|
||||||
message_dict = {"role": "user", "content": message.content}
|
|
||||||
else:
|
|
||||||
sub_messages = []
|
|
||||||
assert message.content is not None
|
|
||||||
for message_content in message.content:
|
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
|
||||||
message_content = cast(TextPromptMessageContent, message_content)
|
|
||||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
|
||||||
sub_messages.append(sub_message_dict)
|
|
||||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
|
||||||
message_content = cast(ImagePromptMessageContent, message_content)
|
|
||||||
sub_message_dict = {
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
|
||||||
}
|
|
||||||
sub_messages.append(sub_message_dict)
|
|
||||||
message_dict = {"role": "user", "content": sub_messages}
|
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
|
||||||
# message = cast(AssistantPromptMessage, message)
|
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
|
||||||
if message.tool_calls:
|
|
||||||
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
|
|
||||||
elif isinstance(message, SystemPromptMessage):
|
|
||||||
message = cast(SystemPromptMessage, message)
|
|
||||||
message_dict = {"role": "system", "content": message.content}
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message = cast(ToolPromptMessage, message)
|
|
||||||
message_dict = {
|
|
||||||
"role": "tool",
|
|
||||||
"name": message.name,
|
|
||||||
"content": message.content,
|
|
||||||
"tool_call_id": message.tool_call_id,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unknown type {message}")
|
|
||||||
|
|
||||||
if message.name:
|
|
||||||
message_dict["name"] = message.name
|
|
||||||
|
|
||||||
return message_dict
|
|
||||||
|
|
||||||
def _num_tokens_from_string(
|
|
||||||
self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None
|
|
||||||
) -> int:
|
|
||||||
try:
|
|
||||||
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
|
|
||||||
except KeyError:
|
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
|
||||||
|
|
||||||
num_tokens = len(encoding.encode(text))
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
|
||||||
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
def _num_tokens_from_messages(
|
|
||||||
self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
|
||||||
) -> int:
|
|
||||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
|
||||||
|
|
||||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
|
||||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
|
||||||
model = credentials["base_model_name"]
|
|
||||||
try:
|
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
|
||||||
except KeyError:
|
|
||||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
|
||||||
model = "cl100k_base"
|
|
||||||
encoding = tiktoken.get_encoding(model)
|
|
||||||
|
|
||||||
if model.startswith("gpt-35-turbo-0301"):
|
|
||||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
|
||||||
tokens_per_message = 4
|
|
||||||
# if there's a name, the role is omitted
|
|
||||||
tokens_per_name = -1
|
|
||||||
elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
|
|
||||||
tokens_per_message = 3
|
|
||||||
tokens_per_name = 1
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"get_num_tokens_from_messages() is not presently implemented "
|
|
||||||
f"for model {model}."
|
|
||||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
|
||||||
"information on how messages are converted to tokens."
|
|
||||||
)
|
|
||||||
num_tokens = 0
|
|
||||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
|
||||||
for message in messages_dict:
|
|
||||||
num_tokens += tokens_per_message
|
|
||||||
for key, value in message.items():
|
|
||||||
# Cast str(value) in case the message value is not a string
|
|
||||||
# This occurs with function messages
|
|
||||||
# TODO: The current token calculation method for the image type is not implemented,
|
|
||||||
# which need to download the image and then get the resolution for calculation,
|
|
||||||
# and will increase the request delay
|
|
||||||
if isinstance(value, list):
|
|
||||||
text = ""
|
|
||||||
for item in value:
|
|
||||||
if isinstance(item, dict) and item["type"] == "text":
|
|
||||||
text += item["text"]
|
|
||||||
|
|
||||||
value = text
|
|
||||||
|
|
||||||
if key == "tool_calls":
|
|
||||||
for tool_call in value:
|
|
||||||
assert isinstance(tool_call, dict)
|
|
||||||
for t_key, t_value in tool_call.items():
|
|
||||||
num_tokens += len(encoding.encode(t_key))
|
|
||||||
if t_key == "function":
|
|
||||||
for f_key, f_value in t_value.items():
|
|
||||||
num_tokens += len(encoding.encode(f_key))
|
|
||||||
num_tokens += len(encoding.encode(f_value))
|
|
||||||
else:
|
|
||||||
num_tokens += len(encoding.encode(t_key))
|
|
||||||
num_tokens += len(encoding.encode(t_value))
|
|
||||||
else:
|
|
||||||
num_tokens += len(encoding.encode(str(value)))
|
|
||||||
|
|
||||||
if key == "name":
|
|
||||||
num_tokens += tokens_per_name
|
|
||||||
|
|
||||||
# every reply is primed with <im_start>assistant
|
|
||||||
num_tokens += 3
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
|
||||||
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int:
|
|
||||||
num_tokens = 0
|
|
||||||
for tool in tools:
|
|
||||||
num_tokens += len(encoding.encode("type"))
|
|
||||||
num_tokens += len(encoding.encode("function"))
|
|
||||||
|
|
||||||
# calculate num tokens for function object
|
|
||||||
num_tokens += len(encoding.encode("name"))
|
|
||||||
num_tokens += len(encoding.encode(tool.name))
|
|
||||||
num_tokens += len(encoding.encode("description"))
|
|
||||||
num_tokens += len(encoding.encode(tool.description))
|
|
||||||
parameters = tool.parameters
|
|
||||||
num_tokens += len(encoding.encode("parameters"))
|
|
||||||
if "title" in parameters:
|
|
||||||
num_tokens += len(encoding.encode("title"))
|
|
||||||
num_tokens += len(encoding.encode(parameters["title"]))
|
|
||||||
num_tokens += len(encoding.encode("type"))
|
|
||||||
num_tokens += len(encoding.encode(parameters["type"]))
|
|
||||||
if "properties" in parameters:
|
|
||||||
num_tokens += len(encoding.encode("properties"))
|
|
||||||
for key, value in parameters["properties"].items():
|
|
||||||
num_tokens += len(encoding.encode(key))
|
|
||||||
for field_key, field_value in value.items():
|
|
||||||
num_tokens += len(encoding.encode(field_key))
|
|
||||||
if field_key == "enum":
|
|
||||||
for enum_field in field_value:
|
|
||||||
num_tokens += 3
|
|
||||||
num_tokens += len(encoding.encode(enum_field))
|
|
||||||
else:
|
|
||||||
num_tokens += len(encoding.encode(field_key))
|
|
||||||
num_tokens += len(encoding.encode(str(field_value)))
|
|
||||||
if "required" in parameters:
|
|
||||||
num_tokens += len(encoding.encode("required"))
|
|
||||||
for required_field in parameters["required"]:
|
|
||||||
num_tokens += 3
|
|
||||||
num_tokens += len(encoding.encode(required_field))
|
|
||||||
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_ai_model_entity(base_model_name: str, model: str):
|
|
||||||
for ai_model_entity in LLM_BASE_MODELS:
|
|
||||||
if ai_model_entity.base_model_name == base_model_name:
|
|
||||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
|
||||||
ai_model_entity_copy.entity.model = model
|
|
||||||
ai_model_entity_copy.entity.label.en_US = model
|
|
||||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
|
||||||
return ai_model_entity_copy
|
|
||||||
|
|
||||||
def _get_base_model_name(self, credentials: dict) -> str:
|
|
||||||
base_model_name = credentials.get("base_model_name")
|
|
||||||
if not base_model_name:
|
|
||||||
raise ValueError("Base Model Name is required")
|
|
||||||
return base_model_name
|
|
||||||
@ -1,450 +0,0 @@
|
|||||||
import base64
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from collections.abc import Generator
|
|
||||||
from typing import Optional, Union, cast
|
|
||||||
|
|
||||||
import google.ai.generativelanguage as glm
|
|
||||||
import google.generativeai as genai
|
|
||||||
import requests
|
|
||||||
from google.api_core import exceptions
|
|
||||||
from google.generativeai.client import _ClientManager
|
|
||||||
from google.generativeai.types import ContentType, GenerateContentResponse
|
|
||||||
from google.generativeai.types.content_types import to_part
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
||||||
from core.model_runtime.entities.message_entities import (
|
|
||||||
AssistantPromptMessage,
|
|
||||||
ImagePromptMessageContent,
|
|
||||||
PromptMessage,
|
|
||||||
PromptMessageContentType,
|
|
||||||
PromptMessageTool,
|
|
||||||
SystemPromptMessage,
|
|
||||||
ToolPromptMessage,
|
|
||||||
UserPromptMessage,
|
|
||||||
)
|
|
||||||
from core.model_runtime.errors.invoke import (
|
|
||||||
InvokeAuthorizationError,
|
|
||||||
InvokeBadRequestError,
|
|
||||||
InvokeConnectionError,
|
|
||||||
InvokeError,
|
|
||||||
InvokeRateLimitError,
|
|
||||||
InvokeServerUnavailableError,
|
|
||||||
)
|
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
|
||||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
|
||||||
if you are not sure about the structure.
|
|
||||||
|
|
||||||
<instructions>
|
|
||||||
{{instructions}}
|
|
||||||
</instructions>
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
|
||||||
def _invoke(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
model_parameters: dict,
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
stream: bool = True,
|
|
||||||
user: Optional[str] = None,
|
|
||||||
) -> Union[LLMResult, Generator]:
|
|
||||||
"""
|
|
||||||
Invoke large language model
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:param model_parameters: model parameters
|
|
||||||
:param tools: tools for tool calling
|
|
||||||
:param stop: stop words
|
|
||||||
:param stream: is stream response
|
|
||||||
:param user: unique user id
|
|
||||||
:return: full response or stream response chunk generator result
|
|
||||||
"""
|
|
||||||
# invoke model
|
|
||||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
||||||
|
|
||||||
def get_num_tokens(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Get number of tokens for given prompt messages
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:param tools: tools for tool calling
|
|
||||||
:return:md = genai.GenerativeModel(model)
|
|
||||||
"""
|
|
||||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
|
||||||
|
|
||||||
return self._get_num_tokens_by_gpt2(prompt)
|
|
||||||
|
|
||||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
|
||||||
"""
|
|
||||||
Format a list of messages into a full prompt for the Google model
|
|
||||||
|
|
||||||
:param messages: List of PromptMessage to combine.
|
|
||||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
|
||||||
"""
|
|
||||||
messages = messages.copy() # don't mutate the original list
|
|
||||||
|
|
||||||
text = "".join(self._convert_one_message_to_text(message) for message in messages)
|
|
||||||
|
|
||||||
return text.rstrip()
|
|
||||||
|
|
||||||
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
|
|
||||||
"""
|
|
||||||
Convert tool messages to glm tools
|
|
||||||
|
|
||||||
:param tools: tool messages
|
|
||||||
:return: glm tools
|
|
||||||
"""
|
|
||||||
function_declarations = []
|
|
||||||
for tool in tools:
|
|
||||||
properties = {}
|
|
||||||
for key, value in tool.parameters.get("properties", {}).items():
|
|
||||||
properties[key] = {
|
|
||||||
"type_": glm.Type.STRING,
|
|
||||||
"description": value.get("description", ""),
|
|
||||||
"enum": value.get("enum", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
if properties:
|
|
||||||
parameters = glm.Schema(
|
|
||||||
type=glm.Type.OBJECT,
|
|
||||||
properties=properties,
|
|
||||||
required=tool.parameters.get("required", []),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
parameters = None
|
|
||||||
|
|
||||||
function_declaration = glm.FunctionDeclaration(
|
|
||||||
name=tool.name,
|
|
||||||
parameters=parameters,
|
|
||||||
description=tool.description,
|
|
||||||
)
|
|
||||||
function_declarations.append(function_declaration)
|
|
||||||
|
|
||||||
return glm.Tool(function_declarations=function_declarations)
|
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
||||||
"""
|
|
||||||
Validate model credentials
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
ping_message = SystemPromptMessage(content="ping")
|
|
||||||
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
|
|
||||||
|
|
||||||
except Exception as ex:
|
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
|
||||||
|
|
||||||
def _generate(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
model_parameters: dict,
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
stream: bool = True,
|
|
||||||
user: Optional[str] = None,
|
|
||||||
) -> Union[LLMResult, Generator]:
|
|
||||||
"""
|
|
||||||
Invoke large language model
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: credentials kwargs
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:param model_parameters: model parameters
|
|
||||||
:param stop: stop words
|
|
||||||
:param stream: is stream response
|
|
||||||
:param user: unique user id
|
|
||||||
:return: full response or stream response chunk generator result
|
|
||||||
"""
|
|
||||||
config_kwargs = model_parameters.copy()
|
|
||||||
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
|
|
||||||
|
|
||||||
if stop:
|
|
||||||
config_kwargs["stop_sequences"] = stop
|
|
||||||
|
|
||||||
google_model = genai.GenerativeModel(model_name=model)
|
|
||||||
|
|
||||||
history = []
|
|
||||||
|
|
||||||
# hack for gemini-pro-vision, which currently does not support multi-turn chat
|
|
||||||
if model == "gemini-pro-vision":
|
|
||||||
last_msg = prompt_messages[-1]
|
|
||||||
content = self._format_message_to_glm_content(last_msg)
|
|
||||||
history.append(content)
|
|
||||||
else:
|
|
||||||
for msg in prompt_messages: # makes message roles strictly alternating
|
|
||||||
content = self._format_message_to_glm_content(msg)
|
|
||||||
if history and history[-1]["role"] == content["role"]:
|
|
||||||
history[-1]["parts"].extend(content["parts"])
|
|
||||||
else:
|
|
||||||
history.append(content)
|
|
||||||
|
|
||||||
# Create a new ClientManager with tenant's API key
|
|
||||||
new_client_manager = _ClientManager()
|
|
||||||
new_client_manager.configure(api_key=credentials["google_api_key"])
|
|
||||||
new_custom_client = new_client_manager.make_client("generative")
|
|
||||||
|
|
||||||
google_model._client = new_custom_client
|
|
||||||
|
|
||||||
response = google_model.generate_content(
|
|
||||||
contents=history,
|
|
||||||
generation_config=genai.types.GenerationConfig(**config_kwargs),
|
|
||||||
stream=stream,
|
|
||||||
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
|
|
||||||
request_options={"timeout": 600},
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
|
||||||
|
|
||||||
def _handle_generate_response(
|
|
||||||
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
|
|
||||||
) -> LLMResult:
|
|
||||||
"""
|
|
||||||
Handle llm response
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: credentials
|
|
||||||
:param response: response
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:return: llm response
|
|
||||||
"""
|
|
||||||
# transform assistant message to prompt message
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=response.text)
|
|
||||||
|
|
||||||
# calculate num tokens
|
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
# transform response
|
|
||||||
result = LLMResult(
|
|
||||||
model=model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
message=assistant_prompt_message,
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _handle_generate_stream_response(
|
|
||||||
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
|
|
||||||
) -> Generator:
|
|
||||||
"""
|
|
||||||
Handle llm stream response
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: credentials
|
|
||||||
:param response: response
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:return: llm response chunk generator result
|
|
||||||
"""
|
|
||||||
index = -1
|
|
||||||
for chunk in response:
|
|
||||||
for part in chunk.parts:
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(content="")
|
|
||||||
|
|
||||||
if part.text:
|
|
||||||
assistant_prompt_message.content += part.text
|
|
||||||
|
|
||||||
if part.function_call:
|
|
||||||
assistant_prompt_message.tool_calls = [
|
|
||||||
AssistantPromptMessage.ToolCall(
|
|
||||||
id=part.function_call.name,
|
|
||||||
type="function",
|
|
||||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
||||||
name=part.function_call.name,
|
|
||||||
arguments=json.dumps(dict(part.function_call.args.items())),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
index += 1
|
|
||||||
|
|
||||||
if not response._done:
|
|
||||||
# transform assistant message to prompt message
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# calculate num tokens
|
|
||||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
|
||||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=index,
|
|
||||||
message=assistant_prompt_message,
|
|
||||||
finish_reason=str(chunk.candidates[0].finish_reason),
|
|
||||||
usage=usage,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
|
||||||
"""
|
|
||||||
Convert a single message to a string.
|
|
||||||
|
|
||||||
:param message: PromptMessage to convert.
|
|
||||||
:return: String representation of the message.
|
|
||||||
"""
|
|
||||||
human_prompt = "\n\nuser:"
|
|
||||||
ai_prompt = "\n\nmodel:"
|
|
||||||
|
|
||||||
content = message.content
|
|
||||||
if isinstance(content, list):
|
|
||||||
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
|
|
||||||
|
|
||||||
if isinstance(message, UserPromptMessage):
|
|
||||||
message_text = f"{human_prompt} {content}"
|
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
|
||||||
message_text = f"{ai_prompt} {content}"
|
|
||||||
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
|
||||||
message_text = f"{human_prompt} {content}"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unknown type {message}")
|
|
||||||
|
|
||||||
return message_text
|
|
||||||
|
|
||||||
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
|
|
||||||
"""
|
|
||||||
Format a single message into glm.Content for Google API
|
|
||||||
|
|
||||||
:param message: one PromptMessage
|
|
||||||
:return: glm Content representation of message
|
|
||||||
"""
|
|
||||||
if isinstance(message, UserPromptMessage):
|
|
||||||
glm_content = {"role": "user", "parts": []}
|
|
||||||
if isinstance(message.content, str):
|
|
||||||
glm_content["parts"].append(to_part(message.content))
|
|
||||||
else:
|
|
||||||
for c in message.content:
|
|
||||||
if c.type == PromptMessageContentType.TEXT:
|
|
||||||
glm_content["parts"].append(to_part(c.data))
|
|
||||||
elif c.type == PromptMessageContentType.IMAGE:
|
|
||||||
message_content = cast(ImagePromptMessageContent, c)
|
|
||||||
if message_content.data.startswith("data:"):
|
|
||||||
metadata, base64_data = c.data.split(",", 1)
|
|
||||||
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
|
||||||
else:
|
|
||||||
# fetch image data from url
|
|
||||||
try:
|
|
||||||
image_content = requests.get(message_content.data).content
|
|
||||||
with Image.open(io.BytesIO(image_content)) as img:
|
|
||||||
mime_type = f"image/{img.format.lower()}"
|
|
||||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
|
||||||
except Exception as ex:
|
|
||||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
|
||||||
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
|
|
||||||
glm_content["parts"].append(blob)
|
|
||||||
|
|
||||||
return glm_content
|
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
|
||||||
glm_content = {"role": "model", "parts": []}
|
|
||||||
if message.content:
|
|
||||||
glm_content["parts"].append(to_part(message.content))
|
|
||||||
if message.tool_calls:
|
|
||||||
glm_content["parts"].append(
|
|
||||||
to_part(
|
|
||||||
glm.FunctionCall(
|
|
||||||
name=message.tool_calls[0].function.name,
|
|
||||||
args=json.loads(message.tool_calls[0].function.arguments),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return glm_content
|
|
||||||
elif isinstance(message, SystemPromptMessage):
|
|
||||||
return {"role": "user", "parts": [to_part(message.content)]}
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
return {
|
|
||||||
"role": "function",
|
|
||||||
"parts": [
|
|
||||||
glm.Part(
|
|
||||||
function_response=glm.FunctionResponse(
|
|
||||||
name=message.name, response={"response": message.content}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unknown type {message}")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
|
||||||
"""
|
|
||||||
Map model invoke error to unified error
|
|
||||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
|
||||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
|
||||||
which needs to be converted into a unified error type for the caller.
|
|
||||||
|
|
||||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
InvokeConnectionError: [exceptions.RetryError],
|
|
||||||
InvokeServerUnavailableError: [
|
|
||||||
exceptions.ServiceUnavailable,
|
|
||||||
exceptions.InternalServerError,
|
|
||||||
exceptions.BadGateway,
|
|
||||||
exceptions.GatewayTimeout,
|
|
||||||
exceptions.DeadlineExceeded,
|
|
||||||
],
|
|
||||||
InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests],
|
|
||||||
InvokeAuthorizationError: [
|
|
||||||
exceptions.Unauthenticated,
|
|
||||||
exceptions.PermissionDenied,
|
|
||||||
exceptions.Unauthenticated,
|
|
||||||
exceptions.Forbidden,
|
|
||||||
],
|
|
||||||
InvokeBadRequestError: [
|
|
||||||
exceptions.BadRequest,
|
|
||||||
exceptions.InvalidArgument,
|
|
||||||
exceptions.FailedPrecondition,
|
|
||||||
exceptions.OutOfRange,
|
|
||||||
exceptions.NotFound,
|
|
||||||
exceptions.MethodNotAllowed,
|
|
||||||
exceptions.Conflict,
|
|
||||||
exceptions.AlreadyExists,
|
|
||||||
exceptions.Aborted,
|
|
||||||
exceptions.LengthRequired,
|
|
||||||
exceptions.PreconditionFailed,
|
|
||||||
exceptions.RequestRangeNotSatisfiable,
|
|
||||||
exceptions.Cancelled,
|
|
||||||
],
|
|
||||||
}
|
|
||||||
@ -1,330 +0,0 @@
|
|||||||
import json
|
|
||||||
from collections.abc import Generator
|
|
||||||
from typing import Optional, Union, cast
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
||||||
from core.model_runtime.entities.message_entities import (
|
|
||||||
AssistantPromptMessage,
|
|
||||||
ImagePromptMessageContent,
|
|
||||||
PromptMessage,
|
|
||||||
PromptMessageContent,
|
|
||||||
PromptMessageContentType,
|
|
||||||
PromptMessageTool,
|
|
||||||
SystemPromptMessage,
|
|
||||||
ToolPromptMessage,
|
|
||||||
UserPromptMessage,
|
|
||||||
)
|
|
||||||
from core.model_runtime.entities.model_entities import (
|
|
||||||
AIModelEntity,
|
|
||||||
FetchFrom,
|
|
||||||
ModelFeature,
|
|
||||||
ModelPropertyKey,
|
|
||||||
ModelType,
|
|
||||||
ParameterRule,
|
|
||||||
ParameterType,
|
|
||||||
)
|
|
||||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
|
||||||
|
|
||||||
|
|
||||||
class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|
||||||
def _invoke(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
model_parameters: dict,
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
stream: bool = True,
|
|
||||||
user: Optional[str] = None,
|
|
||||||
) -> Union[LLMResult, Generator]:
|
|
||||||
self._add_custom_parameters(credentials)
|
|
||||||
self._add_function_call(model, credentials)
|
|
||||||
user = user[:32] if user else None
|
|
||||||
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
|
|
||||||
if "response_format" in model_parameters:
|
|
||||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
|
||||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
||||||
self._add_custom_parameters(credentials)
|
|
||||||
super().validate_credentials(model, credentials)
|
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
|
||||||
return AIModelEntity(
|
|
||||||
model=model,
|
|
||||||
label=I18nObject(en_US=model, zh_Hans=model),
|
|
||||||
model_type=ModelType.LLM,
|
|
||||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
|
||||||
if credentials.get("function_calling_type") == "tool_call"
|
|
||||||
else [],
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
||||||
model_properties={
|
|
||||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
|
|
||||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
|
||||||
},
|
|
||||||
parameter_rules=[
|
|
||||||
ParameterRule(
|
|
||||||
name="temperature",
|
|
||||||
use_template="temperature",
|
|
||||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
|
||||||
type=ParameterType.FLOAT,
|
|
||||||
),
|
|
||||||
ParameterRule(
|
|
||||||
name="max_tokens",
|
|
||||||
use_template="max_tokens",
|
|
||||||
default=512,
|
|
||||||
min=1,
|
|
||||||
max=int(credentials.get("max_tokens", 4096)),
|
|
||||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
|
||||||
type=ParameterType.INT,
|
|
||||||
),
|
|
||||||
ParameterRule(
|
|
||||||
name="top_p",
|
|
||||||
use_template="top_p",
|
|
||||||
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
|
||||||
type=ParameterType.FLOAT,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_custom_parameters(self, credentials: dict) -> None:
|
|
||||||
credentials["mode"] = "chat"
|
|
||||||
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
|
|
||||||
credentials["endpoint_url"] = "https://api.moonshot.cn/v1"
|
|
||||||
|
|
||||||
def _add_function_call(self, model: str, credentials: dict) -> None:
|
|
||||||
model_schema = self.get_model_schema(model, credentials)
|
|
||||||
if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection(
|
|
||||||
model_schema.features or []
|
|
||||||
):
|
|
||||||
credentials["function_calling_type"] = "tool_call"
|
|
||||||
|
|
||||||
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict:
|
|
||||||
"""
|
|
||||||
Convert PromptMessage to dict for OpenAI API format
|
|
||||||
"""
|
|
||||||
if isinstance(message, UserPromptMessage):
|
|
||||||
message = cast(UserPromptMessage, message)
|
|
||||||
if isinstance(message.content, str):
|
|
||||||
message_dict = {"role": "user", "content": message.content}
|
|
||||||
else:
|
|
||||||
sub_messages = []
|
|
||||||
for message_content in message.content:
|
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
|
||||||
message_content = cast(PromptMessageContent, message_content)
|
|
||||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
|
||||||
sub_messages.append(sub_message_dict)
|
|
||||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
|
||||||
message_content = cast(ImagePromptMessageContent, message_content)
|
|
||||||
sub_message_dict = {
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
|
||||||
}
|
|
||||||
sub_messages.append(sub_message_dict)
|
|
||||||
message_dict = {"role": "user", "content": sub_messages}
|
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
|
||||||
message = cast(AssistantPromptMessage, message)
|
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
|
||||||
if message.tool_calls:
|
|
||||||
message_dict["tool_calls"] = []
|
|
||||||
for function_call in message.tool_calls:
|
|
||||||
message_dict["tool_calls"].append(
|
|
||||||
{
|
|
||||||
"id": function_call.id,
|
|
||||||
"type": function_call.type,
|
|
||||||
"function": {
|
|
||||||
"name": function_call.function.name,
|
|
||||||
"arguments": function_call.function.arguments,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message = cast(ToolPromptMessage, message)
|
|
||||||
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
|
|
||||||
elif isinstance(message, SystemPromptMessage):
|
|
||||||
message = cast(SystemPromptMessage, message)
|
|
||||||
message_dict = {"role": "system", "content": message.content}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unknown type {message}")
|
|
||||||
|
|
||||||
if message.name:
|
|
||||||
message_dict["name"] = message.name
|
|
||||||
|
|
||||||
return message_dict
|
|
||||||
|
|
||||||
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
|
|
||||||
"""
|
|
||||||
Extract tool calls from response
|
|
||||||
|
|
||||||
:param response_tool_calls: response tool calls
|
|
||||||
:return: list of tool calls
|
|
||||||
"""
|
|
||||||
tool_calls = []
|
|
||||||
if response_tool_calls:
|
|
||||||
for response_tool_call in response_tool_calls:
|
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
||||||
name=response_tool_call["function"]["name"]
|
|
||||||
if response_tool_call.get("function", {}).get("name")
|
|
||||||
else "",
|
|
||||||
arguments=response_tool_call["function"]["arguments"]
|
|
||||||
if response_tool_call.get("function", {}).get("arguments")
|
|
||||||
else "",
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
|
||||||
id=response_tool_call["id"] if response_tool_call.get("id") else "",
|
|
||||||
type=response_tool_call["type"] if response_tool_call.get("type") else "",
|
|
||||||
function=function,
|
|
||||||
)
|
|
||||||
tool_calls.append(tool_call)
|
|
||||||
|
|
||||||
return tool_calls
|
|
||||||
|
|
||||||
def _handle_generate_stream_response(
|
|
||||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
|
||||||
) -> Generator:
|
|
||||||
"""
|
|
||||||
Handle llm stream response
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param response: streamed response
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:return: llm response chunk generator
|
|
||||||
"""
|
|
||||||
full_assistant_content = ""
|
|
||||||
chunk_index = 0
|
|
||||||
|
|
||||||
def create_final_llm_result_chunk(
|
|
||||||
index: int, message: AssistantPromptMessage, finish_reason: str
|
|
||||||
) -> LLMResultChunk:
|
|
||||||
# calculate num tokens
|
|
||||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
|
||||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
return LLMResultChunk(
|
|
||||||
model=model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
|
||||||
)
|
|
||||||
|
|
||||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
|
||||||
finish_reason = "Unknown"
|
|
||||||
|
|
||||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
|
||||||
def get_tool_call(tool_name: str):
|
|
||||||
if not tool_name:
|
|
||||||
return tools_calls[-1]
|
|
||||||
|
|
||||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
|
|
||||||
if tool_call is None:
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
|
||||||
id="",
|
|
||||||
type="",
|
|
||||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
|
|
||||||
)
|
|
||||||
tools_calls.append(tool_call)
|
|
||||||
|
|
||||||
return tool_call
|
|
||||||
|
|
||||||
for new_tool_call in new_tool_calls:
|
|
||||||
# get tool call
|
|
||||||
tool_call = get_tool_call(new_tool_call.function.name)
|
|
||||||
# update tool call
|
|
||||||
if new_tool_call.id:
|
|
||||||
tool_call.id = new_tool_call.id
|
|
||||||
if new_tool_call.type:
|
|
||||||
tool_call.type = new_tool_call.type
|
|
||||||
if new_tool_call.function.name:
|
|
||||||
tool_call.function.name = new_tool_call.function.name
|
|
||||||
if new_tool_call.function.arguments:
|
|
||||||
tool_call.function.arguments += new_tool_call.function.arguments
|
|
||||||
|
|
||||||
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
|
|
||||||
if chunk:
|
|
||||||
# ignore sse comments
|
|
||||||
if chunk.startswith(":"):
|
|
||||||
continue
|
|
||||||
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
|
||||||
chunk_json = None
|
|
||||||
try:
|
|
||||||
chunk_json = json.loads(decoded_chunk)
|
|
||||||
# stream ended
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
yield create_final_llm_result_chunk(
|
|
||||||
index=chunk_index + 1,
|
|
||||||
message=AssistantPromptMessage(content=""),
|
|
||||||
finish_reason="Non-JSON encountered.",
|
|
||||||
)
|
|
||||||
break
|
|
||||||
if not chunk_json or len(chunk_json["choices"]) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
choice = chunk_json["choices"][0]
|
|
||||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
|
||||||
chunk_index += 1
|
|
||||||
|
|
||||||
if "delta" in choice:
|
|
||||||
delta = choice["delta"]
|
|
||||||
delta_content = delta.get("content")
|
|
||||||
|
|
||||||
assistant_message_tool_calls = delta.get("tool_calls", None)
|
|
||||||
# assistant_message_function_call = delta.delta.function_call
|
|
||||||
|
|
||||||
# extract tool calls from response
|
|
||||||
if assistant_message_tool_calls:
|
|
||||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
|
||||||
increase_tool_call(tool_calls)
|
|
||||||
|
|
||||||
if delta_content is None or delta_content == "":
|
|
||||||
continue
|
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
|
||||||
content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else []
|
|
||||||
)
|
|
||||||
|
|
||||||
full_assistant_content += delta_content
|
|
||||||
elif "text" in choice:
|
|
||||||
choice_text = choice.get("text", "")
|
|
||||||
if choice_text == "":
|
|
||||||
continue
|
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
|
||||||
full_assistant_content += choice_text
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# check payload indicator for completion
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=chunk_index,
|
|
||||||
message=assistant_prompt_message,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
chunk_index += 1
|
|
||||||
|
|
||||||
if tools_calls:
|
|
||||||
yield LLMResultChunk(
|
|
||||||
model=model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=chunk_index,
|
|
||||||
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
yield create_final_llm_result_chunk(
|
|
||||||
index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason
|
|
||||||
)
|
|
||||||
@ -1,847 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
from collections.abc import Generator
|
|
||||||
from decimal import Decimal
|
|
||||||
from typing import Optional, Union, cast
|
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
||||||
from core.model_runtime.entities.message_entities import (
|
|
||||||
AssistantPromptMessage,
|
|
||||||
ImagePromptMessageContent,
|
|
||||||
PromptMessage,
|
|
||||||
PromptMessageContent,
|
|
||||||
PromptMessageContentType,
|
|
||||||
PromptMessageFunction,
|
|
||||||
PromptMessageTool,
|
|
||||||
SystemPromptMessage,
|
|
||||||
ToolPromptMessage,
|
|
||||||
UserPromptMessage,
|
|
||||||
)
|
|
||||||
from core.model_runtime.entities.model_entities import (
|
|
||||||
AIModelEntity,
|
|
||||||
DefaultParameterName,
|
|
||||||
FetchFrom,
|
|
||||||
ModelFeature,
|
|
||||||
ModelPropertyKey,
|
|
||||||
ModelType,
|
|
||||||
ParameterRule,
|
|
||||||
ParameterType,
|
|
||||||
PriceConfig,
|
|
||||||
)
|
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
||||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
|
||||||
from core.model_runtime.utils import helper
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|
||||||
"""
|
|
||||||
Model class for OpenAI large language model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _invoke(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
model_parameters: dict,
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
stream: bool = True,
|
|
||||||
user: Optional[str] = None,
|
|
||||||
) -> Union[LLMResult, Generator]:
|
|
||||||
"""
|
|
||||||
Invoke large language model
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:param model_parameters: model parameters
|
|
||||||
:param tools: tools for tool calling
|
|
||||||
:param stop: stop words
|
|
||||||
:param stream: is stream response
|
|
||||||
:param user: unique user id
|
|
||||||
:return: full response or stream response chunk generator result
|
|
||||||
"""
|
|
||||||
|
|
||||||
# text completion model
|
|
||||||
return self._generate(
|
|
||||||
model=model,
|
|
||||||
credentials=credentials,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
model_parameters=model_parameters,
|
|
||||||
tools=tools,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_num_tokens(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Get number of tokens for given prompt messages
|
|
||||||
|
|
||||||
:param model:
|
|
||||||
:param credentials:
|
|
||||||
:param prompt_messages:
|
|
||||||
:param tools: tools for tool calling
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self._num_tokens_from_messages(model, prompt_messages, tools, credentials)
|
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
||||||
"""
|
|
||||||
Validate model credentials using requests to ensure compatibility with all providers following
|
|
||||||
OpenAI's API standard.
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
|
|
||||||
api_key = credentials.get("api_key")
|
|
||||||
if api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
|
||||||
|
|
||||||
endpoint_url = credentials["endpoint_url"]
|
|
||||||
if not endpoint_url.endswith("/"):
|
|
||||||
endpoint_url += "/"
|
|
||||||
|
|
||||||
# prepare the payload for a simple ping to the model
|
|
||||||
data = {"model": model, "max_tokens": 5}
|
|
||||||
|
|
||||||
completion_type = LLMMode.value_of(credentials["mode"])
|
|
||||||
|
|
||||||
if completion_type is LLMMode.CHAT:
|
|
||||||
data["messages"] = [
|
|
||||||
{"role": "user", "content": "ping"},
|
|
||||||
]
|
|
||||||
endpoint_url = urljoin(endpoint_url, "chat/completions")
|
|
||||||
elif completion_type is LLMMode.COMPLETION:
|
|
||||||
data["prompt"] = "ping"
|
|
||||||
endpoint_url = urljoin(endpoint_url, "completions")
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported completion type for model configuration.")
|
|
||||||
|
|
||||||
# send a post request to validate the credentials
|
|
||||||
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300))
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise CredentialsValidateFailedError(
|
|
||||||
f"Credentials validation failed with status code {response.status_code}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
json_result = response.json()
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
|
|
||||||
|
|
||||||
if completion_type is LLMMode.CHAT and json_result.get("object", "") == "":
|
|
||||||
json_result["object"] = "chat.completion"
|
|
||||||
elif completion_type is LLMMode.COMPLETION and json_result.get("object", "") == "":
|
|
||||||
json_result["object"] = "text_completion"
|
|
||||||
|
|
||||||
if completion_type is LLMMode.CHAT and (
|
|
||||||
"object" not in json_result or json_result["object"] != "chat.completion"
|
|
||||||
):
|
|
||||||
raise CredentialsValidateFailedError(
|
|
||||||
"Credentials validation failed: invalid response object, must be 'chat.completion'"
|
|
||||||
)
|
|
||||||
elif completion_type is LLMMode.COMPLETION and (
|
|
||||||
"object" not in json_result or json_result["object"] != "text_completion"
|
|
||||||
):
|
|
||||||
raise CredentialsValidateFailedError(
|
|
||||||
"Credentials validation failed: invalid response object, must be 'text_completion'"
|
|
||||||
)
|
|
||||||
except CredentialsValidateFailedError:
|
|
||||||
raise
|
|
||||||
except Exception as ex:
|
|
||||||
raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}")
|
|
||||||
|
|
||||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
|
||||||
"""
|
|
||||||
generate custom model entities from credentials
|
|
||||||
"""
|
|
||||||
features = []
|
|
||||||
|
|
||||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
|
||||||
if function_calling_type == "function_call":
|
|
||||||
features.append(ModelFeature.TOOL_CALL)
|
|
||||||
elif function_calling_type == "tool_call":
|
|
||||||
features.append(ModelFeature.MULTI_TOOL_CALL)
|
|
||||||
|
|
||||||
stream_function_calling = credentials.get("stream_function_calling", "supported")
|
|
||||||
if stream_function_calling == "supported":
|
|
||||||
features.append(ModelFeature.STREAM_TOOL_CALL)
|
|
||||||
|
|
||||||
vision_support = credentials.get("vision_support", "not_support")
|
|
||||||
if vision_support == "support":
|
|
||||||
features.append(ModelFeature.VISION)
|
|
||||||
|
|
||||||
entity = AIModelEntity(
|
|
||||||
model=model,
|
|
||||||
label=I18nObject(en_US=model),
|
|
||||||
model_type=ModelType.LLM,
|
|
||||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
||||||
features=features,
|
|
||||||
model_properties={
|
|
||||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")),
|
|
||||||
ModelPropertyKey.MODE: credentials.get("mode"),
|
|
||||||
},
|
|
||||||
parameter_rules=[
|
|
||||||
ParameterRule(
|
|
||||||
name=DefaultParameterName.TEMPERATURE.value,
|
|
||||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
|
||||||
help=I18nObject(
|
|
||||||
en_US="Kernel sampling threshold. Used to determine the randomness of the results."
|
|
||||||
"The higher the value, the stronger the randomness."
|
|
||||||
"The higher the possibility of getting different answers to the same question.",
|
|
||||||
zh_Hans="核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。",
|
|
||||||
),
|
|
||||||
type=ParameterType.FLOAT,
|
|
||||||
default=float(credentials.get("temperature", 0.7)),
|
|
||||||
min=0,
|
|
||||||
max=2,
|
|
||||||
precision=2,
|
|
||||||
),
|
|
||||||
ParameterRule(
|
|
||||||
name=DefaultParameterName.TOP_P.value,
|
|
||||||
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
|
||||||
help=I18nObject(
|
|
||||||
en_US="The probability threshold of the nucleus sampling method during the generation process."
|
|
||||||
"The larger the value is, the higher the randomness of generation will be."
|
|
||||||
"The smaller the value is, the higher the certainty of generation will be.",
|
|
||||||
zh_Hans="生成过程中核采样方法概率阈值。取值越大,生成的随机性越高;取值越小,生成的确定性越高。",
|
|
||||||
),
|
|
||||||
type=ParameterType.FLOAT,
|
|
||||||
default=float(credentials.get("top_p", 1)),
|
|
||||||
min=0,
|
|
||||||
max=1,
|
|
||||||
precision=2,
|
|
||||||
),
|
|
||||||
ParameterRule(
|
|
||||||
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
|
||||||
label=I18nObject(en_US="Frequency Penalty", zh_Hans="频率惩罚"),
|
|
||||||
help=I18nObject(
|
|
||||||
en_US="For controlling the repetition rate of words used by the model."
|
|
||||||
"Increasing this can reduce the repetition of the same words in the model's output.",
|
|
||||||
zh_Hans="用于控制模型已使用字词的重复率。 提高此项可以降低模型在输出中重复相同字词的重复度。",
|
|
||||||
),
|
|
||||||
type=ParameterType.FLOAT,
|
|
||||||
default=float(credentials.get("frequency_penalty", 0)),
|
|
||||||
min=-2,
|
|
||||||
max=2,
|
|
||||||
),
|
|
||||||
ParameterRule(
|
|
||||||
name=DefaultParameterName.PRESENCE_PENALTY.value,
|
|
||||||
label=I18nObject(en_US="Presence Penalty", zh_Hans="存在惩罚"),
|
|
||||||
help=I18nObject(
|
|
||||||
en_US="Used to control the repetition rate when generating models."
|
|
||||||
"Increasing this can reduce the repetition rate of model generation.",
|
|
||||||
zh_Hans="用于控制模型生成时的重复度。提高此项可以降低模型生成的重复度。",
|
|
||||||
),
|
|
||||||
type=ParameterType.FLOAT,
|
|
||||||
default=float(credentials.get("presence_penalty", 0)),
|
|
||||||
min=-2,
|
|
||||||
max=2,
|
|
||||||
),
|
|
||||||
ParameterRule(
|
|
||||||
name=DefaultParameterName.MAX_TOKENS.value,
|
|
||||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
|
||||||
help=I18nObject(
|
|
||||||
en_US="Maximum length of tokens for the model response.", zh_Hans="模型回答的tokens的最大长度。"
|
|
||||||
),
|
|
||||||
type=ParameterType.INT,
|
|
||||||
default=512,
|
|
||||||
min=1,
|
|
||||||
max=int(credentials.get("max_tokens_to_sample", 4096)),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
pricing=PriceConfig(
|
|
||||||
input=Decimal(credentials.get("input_price", 0)),
|
|
||||||
output=Decimal(credentials.get("output_price", 0)),
|
|
||||||
unit=Decimal(credentials.get("unit", 0)),
|
|
||||||
currency=credentials.get("currency", "USD"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if credentials["mode"] == "chat":
|
|
||||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
|
|
||||||
elif credentials["mode"] == "completion":
|
|
||||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
|
||||||
|
|
||||||
return entity
|
|
||||||
|
|
||||||
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers
|
|
||||||
# following OpenAI's API standard.
|
|
||||||
def _generate(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
credentials: dict,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
model_parameters: dict,
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
stream: bool = True,
|
|
||||||
user: Optional[str] = None,
|
|
||||||
) -> Union[LLMResult, Generator]:
|
|
||||||
"""
|
|
||||||
Invoke llm completion model
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: credentials
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:param model_parameters: model parameters
|
|
||||||
:param stop: stop words
|
|
||||||
:param stream: is stream response
|
|
||||||
:param user: unique user id
|
|
||||||
:return: full response or stream response chunk generator result
|
|
||||||
"""
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Accept-Charset": "utf-8",
|
|
||||||
}
|
|
||||||
extra_headers = credentials.get("extra_headers")
|
|
||||||
if extra_headers is not None:
|
|
||||||
headers = {
|
|
||||||
**headers,
|
|
||||||
**extra_headers,
|
|
||||||
}
|
|
||||||
|
|
||||||
api_key = credentials.get("api_key")
|
|
||||||
if api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
|
||||||
|
|
||||||
endpoint_url = credentials["endpoint_url"]
|
|
||||||
if not endpoint_url.endswith("/"):
|
|
||||||
endpoint_url += "/"
|
|
||||||
|
|
||||||
data = {"model": model, "stream": stream, **model_parameters}
|
|
||||||
|
|
||||||
completion_type = LLMMode.value_of(credentials["mode"])
|
|
||||||
|
|
||||||
if completion_type is LLMMode.CHAT:
|
|
||||||
endpoint_url = urljoin(endpoint_url, "chat/completions")
|
|
||||||
data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages]
|
|
||||||
elif completion_type is LLMMode.COMPLETION:
|
|
||||||
endpoint_url = urljoin(endpoint_url, "completions")
|
|
||||||
data["prompt"] = prompt_messages[0].content
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported completion type for model configuration.")
|
|
||||||
|
|
||||||
# annotate tools with names, descriptions, etc.
|
|
||||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
|
||||||
formatted_tools = []
|
|
||||||
if tools:
|
|
||||||
if function_calling_type == "function_call":
|
|
||||||
data["functions"] = [
|
|
||||||
{"name": tool.name, "description": tool.description, "parameters": tool.parameters}
|
|
||||||
for tool in tools
|
|
||||||
]
|
|
||||||
elif function_calling_type == "tool_call":
|
|
||||||
data["tool_choice"] = "auto"
|
|
||||||
|
|
||||||
for tool in tools:
|
|
||||||
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
|
||||||
|
|
||||||
data["tools"] = formatted_tools
|
|
||||||
|
|
||||||
if stop:
|
|
||||||
data["stop"] = stop
|
|
||||||
|
|
||||||
if user:
|
|
||||||
data["user"] = user
|
|
||||||
|
|
||||||
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream)
|
|
||||||
|
|
||||||
if response.encoding is None or response.encoding == "ISO-8859-1":
|
|
||||||
response.encoding = "utf-8"
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
|
||||||
|
|
||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
|
||||||
|
|
||||||
def _handle_generate_stream_response(
|
|
||||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
|
||||||
) -> Generator:
|
|
||||||
"""
|
|
||||||
Handle llm stream response
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param credentials: model credentials
|
|
||||||
:param response: streamed response
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:return: llm response chunk generator
|
|
||||||
"""
|
|
||||||
full_assistant_content = ""
|
|
||||||
chunk_index = 0
|
|
||||||
|
|
||||||
def create_final_llm_result_chunk(
|
|
||||||
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
|
|
||||||
) -> LLMResultChunk:
|
|
||||||
# calculate num tokens
|
|
||||||
prompt_tokens = usage and usage.get("prompt_tokens")
|
|
||||||
if prompt_tokens is None:
|
|
||||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
|
||||||
completion_tokens = usage and usage.get("completion_tokens")
|
|
||||||
if completion_tokens is None:
|
|
||||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
return LLMResultChunk(
|
|
||||||
id=id,
|
|
||||||
model=model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
|
||||||
)
|
|
||||||
|
|
||||||
# delimiter for stream response, need unicode_escape
|
|
||||||
import codecs
|
|
||||||
|
|
||||||
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
|
|
||||||
delimiter = codecs.decode(delimiter, "unicode_escape")
|
|
||||||
|
|
||||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
|
||||||
|
|
||||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
|
||||||
def get_tool_call(tool_call_id: str):
|
|
||||||
if not tool_call_id:
|
|
||||||
return tools_calls[-1]
|
|
||||||
|
|
||||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
|
|
||||||
if tool_call is None:
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
|
||||||
id=tool_call_id,
|
|
||||||
type="function",
|
|
||||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
|
||||||
)
|
|
||||||
tools_calls.append(tool_call)
|
|
||||||
|
|
||||||
return tool_call
|
|
||||||
|
|
||||||
for new_tool_call in new_tool_calls:
|
|
||||||
# get tool call
|
|
||||||
tool_call = get_tool_call(new_tool_call.function.name)
|
|
||||||
# update tool call
|
|
||||||
if new_tool_call.id:
|
|
||||||
tool_call.id = new_tool_call.id
|
|
||||||
if new_tool_call.type:
|
|
||||||
tool_call.type = new_tool_call.type
|
|
||||||
if new_tool_call.function.name:
|
|
||||||
tool_call.function.name = new_tool_call.function.name
|
|
||||||
if new_tool_call.function.arguments:
|
|
||||||
tool_call.function.arguments += new_tool_call.function.arguments
|
|
||||||
|
|
||||||
finish_reason = None # The default value of finish_reason is None
|
|
||||||
message_id, usage = None, None
|
|
||||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
|
||||||
chunk = chunk.strip()
|
|
||||||
if chunk:
|
|
||||||
# ignore sse comments
|
|
||||||
if chunk.startswith(":"):
|
|
||||||
continue
|
|
||||||
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
|
|
||||||
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
chunk_json: dict = json.loads(decoded_chunk)
|
|
||||||
# stream ended
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
yield create_final_llm_result_chunk(
|
|
||||||
id=message_id,
|
|
||||||
index=chunk_index + 1,
|
|
||||||
message=AssistantPromptMessage(content=""),
|
|
||||||
finish_reason="Non-JSON encountered.",
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
if chunk_json:
|
|
||||||
if u := chunk_json.get("usage"):
|
|
||||||
usage = u
|
|
||||||
if not chunk_json or len(chunk_json["choices"]) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
choice = chunk_json["choices"][0]
|
|
||||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
|
||||||
message_id = chunk_json.get("id")
|
|
||||||
chunk_index += 1
|
|
||||||
|
|
||||||
if "delta" in choice:
|
|
||||||
delta = choice["delta"]
|
|
||||||
delta_content = delta.get("content")
|
|
||||||
|
|
||||||
assistant_message_tool_calls = None
|
|
||||||
|
|
||||||
if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call":
|
|
||||||
assistant_message_tool_calls = delta.get("tool_calls", None)
|
|
||||||
elif (
|
|
||||||
"function_call" in delta
|
|
||||||
and credentials.get("function_calling_type", "no_call") == "function_call"
|
|
||||||
):
|
|
||||||
assistant_message_tool_calls = [
|
|
||||||
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
|
|
||||||
]
|
|
||||||
|
|
||||||
# assistant_message_function_call = delta.delta.function_call
|
|
||||||
|
|
||||||
# extract tool calls from response
|
|
||||||
if assistant_message_tool_calls:
|
|
||||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
|
||||||
increase_tool_call(tool_calls)
|
|
||||||
|
|
||||||
if delta_content is None or delta_content == "":
|
|
||||||
continue
|
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
|
||||||
content=delta_content,
|
|
||||||
)
|
|
||||||
|
|
||||||
# reset tool calls
|
|
||||||
tool_calls = []
|
|
||||||
full_assistant_content += delta_content
|
|
||||||
elif "text" in choice:
|
|
||||||
choice_text = choice.get("text", "")
|
|
||||||
if choice_text == "":
|
|
||||||
continue
|
|
||||||
|
|
||||||
# transform assistant message to prompt message
|
|
||||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
|
||||||
full_assistant_content += choice_text
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
|
||||||
id=message_id,
|
|
||||||
model=model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=chunk_index,
|
|
||||||
message=assistant_prompt_message,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
chunk_index += 1
|
|
||||||
|
|
||||||
if tools_calls:
|
|
||||||
yield LLMResultChunk(
|
|
||||||
id=message_id,
|
|
||||||
model=model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
delta=LLMResultChunkDelta(
|
|
||||||
index=chunk_index,
|
|
||||||
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
yield create_final_llm_result_chunk(
|
|
||||||
id=message_id,
|
|
||||||
index=chunk_index,
|
|
||||||
message=AssistantPromptMessage(content=""),
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_generate_response(
|
|
||||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
|
||||||
) -> LLMResult:
|
|
||||||
response_json: dict = response.json()
|
|
||||||
|
|
||||||
completion_type = LLMMode.value_of(credentials["mode"])
|
|
||||||
|
|
||||||
output = response_json["choices"][0]
|
|
||||||
message_id = response_json.get("id")
|
|
||||||
|
|
||||||
response_content = ""
|
|
||||||
tool_calls = None
|
|
||||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
|
||||||
if completion_type is LLMMode.CHAT:
|
|
||||||
response_content = output.get("message", {})["content"]
|
|
||||||
if function_calling_type == "tool_call":
|
|
||||||
tool_calls = output.get("message", {}).get("tool_calls")
|
|
||||||
elif function_calling_type == "function_call":
|
|
||||||
tool_calls = output.get("message", {}).get("function_call")
|
|
||||||
|
|
||||||
elif completion_type is LLMMode.COMPLETION:
|
|
||||||
response_content = output["text"]
|
|
||||||
|
|
||||||
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
|
|
||||||
|
|
||||||
if tool_calls:
|
|
||||||
if function_calling_type == "tool_call":
|
|
||||||
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
|
||||||
elif function_calling_type == "function_call":
|
|
||||||
assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]
|
|
||||||
|
|
||||||
usage = response_json.get("usage")
|
|
||||||
if usage:
|
|
||||||
# transform usage
|
|
||||||
prompt_tokens = usage["prompt_tokens"]
|
|
||||||
completion_tokens = usage["completion_tokens"]
|
|
||||||
else:
|
|
||||||
# calculate num tokens
|
|
||||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
|
||||||
completion_tokens = self._num_tokens_from_string(model, assistant_message.content)
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
# transform response
|
|
||||||
result = LLMResult(
|
|
||||||
id=message_id,
|
|
||||||
model=response_json["model"],
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
message=assistant_message,
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict:
|
|
||||||
"""
|
|
||||||
Convert PromptMessage to dict for OpenAI API format
|
|
||||||
"""
|
|
||||||
if isinstance(message, UserPromptMessage):
|
|
||||||
message = cast(UserPromptMessage, message)
|
|
||||||
if isinstance(message.content, str):
|
|
||||||
message_dict = {"role": "user", "content": message.content}
|
|
||||||
else:
|
|
||||||
sub_messages = []
|
|
||||||
for message_content in message.content:
|
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
|
||||||
message_content = cast(PromptMessageContent, message_content)
|
|
||||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
|
||||||
sub_messages.append(sub_message_dict)
|
|
||||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
|
||||||
message_content = cast(ImagePromptMessageContent, message_content)
|
|
||||||
sub_message_dict = {
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
|
||||||
}
|
|
||||||
sub_messages.append(sub_message_dict)
|
|
||||||
|
|
||||||
message_dict = {"role": "user", "content": sub_messages}
|
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
|
||||||
message = cast(AssistantPromptMessage, message)
|
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
|
||||||
if message.tool_calls:
|
|
||||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
|
||||||
if function_calling_type == "tool_call":
|
|
||||||
message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls]
|
|
||||||
elif function_calling_type == "function_call":
|
|
||||||
function_call = message.tool_calls[0]
|
|
||||||
message_dict["function_call"] = {
|
|
||||||
"name": function_call.function.name,
|
|
||||||
"arguments": function_call.function.arguments,
|
|
||||||
}
|
|
||||||
elif isinstance(message, SystemPromptMessage):
|
|
||||||
message = cast(SystemPromptMessage, message)
|
|
||||||
message_dict = {"role": "system", "content": message.content}
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message = cast(ToolPromptMessage, message)
|
|
||||||
function_calling_type = credentials.get("function_calling_type", "no_call")
|
|
||||||
if function_calling_type == "tool_call":
|
|
||||||
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
|
|
||||||
elif function_calling_type == "function_call":
|
|
||||||
message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unknown type {message}")
|
|
||||||
|
|
||||||
if message.name and message_dict.get("role", "") != "tool":
|
|
||||||
message_dict["name"] = message.name
|
|
||||||
|
|
||||||
return message_dict
|
|
||||||
|
|
||||||
def _num_tokens_from_string(
|
|
||||||
self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Approximate num tokens for model with gpt2 tokenizer.
|
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param text: prompt text
|
|
||||||
:param tools: tools for tool calling
|
|
||||||
:return: number of tokens
|
|
||||||
"""
|
|
||||||
if isinstance(text, str):
|
|
||||||
full_text = text
|
|
||||||
else:
|
|
||||||
full_text = ""
|
|
||||||
for message_content in text:
|
|
||||||
if message_content.type == PromptMessageContentType.TEXT:
|
|
||||||
message_content = cast(PromptMessageContent, message_content)
|
|
||||||
full_text += message_content.data
|
|
||||||
|
|
||||||
num_tokens = self._get_num_tokens_by_gpt2(full_text)
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
num_tokens += self._num_tokens_for_tools(tools)
|
|
||||||
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
def _num_tokens_from_messages(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list[PromptMessage],
|
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
|
||||||
credentials: Optional[dict] = None,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Approximate num tokens with GPT2 tokenizer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tokens_per_message = 3
|
|
||||||
tokens_per_name = 1
|
|
||||||
|
|
||||||
num_tokens = 0
|
|
||||||
messages_dict = [self._convert_prompt_message_to_dict(m, credentials) for m in messages]
|
|
||||||
for message in messages_dict:
|
|
||||||
num_tokens += tokens_per_message
|
|
||||||
for key, value in message.items():
|
|
||||||
# Cast str(value) in case the message value is not a string
|
|
||||||
# This occurs with function messages
|
|
||||||
# TODO: The current token calculation method for the image type is not implemented,
|
|
||||||
# which need to download the image and then get the resolution for calculation,
|
|
||||||
# and will increase the request delay
|
|
||||||
if isinstance(value, list):
|
|
||||||
text = ""
|
|
||||||
for item in value:
|
|
||||||
if isinstance(item, dict) and item["type"] == "text":
|
|
||||||
text += item["text"]
|
|
||||||
|
|
||||||
value = text
|
|
||||||
|
|
||||||
if key == "tool_calls":
|
|
||||||
for tool_call in value:
|
|
||||||
for t_key, t_value in tool_call.items():
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(t_key)
|
|
||||||
if t_key == "function":
|
|
||||||
for f_key, f_value in t_value.items():
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(f_key)
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(f_value)
|
|
||||||
else:
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(t_key)
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(t_value)
|
|
||||||
else:
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(str(value))
|
|
||||||
|
|
||||||
if key == "name":
|
|
||||||
num_tokens += tokens_per_name
|
|
||||||
|
|
||||||
# every reply is primed with <im_start>assistant
|
|
||||||
num_tokens += 3
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
num_tokens += self._num_tokens_for_tools(tools)
|
|
||||||
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
|
||||||
"""
|
|
||||||
Calculate num tokens for tool calling with tiktoken package.
|
|
||||||
|
|
||||||
:param tools: tools for tool calling
|
|
||||||
:return: number of tokens
|
|
||||||
"""
|
|
||||||
num_tokens = 0
|
|
||||||
for tool in tools:
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2("type")
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2("function")
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2("function")
|
|
||||||
|
|
||||||
# calculate num tokens for function object
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2("name")
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(tool.name)
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2("description")
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(tool.description)
|
|
||||||
parameters = tool.parameters
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2("parameters")
|
|
||||||
if "title" in parameters:
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2("title")
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title"))
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2("type")
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type"))
|
|
||||||
if "properties" in parameters:
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2("properties")
|
|
||||||
for key, value in parameters.get("properties").items():
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(key)
|
|
||||||
for field_key, field_value in value.items():
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
|
||||||
if field_key == "enum":
|
|
||||||
for enum_field in field_value:
|
|
||||||
num_tokens += 3
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(enum_field)
|
|
||||||
else:
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(field_key)
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(str(field_value))
|
|
||||||
if "required" in parameters:
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2("required")
|
|
||||||
for required_field in parameters["required"]:
|
|
||||||
num_tokens += 3
|
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
|
||||||
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
|
|
||||||
"""
|
|
||||||
Extract tool calls from response
|
|
||||||
|
|
||||||
:param response_tool_calls: response tool calls
|
|
||||||
:return: list of tool calls
|
|
||||||
"""
|
|
||||||
tool_calls = []
|
|
||||||
if response_tool_calls:
|
|
||||||
for response_tool_call in response_tool_calls:
|
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
||||||
name=response_tool_call.get("function", {}).get("name", ""),
|
|
||||||
arguments=response_tool_call.get("function", {}).get("arguments", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
|
||||||
id=response_tool_call.get("id", ""), type=response_tool_call.get("type", ""), function=function
|
|
||||||
)
|
|
||||||
tool_calls.append(tool_call)
|
|
||||||
|
|
||||||
return tool_calls
|
|
||||||
|
|
||||||
def _extract_response_function_call(self, response_function_call) -> AssistantPromptMessage.ToolCall:
|
|
||||||
"""
|
|
||||||
Extract function call from response
|
|
||||||
|
|
||||||
:param response_function_call: response function call
|
|
||||||
:return: tool call
|
|
||||||
"""
|
|
||||||
tool_call = None
|
|
||||||
if response_function_call:
|
|
||||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
||||||
name=response_function_call.get("name", ""), arguments=response_function_call.get("arguments", "")
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
|
||||||
id=response_function_call.get("id", ""), type="function", function=function
|
|
||||||
)
|
|
||||||
|
|
||||||
return tool_call
|
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
@ -0,0 +1,3 @@
|
|||||||
|
<svg width="1200" height="925" viewBox="0 0 1200 925" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<path d="M780.152 250.999L907.882 462.174C907.882 462.174 880.925 510.854 867.43 535.21C834.845 594.039 764.171 612.49 710.442 508.333L420.376 0H0L459.926 803.307C552.303 964.663 787.366 964.663 879.743 803.307C989.874 610.952 1089.87 441.97 1200 249.646L1052.28 0H639.519L780.152 250.999Z" fill="#3366FF"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 417 B |
@ -0,0 +1,83 @@
|
|||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMMode
|
||||||
|
from core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
DefaultParameterName,
|
||||||
|
FetchFrom,
|
||||||
|
ModelPropertyKey,
|
||||||
|
ModelType,
|
||||||
|
ParameterRule,
|
||||||
|
ParameterType,
|
||||||
|
PriceConfig,
|
||||||
|
)
|
||||||
|
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||||
|
|
||||||
|
|
||||||
|
class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||||
|
features = []
|
||||||
|
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
features=features,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.MODE: credentials.get("mode"),
|
||||||
|
},
|
||||||
|
parameter_rules=[
|
||||||
|
ParameterRule(
|
||||||
|
name=DefaultParameterName.TEMPERATURE.value,
|
||||||
|
label=I18nObject(en_US="Temperature"),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
default=float(credentials.get("temperature", 0.7)),
|
||||||
|
min=0,
|
||||||
|
max=2,
|
||||||
|
precision=2,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name=DefaultParameterName.TOP_P.value,
|
||||||
|
label=I18nObject(en_US="Top P"),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
default=float(credentials.get("top_p", 1)),
|
||||||
|
min=0,
|
||||||
|
max=1,
|
||||||
|
precision=2,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name=DefaultParameterName.TOP_K.value,
|
||||||
|
label=I18nObject(en_US="Top K"),
|
||||||
|
type=ParameterType.INT,
|
||||||
|
default=int(credentials.get("top_k", 50)),
|
||||||
|
min=-2147483647,
|
||||||
|
max=2147483647,
|
||||||
|
precision=0,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name=DefaultParameterName.MAX_TOKENS.value,
|
||||||
|
label=I18nObject(en_US="Max Tokens"),
|
||||||
|
type=ParameterType.INT,
|
||||||
|
default=512,
|
||||||
|
min=1,
|
||||||
|
max=int(credentials.get("max_tokens_to_sample", 4096)),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
pricing=PriceConfig(
|
||||||
|
input=Decimal(credentials.get("input_price", 0)),
|
||||||
|
output=Decimal(credentials.get("output_price", 0)),
|
||||||
|
unit=Decimal(credentials.get("unit", 0)),
|
||||||
|
currency=credentials.get("currency", "USD"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if credentials["mode"] == "chat":
|
||||||
|
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
|
||||||
|
elif credentials["mode"] == "completion":
|
||||||
|
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
||||||
|
|
||||||
|
return entity
|
||||||
@ -0,0 +1,10 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VesslAIProvider(ModelProvider):
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
pass
|
||||||
@ -0,0 +1,56 @@
|
|||||||
|
provider: vessl_ai
|
||||||
|
label:
|
||||||
|
en_US: vessl_ai
|
||||||
|
icon_small:
|
||||||
|
en_US: icon_s_en.svg
|
||||||
|
icon_large:
|
||||||
|
en_US: icon_l_en.png
|
||||||
|
background: "#F1EFED"
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: How to deploy VESSL AI LLM Model Endpoint
|
||||||
|
url:
|
||||||
|
en_US: https://docs.vessl.ai/guides/get-started/llama3-deployment
|
||||||
|
supported_model_types:
|
||||||
|
- llm
|
||||||
|
configurate_methods:
|
||||||
|
- customizable-model
|
||||||
|
model_credential_schema:
|
||||||
|
model:
|
||||||
|
label:
|
||||||
|
en_US: Model Name
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your model name
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: endpoint_url
|
||||||
|
label:
|
||||||
|
en_US: endpoint url
|
||||||
|
type: text-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter the url of your endpoint url
|
||||||
|
- variable: api_key
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your VESSL AI secret key
|
||||||
|
- variable: mode
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
label:
|
||||||
|
en_US: Completion mode
|
||||||
|
type: select
|
||||||
|
required: false
|
||||||
|
default: chat
|
||||||
|
placeholder:
|
||||||
|
en_US: Select completion mode
|
||||||
|
options:
|
||||||
|
- value: completion
|
||||||
|
label:
|
||||||
|
en_US: Completion
|
||||||
|
- value: chat
|
||||||
|
label:
|
||||||
|
en_US: Chat
|
||||||
@ -1,42 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
class AliYuqueTool:
|
|
||||||
# yuque service url
|
|
||||||
server_url = "https://www.yuque.com"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def auth(token):
|
|
||||||
session = requests.Session()
|
|
||||||
session.headers.update({"Accept": "application/json", "X-Auth-Token": token})
|
|
||||||
login = session.request("GET", AliYuqueTool.server_url + "/api/v2/user")
|
|
||||||
login.raise_for_status()
|
|
||||||
resp = login.json()
|
|
||||||
return resp
|
|
||||||
|
|
||||||
def request(self, method: str, token, tool_parameters: dict[str, Any], path: str) -> str:
|
|
||||||
if not token:
|
|
||||||
raise Exception("token is required")
|
|
||||||
session = requests.Session()
|
|
||||||
session.headers.update({"accept": "application/json", "X-Auth-Token": token})
|
|
||||||
new_params = {**tool_parameters}
|
|
||||||
|
|
||||||
replacements = {k: v for k, v in new_params.items() if f"{{{k}}}" in path}
|
|
||||||
|
|
||||||
for key, value in replacements.items():
|
|
||||||
path = path.replace(f"{{{key}}}", str(value))
|
|
||||||
del new_params[key]
|
|
||||||
|
|
||||||
if method.upper() in {"POST", "PUT"}:
|
|
||||||
session.headers.update(
|
|
||||||
{
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
response = session.request(method.upper(), self.server_url + path, json=new_params)
|
|
||||||
else:
|
|
||||||
response = session.request(method, self.server_url + path, params=new_params)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.text
|
|
||||||
@ -1,15 +0,0 @@
|
|||||||
from typing import Any, Union
|
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
||||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
|
||||||
|
|
||||||
|
|
||||||
class AliYuqueCreateDocumentTool(AliYuqueTool, BuiltinTool):
|
|
||||||
def _invoke(
|
|
||||||
self, user_id: str, tool_parameters: dict[str, Any]
|
|
||||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
||||||
token = self.runtime.credentials.get("token", None)
|
|
||||||
if not token:
|
|
||||||
raise Exception("token is required")
|
|
||||||
return self.create_text_message(self.request("POST", token, tool_parameters, "/api/v2/repos/{book_id}/docs"))
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
from typing import Any, Union
|
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
||||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
|
||||||
|
|
||||||
|
|
||||||
class AliYuqueDeleteDocumentTool(AliYuqueTool, BuiltinTool):
|
|
||||||
def _invoke(
|
|
||||||
self, user_id: str, tool_parameters: dict[str, Any]
|
|
||||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
||||||
token = self.runtime.credentials.get("token", None)
|
|
||||||
if not token:
|
|
||||||
raise Exception("token is required")
|
|
||||||
return self.create_text_message(
|
|
||||||
self.request("DELETE", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}")
|
|
||||||
)
|
|
||||||
@ -1,37 +0,0 @@
|
|||||||
identity:
|
|
||||||
name: aliyuque_delete_document
|
|
||||||
author: 佐井
|
|
||||||
label:
|
|
||||||
en_US: Delete Document
|
|
||||||
zh_Hans: 删除文档
|
|
||||||
icon: icon.svg
|
|
||||||
description:
|
|
||||||
human:
|
|
||||||
en_US: Delete Document
|
|
||||||
zh_Hans: 根据id删除文档
|
|
||||||
llm: Delete document.
|
|
||||||
|
|
||||||
parameters:
|
|
||||||
- name: book_id
|
|
||||||
type: string
|
|
||||||
required: true
|
|
||||||
form: llm
|
|
||||||
label:
|
|
||||||
en_US: Knowledge Base ID
|
|
||||||
zh_Hans: 知识库ID
|
|
||||||
human_description:
|
|
||||||
en_US: The unique identifier of the knowledge base where the document will be created.
|
|
||||||
zh_Hans: 文档将被创建的知识库的唯一标识。
|
|
||||||
llm_description: ID of the target knowledge base.
|
|
||||||
|
|
||||||
- name: id
|
|
||||||
type: string
|
|
||||||
required: true
|
|
||||||
form: llm
|
|
||||||
label:
|
|
||||||
en_US: Document ID or Path
|
|
||||||
zh_Hans: 文档 ID or 路径
|
|
||||||
human_description:
|
|
||||||
en_US: Document ID or path.
|
|
||||||
zh_Hans: 文档 ID or 路径。
|
|
||||||
llm_description: Document ID or path.
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
from typing import Any, Union
|
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
||||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
|
||||||
|
|
||||||
|
|
||||||
class AliYuqueDescribeBookIndexPageTool(AliYuqueTool, BuiltinTool):
|
|
||||||
def _invoke(
|
|
||||||
self, user_id: str, tool_parameters: dict[str, Any]
|
|
||||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
||||||
token = self.runtime.credentials.get("token", None)
|
|
||||||
if not token:
|
|
||||||
raise Exception("token is required")
|
|
||||||
return self.create_text_message(
|
|
||||||
self.request("GET", token, tool_parameters, "/api/v2/repos/{group_login}/{book_slug}/index_page")
|
|
||||||
)
|
|
||||||
@ -1,15 +0,0 @@
|
|||||||
from typing import Any, Union
|
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
||||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
|
||||||
|
|
||||||
|
|
||||||
class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool):
|
|
||||||
def _invoke(
|
|
||||||
self, user_id: str, tool_parameters: dict[str, Any]
|
|
||||||
) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
||||||
token = self.runtime.credentials.get("token", None)
|
|
||||||
if not token:
|
|
||||||
raise Exception("token is required")
|
|
||||||
return self.create_text_message(self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/toc"))
|
|
||||||
@ -1,25 +0,0 @@
|
|||||||
identity:
|
|
||||||
name: aliyuque_describe_book_table_of_contents
|
|
||||||
author: 佐井
|
|
||||||
label:
|
|
||||||
en_US: Get Book's Table of Contents
|
|
||||||
zh_Hans: 获取知识库的目录
|
|
||||||
icon: icon.svg
|
|
||||||
description:
|
|
||||||
human:
|
|
||||||
en_US: Get Book's Table of Contents.
|
|
||||||
zh_Hans: 获取知识库的目录。
|
|
||||||
llm: Get Book's Table of Contents.
|
|
||||||
|
|
||||||
parameters:
|
|
||||||
- name: book_id
|
|
||||||
type: string
|
|
||||||
required: true
|
|
||||||
form: llm
|
|
||||||
label:
|
|
||||||
en_US: Book ID
|
|
||||||
zh_Hans: 知识库 ID
|
|
||||||
human_description:
|
|
||||||
en_US: Book ID.
|
|
||||||
zh_Hans: 知识库 ID。
|
|
||||||
llm_description: Book ID.
|
|
||||||
@ -1,52 +0,0 @@
|
|||||||
import json
|
|
||||||
from typing import Any, Union
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
||||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
|
||||||
|
|
||||||
|
|
||||||
class AliYuqueDescribeDocumentContentTool(AliYuqueTool, BuiltinTool):
|
|
||||||
def _invoke(
|
|
||||||
self, user_id: str, tool_parameters: dict[str, Any]
|
|
||||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
||||||
new_params = {**tool_parameters}
|
|
||||||
token = new_params.pop("token")
|
|
||||||
if not token or token.lower() == "none":
|
|
||||||
token = self.runtime.credentials.get("token", None)
|
|
||||||
if not token:
|
|
||||||
raise Exception("token is required")
|
|
||||||
new_params = {**tool_parameters}
|
|
||||||
url = new_params.pop("url")
|
|
||||||
if not url or not url.startswith("http"):
|
|
||||||
raise Exception("url is not valid")
|
|
||||||
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
path_parts = parsed_url.path.strip("/").split("/")
|
|
||||||
if len(path_parts) < 3:
|
|
||||||
raise Exception("url is not correct")
|
|
||||||
doc_id = path_parts[-1]
|
|
||||||
book_slug = path_parts[-2]
|
|
||||||
group_id = path_parts[-3]
|
|
||||||
|
|
||||||
new_params["group_login"] = group_id
|
|
||||||
new_params["book_slug"] = book_slug
|
|
||||||
index_page = json.loads(
|
|
||||||
self.request("GET", token, new_params, "/api/v2/repos/{group_login}/{book_slug}/index_page")
|
|
||||||
)
|
|
||||||
book_id = index_page.get("data", {}).get("book", {}).get("id")
|
|
||||||
if not book_id:
|
|
||||||
raise Exception(f"can not parse book_id from {index_page}")
|
|
||||||
new_params["book_id"] = book_id
|
|
||||||
new_params["id"] = doc_id
|
|
||||||
data = self.request("GET", token, new_params, "/api/v2/repos/{book_id}/docs/{id}")
|
|
||||||
data = json.loads(data)
|
|
||||||
body_only = tool_parameters.get("body_only") or ""
|
|
||||||
if body_only.lower() == "true":
|
|
||||||
return self.create_text_message(data.get("data").get("body"))
|
|
||||||
else:
|
|
||||||
raw = data.get("data")
|
|
||||||
del raw["body_lake"]
|
|
||||||
del raw["body_html"]
|
|
||||||
return self.create_text_message(json.dumps(data))
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
from typing import Any, Union
|
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
||||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
|
||||||
|
|
||||||
|
|
||||||
class AliYuqueDescribeDocumentsTool(AliYuqueTool, BuiltinTool):
|
|
||||||
def _invoke(
|
|
||||||
self, user_id: str, tool_parameters: dict[str, Any]
|
|
||||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
||||||
token = self.runtime.credentials.get("token", None)
|
|
||||||
if not token:
|
|
||||||
raise Exception("token is required")
|
|
||||||
return self.create_text_message(
|
|
||||||
self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}")
|
|
||||||
)
|
|
||||||
@ -1,38 +0,0 @@
|
|||||||
identity:
|
|
||||||
name: aliyuque_describe_documents
|
|
||||||
author: 佐井
|
|
||||||
label:
|
|
||||||
en_US: Get Doc Detail
|
|
||||||
zh_Hans: 获取文档详情
|
|
||||||
icon: icon.svg
|
|
||||||
|
|
||||||
description:
|
|
||||||
human:
|
|
||||||
en_US: Retrieves detailed information of a specific document identified by its ID or path within a knowledge base.
|
|
||||||
zh_Hans: 根据知识库ID和文档ID或路径获取文档详细信息。
|
|
||||||
llm: Fetches detailed doc info using ID/path from a knowledge base; supports doc lookup in Yuque.
|
|
||||||
|
|
||||||
parameters:
|
|
||||||
- name: book_id
|
|
||||||
type: string
|
|
||||||
required: true
|
|
||||||
form: llm
|
|
||||||
label:
|
|
||||||
en_US: Knowledge Base ID
|
|
||||||
zh_Hans: 知识库 ID
|
|
||||||
human_description:
|
|
||||||
en_US: Identifier for the knowledge base where the document resides.
|
|
||||||
zh_Hans: 文档所属知识库的唯一标识。
|
|
||||||
llm_description: ID of the knowledge base holding the document.
|
|
||||||
|
|
||||||
- name: id
|
|
||||||
type: string
|
|
||||||
required: true
|
|
||||||
form: llm
|
|
||||||
label:
|
|
||||||
en_US: Document ID or Path
|
|
||||||
zh_Hans: 文档 ID 或路径
|
|
||||||
human_description:
|
|
||||||
en_US: The unique identifier or path of the document to retrieve.
|
|
||||||
zh_Hans: 需要获取的文档的ID或其在知识库中的路径。
|
|
||||||
llm_description: Unique doc ID or its path for retrieval.
|
|
||||||
@ -1,21 +0,0 @@
|
|||||||
from typing import Any, Union
|
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
||||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
|
||||||
|
|
||||||
|
|
||||||
class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool):
|
|
||||||
def _invoke(
|
|
||||||
self, user_id: str, tool_parameters: dict[str, Any]
|
|
||||||
) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
||||||
token = self.runtime.credentials.get("token", None)
|
|
||||||
if not token:
|
|
||||||
raise Exception("token is required")
|
|
||||||
|
|
||||||
doc_ids = tool_parameters.get("doc_ids")
|
|
||||||
if doc_ids:
|
|
||||||
doc_ids = [int(doc_id.strip()) for doc_id in doc_ids.split(",")]
|
|
||||||
tool_parameters["doc_ids"] = doc_ids
|
|
||||||
|
|
||||||
return self.create_text_message(self.request("PUT", token, tool_parameters, "/api/v2/repos/{book_id}/toc"))
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue