pull/22446/merge
Kumbham Ajay Goud 10 months ago committed by GitHub
commit ece31290e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,7 +1,7 @@
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Literal, Union from typing import Any, Literal, Union
from pydantic import BaseModel from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector from core.tools.entities.tool_entities import ToolSelector
@ -17,6 +17,7 @@ class AgentNodeData(BaseNodeData):
class AgentInput(BaseModel): class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any] value: Union[list[str], list[ToolSelector], Any]
type: Literal["mixed", "variable", "constant"] type: Literal["mixed", "variable", "constant"]
description: str | None = Field(default=None, description="Optional description for this input parameter.")
agent_parameters: dict[str, AgentInput] agent_parameters: dict[str, AgentInput]

@ -1,6 +1,6 @@
from typing import Any, Literal, Union from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator, Field
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
@ -36,6 +36,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
# TODO: check this type # TODO: check this type
value: Union[Any, list[str]] value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"] type: Literal["mixed", "variable", "constant"]
description: str | None = Field(default=None, description="Optional description for this input parameter.")
@field_validator("type", mode="before") @field_validator("type", mode="before")
@classmethod @classmethod

@ -0,0 +1,11 @@
def test_agent_input_description():
from core.workflow.nodes.agent.entities import AgentNodeData
# Description provided
input_with_desc = AgentNodeData.AgentInput(value=["foo"], type="mixed", description="A test description.")
assert input_with_desc.description == "A test description."
# Description omitted
input_without_desc = AgentNodeData.AgentInput(value=["bar"], type="mixed")
assert input_without_desc.description is None
# Serialization
data = input_with_desc.model_dump()
assert data["description"] == "A test description."

@ -113,3 +113,16 @@ def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch):
assert "oops" in result.error assert "oops" in result.error
assert "Failed to transform tool message:" in result.error assert "Failed to transform tool message:" in result.error
assert result.error_type == "ToolInvokeError" assert result.error_type == "ToolInvokeError"
def test_tool_input_description():
from core.workflow.nodes.tool.entities import ToolNodeData
# Description provided
input_with_desc = ToolNodeData.ToolInput(value="foo", type="mixed", description="A test description.")
assert input_with_desc.description == "A test description."
# Description omitted
input_without_desc = ToolNodeData.ToolInput(value="bar", type="mixed")
assert input_without_desc.description is None
# Serialization
data = input_with_desc.model_dump()
assert data["description"] == "A test description."

Loading…
Cancel
Save