diff --git a/api/core/file/constants.py b/api/core/file/constants.py index 81ad59f4c0..ce1d238e93 100644 --- a/api/core/file/constants.py +++ b/api/core/file/constants.py @@ -1,21 +1 @@ -from typing import Any - FILE_MODEL_IDENTITY = "__dify__file__" - -# DUMMY_OUTPUT_IDENTITY is a placeholder output for workflow nodes. -# Its sole possible value is `None`. -# -# This is used to signal the execution of a workflow node when it has no other outputs. -_DUMMY_OUTPUT_IDENTITY = "__dummy__" -_DUMMY_OUTPUT_VALUE: None = None - - -def add_dummy_output(original: dict[str, Any] | None) -> dict[str, Any]: - if original is None: - original = {} - original[_DUMMY_OUTPUT_IDENTITY] = _DUMMY_OUTPUT_VALUE - return original - - -def is_dummy_output_variable(name: str) -> bool: - return name == _DUMMY_OUTPUT_IDENTITY diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index fa2592842e..1f77be4eb8 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,4 +1,3 @@ -from core.file.constants import add_dummy_output from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunResult from core.workflow.nodes.base import BaseNode @@ -24,8 +23,5 @@ class StartNode(BaseNode[StartNodeData]): for var in system_inputs: node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] outputs = dict(node_inputs) - # Need special handling for `Start` node, as all other output variables - # are treated as systemd variables. - add_dummy_output(outputs) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 6d38e5a201..7133fd0c77 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -8,7 +8,6 @@ from configs import dify_config from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File -from core.variables import Variable from core.workflow.callbacks import WorkflowCallback from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunMetadataKey @@ -124,8 +123,8 @@ class WorkflowEntry: node_id: str, user_id: str, user_inputs: dict, - conversation_variables: dict | None = None, - variable_loader: VariableLoader = _DUMMY_VARIABLE_LOADER, + variable_pool: VariablePool, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: """ Single step run workflow node @@ -135,35 +134,22 @@ class WorkflowEntry: :param user_inputs: user inputs :return: """ - # fetch node info from workflow graph - workflow_graph = workflow.graph_dict - if not workflow_graph: - raise ValueError("workflow graph not found") - - nodes = workflow_graph.get("nodes") - if not nodes: - raise ValueError("nodes not found in workflow graph") - - # fetch node config from node id - try: - node_config = next(filter(lambda node: node["id"] == node_id, nodes)) - except StopIteration: - raise ValueError("node id not found in workflow graph") - + node_config = workflow.get_node_config_by_id(node_id) node_config_data = node_config.get("data", {}) # Get node class node_type = NodeType(node_config_data.get("type")) node_version = node_config_data.get("version", "1") + if node_type == NodeType.START: + # special handing for start node. + # + # 1. create conversation variables and system variables + # 2. create environment variables + pass + node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] metadata_attacher = _attach_execution_metadata_based_on_node_config(node_config_data) - # init variable pool - variable_pool = VariablePool( - environment_variables=workflow.environment_variables, - conversation_variable=conversation_variables or {}, - ) - # init graph graph = Graph.init(graph_config=workflow.graph_dict) @@ -196,16 +182,12 @@ class WorkflowEntry: # Loading missing variable from draft var here, and set it into # variable_pool. - variables_to_load: list[list[str]] = [] - for key, selector in variable_mapping.items(): - trimmed_key = key.removeprefix(f"{node_id}.") - if trimmed_key in user_inputs: - continue - if variable_pool.get(selector) is None: - variables_to_load.append(list(selector)) - loaded = variable_loader.load_variables(variables_to_load) - for var in loaded: - variable_pool.add(var.selector, var.value) + load_into_variable_pool( + variable_loader=variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=user_inputs, + ) cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, @@ -213,7 +195,6 @@ class WorkflowEntry: variable_pool=variable_pool, tenant_id=workflow.tenant_id, ) - cls._load_persisted_draft_var_and_populate_pool(app_id=workflow.app_id, variable_pool=variable_pool) try: # run node @@ -348,16 +329,6 @@ class WorkflowEntry: return value.to_dict() return value - @classmethod - def _load_persisted_draft_var_and_populate_pool(cls, app_id: str, variable_pool: VariablePool) -> None: - """ - Load persisted draft variables and populate the variable pool. - :param app_id: The application ID. - :param variable_pool: The variable pool to populate. - """ - # TODO(QuantumGhost): - pass - @classmethod def mapping_user_inputs_to_variable_pool( cls, @@ -367,6 +338,13 @@ class WorkflowEntry: variable_pool: VariablePool, tenant_id: str, ) -> None: + # NOTE(QuantumGhost): This logic should remain synchronized with + # the implementation of `load_into_variable_pool`, specifically the logic about + # variable existence checking. + + # WARNING(QuantumGhost): The semantics of this method are not clearly defined, + # and multiple parts of the codebase depend on its current behavior. + # Modify with caution. for node_variable, variable_selector in variable_mapping.items(): # fetch node id and variable key from node_variable node_variable_list = node_variable.split(".") diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 504d693742..b0dedd78e2 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -1,7 +1,7 @@ import dataclasses import logging from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, ClassVar from sqlalchemy import Engine, orm from sqlalchemy.dialects.postgresql import insert @@ -9,14 +9,16 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.expression import and_, or_ from core.app.entities.app_invoke_entities import InvokeFrom -from core.variables import Segment, Variable +from core.variables import Segment, StringSegment, Variable from core.variables.consts import MIN_SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.variable_loader import VariableLoader from factories import variable_factory from factories.variable_factory import build_segment, segment_to_variable -from models.workflow import WorkflowDraftVariable, is_system_variable_editable +from models import App, Conversation +from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable _logger = logging.getLogger(__name__) @@ -88,54 +90,6 @@ class WorkflowDraftVariableService: ) return variables - def save_output_variables(self, app_id: str, node_id: str, node_type: NodeType, output: Mapping[str, Any]): - variable_builder = _DraftVariableBuilder(app_id=app_id) - variable_builder.build(node_id=node_id, node_type=node_type, output=output) - draft_variables = variable_builder.get_variables() - # draft_variables = _build_variables_from_output_mapping(app_id, node_id, node_type, output) - if not draft_variables: - return - - # Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons: - # - # 1. The variable saving process involves writing multiple rows to the - # `workflow_draft_variables` table. Batch insertion significantly improves performance. - # 2. Using the ORM would require either: - # - # a. Checking for the existence of each variable before insertion, - # resulting in 2n SQL statements for n variables and potential concurrency issues. - # b. Attempting insertion first, then updating if a unique index violation occurs, - # which still results in n to 2n SQL statements. - # - # Both approaches are inefficient and suboptimal. - # 3. We do not need to retrieve the results of the SQL execution or populate ORM - # model instances with the returned values. - # 4. Batch insertion with `ON CONFLICT DO UPDATE` allows us to insert or update all - # variables in a single SQL statement, avoiding the issues above. - # - # For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific - # insert operations instead of the ORM layer. - if node_type == NodeType.CODE: - # Clear existing variable for code node. - self._session.query(WorkflowDraftVariable).filter( - WorkflowDraftVariable.app_id == app_id, - WorkflowDraftVariable.node_id == node_id, - ).delete(synchronize_session=False) - stmt = insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_variables]) - stmt = stmt.on_conflict_do_update( - index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(), - set_={ - "updated_at": stmt.excluded.updated_at, - "last_edited_at": stmt.excluded.last_edited_at, - "description": stmt.excluded.description, - "value_type": stmt.excluded.value_type, - "value": stmt.excluded.value, - "visible": stmt.excluded.visible, - "editable": stmt.excluded.editable, - }, - ) - self._session.execute(stmt) - def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList: criteria = WorkflowDraftVariable.app_id == app_id total = None @@ -224,6 +178,116 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.node_id == node_id, ).delete() + def _get_conversation_id_from_draft_variable(self, app_id: str) -> str | None: + draft_var = self._get_variable( + app_id=app_id, + node_id=SYSTEM_VARIABLE_NODE_ID, + name=str(SystemVariableKey.CONVERSATION_ID), + ) + if draft_var is None: + return None + segment = draft_var.get_value() + if not isinstance(segment, StringSegment): + _logger.warning( + "sys.conversation_id variable is not a string: app_id=%s, id=%s", + app_id, + draft_var.id, + ) + return None + return segment.value + + def create_conversation_and_set_conversation_variables( + self, + account_id: str, + app: App, + workflow: Workflow, + ) -> str: + conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id) + + if conv_id is not None: + conversation = ( + self._session.query(Conversation) + .filter( + Conversation.id == conv_id, + Conversation.app_id == workflow.app_id, + ) + .first() + ) + # Only return the conversation ID if it exists and is valid (has a correspond conversation record in DB). + if conversation is not None: + return conv_id + conversation = Conversation( + app_id=workflow.app_id, + app_model_config_id=app.app_model_config_id, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name="Draft Debugging Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=InvokeFrom.DEBUGGER.value, + from_source="console", + from_end_user_id=None, + from_account_id=account_id, + ) + + self._session.add(conversation) + self._session.flush() + draft_conv_vars: list[WorkflowDraftVariable] = [] + for conv_var in workflow.conversation_variables: + draft_var = WorkflowDraftVariable.new_conversation_variable( + app_id=workflow.app_id, + name=conv_var.name, + value=conv_var, + description=conv_var.description, + ) + draft_conv_vars.append(draft_var) + + _batch_upsert_draft_varaible(self._session, draft_conv_vars) + return conversation.id + + +def _batch_upsert_draft_varaible(session: Session, draft_vars: Sequence[WorkflowDraftVariable]): + if not draft_vars: + return + # Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons: + # + # 1. The variable saving process involves writing multiple rows to the + # `workflow_draft_variables` table. Batch insertion significantly improves performance. + # 2. Using the ORM would require either: + # + # a. Checking for the existence of each variable before insertion, + # resulting in 2n SQL statements for n variables and potential concurrency issues. + # b. Attempting insertion first, then updating if a unique index violation occurs, + # which still results in n to 2n SQL statements. + # + # Both approaches are inefficient and suboptimal. + # 3. We do not need to retrieve the results of the SQL execution or populate ORM + # model instances with the returned values. + # 4. Batch insertion with `ON CONFLICT DO UPDATE` allows us to insert or update all + # variables in a single SQL statement, avoiding the issues above. + # + # For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific + # insert operations instead of the ORM layer. + stmt = insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars]) + stmt = stmt.on_conflict_do_update( + index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(), + set_={ + "updated_at": stmt.excluded.updated_at, + "last_edited_at": stmt.excluded.last_edited_at, + "description": stmt.excluded.description, + "value_type": stmt.excluded.value_type, + "value": stmt.excluded.value, + "visible": stmt.excluded.visible, + "editable": stmt.excluded.editable, + }, + ) + session.execute(stmt) + def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]: d: dict[str, Any] = { @@ -248,30 +312,69 @@ def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]: return d -def should_save_output_variables_for_draft( - invoke_from: InvokeFrom, loop_id: str | None, iteration_id: str | None -) -> bool: - # Only save output variables for debugging execution of workflow. - if invoke_from != InvokeFrom.DEBUGGER: - return False - - # Currently we do not save output variables for nodes inside loop or iteration. - if loop_id is not None: - return False - if iteration_id is not None: - return False - return True +class DraftVariableSaver: + # _DUMMY_OUTPUT_IDENTITY is a placeholder output for workflow nodes. + # Its sole possible value is `None`. + # + # This is used to signal the execution of a workflow node when it has no other outputs. + _DUMMY_OUTPUT_IDENTITY: ClassVar[str] = "__dummy__" + _DUMMY_OUTPUT_VALUE: ClassVar[None] = None + # Database session used for persisting draft variables. + _session: Session -class _DraftVariableBuilder: + # The application ID associated with the draft variables. + # This should match the `Workflow.app_id` of the workflow to which the current node belongs. _app_id: str + + # The ID of the node for which DraftVariableSaver is saving output variables. + _node_id: str + + # The type of the current node (see NodeType). + _node_type: NodeType + + # Indicates how the workflow execution was triggered (see InvokeFrom). + _invoke_from: InvokeFrom + + # _enclosing_node_id identifies the container node that the current node belongs to. + # For example, if the current node is an LLM node inside an Iteration node + # or Loop node, then `_enclosing_node_id` refers to the ID of + # the containing Iteration or Loop node. + # + # If the current node is not nested within another node, `_enclosing_node_id` is + # `None`. + _enclosing_node_id: str | None + + # pending variables to save. _draft_vars: list[WorkflowDraftVariable] - def __init__(self, app_id: str): + def __init__( + self, + session: Session, + app_id: str, + node_id: str, + node_type: NodeType, + invoke_from: InvokeFrom, + enclosing_node_id: str | None = None, + ): + self._session = session self._app_id = app_id - self._draft_vars: list[WorkflowDraftVariable] = [] + self._node_id = node_id + self._node_type = node_type + self._invoke_from = invoke_from + self._enclosing_node_id = enclosing_node_id + + def _should_save_output_variables_for_draft(self) -> bool: + # Only save output variables for debugging execution of workflow. + if self._invoke_from != InvokeFrom.DEBUGGER: + return False + if self._enclosing_node_id is not None and self._node_type != NodeType.VARIABLE_ASSIGNER: + # Currently we do not save output variables for nodes inside loop or iteration. + return False + return True - def _build_from_variable_assigner_mapping(self, node_id: str, output: Mapping[str, Any]): + def _build_from_variable_assigner_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: + draft_vars: list[WorkflowDraftVariable] = [] updated_variables = output.get("updated_variables", []) for item in updated_variables: selector = item.get("selector") @@ -294,36 +397,37 @@ class _DraftVariableBuilder: var_seg = variable_factory.build_segment(new_value) if var_seg.value_type != value_type: raise Exception("value_type mismatch!") - self._draft_vars.append( + draft_vars.append( WorkflowDraftVariable.new_conversation_variable( app_id=self._app_id, name=name, value=var_seg, ) ) + return draft_vars - def _build_variables_from_start_mapping( - self, - node_id: str, - output: Mapping[str, Any], - ): - original_node_id = node_id + def _build_variables_from_start_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: + draft_vars = [] + has_non_sys_variables = False for name, value in output.items(): value_seg = variable_factory.build_segment(value) - node_id, name = self._normalize_variable_for_start_node(node_id, name) + node_id, name = self._normalize_variable_for_start_node(name) + # If node_id is not `sys`, it means that the variable is a user-defined input field + # in `Start` node. if node_id != SYSTEM_VARIABLE_NODE_ID: - self._draft_vars.append( + draft_vars.append( WorkflowDraftVariable.new_node_variable( app_id=self._app_id, - node_id=original_node_id, + node_id=self._node_id, name=name, value=value_seg, - visible=False, - editable=False, + visible=True, + editable=True, ) ) + has_non_sys_variables = True else: - self._draft_vars.append( + draft_vars.append( WorkflowDraftVariable.new_sys_variable( app_id=self._app_id, name=name, @@ -331,47 +435,57 @@ class _DraftVariableBuilder: editable=self._should_variable_be_editable(node_id, name), ) ) + if not has_non_sys_variables: + draft_vars.append( + WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + node_id=self._node_id, + name=self._DUMMY_OUTPUT_IDENTITY, + value=build_segment(self._DUMMY_OUTPUT_VALUE), + visible=False, + editable=False, + ) + ) + return draft_vars - @staticmethod - def _normalize_variable_for_start_node(node_id: str, name: str) -> tuple[str, str]: + def _normalize_variable_for_start_node(self, name: str) -> tuple[str, str]: if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."): - return node_id, name - node_id, name_ = name.split(".", maxsplit=1) - return node_id, name_ + return self._node_id, name + _, name_ = name.split(".", maxsplit=1) + return SYSTEM_VARIABLE_NODE_ID, name_ - def _build_variables_from_mapping( - self, - node_id: str, - node_type: NodeType, - output: Mapping[str, Any], - ): + def _build_variables_from_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: + draft_vars = [] for name, value in output.items(): value_seg = variable_factory.build_segment(value) - self._draft_vars.append( + draft_vars.append( WorkflowDraftVariable.new_node_variable( app_id=self._app_id, - node_id=node_id, + node_id=self._node_id, name=name, value=value_seg, - visible=self._should_variable_be_visible(node_type, node_id, name), + visible=self._should_variable_be_visible(self._node_id, self._node_type, name), ) ) + return draft_vars - def build( - self, - node_id: str, - node_type: NodeType, - output: Mapping[str, Any], - ): - if node_type == NodeType.VARIABLE_ASSIGNER: - self._build_from_variable_assigner_mapping(node_id, output) - elif node_type == NodeType.START: - self._build_variables_from_start_mapping(node_id, output) + def save(self, output: Mapping[str, Any] | None): + draft_vars: list[WorkflowDraftVariable] = [] + if output is None: + output = {} + if not self._should_save_output_variables_for_draft(): + return + if self._node_type == NodeType.VARIABLE_ASSIGNER: + draft_vars = self._build_from_variable_assigner_mapping(output) + elif self._node_type == NodeType.START: + draft_vars = self._build_variables_from_start_mapping(output) + elif self._node_type == NodeType.LOOP: + # Do not save output variables for loop node. + # (since the loop variables are inaccessible outside the loop node.) + return else: - self._build_variables_from_mapping(node_id, node_type, output) - - def get_variables(self) -> Sequence[WorkflowDraftVariable]: - return self._draft_vars + draft_vars = self._build_variables_from_mapping(output) + _batch_upsert_draft_varaible(self._session, draft_vars) @staticmethod def _should_variable_be_editable(node_id: str, name: str) -> bool: @@ -382,7 +496,7 @@ class _DraftVariableBuilder: return True @staticmethod - def _should_variable_be_visible(node_type: NodeType, node_id: str, name: str) -> bool: + def _should_variable_be_visible(node_id: str, node_type: NodeType, name: str) -> bool: if node_type in (NodeType.IF_ELSE, NodeType.START): return False if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name): diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 23e19235b6..c902721bdb 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,7 +1,8 @@ import json import logging import time -from collections.abc import Callable, Generator, Sequence +import uuid +from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime from typing import Any, Optional from uuid import uuid4 @@ -12,10 +13,13 @@ from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom +from core.file import File from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.variables import Variable from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes import NodeType @@ -27,6 +31,7 @@ from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_M from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db +from factories.variable_factory import segment_to_variable from libs import gen_utils from models.account import Account from models.model import App, AppMode @@ -43,9 +48,9 @@ from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError from .workflow_draft_variable_service import ( + DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService, - should_save_output_variables_for_draft, ) @@ -108,6 +113,8 @@ class WorkflowService: ) .first() ) + if not workflow: + return None if workflow.version == Workflow.VERSION_DRAFT: raise IsDraftWorkflowError(f"Workflow is draft version, id={workflow_id}") return workflow @@ -304,7 +311,14 @@ class WorkflowService: return default_config def run_draft_workflow_node( - self, app_model: App, node_id: str, user_inputs: dict, account: Account + self, + app_model: App, + node_id: str, + user_inputs: dict, + account: Account, + query: str = "", + files: list[File] | None = None, + conversation_id: str | None = None, ) -> WorkflowNodeExecution: """ Run draft workflow node @@ -319,16 +333,51 @@ class WorkflowService: with Session(bind=db.engine) as session: draft_var_srv = WorkflowDraftVariableService(session) - conv_vars_list = draft_var_srv.list_conversation_variables(app_id=app_model.id) - conv_var_mapping = {v.name: v.get_value().value for v in conv_vars_list.variables} + conv_vars_models = draft_var_srv.list_conversation_variables(app_id=app_model.id) + conv_vars = [ + segment_to_variable(segment=v.get_value(), id=v.id, selector=v.get_selector()) + for v in conv_vars_models.variables + ] + + node_config = draft_workflow.get_node_config_by_id(node_id) + node_type = NodeType(node_config.get("data", {}).get("type")) + if node_type == NodeType.START: + with Session(bind=db.engine) as session, session.begin(): + draft_var_srv = WorkflowDraftVariableService(session) + conversation_id = draft_var_srv.create_conversation_and_set_conversation_variables( + account_id=account.id, + app=app_model, + workflow=draft_workflow, + ) + + # init variable pool + variable_pool = _setup_variable_pool( + query=query, + files=files or [], + user_id=account.id, + user_inputs=user_inputs, + workflow=draft_workflow, + conversation_variables=conv_vars, + node_type=node_type, + conversation_id=conversation_id, + ) + + else: + variable_pool = VariablePool( + system_variables={}, + user_inputs=user_inputs, + environment_variables=draft_workflow.environment_variables, + conversation_variables=[], + ) variable_loader = DraftVarLoader(engine=db.engine, app_id=app_model.id) + run = WorkflowEntry.single_step_run( workflow=draft_workflow, node_id=node_id, user_inputs=user_inputs, user_id=account.id, - conversation_variables=conv_var_mapping, + variable_pool=variable_pool, variable_loader=variable_loader, ) @@ -358,24 +407,21 @@ class WorkflowService: exec_metadata = workflow_node_execution.execution_metadata_dict or {} - should_save = should_save_output_variables_for_draft( - invoke_from=InvokeFrom.DEBUGGER, - loop_id=exec_metadata.get(NodeRunMetadataKey.LOOP_ID, None), - iteration_id=exec_metadata.get(NodeRunMetadataKey.ITERATION_ID, None), - ) - if not should_save: - return workflow_node_execution + loop_id = exec_metadata.get(NodeRunMetadataKey.LOOP_ID, None) + iteration_id = exec_metadata.get(NodeRunMetadataKey.ITERATION_ID, None) + # TODO(QuantumGhost): single step does not include loop_id or iteration_id in execution_metadata. - with Session(bind=db.engine) as session: - draft_var_srv = WorkflowDraftVariableService(session) - draft_var_srv.save_output_variables( + with Session(bind=db.engine) as session, session.begin(): + draft_var_saver = DraftVariableSaver( + session=session, app_id=app_model.id, node_id=workflow_node_execution.node_id, node_type=NodeType(workflow_node_execution.node_type), - output=output, + invoke_from=InvokeFrom.DEBUGGER, + enclosing_node_id=loop_id or iteration_id or None, ) + draft_var_saver.save(output) session.commit() - return workflow_node_execution def run_free_workflow_node( @@ -616,3 +662,44 @@ class WorkflowService: session.delete(workflow) return True + + +def _setup_variable_pool( + query: str, + files: Sequence[File], + user_id: str, + user_inputs: Mapping[str, Any], + workflow: Workflow, + node_type: NodeType, + conversation_id: str, + conversation_variables: list[Variable], +): + # Only inject system variables for START node type. + if node_type == NodeType.START: + # Create a variable pool. + system_inputs = { + # From inputs: + SystemVariableKey.QUERY: query, + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, + # From sysvar + SystemVariableKey.CONVERSATION_ID: conversation_id, + SystemVariableKey.DIALOGUE_COUNT: 0, + # From workflow model + SystemVariableKey.APP_ID: workflow.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + # Randomly generated. + SystemVariableKey.WORKFLOW_RUN_ID: str(uuid.uuid4()), + } + else: + system_inputs = {} + + # init variable pool + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=user_inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) + + return variable_pool