diff --git a/api/models/workflow.py b/api/models/workflow.py index a0f3a9990a..28b661670e 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -7,8 +7,11 @@ from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 from flask_login import current_user +from sqlalchemy import orm +from core.file.models import File from core.variables import utils as variable_utils +from core.variables.segments import ArrayFileSegment, FileSegment from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment @@ -874,6 +877,8 @@ class WorkflowDraftVariable(Base): __tablename__ = "workflow_draft_variables" __table_args__ = (UniqueConstraint(*unique_app_id_node_id_name()),) + # Required for instance variable annotation. + __allow_unmapped__ = True # id is the unique identifier of a draft variable. id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) @@ -939,6 +944,36 @@ class WorkflowDraftVariable(Base): visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) + # Cache for deserialized value + # + # NOTE(QuantumGhost): This field serves two purposes: + # + # 1. Caches deserialized values to reduce repeated parsing costs + # 2. Allows modification of the deserialized value after retrieval, + # particularly important for `File`` variables which require database + # lookups to obtain storage_key and other metadata + # + # Use double underscore prefix for better encapsulation, + # making this attribute harder to access from outside the class. + __value: Segment | None + + def __init__(self, *args, **kwargs): + """ + The constructor of `WorkflowDraftVariable` is not intended for + direct use outside this file. Its solo purpose is setup private state + used by the model instance. + + Please use the factory methods + (`new_conversation_variable`, `new_sys_variable`, `new_node_variable`) + defined below to create instances of this class. + """ + super().__init__(*args, **kwargs) + self.__value = None + + @orm.reconstructor + def _init_on_load(self): + self.__value = None + def get_selector(self) -> list[str]: selector = json.loads(self.selector) if not isinstance(selector, list): @@ -953,8 +988,39 @@ class WorkflowDraftVariable(Base): def _set_selector(self, value: list[str]): self.selector = json.dumps(value) + def _loads_value(self) -> Segment: + value = json.loads(self.value) + value_type = self.value_type + if value_type == SegmentType.FILE: + file = File.model_validate(value) + return FileSegment(value=file) + elif value_type == SegmentType.ARRAY_FILE: + files = [File.model_validate(i) for i in value] + return ArrayFileSegment(value=files) + else: + return build_segment(value) + def get_value(self) -> Segment: - return build_segment(json.loads(self.value)) + """Decode the serialized value into its corresponding `Segment` object. + + This method caches the result, so repeated calls will return the same + object instance without re-parsing the serialized data. + + If you need to modify the returned `Segment`, use `value.model_copy()` + to create a copy first to avoid affecting the cached instance. + + For more information about the caching mechanism, see the documentation + of the `__value` field. + + Returns: + Segment: The deserialized value as a Segment object. + """ + + if self.__value is not None: + return self.__value + value = self._loads_value() + self.__value = value + return value def set_name(self, name: str): self.name = name