|
|
|
|
@ -25,6 +25,7 @@ from core.model_runtime.entities.model_entities import (
|
|
|
|
|
AIModelEntity,
|
|
|
|
|
DefaultParameterName,
|
|
|
|
|
FetchFrom,
|
|
|
|
|
ModelFeature,
|
|
|
|
|
ModelPropertyKey,
|
|
|
|
|
ModelType,
|
|
|
|
|
ParameterRule,
|
|
|
|
|
@ -166,11 +167,23 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
|
|
"""
|
|
|
|
|
generate custom model entities from credentials
|
|
|
|
|
"""
|
|
|
|
|
support_function_call = False
|
|
|
|
|
features = []
|
|
|
|
|
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
|
|
|
|
if function_calling_type == 'function_call':
|
|
|
|
|
features = [ModelFeature.TOOL_CALL]
|
|
|
|
|
support_function_call = True
|
|
|
|
|
endpoint_url = credentials["endpoint_url"]
|
|
|
|
|
# if not endpoint_url.endswith('/'):
|
|
|
|
|
# endpoint_url += '/'
|
|
|
|
|
# if 'https://api.openai.com/v1/' == endpoint_url:
|
|
|
|
|
# features = [ModelFeature.STREAM_TOOL_CALL]
|
|
|
|
|
entity = AIModelEntity(
|
|
|
|
|
model=model,
|
|
|
|
|
label=I18nObject(en_US=model),
|
|
|
|
|
model_type=ModelType.LLM,
|
|
|
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
|
|
features=features if support_function_call else [],
|
|
|
|
|
model_properties={
|
|
|
|
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
|
|
|
|
|
ModelPropertyKey.MODE: credentials.get('mode'),
|
|
|
|
|
@ -194,14 +207,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
|
|
max=1,
|
|
|
|
|
precision=2
|
|
|
|
|
),
|
|
|
|
|
ParameterRule(
|
|
|
|
|
name="top_k",
|
|
|
|
|
label=I18nObject(en_US="Top K"),
|
|
|
|
|
type=ParameterType.INT,
|
|
|
|
|
default=int(credentials.get('top_k', 1)),
|
|
|
|
|
min=1,
|
|
|
|
|
max=100
|
|
|
|
|
),
|
|
|
|
|
ParameterRule(
|
|
|
|
|
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
|
|
|
|
label=I18nObject(en_US="Frequency Penalty"),
|
|
|
|
|
@ -232,7 +237,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
|
|
output=Decimal(credentials.get('output_price', 0)),
|
|
|
|
|
unit=Decimal(credentials.get('unit', 0)),
|
|
|
|
|
currency=credentials.get('currency', "USD")
|
|
|
|
|
)
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if credentials['mode'] == 'chat':
|
|
|
|
|
@ -292,14 +297,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
|
|
raise ValueError("Unsupported completion type for model configuration.")
|
|
|
|
|
|
|
|
|
|
# annotate tools with names, descriptions, etc.
|
|
|
|
|
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
|
|
|
|
formatted_tools = []
|
|
|
|
|
if tools:
|
|
|
|
|
data["tool_choice"] = "auto"
|
|
|
|
|
if function_calling_type == 'function_call':
|
|
|
|
|
data['functions'] = [{
|
|
|
|
|
"name": tool.name,
|
|
|
|
|
"description": tool.description,
|
|
|
|
|
"parameters": tool.parameters
|
|
|
|
|
} for tool in tools]
|
|
|
|
|
elif function_calling_type == 'tool_call':
|
|
|
|
|
data["tool_choice"] = "auto"
|
|
|
|
|
|
|
|
|
|
for tool in tools:
|
|
|
|
|
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
|
|
|
|
for tool in tools:
|
|
|
|
|
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
|
|
|
|
|
|
|
|
|
data["tools"] = formatted_tools
|
|
|
|
|
data["tools"] = formatted_tools
|
|
|
|
|
|
|
|
|
|
if stop:
|
|
|
|
|
data["stop"] = stop
|
|
|
|
|
@ -367,7 +380,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
|
|
|
|
|
|
|
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
|
|
|
|
if chunk:
|
|
|
|
|
#ignore sse comments
|
|
|
|
|
# ignore sse comments
|
|
|
|
|
if chunk.startswith(':'):
|
|
|
|
|
continue
|
|
|
|
|
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
|
|
|
|
@ -452,10 +465,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
|
|
|
|
|
|
|
response_content = ''
|
|
|
|
|
tool_calls = None
|
|
|
|
|
|
|
|
|
|
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
|
|
|
|
if completion_type is LLMMode.CHAT:
|
|
|
|
|
response_content = output.get('message', {})['content']
|
|
|
|
|
tool_calls = output.get('message', {}).get('tool_calls')
|
|
|
|
|
if function_calling_type == 'tool_call':
|
|
|
|
|
tool_calls = output.get('message', {}).get('tool_calls')
|
|
|
|
|
elif function_calling_type == 'function_call':
|
|
|
|
|
tool_calls = output.get('message', {}).get('function_call')
|
|
|
|
|
|
|
|
|
|
elif completion_type is LLMMode.COMPLETION:
|
|
|
|
|
response_content = output['text']
|
|
|
|
|
@ -463,7 +479,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
|
|
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
|
|
|
|
|
|
|
|
|
|
if tool_calls:
|
|
|
|
|
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
|
|
|
|
if function_calling_type == 'tool_call':
|
|
|
|
|
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
|
|
|
|
elif function_calling_type == 'function_call':
|
|
|
|
|
assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]
|
|
|
|
|
|
|
|
|
|
usage = response_json.get("usage")
|
|
|
|
|
if usage:
|
|
|
|
|
@ -522,33 +541,34 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
|
|
message = cast(AssistantPromptMessage, message)
|
|
|
|
|
message_dict = {"role": "assistant", "content": message.content}
|
|
|
|
|
if message.tool_calls:
|
|
|
|
|
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
|
|
|
|
in
|
|
|
|
|
message.tool_calls]
|
|
|
|
|
# function_call = message.tool_calls[0]
|
|
|
|
|
# message_dict["function_call"] = {
|
|
|
|
|
# "name": function_call.function.name,
|
|
|
|
|
# "arguments": function_call.function.arguments,
|
|
|
|
|
# }
|
|
|
|
|
# message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
|
|
|
|
# in
|
|
|
|
|
# message.tool_calls]
|
|
|
|
|
|
|
|
|
|
function_call = message.tool_calls[0]
|
|
|
|
|
message_dict["function_call"] = {
|
|
|
|
|
"name": function_call.function.name,
|
|
|
|
|
"arguments": function_call.function.arguments,
|
|
|
|
|
}
|
|
|
|
|
elif isinstance(message, SystemPromptMessage):
|
|
|
|
|
message = cast(SystemPromptMessage, message)
|
|
|
|
|
message_dict = {"role": "system", "content": message.content}
|
|
|
|
|
elif isinstance(message, ToolPromptMessage):
|
|
|
|
|
message = cast(ToolPromptMessage, message)
|
|
|
|
|
message_dict = {
|
|
|
|
|
"role": "tool",
|
|
|
|
|
"content": message.content,
|
|
|
|
|
"tool_call_id": message.tool_call_id
|
|
|
|
|
}
|
|
|
|
|
# message_dict = {
|
|
|
|
|
# "role": "function",
|
|
|
|
|
# "role": "tool",
|
|
|
|
|
# "content": message.content,
|
|
|
|
|
# "name": message.tool_call_id
|
|
|
|
|
# "tool_call_id": message.tool_call_id
|
|
|
|
|
# }
|
|
|
|
|
message_dict = {
|
|
|
|
|
"role": "function",
|
|
|
|
|
"content": message.content,
|
|
|
|
|
"name": message.tool_call_id
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Got unknown type {message}")
|
|
|
|
|
|
|
|
|
|
if message.name is not None:
|
|
|
|
|
if message.name:
|
|
|
|
|
message_dict["name"] = message.name
|
|
|
|
|
|
|
|
|
|
return message_dict
|
|
|
|
|
@ -693,3 +713,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
|
|
tool_calls.append(tool_call)
|
|
|
|
|
|
|
|
|
|
return tool_calls
|
|
|
|
|
|
|
|
|
|
def _extract_response_function_call(self, response_function_call) \
|
|
|
|
|
-> AssistantPromptMessage.ToolCall:
|
|
|
|
|
"""
|
|
|
|
|
Extract function call from response
|
|
|
|
|
|
|
|
|
|
:param response_function_call: response function call
|
|
|
|
|
:return: tool call
|
|
|
|
|
"""
|
|
|
|
|
tool_call = None
|
|
|
|
|
if response_function_call:
|
|
|
|
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
|
|
name=response_function_call['name'],
|
|
|
|
|
arguments=response_function_call['arguments']
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
tool_call = AssistantPromptMessage.ToolCall(
|
|
|
|
|
id=response_function_call['name'],
|
|
|
|
|
type="function",
|
|
|
|
|
function=function
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return tool_call
|
|
|
|
|
|