feat(api): utilize `ConversationVariableUpdater` in variable assigner nodes

Removed the original logic
pull/20699/head
QuantumGhost 12 months ago
parent 1b234de81f
commit d720287504

@ -0,0 +1,39 @@
import abc
from typing import Protocol
from core.variables import Variable
class ConversationVariableUpdater(Protocol):
"""
ConversationVariableUpdater defines an abstraction for updating conversation variable values.
It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating
conversation variables.
Implementations may choose to batch updates. If batching is used, the `flush` method
should be implemented to persist buffered changes, and `update_conversation_variable`
should handle buffering accordingly.
Note: Since implementations may buffer updates, instances of ConversationVariableUpdater
are not thread-safe. Each VariableAssignerNode should create its own instance during execution.
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable") -> None:
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value.
"""
pass
@abc.abstractmethod
def flush(self):
"""
Flushes all pending updates to the underlying storage system.
If the implementation does not buffer updates, this method can be a no-op.
"""
pass

@ -1,26 +1,8 @@
from collections.abc import Sequence
from typing import Any, TypedDict
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables import Segment, SegmentType, Variable
from core.variables import Segment, SegmentType
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from extensions.ext_database import db
from models import ConversationVariable
def update_conversation_variable(conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(db.engine) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
class VariableOutput(TypedDict):

@ -0,0 +1,38 @@
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session
from core.variables.variables import Variable
from models.engine import db
from models.workflow import ConversationVariable
from .exc import VariableOperatorNodeError
class ConversationVariableUpdaterImpl:
_engine: Engine | None
def __init__(self, engine: Engine | None = None) -> None:
self._engine = engine
def _get_engine(self) -> Engine:
if self._engine:
return self._engine
return db.engine
def update(self, conversation_id: str, variable: Variable):
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)
with Session(self._get_engine()) as session:
row = session.scalar(stmt)
if not row:
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
def flush(self):
pass
def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
return ConversationVariableUpdaterImpl()

@ -1,18 +1,27 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Any, ClassVar, TypeAlias
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from factories import variable_factory
from models.workflow import WorkflowNodeExecutionStatus
from .node_data import VariableAssignerData, WriteMode
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls = VariableAssignerData
_node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: ClassVar[_CONV_VAR_UPDATER_FACTORY] = staticmethod(conversation_variable_updater_factory)
@classmethod
def version(cls) -> str:
@ -73,7 +82,9 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise VariableOperatorNodeError("conversation_id not found")
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
conv_var_updater = self._conv_var_updater_factory()
conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
conv_var_updater.flush()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,

@ -1,15 +1,17 @@
import json
from collections.abc import Sequence
from typing import Any, cast
from collections.abc import Callable, Mapping, Sequence
from typing import Any, ClassVar, TypeAlias, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from models.workflow import WorkflowNodeExecutionStatus
from . import helpers
@ -24,11 +26,15 @@ from .exc import (
VariableNotFoundError,
)
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: ClassVar[_CONV_VAR_UPDATER_FACTORY] = staticmethod(conversation_variable_updater_factory)
@classmethod
def version(cls) -> str:
return "2"
@ -136,6 +142,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
# remove the duplicated items first.
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
conv_var_updater = self._conv_var_updater_factory()
# Update variables
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
@ -150,10 +157,11 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value
common_helpers.update_conversation_variable(
conv_var_updater.update(
conversation_id=cast(str, conversation_id),
variable=variable,
)
conv_var_updater.flush()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,

Loading…
Cancel
Save