refactor(api): Inject conv_var_updater_factory into v1.VariableAssignerNode

pull/20699/head
QuantumGhost 12 months ago
parent 83cd796b4d
commit 46a2476185

@ -1,5 +1,5 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Any, ClassVar, TypeAlias
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
@ -10,18 +10,44 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from factories import variable_factory
from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls = VariableAssignerData
_node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: ClassVar[_CONV_VAR_UPDATER_FACTORY] = staticmethod(conversation_variable_updater_factory)
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
self._conv_var_updater_factory = conv_var_updater_factory
@classmethod
def version(cls) -> str:

Loading…
Cancel
Save