diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 075a41fb2f..af09139ec8 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,7 +1,7 @@ from enum import Enum, StrEnum from typing import Any, Literal, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector @@ -17,6 +17,7 @@ class AgentNodeData(BaseNodeData): class AgentInput(BaseModel): value: Union[list[str], list[ToolSelector], Any] type: Literal["mixed", "variable", "constant"] + description: str | None = Field(default=None, description="Optional description for this input parameter.") agent_parameters: dict[str, AgentInput] diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 88c5160d14..1d1c5c4678 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -1,6 +1,6 @@ from typing import Any, Literal, Union -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, Field from pydantic_core.core_schema import ValidationInfo from core.tools.entities.tool_entities import ToolProviderType @@ -36,6 +36,7 @@ class ToolNodeData(BaseNodeData, ToolEntity): # TODO: check this type value: Union[Any, list[str]] type: Literal["mixed", "variable", "constant"] + description: str | None = Field(default=None, description="Optional description for this input parameter.") @field_validator("type", mode="before") @classmethod diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py new file mode 100644 index 0000000000..f408ed603d --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py @@ -0,0 +1,11 @@ +def test_agent_input_description(): + from core.workflow.nodes.agent.entities import AgentNodeData + # Description provided + input_with_desc = AgentNodeData.AgentInput(value=["foo"], type="mixed", description="A test description.") + assert input_with_desc.description == "A test description." + # Description omitted + input_without_desc = AgentNodeData.AgentInput(value=["bar"], type="mixed") + assert input_without_desc.description is None + # Serialization + data = input_with_desc.model_dump() + assert data["description"] == "A test description." \ No newline at end of file diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 0eaabd0c40..4ba72fdc30 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -113,3 +113,16 @@ def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch): assert "oops" in result.error assert "Failed to transform tool message:" in result.error assert result.error_type == "ToolInvokeError" + + +def test_tool_input_description(): + from core.workflow.nodes.tool.entities import ToolNodeData + # Description provided + input_with_desc = ToolNodeData.ToolInput(value="foo", type="mixed", description="A test description.") + assert input_with_desc.description == "A test description." + # Description omitted + input_without_desc = ToolNodeData.ToolInput(value="bar", type="mixed") + assert input_without_desc.description is None + # Serialization + data = input_with_desc.model_dump() + assert data["description"] == "A test description."