From fbfb7fa131b8da0ed1a01e012889bfa04378e151 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 14 Jul 2025 18:10:54 +0800 Subject: [PATCH] refactor(ndoes): Refactors node data management in workflow engine Signed-off-by: -LAN- --- api/core/workflow/graph_engine/__init__.py | 3 ++- api/core/workflow/graph_engine/graph_engine.py | 4 ++-- api/core/workflow/nodes/answer/answer_node.py | 8 ++++++-- api/core/workflow/nodes/base/node.py | 18 ++++++------------ api/core/workflow/nodes/code/code_node.py | 8 ++++++-- .../workflow/nodes/document_extractor/node.py | 8 ++++++-- api/core/workflow/nodes/end/end_node.py | 11 +++++++++-- api/core/workflow/nodes/http_request/node.py | 8 ++++++-- .../workflow/nodes/if_else/if_else_node.py | 8 ++++++-- .../workflow/nodes/iteration/iteration_node.py | 8 ++++++-- .../nodes/iteration/iteration_start_node.py | 11 +++++++++-- api/core/workflow/nodes/list_operator/node.py | 10 +++++++--- api/core/workflow/nodes/llm/node.py | 12 +++++++----- api/core/workflow/nodes/loop/loop_end_node.py | 11 +++++++++-- api/core/workflow/nodes/loop/loop_node.py | 8 ++++++-- .../workflow/nodes/loop/loop_start_node.py | 11 +++++++++-- .../parameter_extractor_node.py | 7 +++++-- api/core/workflow/nodes/start/start_node.py | 11 +++++++++-- .../template_transform_node.py | 8 ++++++-- api/core/workflow/nodes/tool/tool_node.py | 14 +++++++++++--- .../variable_aggregator_node.py | 9 +++++++-- .../nodes/variable_assigner/v1/node.py | 8 ++++++-- .../nodes/variable_assigner/v2/node.py | 8 ++++++-- 23 files changed, 152 insertions(+), 60 deletions(-) diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py index 2fee3d7fad..12e1de464b 100644 --- a/api/core/workflow/graph_engine/__init__.py +++ b/api/core/workflow/graph_engine/__init__.py @@ -1,3 +1,4 @@ from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState +from .graph_engine import GraphEngine -__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] +__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 5a2915e2d3..701b90ae0d 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -265,7 +265,7 @@ class GraphEngine: previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None # init workflow run state - node_instance = node_cls( # type: ignore + node_instance = node_cls( id=route_node_state.id, config=node_config, graph_init_params=self.init_params, @@ -274,7 +274,7 @@ class GraphEngine: previous_node_id=previous_node_id, thread_pool_id=self.thread_pool_id, ) - node_instance = cast(BaseNode[BaseNodeData], node_instance) + node_instance.from_dict(node_config.get("data", {})) try: # run node generator = self._run_node( diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 38c2bcbdf5..f9e1bc4b04 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -16,10 +16,14 @@ from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser -class AnswerNode(BaseNode[AnswerNodeData]): - _node_data_cls = AnswerNodeData +class AnswerNode(BaseNode): _node_type = NodeType.ANSWER + node_data: AnswerNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = AnswerNodeData(**data) + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 6973401429..0fdf77c163 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,28 +1,21 @@ import logging from abc import abstractmethod -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast +from collections.abc import Callable, Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent -from .entities import BaseNodeData - if TYPE_CHECKING: + from core.workflow.graph_engine import Graph, GraphEngine, GraphInitParams, GraphRuntimeState from core.workflow.graph_engine.entities.event import InNodeEvent - from core.workflow.graph_engine.entities.graph import Graph - from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams - from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState logger = logging.getLogger(__name__) -GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData) - -class BaseNode(Generic[GenericNodeData]): - _node_data_cls: type[GenericNodeData] +class BaseNode: _node_type: ClassVar[NodeType] def __init__( @@ -57,7 +50,8 @@ class BaseNode(Generic[GenericNodeData]): self.node_id = node_id node_data = self._node_data_cls.model_validate(config.get("data", {})) - self.node_data = node_data + @abstractmethod + def from_dict(self, data: Mapping[str, Any]) -> None: ... @abstractmethod def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]: diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 1adabf7247..cdf3913197 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -21,10 +21,14 @@ from .exc import ( ) -class CodeNode(BaseNode[CodeNodeData]): - _node_data_cls = CodeNodeData +class CodeNode(BaseNode): _node_type = NodeType.CODE + node_data: CodeNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = CodeNodeData(**data) + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 8e6150f9cc..d62a1c6cf2 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -36,15 +36,19 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, logger = logging.getLogger(__name__) -class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): +class DocumentExtractorNode(BaseNode): """ Extracts text content from various file types. Supports plain text, PDF, and DOC/DOCX files. """ - _node_data_cls = DocumentExtractorNodeData _node_type = NodeType.DOCUMENT_EXTRACTOR + node_data: DocumentExtractorNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = DocumentExtractorNodeData(**data) + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 17a0b3adeb..c0d7e084c2 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping +from typing import Any + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -5,10 +8,14 @@ from core.workflow.nodes.end.entities import EndNodeData from core.workflow.nodes.enums import NodeType -class EndNode(BaseNode[EndNodeData]): - _node_data_cls = EndNodeData +class EndNode(BaseNode): _node_type = NodeType.END + node_data: EndNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = EndNodeData(**data) + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 971e0f73e7..f164794630 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -32,10 +32,14 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( logger = logging.getLogger(__name__) -class HttpRequestNode(BaseNode[HttpRequestNodeData]): - _node_data_cls = HttpRequestNodeData +class HttpRequestNode(BaseNode): _node_type = NodeType.HTTP_REQUEST + node_data: HttpRequestNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = HttpRequestNodeData(**data) + @classmethod def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: return { diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 22b748030c..0e36a2a135 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -13,10 +13,14 @@ from core.workflow.utils.condition.entities import Condition from core.workflow.utils.condition.processor import ConditionProcessor -class IfElseNode(BaseNode[IfElseNodeData]): - _node_data_cls = IfElseNodeData +class IfElseNode(BaseNode): _node_type = NodeType.IF_ELSE + node_data: IfElseNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = IfElseNodeData(**data) + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 8b566c83cd..d3ee642e81 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -56,14 +56,18 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class IterationNode(BaseNode[IterationNodeData]): +class IterationNode(BaseNode): """ Iteration Node. """ - _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION + node_data: IterationNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = IterationNodeData(**data) + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: return { diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 9900aa225d..d01382e37e 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping +from typing import Any + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -5,14 +8,18 @@ from core.workflow.nodes.enums import NodeType from core.workflow.nodes.iteration.entities import IterationStartNodeData -class IterationStartNode(BaseNode[IterationStartNodeData]): +class IterationStartNode(BaseNode): """ Iteration Start Node. """ - _node_data_cls = IterationStartNodeData _node_type = NodeType.ITERATION_START + node_data: IterationStartNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = IterationStartNodeData(**data) + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 3c9ba44cf1..cc4393769b 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,4 +1,4 @@ -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from typing import Any, Literal, Union from core.file import File @@ -13,10 +13,14 @@ from .entities import ListOperatorNodeData from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError -class ListOperatorNode(BaseNode[ListOperatorNodeData]): - _node_data_cls = ListOperatorNodeData +class ListOperatorNode(BaseNode): _node_type = NodeType.LIST_OPERATOR + node_data: ListOperatorNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = ListOperatorNodeData(**data) + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index be0675a0f2..cd2d4a7970 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -90,17 +90,16 @@ from .file_saver import FileSaverImpl, LLMFileSaver if TYPE_CHECKING: from core.file.models import File - from core.workflow.graph_engine.entities.graph import Graph - from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams - from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState logger = logging.getLogger(__name__) -class LLMNode(BaseNode[LLMNodeData]): - _node_data_cls = LLMNodeData +class LLMNode(BaseNode): _node_type = NodeType.LLM + node_data: LLMNodeData + # Instance attributes specific to LLMNode. # Output variable for file _file_outputs: list["File"] @@ -138,6 +137,9 @@ class LLMNode(BaseNode[LLMNodeData]): ) self._llm_file_saver = llm_file_saver + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = LLMNodeData(**data) + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index b144021bab..2657165d61 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping +from typing import Any + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -5,14 +8,18 @@ from core.workflow.nodes.enums import NodeType from core.workflow.nodes.loop.entities import LoopEndNodeData -class LoopEndNode(BaseNode[LoopEndNodeData]): +class LoopEndNode(BaseNode): """ Loop End Node. """ - _node_data_cls = LoopEndNodeData _node_type = NodeType.LOOP_END + node_data: LoopEndNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = LoopEndNodeData(**data) + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 20501d0317..004ec8bdf4 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -43,14 +43,18 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LoopNode(BaseNode[LoopNodeData]): +class LoopNode(BaseNode): """ Loop Node. """ - _node_data_cls = LoopNodeData _node_type = NodeType.LOOP + node_data: LoopNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = LoopNodeData(**data) + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index f5e38b7516..93db39ab12 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping +from typing import Any + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode @@ -5,14 +8,18 @@ from core.workflow.nodes.enums import NodeType from core.workflow.nodes.loop.entities import LoopStartNodeData -class LoopStartNode(BaseNode[LoopStartNodeData]): +class LoopStartNode(BaseNode): """ Loop Start Node. """ - _node_data_cls = LoopStartNodeData _node_type = NodeType.LOOP_START + node_data: LoopStartNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = LoopStartNodeData(**data) + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 25a534256b..0fa7d0147e 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -91,10 +91,13 @@ class ParameterExtractorNode(BaseNode): Parameter Extractor Node. """ - # FIXME: figure out why here is different from super class - _node_data_cls = ParameterExtractorNodeData # type: ignore _node_type = NodeType.PARAMETER_EXTRACTOR + node_data: ParameterExtractorNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = ParameterExtractorNodeData(**data) + _model_instance: Optional[ModelInstance] = None _model_config: Optional[ModelConfigWithCredentialsEntity] = None diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index e215591888..49efdfb317 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping +from typing import Any + from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -6,10 +9,14 @@ from core.workflow.nodes.enums import NodeType from core.workflow.nodes.start.entities import StartNodeData -class StartNode(BaseNode[StartNodeData]): - _node_data_cls = StartNodeData +class StartNode(BaseNode): _node_type = NodeType.START + node_data: StartNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = StartNodeData(**data) + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index ba573074c3..4ca1eb8a94 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -12,10 +12,14 @@ from core.workflow.nodes.template_transform.entities import TemplateTransformNod MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) -class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): - _node_data_cls = TemplateTransformNodeData +class TemplateTransformNode(BaseNode): _node_type = NodeType.TEMPLATE_TRANSFORM + node_data: TemplateTransformNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = TemplateTransformNodeData(**data) + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 3853a5d920..790c37fb36 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -37,14 +37,18 @@ from .exc import ( ) -class ToolNode(BaseNode[ToolNodeData]): +class ToolNode(BaseNode): """ Tool Node """ - _node_data_cls = ToolNodeData _node_type = NodeType.TOOL + node_data: ToolNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = ToolNodeData(**data) + @classmethod def version(cls) -> str: return "1" @@ -124,7 +128,11 @@ class ToolNode(BaseNode[ToolNodeData]): try: # convert tool messages - yield from self._transform_message(message_stream, tool_info, parameters_for_log) + yield from self._transform_message( + messages=message_stream, + tool_info=tool_info, + parameters_for_log=parameters_for_log, + ) except (PluginDaemonClientSideError, ToolInvokeError) as e: yield RunCompletedEvent( run_result=NodeRunResult( diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 96bb3e793a..08ec38e734 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,4 +1,5 @@ from collections.abc import Mapping +from typing import Any from core.variables.segments import Segment from core.workflow.entities.node_entities import NodeRunResult @@ -8,10 +9,14 @@ from core.workflow.nodes.enums import NodeType from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData -class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): - _node_data_cls = VariableAssignerNodeData +class VariableAggregatorNode(BaseNode): _node_type = NodeType.VARIABLE_AGGREGATOR + node_data: VariableAssignerNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = VariableAssignerNodeData(**data) + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 1864b13784..899372f5e5 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -22,11 +22,15 @@ if TYPE_CHECKING: _CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] -class VariableAssignerNode(BaseNode[VariableAssignerData]): - _node_data_cls = VariableAssignerData +class VariableAssignerNode(BaseNode): _node_type = NodeType.VARIABLE_ASSIGNER _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY + node_data: VariableAssignerData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = VariableAssignerData(**data) + def __init__( self, id: str, diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 9292da6f1c..daa962904b 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -54,10 +54,14 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ mapping[key] = selector -class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): - _node_data_cls = VariableAssignerNodeData +class VariableAssignerNode(BaseNode): _node_type = NodeType.VARIABLE_ASSIGNER + node_data: VariableAssignerNodeData + + def from_dict(self, data: Mapping[str, Any]) -> None: + self.node_data = VariableAssignerNodeData(**data) + def _conv_var_updater_factory(self) -> ConversationVariableUpdater: return conversation_variable_updater_factory()