|
|
|
|
@ -119,8 +119,15 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
if stop:
|
|
|
|
|
req_params['stop'] = stop
|
|
|
|
|
|
|
|
|
|
extra_model_kwargs = {}
|
|
|
|
|
|
|
|
|
|
if tools:
|
|
|
|
|
extra_model_kwargs['tools'] = [
|
|
|
|
|
MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
resp = MaaSClient.wrap_exception(
|
|
|
|
|
lambda: client.chat(req_params, prompt_messages, stream))
|
|
|
|
|
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
|
|
|
|
|
if not stream:
|
|
|
|
|
return self._handle_chat_response(model, credentials, prompt_messages, resp)
|
|
|
|
|
return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
|
|
|
|
|
@ -156,12 +163,26 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
choice = choices[0]
|
|
|
|
|
message = choice['message']
|
|
|
|
|
|
|
|
|
|
# parse tool calls
|
|
|
|
|
tool_calls = []
|
|
|
|
|
if message['tool_calls']:
|
|
|
|
|
for call in message['tool_calls']:
|
|
|
|
|
tool_call = AssistantPromptMessage.ToolCall(
|
|
|
|
|
id=call['function']['name'],
|
|
|
|
|
type=call['type'],
|
|
|
|
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
|
|
name=call['function']['name'],
|
|
|
|
|
arguments=call['function']['arguments']
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
tool_calls.append(tool_call)
|
|
|
|
|
|
|
|
|
|
return LLMResult(
|
|
|
|
|
model=model,
|
|
|
|
|
prompt_messages=prompt_messages,
|
|
|
|
|
message=AssistantPromptMessage(
|
|
|
|
|
content=message['content'] if message['content'] else '',
|
|
|
|
|
tool_calls=[],
|
|
|
|
|
tool_calls=tool_calls,
|
|
|
|
|
),
|
|
|
|
|
usage=self._calc_usage(model, credentials, resp['usage']),
|
|
|
|
|
)
|
|
|
|
|
@ -252,6 +273,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
if credentials.get('context_size'):
|
|
|
|
|
model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
|
|
|
|
credentials.get('context_size', 4096))
|
|
|
|
|
|
|
|
|
|
model_features = ModelConfigs.get(
|
|
|
|
|
credentials['base_model_name'], {}).get('features', [])
|
|
|
|
|
|
|
|
|
|
entity = AIModelEntity(
|
|
|
|
|
model=model,
|
|
|
|
|
label=I18nObject(
|
|
|
|
|
@ -260,7 +285,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
|
|
model_type=ModelType.LLM,
|
|
|
|
|
model_properties=model_properties,
|
|
|
|
|
parameter_rules=rules
|
|
|
|
|
parameter_rules=rules,
|
|
|
|
|
features=model_features,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return entity
|
|
|
|
|
|