|
|
|
|
@ -5,7 +5,7 @@ from typing import Union, Generator, Dict, Any, Tuple, List
|
|
|
|
|
|
|
|
|
|
from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
|
|
|
|
|
SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage, LLMResultChunkDelta
|
|
|
|
|
from core.model_manager import ModelInstance
|
|
|
|
|
from core.application_queue_manager import PublishFrom
|
|
|
|
|
|
|
|
|
|
@ -20,8 +20,7 @@ from models.model import Conversation, Message, MessageAgentThought
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
|
|
def run(self, model_instance: ModelInstance,
|
|
|
|
|
conversation: Conversation,
|
|
|
|
|
def run(self, conversation: Conversation,
|
|
|
|
|
message: Message,
|
|
|
|
|
query: str,
|
|
|
|
|
) -> Generator[LLMResultChunk, None, None]:
|
|
|
|
|
@ -81,6 +80,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
|
|
llm_usage.prompt_price += usage.prompt_price
|
|
|
|
|
llm_usage.completion_price += usage.completion_price
|
|
|
|
|
|
|
|
|
|
model_instance = self.model_instance
|
|
|
|
|
|
|
|
|
|
while function_call_state and iteration_step <= max_iteration_steps:
|
|
|
|
|
function_call_state = False
|
|
|
|
|
|
|
|
|
|
@ -101,12 +102,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
|
|
# recale llm max tokens
|
|
|
|
|
self.recale_llm_max_tokens(self.model_config, prompt_messages)
|
|
|
|
|
# invoke model
|
|
|
|
|
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
|
|
|
|
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
|
|
|
|
prompt_messages=prompt_messages,
|
|
|
|
|
model_parameters=app_orchestration_config.model_config.parameters,
|
|
|
|
|
tools=prompt_messages_tools,
|
|
|
|
|
stop=app_orchestration_config.model_config.stop,
|
|
|
|
|
stream=True,
|
|
|
|
|
stream=self.stream_tool_call,
|
|
|
|
|
user=self.user_id,
|
|
|
|
|
callbacks=[],
|
|
|
|
|
)
|
|
|
|
|
@ -122,11 +123,41 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
|
|
|
|
|
|
|
current_llm_usage = None
|
|
|
|
|
|
|
|
|
|
for chunk in chunks:
|
|
|
|
|
if self.stream_tool_call:
|
|
|
|
|
for chunk in chunks:
|
|
|
|
|
# check if there is any tool call
|
|
|
|
|
if self.check_tool_calls(chunk):
|
|
|
|
|
function_call_state = True
|
|
|
|
|
tool_calls.extend(self.extract_tool_calls(chunk))
|
|
|
|
|
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
|
|
|
|
try:
|
|
|
|
|
tool_call_inputs = json.dumps({
|
|
|
|
|
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
|
|
|
|
}, ensure_ascii=False)
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
# ensure ascii to avoid encoding error
|
|
|
|
|
tool_call_inputs = json.dumps({
|
|
|
|
|
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if chunk.delta.message and chunk.delta.message.content:
|
|
|
|
|
if isinstance(chunk.delta.message.content, list):
|
|
|
|
|
for content in chunk.delta.message.content:
|
|
|
|
|
response += content.data
|
|
|
|
|
else:
|
|
|
|
|
response += chunk.delta.message.content
|
|
|
|
|
|
|
|
|
|
if chunk.delta.usage:
|
|
|
|
|
increase_usage(llm_usage, chunk.delta.usage)
|
|
|
|
|
current_llm_usage = chunk.delta.usage
|
|
|
|
|
|
|
|
|
|
yield chunk
|
|
|
|
|
else:
|
|
|
|
|
result: LLMResult = chunks
|
|
|
|
|
# check if there is any tool call
|
|
|
|
|
if self.check_tool_calls(chunk):
|
|
|
|
|
if self.check_blocking_tool_calls(result):
|
|
|
|
|
function_call_state = True
|
|
|
|
|
tool_calls.extend(self.extract_tool_calls(chunk))
|
|
|
|
|
tool_calls.extend(self.extract_blocking_tool_calls(result))
|
|
|
|
|
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
|
|
|
|
try:
|
|
|
|
|
tool_call_inputs = json.dumps({
|
|
|
|
|
@ -138,18 +169,44 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
|
|
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if chunk.delta.message and chunk.delta.message.content:
|
|
|
|
|
if isinstance(chunk.delta.message.content, list):
|
|
|
|
|
for content in chunk.delta.message.content:
|
|
|
|
|
if result.usage:
|
|
|
|
|
increase_usage(llm_usage, result.usage)
|
|
|
|
|
current_llm_usage = result.usage
|
|
|
|
|
|
|
|
|
|
if result.message and result.message.content:
|
|
|
|
|
if isinstance(result.message.content, list):
|
|
|
|
|
for content in result.message.content:
|
|
|
|
|
response += content.data
|
|
|
|
|
else:
|
|
|
|
|
response += chunk.delta.message.content
|
|
|
|
|
|
|
|
|
|
if chunk.delta.usage:
|
|
|
|
|
increase_usage(llm_usage, chunk.delta.usage)
|
|
|
|
|
current_llm_usage = chunk.delta.usage
|
|
|
|
|
response += result.message.content
|
|
|
|
|
|
|
|
|
|
if not result.message.content:
|
|
|
|
|
result.message.content = ''
|
|
|
|
|
|
|
|
|
|
yield LLMResultChunk(
|
|
|
|
|
model=model_instance.model,
|
|
|
|
|
prompt_messages=result.prompt_messages,
|
|
|
|
|
system_fingerprint=result.system_fingerprint,
|
|
|
|
|
delta=LLMResultChunkDelta(
|
|
|
|
|
index=0,
|
|
|
|
|
message=result.message,
|
|
|
|
|
usage=result.usage,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
yield chunk
|
|
|
|
|
if tool_calls:
|
|
|
|
|
prompt_messages.append(AssistantPromptMessage(
|
|
|
|
|
content='',
|
|
|
|
|
name='',
|
|
|
|
|
tool_calls=[AssistantPromptMessage.ToolCall(
|
|
|
|
|
id=tool_call[0],
|
|
|
|
|
type='function',
|
|
|
|
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
|
|
name=tool_call[1],
|
|
|
|
|
arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
|
|
|
|
)
|
|
|
|
|
) for tool_call in tool_calls]
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
# save thought
|
|
|
|
|
self.save_agent_thought(
|
|
|
|
|
@ -167,6 +224,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
|
|
|
|
|
|
|
final_answer += response + '\n'
|
|
|
|
|
|
|
|
|
|
# update prompt messages
|
|
|
|
|
if response.strip():
|
|
|
|
|
prompt_messages.append(AssistantPromptMessage(
|
|
|
|
|
content=response,
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
# call tools
|
|
|
|
|
tool_responses = []
|
|
|
|
|
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
|
|
|
|
@ -256,12 +319,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
|
|
)
|
|
|
|
|
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
|
|
|
|
|
|
|
|
|
# update prompt messages
|
|
|
|
|
if response.strip():
|
|
|
|
|
prompt_messages.append(AssistantPromptMessage(
|
|
|
|
|
content=response,
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
# update prompt tool
|
|
|
|
|
for prompt_tool in prompt_messages_tools:
|
|
|
|
|
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
|
|
|
|
@ -287,6 +344,14 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
|
|
if llm_result_chunk.delta.message.tool_calls:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
Check if there is any blocking tool call in llm result
|
|
|
|
|
"""
|
|
|
|
|
if llm_result.message.tool_calls:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
|
|
|
|
"""
|
|
|
|
|
@ -304,6 +369,23 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
return tool_calls
|
|
|
|
|
|
|
|
|
|
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
|
|
|
|
"""
|
|
|
|
|
Extract blocking tool calls from llm result
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
|
|
|
|
"""
|
|
|
|
|
tool_calls = []
|
|
|
|
|
for prompt_message in llm_result.message.tool_calls:
|
|
|
|
|
tool_calls.append((
|
|
|
|
|
prompt_message.id,
|
|
|
|
|
prompt_message.function.name,
|
|
|
|
|
json.loads(prompt_message.function.arguments),
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
return tool_calls
|
|
|
|
|
|
|
|
|
|
def organize_prompt_messages(self, prompt_template: str,
|
|
|
|
|
query: str = None,
|
|
|
|
|
|