feat: agent node custom tool input

feat/custom-tool-input
Novice Lee 1 year ago
parent 933b6abc13
commit 95eeb7b0d1

@ -8,12 +8,12 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.plugin.manager.exc import PluginDaemonClientSideError from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager from core.plugin.manager.plugin import PluginInstallationManager
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.agent.entities import AgentNodeData from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated
from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.event.event import RunCompletedEvent
@ -162,11 +162,28 @@ class AgentNode(ToolNode):
tool_value = [] tool_value = []
for tool in value: for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN.value)) provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN.value))
# handle the original settings
original_parameters = tool.get("parameters", {})
setting_params = tool.get("settings", {})
manual_input_params = []
# handle legacy data compatibility
if not all(isinstance(v, dict) for _, v in original_parameters.items()):
parameters = original_parameters
else:
params = {}
for key, param in original_parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN.value) == ParamsAutoGenerated.CLOSE.value:
params[key] = param.get("value", "")
manual_input_params.append(key)
else:
params[key] = None
settings = {k: v.get("value", None) for k, v in setting_params.items()}
parameters = {**params, **settings}
entity = AgentToolEntity( entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""), provider_id=tool.get("provider_name", ""),
provider_type=provider_type, provider_type=provider_type,
tool_name=tool.get("tool_name", ""), tool_name=tool.get("tool_name", ""),
tool_parameters=tool.get("parameters", {}), tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None), plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
) )
@ -179,14 +196,27 @@ class AgentNode(ToolNode):
tool_runtime.entity.description.llm = ( tool_runtime.entity.description.llm = (
extra.get("descrption", "") or tool_runtime.entity.description.llm extra.get("descrption", "") or tool_runtime.entity.description.llm
) )
for params in tool_runtime.entity.parameters:
tool_value.append( params.form = (
{ ToolParameter.ToolParameterForm.FORM
**tool_runtime.entity.model_dump(mode="json"), if params.name in manual_input_params
"runtime_parameters": tool_runtime.runtime.runtime_parameters, else params.form
"provider_type": provider_type.value, )
if tool_runtime.entity.parameters:
manual_input_value = {
key: value for key, value in parameters.items() if key in manual_input_params
} }
) runtime_parameters = {
**tool_runtime.runtime.runtime_parameters,
**manual_input_value,
}
tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"provider_type": provider_type.value,
}
)
value = tool_value value = tool_value
if parameter.type == "model-selector": if parameter.type == "model-selector":
value = cast(dict[str, Any], value) value = cast(dict[str, Any], value)

@ -1,3 +1,4 @@
from enum import Enum
from typing import Any, Literal, Union from typing import Any, Literal, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -16,3 +17,8 @@ class AgentNodeData(BaseNodeData):
type: Literal["mixed", "variable", "constant"] type: Literal["mixed", "variable", "constant"]
agent_parameters: dict[str, AgentInput] agent_parameters: dict[str, AgentInput]
class ParamsAutoGenerated(Enum):
CLOSE = 0
OPEN = 1

Loading…
Cancel
Save