diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 987f670acb..fbb32e8f02 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -15,7 +15,7 @@ from core.plugin.impl.plugin import PluginInstaller from core.provider_manager import ProviderManager from core.tools.entities.tool_entities import ToolParameter, ToolProviderType from core.tools.tool_manager import ToolManager -from core.variables.segments import StringSegment +from core.variables.segments import StringSegment, FileSegment, ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -29,6 +29,11 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from factories.agent_factory import get_plugin_agent_strategy from models.model import Conversation +from core.model_runtime.entities import ( + UserPromptMessage, + TextPromptMessageContent +) +from core.file import file_manager class AgentNode(ToolNode): @@ -268,6 +273,35 @@ class AgentNode(ToolNode): value["entity"] = model_schema.model_dump(mode="json") else: value["entity"] = None + + current_query_segment = variable_pool.get(["sys", SystemVariableKey.QUERY.value]) + current_files_segment = variable_pool.get(["sys", SystemVariableKey.FILES.value]) + + if current_query_segment is not None: + if isinstance(current_query_segment, StringSegment): + current_query = current_query_segment.value + + prompt_contents: list = [TextPromptMessageContent(data=current_query)] + + files: list = [] + if isinstance(current_files_segment, FileSegment): + files = [current_files_segment.value] + elif isinstance(current_files_segment, ArrayFileSegment): + files = list(current_files_segment.value) + + if files: + for f in files: + try: + prompt_contents.append( + file_manager.to_prompt_message_content( + f + ) + ) + except Exception: + continue + + synthetic_user_prompt = UserPromptMessage(content=prompt_contents if files else current_query) + history_prompt_messages.append(synthetic_user_prompt.model_dump(mode="json")) result[parameter_name] = value return result