refactor(tool_node, agent_node): Refactors agent node message handling

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent fbfb7fa131
commit 4f7f37f398
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -11,8 +11,10 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter from core.agent.plugin_entities import AgentStrategyParameter
from core.agent.strategy.plugin import PluginAgentStrategy from core.agent.strategy.plugin import PluginAgentStrategy
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.plugin.entities.request import InvokeCredentials from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
@ -25,29 +27,46 @@ from core.tools.entities.tool_entities import (
ToolProviderType, ToolProviderType,
) )
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation from models.model import Conversation
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
AgentInputTypeError,
AgentInvocationError,
AgentMessageTransformError,
AgentVariableNotFoundError,
AgentVariableTypeError,
ToolFileNotFoundError,
)
class AgentNode(ToolNode): class AgentNode(BaseNode):
""" """
Agent Node Agent Node
""" """
_node_data_cls = AgentNodeData # type: ignore
_node_type = NodeType.AGENT _node_type = NodeType.AGENT
node_data: AgentNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = AgentNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -105,11 +124,12 @@ class AgentNode(ToolNode):
credentials=credentials, credentials=credentials,
) )
except Exception as e: except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
error=f"Failed to invoke agent: {str(e)}", error=str(error),
) )
) )
return return
@ -138,26 +158,26 @@ class AgentNode(ToolNode):
), ),
) )
def enhanced_message_stream():
yield thought_log_message
yield from message_stream
yield from self._transform_message( yield from self._transform_message(
message_stream, messages=message_stream,
{ tool_info={
"icon": self.agent_strategy_icon, "icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
}, },
parameters_for_log, parameters_for_log=parameters_for_log,
agent_thoughts, user_id=self.user_id,
tenant_id=self.tenant_id,
node_type=self.node_type,
node_id=self.node_id,
node_execution_id=self.id,
) )
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
error = AgentMessageTransformError(f"Failed to transform agent message: {str(e)}", original_error=e)
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
error=f"Failed to transform agent message: {str(e)}", error=str(error),
) )
) )
@ -194,7 +214,7 @@ class AgentNode(ToolNode):
if agent_input.type == "variable": if agent_input.type == "variable":
variable = variable_pool.get(agent_input.value) # type: ignore variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None: if variable is None:
raise ValueError(f"Variable {agent_input.value} does not exist") raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value parameter_value = variable.value
elif agent_input.type in {"mixed", "constant"}: elif agent_input.type in {"mixed", "constant"}:
# variable_pool.convert_template expects a string template, # variable_pool.convert_template expects a string template,
@ -216,7 +236,7 @@ class AgentNode(ToolNode):
except json.JSONDecodeError: except json.JSONDecodeError:
parameter_value = parameter_value parameter_value = parameter_value
else: else:
raise ValueError(f"Unknown agent input type '{agent_input.type}'") raise AgentInputTypeError(agent_input.type)
value = parameter_value value = parameter_value
if parameter.type == "array[tools]": if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value) value = cast(list[dict[str, Any]], value)
@ -448,3 +468,236 @@ class AgentNode(ToolNode):
return tools return tools
else: else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value] return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value]
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage: LLMUsage | None = None
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(msg_metadata)
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
if message.message.json_object is not None:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise AgentVariableTypeError(
"When 'stream' is True, 'variable_value' must be a string.",
variable_name=variable_name,
expected_type="str",
actual_type=type(variable_value).__name__,
)
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, File)
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
json_output.extend(json)
else:
json_output.append({"data": []})
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

@ -0,0 +1,124 @@
from typing import Optional
class AgentNodeError(Exception):
"""Base exception for all agent node errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class AgentStrategyError(AgentNodeError):
"""Exception raised when there's an error with the agent strategy."""
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
self.strategy_name = strategy_name
self.provider_name = provider_name
super().__init__(message)
class AgentStrategyNotFoundError(AgentStrategyError):
"""Exception raised when the specified agent strategy is not found."""
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
super().__init__(
f"Agent strategy '{strategy_name}' not found"
+ (f" for provider '{provider_name}'" if provider_name else ""),
strategy_name,
provider_name,
)
class AgentInvocationError(AgentNodeError):
"""Exception raised when there's an error invoking the agent."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
self.original_error = original_error
super().__init__(message)
class AgentParameterError(AgentNodeError):
"""Exception raised when there's an error with agent parameters."""
def __init__(self, message: str, parameter_name: Optional[str] = None):
self.parameter_name = parameter_name
super().__init__(message)
class AgentVariableError(AgentNodeError):
"""Exception raised when there's an error with variables in the agent node."""
def __init__(self, message: str, variable_name: Optional[str] = None):
self.variable_name = variable_name
super().__init__(message)
class AgentVariableNotFoundError(AgentVariableError):
"""Exception raised when a variable is not found in the variable pool."""
def __init__(self, variable_name: str):
super().__init__(f"Variable '{variable_name}' does not exist", variable_name)
class AgentInputTypeError(AgentNodeError):
"""Exception raised when an unknown agent input type is encountered."""
def __init__(self, input_type: str):
super().__init__(f"Unknown agent input type '{input_type}'")
class ToolFileError(AgentNodeError):
"""Exception raised when there's an error with a tool file."""
def __init__(self, message: str, file_id: Optional[str] = None):
self.file_id = file_id
super().__init__(message)
class ToolFileNotFoundError(ToolFileError):
"""Exception raised when a tool file is not found."""
def __init__(self, file_id: str):
super().__init__(f"Tool file '{file_id}' does not exist", file_id)
class AgentMessageTransformError(AgentNodeError):
"""Exception raised when there's an error transforming agent messages."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
self.original_error = original_error
super().__init__(message)
class AgentModelError(AgentNodeError):
"""Exception raised when there's an error with the model used by the agent."""
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
self.model_name = model_name
self.provider = provider
super().__init__(message)
class AgentMemoryError(AgentNodeError):
"""Exception raised when there's an error with the agent's memory."""
def __init__(self, message: str, conversation_id: Optional[str] = None):
self.conversation_id = conversation_id
super().__init__(message)
class AgentVariableTypeError(AgentNodeError):
"""Exception raised when a variable has an unexpected type."""
def __init__(
self,
message: str,
variable_name: Optional[str] = None,
expected_type: Optional[str] = None,
actual_type: Optional[str] = None,
):
self.variable_name = variable_name
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)

@ -1,12 +1,11 @@
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod from core.file import File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -19,7 +18,6 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
@ -132,6 +130,9 @@ class ToolNode(BaseNode):
messages=message_stream, messages=message_stream,
tool_info=tool_info, tool_info=tool_info,
parameters_for_log=parameters_for_log, parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_id=self.node_id,
) )
except (PluginDaemonClientSideError, ToolInvokeError) as e: except (PluginDaemonClientSideError, ToolInvokeError) as e:
yield RunCompletedEvent( yield RunCompletedEvent(
@ -199,7 +200,9 @@ class ToolNode(BaseNode):
messages: Generator[ToolInvokeMessage, None, None], messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any], tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any], parameters_for_log: dict[str, Any],
agent_thoughts: Optional[list] = None, user_id: str,
tenant_id: str,
node_id: str,
) -> Generator: ) -> Generator:
""" """
Convert ToolInvokeMessages into tuple[plain_text, files] Convert ToolInvokeMessages into tuple[plain_text, files]
@ -207,8 +210,8 @@ class ToolNode(BaseNode):
# transform message and handle file storage # transform message and handle file storage
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages, messages=messages,
user_id=self.user_id, user_id=user_id,
tenant_id=self.tenant_id, tenant_id=tenant_id,
conversation_id=None, conversation_id=None,
) )
@ -216,9 +219,6 @@ class ToolNode(BaseNode):
files: list[File] = [] files: list[File] = []
json: list[dict] = [] json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage: LLMUsage | None = None
variables: dict[str, Any] = {} variables: dict[str, Any] = {}
for message in message_stream: for message in message_stream:
@ -251,7 +251,7 @@ class ToolNode(BaseNode):
} }
file = file_factory.build_from_mapping( file = file_factory.build_from_mapping(
mapping=mapping, mapping=mapping,
tenant_id=self.tenant_id, tenant_id=tenant_id,
) )
files.append(file) files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB: elif message.type == ToolInvokeMessage.MessageType.BLOB:
@ -274,45 +274,36 @@ class ToolNode(BaseNode):
files.append( files.append(
file_factory.build_from_mapping( file_factory.build_from_mapping(
mapping=mapping, mapping=mapping,
tenant_id=self.tenant_id, tenant_id=tenant_id,
) )
) )
elif message.type == ToolInvokeMessage.MessageType.TEXT: elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text text += message.message.text
yield RunStreamChunkEvent( yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
)
elif message.type == ToolInvokeMessage.MessageType.JSON: elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage) assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT: # JSON message handling for tool node
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(msg_metadata)
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
if message.message.json_object is not None: if message.message.json_object is not None:
json.append(message.message.json_object) json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK: elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n" stream_text = f"Link: {message.message.text}\n"
text += stream_text text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE: elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage) assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name variable_name = message.message.variable_name
variable_value = message.message.variable_value variable_value = message.message.variable_value
if message.message.stream: if message.message.stream:
if not isinstance(variable_value, str): if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.") raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables: if variable_name not in variables:
variables[variable_name] = "" variables[variable_name] = ""
variables[variable_name] += variable_value variables[variable_name] += variable_value
yield RunStreamChunkEvent( yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
) )
else: else:
variables[variable_name] = variable_value variables[variable_name] = variable_value
@ -327,7 +318,7 @@ class ToolNode(BaseNode):
dict_metadata = dict(message.message.metadata) dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"): if dict_metadata.get("provider"):
manager = PluginInstaller() manager = PluginInstaller()
plugins = manager.list_plugins(self.tenant_id) plugins = manager.list_plugins(tenant_id)
try: try:
current_plugin = next( current_plugin = next(
plugin plugin
@ -342,8 +333,8 @@ class ToolNode(BaseNode):
builtin_tool = next( builtin_tool = next(
provider provider
for provider in BuiltinToolManageService.list_builtin_tools( for provider in BuiltinToolManageService.list_builtin_tools(
self.user_id, user_id,
self.tenant_id, tenant_id,
) )
if provider.name == dict_metadata["provider"] if provider.name == dict_metadata["provider"]
) )
@ -355,57 +346,10 @@ class ToolNode(BaseNode):
dict_metadata["icon"] = icon dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
node_execution_id=self.id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=self.node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
elif message.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
assert isinstance(message.message, ToolInvokeMessage.RetrieverResourceMessage)
yield RunRetrieverResourceEvent(
retriever_resources=message.message.retriever_resources,
context=message.message.context,
)
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process # Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = [] json_output: list[dict[str, Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict] # Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json: if json:
json_output.extend(json) json_output.extend(json)
@ -417,12 +361,9 @@ class ToolNode(BaseNode):
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={ metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
}, },
inputs=parameters_for_log, inputs=parameters_for_log,
llm_usage=llm_usage,
) )
) )

Loading…
Cancel
Save