|
|
|
|
@ -63,6 +63,9 @@ from core.model_runtime.model_providers.xinference.xinference_helper import (
|
|
|
|
|
)
|
|
|
|
|
from core.model_runtime.utils import helper
|
|
|
|
|
|
|
|
|
|
DEFAULT_MAX_RETRIES = 3
|
|
|
|
|
DEFAULT_INVOKE_TIMEOUT = 60
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
def _invoke(
|
|
|
|
|
@ -315,7 +318,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
message_dict = {"role": "system", "content": message.content}
|
|
|
|
|
elif isinstance(message, ToolPromptMessage):
|
|
|
|
|
message = cast(ToolPromptMessage, message)
|
|
|
|
|
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
|
|
|
|
|
message_dict = {
|
|
|
|
|
"tool_call_id": message.tool_call_id,
|
|
|
|
|
"role": "tool",
|
|
|
|
|
"content": message.content,
|
|
|
|
|
"name": message.name,
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unknown message type {type(message)}")
|
|
|
|
|
|
|
|
|
|
@ -466,8 +474,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
client = OpenAI(
|
|
|
|
|
base_url=f'{credentials["server_url"]}/v1',
|
|
|
|
|
api_key=api_key,
|
|
|
|
|
max_retries=3,
|
|
|
|
|
timeout=60,
|
|
|
|
|
max_retries=int(credentials.get("max_retries") or DEFAULT_MAX_RETRIES),
|
|
|
|
|
timeout=int(credentials.get("invoke_timeout") or DEFAULT_INVOKE_TIMEOUT),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
xinference_client = Client(
|
|
|
|
|
|