From 46a2476185cbef89981d669d135ee7e94f499b4b Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Thu, 5 Jun 2025 16:00:46 +0800 Subject: [PATCH] refactor(api): Inject conv_var_updater_factory into v1.VariableAssignerNode --- .../nodes/variable_assigner/v1/node.py | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 2b2e9ff567..438c4a7d06 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Mapping, Sequence -from typing import Any, ClassVar, TypeAlias +from typing import TYPE_CHECKING, Any, Optional, TypeAlias from core.variables import SegmentType, Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID @@ -10,18 +10,44 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from factories import variable_factory +from ..common.impl import conversation_variable_updater_factory from .node_data import VariableAssignerData, WriteMode +if TYPE_CHECKING: + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState + + _CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] class VariableAssignerNode(BaseNode[VariableAssignerData]): _node_data_cls = VariableAssignerData _node_type = NodeType.VARIABLE_ASSIGNER - _conv_var_updater_factory: ClassVar[_CONV_VAR_UPDATER_FACTORY] = staticmethod(conversation_variable_updater_factory) + _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + self._conv_var_updater_factory = conv_var_updater_factory @classmethod def version(cls) -> str: