From d72028750425ecf20b3d22374187f977f3de0e1b Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Thu, 29 May 2025 02:28:23 +0800 Subject: [PATCH] feat(api): utilize `ConversationVariableUpdater` in variable assigner nodes Removed the original logic --- .../workflow/conversation_variable_updater.py | 39 +++++++++++++++++++ .../nodes/variable_assigner/common/helpers.py | 20 +--------- .../nodes/variable_assigner/common/impl.py | 38 ++++++++++++++++++ .../nodes/variable_assigner/v1/node.py | 13 ++++++- .../nodes/variable_assigner/v2/node.py | 14 +++++-- 5 files changed, 101 insertions(+), 23 deletions(-) create mode 100644 api/core/workflow/conversation_variable_updater.py create mode 100644 api/core/workflow/nodes/variable_assigner/common/impl.py diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py new file mode 100644 index 0000000000..808b65627d --- /dev/null +++ b/api/core/workflow/conversation_variable_updater.py @@ -0,0 +1,39 @@ +import abc +from typing import Protocol + +from core.variables import Variable + + +class ConversationVariableUpdater(Protocol): + """ + ConversationVariableUpdater defines an abstraction for updating conversation variable values. + + It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating + conversation variables. + + Implementations may choose to batch updates. If batching is used, the `flush` method + should be implemented to persist buffered changes, and `update_conversation_variable` + should handle buffering accordingly. + + Note: Since implementations may buffer updates, instances of ConversationVariableUpdater + are not thread-safe. Each VariableAssignerNode should create its own instance during execution. + """ + + @abc.abstractmethod + def update(self, conversation_id: str, variable: "Variable") -> None: + """ + Updates the value of the specified conversation variable in the underlying storage. + + :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. + :param variable: The `Variable` instance containing the updated value. + """ + pass + + @abc.abstractmethod + def flush(self): + """ + Flushes all pending updates to the underlying storage system. + + If the implementation does not buffer updates, this method can be a no-op. + """ + pass diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/core/workflow/nodes/variable_assigner/common/helpers.py index 2455b69025..f1b66f63ff 100644 --- a/api/core/workflow/nodes/variable_assigner/common/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/common/helpers.py @@ -1,26 +1,8 @@ from collections.abc import Sequence from typing import Any, TypedDict -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.variables import Segment, SegmentType, Variable +from core.variables import Segment, SegmentType from core.variables.consts import MIN_SELECTORS_LENGTH -from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from extensions.ext_database import db -from models import ConversationVariable - - -def update_conversation_variable(conversation_id: str, variable: Variable): - stmt = select(ConversationVariable).where( - ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id - ) - with Session(db.engine) as session: - row = session.scalar(stmt) - if not row: - raise VariableOperatorNodeError("conversation variable not found in the database") - row.data = variable.model_dump_json() - session.commit() class VariableOutput(TypedDict): diff --git a/api/core/workflow/nodes/variable_assigner/common/impl.py b/api/core/workflow/nodes/variable_assigner/common/impl.py new file mode 100644 index 0000000000..8f7a44bb62 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/common/impl.py @@ -0,0 +1,38 @@ +from sqlalchemy import Engine, select +from sqlalchemy.orm import Session + +from core.variables.variables import Variable +from models.engine import db +from models.workflow import ConversationVariable + +from .exc import VariableOperatorNodeError + + +class ConversationVariableUpdaterImpl: + _engine: Engine | None + + def __init__(self, engine: Engine | None = None) -> None: + self._engine = engine + + def _get_engine(self) -> Engine: + if self._engine: + return self._engine + return db.engine + + def update(self, conversation_id: str, variable: Variable): + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with Session(self._get_engine()) as session: + row = session.scalar(stmt) + if not row: + raise VariableOperatorNodeError("conversation variable not found in the database") + row.data = variable.model_dump_json() + session.commit() + + def flush(self): + pass + + +def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl: + return ConversationVariableUpdaterImpl() diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index b6a6a5f319..005d506a13 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,18 +1,27 @@ +from collections.abc import Callable, Mapping, Sequence +from typing import Any, ClassVar, TypeAlias + from core.variables import SegmentType, Variable +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.node_entities import NodeRunResult from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from factories import variable_factory from models.workflow import WorkflowNodeExecutionStatus from .node_data import VariableAssignerData, WriteMode +_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] + class VariableAssignerNode(BaseNode[VariableAssignerData]): _node_data_cls = VariableAssignerData _node_type = NodeType.VARIABLE_ASSIGNER + _conv_var_updater_factory: ClassVar[_CONV_VAR_UPDATER_FACTORY] = staticmethod(conversation_variable_updater_factory) @classmethod def version(cls) -> str: @@ -73,7 +82,9 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) if not conversation_id: raise VariableOperatorNodeError("conversation_id not found") - common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) + conv_var_updater = self._conv_var_updater_factory() + conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable) + conv_var_updater.flush() return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index d89e334206..096889a23c 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,15 +1,17 @@ import json -from collections.abc import Sequence -from typing import Any, cast +from collections.abc import Callable, Mapping, Sequence +from typing import Any, ClassVar, TypeAlias, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.node_entities import NodeRunResult from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from models.workflow import WorkflowNodeExecutionStatus from . import helpers @@ -24,11 +26,15 @@ from .exc import ( VariableNotFoundError, ) +_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] + class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_ASSIGNER + _conv_var_updater_factory: ClassVar[_CONV_VAR_UPDATER_FACTORY] = staticmethod(conversation_variable_updater_factory) + @classmethod def version(cls) -> str: return "2" @@ -136,6 +142,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): # remove the duplicated items first. updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) + conv_var_updater = self._conv_var_updater_factory() # Update variables for selector in updated_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(selector) @@ -150,10 +157,11 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): raise ConversationIDNotFoundError else: conversation_id = conversation_id.value - common_helpers.update_conversation_variable( + conv_var_updater.update( conversation_id=cast(str, conversation_id), variable=variable, ) + conv_var_updater.flush() return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED,