feat(api): utilize `ConversationVariableUpdater` in variable assigner nodes

Removed the original logic
pull/20699/head
QuantumGhost 12 months ago
parent 1b234de81f
commit d720287504

@ -0,0 +1,39 @@
import abc
from typing import Protocol
from core.variables import Variable
class ConversationVariableUpdater(Protocol):
"""
ConversationVariableUpdater defines an abstraction for updating conversation variable values.
It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating
conversation variables.
Implementations may choose to batch updates. If batching is used, the `flush` method
should be implemented to persist buffered changes, and `update_conversation_variable`
should handle buffering accordingly.
Note: Since implementations may buffer updates, instances of ConversationVariableUpdater
are not thread-safe. Each VariableAssignerNode should create its own instance during execution.
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable") -> None:
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value.
"""
pass
@abc.abstractmethod
def flush(self):
"""
Flushes all pending updates to the underlying storage system.
If the implementation does not buffer updates, this method can be a no-op.
"""
pass

@ -1,26 +1,8 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, TypedDict from typing import Any, TypedDict
from sqlalchemy import select from core.variables import Segment, SegmentType
from sqlalchemy.orm import Session
from core.variables import Segment, SegmentType, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.consts import MIN_SELECTORS_LENGTH
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from extensions.ext_database import db
from models import ConversationVariable
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
class VariableOutput(TypedDict): class VariableOutput(TypedDict):

@ -0,0 +1,38 @@
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session
from core.variables.variables import Variable
from models.engine import db
from models.workflow import ConversationVariable
from .exc import VariableOperatorNodeError
class ConversationVariableUpdaterImpl:
_engine: Engine | None
def __init__(self, engine: Engine | None = None) -> None:
self._engine = engine
def _get_engine(self) -> Engine:
if self._engine:
return self._engine
return db.engine
def update(self, conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(self._get_engine()) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
def flush(self):
pass
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
return ConversationVariableUpdaterImpl()

@ -1,18 +1,27 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Any, ClassVar, TypeAlias
from core.variables import SegmentType, Variable from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from factories import variable_factory from factories import variable_factory
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from .node_data import VariableAssignerData, WriteMode from .node_data import VariableAssignerData, WriteMode
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(BaseNode[VariableAssignerData]): class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls = VariableAssignerData _node_data_cls = VariableAssignerData
_node_type = NodeType.VARIABLE_ASSIGNER _node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: ClassVar[_CONV_VAR_UPDATER_FACTORY] = staticmethod(conversation_variable_updater_factory)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -73,7 +82,9 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id: if not conversation_id:
raise VariableOperatorNodeError("conversation_id not found") raise VariableOperatorNodeError("conversation_id not found")
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) conv_var_updater = self._conv_var_updater_factory()
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
conv_var_updater.flush()
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,

@ -1,15 +1,17 @@
import json import json
from collections.abc import Sequence from collections.abc import Callable, Mapping, Sequence
from typing import Any, cast from typing import Any, ClassVar, TypeAlias, cast
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from . import helpers from . import helpers
@ -24,11 +26,15 @@ from .exc import (
VariableNotFoundError, VariableNotFoundError,
) )
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData _node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER _node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: ClassVar[_CONV_VAR_UPDATER_FACTORY] = staticmethod(conversation_variable_updater_factory)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "2" return "2"
@ -136,6 +142,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
# remove the duplicated items first. # remove the duplicated items first.
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
conv_var_updater = self._conv_var_updater_factory()
# Update variables # Update variables
for selector in updated_variable_selectors: for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)
@ -150,10 +157,11 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
raise ConversationIDNotFoundError raise ConversationIDNotFoundError
else: else:
conversation_id = conversation_id.value conversation_id = conversation_id.value
common_helpers.update_conversation_variable( conv_var_updater.update(
conversation_id=cast(str, conversation_id), conversation_id=cast(str, conversation_id),
variable=variable, variable=variable,
) )
conv_var_updater.flush()
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,

Loading…
Cancel
Save