feat: add variable to tool node config

pull/22036/head
Novice 12 months ago
parent 08024fe6de
commit c0684a40e4

@ -148,8 +148,6 @@ class Tool(ABC):
tool_parameter.default = parameter.default tool_parameter.default = parameter.default
tool_parameter.options = parameter.options tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description tool_parameter.llm_description = parameter.llm_description
if parameter.input_schema:
tool_parameter.input_schema = parameter.input_schema
break break
else: else:
# add new parameter # add new parameter

@ -4,7 +4,7 @@ import mimetypes
from collections.abc import Generator from collections.abc import Generator
from os import listdir, path from os import listdir, path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, Union, cast from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from yarl import URL from yarl import URL
@ -18,6 +18,7 @@ from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool from core.tools.plugin_tool.tool import PluginTool
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool
from services.tools.mcp_tools_mange_service import MCPToolManageService from services.tools.mcp_tools_mange_service import MCPToolManageService
if TYPE_CHECKING: if TYPE_CHECKING:
@ -307,6 +308,7 @@ class ToolManager:
app_id: str, app_id: str,
agent_tool: AgentToolEntity, agent_tool: AgentToolEntity,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional[VariablePool] = None,
) -> Tool: ) -> Tool:
""" """
get the agent tool runtime get the agent tool runtime
@ -321,24 +323,9 @@ class ToolManager:
) )
runtime_parameters = {} runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters() parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters: runtime_parameters = cls._convert_tool_parameters_type(
# check file types parameters, variable_pool, agent_tool.tool_parameters, typ="agent"
if ( )
parameter.type
in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}
and parameter.required
):
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# save tool parameter to tool entity memory
value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name))
runtime_parameters[parameter.name] = value
# decrypt runtime parameters # decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager( encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -362,10 +349,12 @@ class ToolManager:
node_id: str, node_id: str,
workflow_tool: "ToolEntity", workflow_tool: "ToolEntity",
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional[VariablePool] = None,
) -> Tool: ) -> Tool:
""" """
get the workflow tool runtime get the workflow tool runtime
""" """
tool_runtime = cls.get_tool_runtime( tool_runtime = cls.get_tool_runtime(
provider_type=workflow_tool.provider_type, provider_type=workflow_tool.provider_type,
provider_id=workflow_tool.provider_id, provider_id=workflow_tool.provider_id,
@ -374,15 +363,11 @@ class ToolManager:
invoke_from=invoke_from, invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW, tool_invoke_from=ToolInvokeFrom.WORKFLOW,
) )
runtime_parameters = {}
parameters = tool_runtime.get_merged_runtime_parameters()
for parameter in parameters:
# save tool parameter to tool entity memory
if parameter.form == ToolParameter.ToolParameterForm.FORM:
value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name))
runtime_parameters[parameter.name] = value
parameters = tool_runtime.get_merged_runtime_parameters()
runtime_parameters = cls._convert_tool_parameters_type(
parameters, variable_pool, workflow_tool.tool_configurations, typ="workflow"
)
# decrypt runtime parameters # decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager( encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -922,5 +907,53 @@ class ToolManager:
else: else:
raise ValueError(f"provider type {provider_type} not found") raise ValueError(f"provider type {provider_type} not found")
@classmethod
def _convert_tool_parameters_type(
cls,
parameters: list[ToolParameter],
variable_pool: Optional[VariablePool],
tool_configurations: dict[str, Any],
typ: Literal["agent", "workflow", "tool"] = "workflow",
) -> dict[str, Any]:
"""
Convert tool parameters type
"""
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.nodes.tool.exc import ToolParameterError
runtime_parameters = {}
for parameter in parameters:
if (
parameter.type
in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}
and parameter.required
and typ == "agent"
):
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
# save tool parameter to tool entity memory
if parameter.form == ToolParameter.ToolParameterForm.FORM:
if variable_pool:
tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.text
else:
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
runtime_parameters[parameter.name] = parameter_value
else:
value = parameter.init_frontend_parameter(tool_configurations.get(parameter.name))
runtime_parameters[parameter.name] = value
return runtime_parameters
ToolManager.load_hardcoded_providers_cache() ToolManager.load_hardcoded_providers_cache()

@ -213,9 +213,9 @@ class AgentNode(ToolNode):
) )
extra = tool.get("extra", {}) extra = tool.get("extra", {})
runtime_variable_pool = variable_pool if self.node_data.version != "1" else None
tool_runtime = ToolManager.get_agent_tool_runtime( tool_runtime = ToolManager.get_agent_tool_runtime(
self.tenant_id, self.app_id, entity, self.invoke_from self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
) )
if tool_runtime.entity.description: if tool_runtime.entity.description:
tool_runtime.entity.description.llm = ( tool_runtime.entity.description.llm = (

@ -68,6 +68,7 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
}, },
NodeType.TOOL: { NodeType.TOOL: {
LATEST_VERSION: ToolNode, LATEST_VERSION: ToolNode,
"2": ToolNode,
"1": ToolNode, "1": ToolNode,
}, },
NodeType.VARIABLE_AGGREGATOR: { NodeType.VARIABLE_AGGREGATOR: {
@ -117,6 +118,7 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
}, },
NodeType.AGENT: { NodeType.AGENT: {
LATEST_VERSION: AgentNode, LATEST_VERSION: AgentNode,
"2": AgentNode,
"1": AgentNode, "1": AgentNode,
}, },
} }

@ -62,8 +62,9 @@ class ToolNode(BaseNode[ToolNodeData]):
try: try:
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None
tool_runtime = ToolManager.get_workflow_tool_runtime( tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool
) )
except ToolNodeError as e: except ToolNodeError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
@ -90,7 +91,6 @@ class ToolNode(BaseNode[ToolNodeData]):
node_data=self.node_data, node_data=self.node_data,
for_log=True, for_log=True,
) )
# get conversation id # get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])

Loading…
Cancel
Save