|
|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
from collections.abc import Generator
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
from typing import Any, Union
|
|
|
|
|
|
|
|
|
|
from core.agent.base_agent_runner import BaseAgentRunner
|
|
|
|
|
@ -10,20 +11,21 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
|
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
|
|
AssistantPromptMessage,
|
|
|
|
|
PromptMessage,
|
|
|
|
|
PromptMessageContentType,
|
|
|
|
|
PromptMessageTool,
|
|
|
|
|
SystemPromptMessage,
|
|
|
|
|
TextPromptMessageContent,
|
|
|
|
|
ToolPromptMessage,
|
|
|
|
|
UserPromptMessage,
|
|
|
|
|
)
|
|
|
|
|
from core.tools.entities.tool_entities import ToolInvokeMeta
|
|
|
|
|
from core.tools.tool_engine import ToolEngine
|
|
|
|
|
from models.model import Conversation, Message, MessageAgentThought
|
|
|
|
|
from models.model import Message
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
class FunctionCallAgentRunner(BaseAgentRunner):
|
|
|
|
|
def run(self, conversation: Conversation,
|
|
|
|
|
message: Message,
|
|
|
|
|
def run(self, message: Message,
|
|
|
|
|
query: str,
|
|
|
|
|
) -> Generator[LLMResultChunk, None, None]:
|
|
|
|
|
"""
|
|
|
|
|
@ -35,11 +37,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|
|
|
|
|
|
|
|
|
prompt_template = app_config.prompt_template.simple_prompt_template or ''
|
|
|
|
|
prompt_messages = self.history_prompt_messages
|
|
|
|
|
prompt_messages = self.organize_prompt_messages(
|
|
|
|
|
prompt_template=prompt_template,
|
|
|
|
|
query=query,
|
|
|
|
|
prompt_messages=prompt_messages
|
|
|
|
|
)
|
|
|
|
|
prompt_messages = self._init_system_message(prompt_template, prompt_messages)
|
|
|
|
|
prompt_messages = self._organize_user_query(query, prompt_messages)
|
|
|
|
|
|
|
|
|
|
# convert tools into ModelRuntime Tool format
|
|
|
|
|
prompt_messages_tools: list[PromptMessageTool] = []
|
|
|
|
|
@ -68,7 +67,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|
|
|
|
|
|
|
|
|
# continue to run until there is not any tool call
|
|
|
|
|
function_call_state = True
|
|
|
|
|
agent_thoughts: list[MessageAgentThought] = []
|
|
|
|
|
llm_usage = {
|
|
|
|
|
'usage': None
|
|
|
|
|
}
|
|
|
|
|
@ -287,9 +285,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tool_responses.append(tool_response)
|
|
|
|
|
prompt_messages = self.organize_prompt_messages(
|
|
|
|
|
prompt_template=prompt_template,
|
|
|
|
|
query=None,
|
|
|
|
|
prompt_messages = self._organize_assistant_message(
|
|
|
|
|
tool_call_id=tool_call_id,
|
|
|
|
|
tool_call_name=tool_call_name,
|
|
|
|
|
tool_response=tool_response['tool_response'],
|
|
|
|
|
@ -324,6 +320,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|
|
|
|
|
|
|
|
|
iteration_step += 1
|
|
|
|
|
|
|
|
|
|
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
|
|
|
|
|
|
|
|
|
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
|
|
|
|
# publish end event
|
|
|
|
|
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
|
|
|
|
@ -386,29 +384,68 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|
|
|
|
|
|
|
|
|
return tool_calls
|
|
|
|
|
|
|
|
|
|
def organize_prompt_messages(self, prompt_template: str,
|
|
|
|
|
query: str = None,
|
|
|
|
|
tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
|
|
|
|
prompt_messages: list[PromptMessage] = None
|
|
|
|
|
) -> list[PromptMessage]:
|
|
|
|
|
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
|
|
|
|
"""
|
|
|
|
|
Organize prompt messages
|
|
|
|
|
Initialize system message
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if not prompt_messages:
|
|
|
|
|
prompt_messages = [
|
|
|
|
|
if not prompt_messages and prompt_template:
|
|
|
|
|
return [
|
|
|
|
|
SystemPromptMessage(content=prompt_template),
|
|
|
|
|
UserPromptMessage(content=query),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
|
|
|
|
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
|
|
|
|
|
|
|
|
|
return prompt_messages
|
|
|
|
|
|
|
|
|
|
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
|
|
|
|
"""
|
|
|
|
|
Organize user query
|
|
|
|
|
"""
|
|
|
|
|
if self.files:
|
|
|
|
|
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
|
|
|
|
for file_obj in self.files:
|
|
|
|
|
prompt_message_contents.append(file_obj.prompt_message_content)
|
|
|
|
|
|
|
|
|
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
|
|
|
|
else:
|
|
|
|
|
if tool_response:
|
|
|
|
|
prompt_messages = prompt_messages.copy()
|
|
|
|
|
prompt_messages.append(
|
|
|
|
|
ToolPromptMessage(
|
|
|
|
|
content=tool_response,
|
|
|
|
|
tool_call_id=tool_call_id,
|
|
|
|
|
name=tool_call_name,
|
|
|
|
|
)
|
|
|
|
|
prompt_messages.append(UserPromptMessage(content=query))
|
|
|
|
|
|
|
|
|
|
return prompt_messages
|
|
|
|
|
|
|
|
|
|
def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
|
|
|
|
prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
|
|
|
|
"""
|
|
|
|
|
Organize assistant message
|
|
|
|
|
"""
|
|
|
|
|
prompt_messages = deepcopy(prompt_messages)
|
|
|
|
|
|
|
|
|
|
if tool_response is not None:
|
|
|
|
|
prompt_messages.append(
|
|
|
|
|
ToolPromptMessage(
|
|
|
|
|
content=tool_response,
|
|
|
|
|
tool_call_id=tool_call_id,
|
|
|
|
|
name=tool_call_name,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return prompt_messages
|
|
|
|
|
|
|
|
|
|
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
|
|
|
|
"""
|
|
|
|
|
As for now, gpt supports both fc and vision at the first iteration.
|
|
|
|
|
We need to remove the image messages from the prompt messages at the first iteration.
|
|
|
|
|
"""
|
|
|
|
|
prompt_messages = deepcopy(prompt_messages)
|
|
|
|
|
|
|
|
|
|
for prompt_message in prompt_messages:
|
|
|
|
|
if isinstance(prompt_message, UserPromptMessage):
|
|
|
|
|
if isinstance(prompt_message.content, list):
|
|
|
|
|
prompt_message.content = '\n'.join([
|
|
|
|
|
content.data if content.type == PromptMessageContentType.TEXT else
|
|
|
|
|
'[image]' if content.type == PromptMessageContentType.IMAGE else
|
|
|
|
|
'[file]'
|
|
|
|
|
for content in prompt_message.content
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
return prompt_messages
|