feat: Introduce Ark SDK v3 and ensure compatibility with models of SDK v2 (#7579)
Co-authored-by: crazywoola <427733928@qq.com>pull/6531/head
parent
b035c02f78
commit
efc136cce5
@ -0,0 +1,134 @@
|
|||||||
|
import re
|
||||||
|
from collections.abc import Callable, Generator
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageContentType,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error
|
||||||
|
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService
|
||||||
|
|
||||||
|
|
||||||
|
class MaaSClient(MaasService):
|
||||||
|
def __init__(self, host: str, region: str):
|
||||||
|
self.endpoint_id = None
|
||||||
|
super().__init__(host, region)
|
||||||
|
|
||||||
|
def set_endpoint_id(self, endpoint_id: str):
|
||||||
|
self.endpoint_id = endpoint_id
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_credential(cls, credentials: dict) -> 'MaaSClient':
|
||||||
|
host = credentials['api_endpoint_host']
|
||||||
|
region = credentials['volc_region']
|
||||||
|
ak = credentials['volc_access_key_id']
|
||||||
|
sk = credentials['volc_secret_access_key']
|
||||||
|
endpoint_id = credentials['endpoint_id']
|
||||||
|
|
||||||
|
client = cls(host, region)
|
||||||
|
client.set_endpoint_id(endpoint_id)
|
||||||
|
client.set_ak(ak)
|
||||||
|
client.set_sk(sk)
|
||||||
|
return client
|
||||||
|
|
||||||
|
def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
|
||||||
|
req = {
|
||||||
|
'parameters': params,
|
||||||
|
'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
|
||||||
|
**extra_model_kwargs,
|
||||||
|
}
|
||||||
|
if not stream:
|
||||||
|
return super().chat(
|
||||||
|
self.endpoint_id,
|
||||||
|
req,
|
||||||
|
)
|
||||||
|
return super().stream_chat(
|
||||||
|
self.endpoint_id,
|
||||||
|
req,
|
||||||
|
)
|
||||||
|
|
||||||
|
def embeddings(self, texts: list[str]) -> dict:
|
||||||
|
req = {
|
||||||
|
'input': texts
|
||||||
|
}
|
||||||
|
return super().embeddings(self.endpoint_id, req)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
|
message = cast(UserPromptMessage, message)
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
message_dict = {"role": ChatRole.USER,
|
||||||
|
"content": message.content}
|
||||||
|
else:
|
||||||
|
content = []
|
||||||
|
for message_content in message.content:
|
||||||
|
if message_content.type == PromptMessageContentType.TEXT:
|
||||||
|
raise ValueError(
|
||||||
|
'Content object type only support image_url')
|
||||||
|
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||||
|
message_content = cast(
|
||||||
|
ImagePromptMessageContent, message_content)
|
||||||
|
image_data = re.sub(
|
||||||
|
r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
||||||
|
content.append({
|
||||||
|
'type': 'image_url',
|
||||||
|
'image_url': {
|
||||||
|
'url': '',
|
||||||
|
'image_bytes': image_data,
|
||||||
|
'detail': message_content.detail,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
message_dict = {'role': ChatRole.USER, 'content': content}
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
message = cast(AssistantPromptMessage, message)
|
||||||
|
message_dict = {'role': ChatRole.ASSISTANT,
|
||||||
|
'content': message.content}
|
||||||
|
if message.tool_calls:
|
||||||
|
message_dict['tool_calls'] = [
|
||||||
|
{
|
||||||
|
'name': call.function.name,
|
||||||
|
'arguments': call.function.arguments
|
||||||
|
} for call in message.tool_calls
|
||||||
|
]
|
||||||
|
elif isinstance(message, SystemPromptMessage):
|
||||||
|
message = cast(SystemPromptMessage, message)
|
||||||
|
message_dict = {'role': ChatRole.SYSTEM,
|
||||||
|
'content': message.content}
|
||||||
|
elif isinstance(message, ToolPromptMessage):
|
||||||
|
message = cast(ToolPromptMessage, message)
|
||||||
|
message_dict = {'role': ChatRole.FUNCTION,
|
||||||
|
'content': message.content,
|
||||||
|
'name': message.tool_call_id}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown PromptMessage type {message}")
|
||||||
|
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
|
||||||
|
try:
|
||||||
|
resp = fn()
|
||||||
|
except MaasException as e:
|
||||||
|
raise wrap_error(e)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.parameters,
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,4 +1,4 @@
|
|||||||
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException
|
||||||
|
|
||||||
|
|
||||||
class ClientSDKRequestError(MaasException):
|
class ClientSDKRequestError(MaasException):
|
||||||
Loading…
Reference in New Issue