From 838630c39e1d9031b163e7d55ac68ff2e07ceb80 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Thu, 17 Jul 2025 21:23:01 +0800 Subject: [PATCH 01/14] feat(api): Add model for workflow suspension --- .../workflow/entities/workflow_suspension.py | 28 ++++ ...b9ee0_add_workflowsuspension_model_add_.py | 51 +++++++ api/models/workflow.py | 142 ++++++++++++++++-- 3 files changed, 211 insertions(+), 10 deletions(-) create mode 100644 api/core/workflow/entities/workflow_suspension.py create mode 100644 api/migrations/versions/2025_07_17_2020-1091956b9ee0_add_workflowsuspension_model_add_.py diff --git a/api/core/workflow/entities/workflow_suspension.py b/api/core/workflow/entities/workflow_suspension.py new file mode 100644 index 0000000000..9304e6cadd --- /dev/null +++ b/api/core/workflow/entities/workflow_suspension.py @@ -0,0 +1,28 @@ +from enum import StrEnum +from uuid import UUID + +from pydantic import BaseModel, Field + +from libs.uuid_utils import uuidv7 + + +class StateVersion(StrEnum): + # `V1` is `GraphRuntimeState` serialized as JSON by dumping with Pydantic. + V1 = "v1" + + +class WorkflowSuspension(BaseModel): + id: UUID = Field(default_factory=uuidv7) + + # Correspond to WorkflowExecution.id_ + execution_id: str + + workflow_id: str + + continuation_node_id: str + + state: str + + state_version: StateVersion = StateVersion.V1 + + inputs: str diff --git a/api/migrations/versions/2025_07_17_2020-1091956b9ee0_add_workflowsuspension_model_add_.py b/api/migrations/versions/2025_07_17_2020-1091956b9ee0_add_workflowsuspension_model_add_.py new file mode 100644 index 0000000000..49182e6474 --- /dev/null +++ b/api/migrations/versions/2025_07_17_2020-1091956b9ee0_add_workflowsuspension_model_add_.py @@ -0,0 +1,51 @@ +"""Add WorkflowSuspension model, add suspension_id to WorkflowRun + +Revision ID: 1091956b9ee0 +Revises: 1c9ba48be8e4 +Create Date: 2025-07-17 20:20:43.710683 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1091956b9ee0' +down_revision = '1c9ba48be8e4' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow_suspensions', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False), + sa.Column('resumed_at', sa.DateTime(), nullable=True), + sa.Column('continuation_node_id', sa.String(length=255), nullable=False), + sa.Column('state_version', sa.String(length=20), nullable=False), + sa.Column('state', sa.Text(), nullable=False), + sa.Column('inputs', sa.Text(), nullable=True), + sa.Column('form_code', sa.String(length=32), nullable=False), + sa.PrimaryKeyConstraint('id', name=op.f('workflow_suspensions_pkey')), + sa.UniqueConstraint('form_code', name=op.f('workflow_suspensions_form_code_key')) + ) + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.add_column(sa.Column('suspension_id', models.types.StringUUID(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_column('suspension_id') + + op.drop_table('workflow_suspensions') + # ### end Alembic commands ### diff --git a/api/models/workflow.py b/api/models/workflow.py index 124fb3bb4c..8c153af741 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -14,10 +14,12 @@ from core.file.models import File from core.variables import utils as variable_utils from core.variables.variables import FloatVariable, IntegerVariable, StringVariable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.entities.workflow_execution import WorkflowExecutionStatus +from core.workflow.entities.workflow_suspension import StateVersion from core.workflow.nodes.enums import NodeType from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now -from libs.helper import extract_tenant_id +from libs.helper import extract_tenant_id, generate_string from ._workflow_exc import NodeNotFoundError, WorkflowDataError @@ -508,7 +510,10 @@ class WorkflowRun(Base): version: Mapped[str] = mapped_column(db.String(255)) graph: Mapped[Optional[str]] = mapped_column(db.Text) inputs: Mapped[Optional[str]] = mapped_column(db.Text) - status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded + status: Mapped[str] = mapped_column( + EnumText(WorkflowExecutionStatus, length=255), + nullable=False, + ) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") error: Mapped[Optional[str]] = mapped_column(db.Text) elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0")) @@ -520,6 +525,10 @@ class WorkflowRun(Base): finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) + # Represents the suspension details of a suspended workflow. + # This field is non-null when `status == SUSPENDED` and null otherwise. + suspension_id: Mapped[StringUUID] = mapped_column(StringUUID, nullable=True) + @property def created_by_account(self): created_by_role = CreatorUserRole(self.created_by_role) @@ -907,10 +916,6 @@ class ConversationVariable(Base): _EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"]) -def _naive_utc_datetime(): - return naive_utc_now() - - class WorkflowDraftVariable(Base): """`WorkflowDraftVariable` record variables and outputs generated during debugging worfklow or chatflow. @@ -941,14 +946,14 @@ class WorkflowDraftVariable(Base): created_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, - default=_naive_utc_datetime, + default=naive_utc_now, server_default=func.current_timestamp(), ) updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, - default=_naive_utc_datetime, + default=naive_utc_now, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), ) @@ -1173,8 +1178,8 @@ class WorkflowDraftVariable(Base): description: str = "", ) -> "WorkflowDraftVariable": variable = WorkflowDraftVariable() - variable.created_at = _naive_utc_datetime() - variable.updated_at = _naive_utc_datetime() + variable.created_at = naive_utc_now() + variable.updated_at = naive_utc_now() variable.description = description variable.app_id = app_id variable.node_id = node_id @@ -1254,3 +1259,120 @@ class WorkflowDraftVariable(Base): def is_system_variable_editable(name: str) -> bool: return name in _EDITABLE_SYSTEM_VARIABLE + + +_SUSPENSION_FORM_CODE_LENGTH = 22 + + +def _generate_suspension_form_code(): + return generate_string(_SUSPENSION_FORM_CODE_LENGTH) + + +class WorkflowSuspension(Base): + __tablename__ = "workflow_suspensions" + + # id is the unique identifier of a suspension + id: Mapped[str] = mapped_column( + StringUUID, + primary_key=True, + # NOTE: The server default acts as a fallback mechanism. + # The application generates the ID for new `WorkflowSuspension` records + # to streamline the insertion process and minimize database roundtrips. + server_default=db.text("uuidv7()"), + ) + + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + default=naive_utc_now, + server_default=func.current_timestamp(), + ) + + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + default=naive_utc_now, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + ) + + # `tenant_id` identifies the tenant associated with this suspension, + # corresponding to the `id` field in the `Tenant` model. + tenant_id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + ) + + # `app_id` represents the application identifier associated with this state. + # It corresponds to the `id` field in the `App` model. + # + # While this field is technically redundant (as the corresponding app can be + # determined by querying the `Workflow`), it is retained to simplify data + # cleanup and management processes. + app_id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + ) + + # `workflow_id` represents the unique identifier of the workflow associated with this suspension. + # It corresponds to the `id` field in the `Workflow` model. + # + # Since an application can have multiple versions of a workflow, each with its own unique ID, + # the `app_id` alone is insufficient to determine which workflow version should be loaded + # when resuming a suspended workflow. + workflow_id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + ) + + # `workflow_run_id` represents the identifier of the execution of workflow, + # correspond to the `id` field of `WorkflowNodeExecutionModel`. + workflow_run_id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + ) + + # `resumed_at` records the timestamp when the suspended workflow was resumed. + # It is set to `NULL` if the workflow has not been resumed. + resumed_at: Mapped[Optional[datetime]] = mapped_column( + sa.DateTime, + nullable=True, + default=sa.null, + ) + + # `continuation_node_id` specifies the next node to execute when the workflow resumes. + # + # Although this information is embedded within the `state` field, it is extracted + # into a separate field to facilitate debugging and data analysis. + continuation_node_id: Mapped[str] = mapped_column( + sa.String(length=255), + nullable=False, + ) + + # The version of the serialized execution state data. Currently, the only supported value is `v1`. + state_version: Mapped[StateVersion] = mapped_column( + EnumText(StateVersion), + nullable=False, + ) + + # `state` contains the serialized runtime state of the `GraphEngine`, + # capturing the workflow's execution context at the time of suspension. + # + # The value of `state` is a JSON-formatted string representing a JSON object (e.g., `{}`). + state: Mapped[str] = mapped_column(sa.Text, nullable=False) + + # The inputs provided by the user when resuming the suspended workflow. + # These inputs are serialized as a JSON-formatted string (e.g., `{}`). + # + # This field is `NULL` if no inputs were submitted by the user. + inputs: Mapped[str] = mapped_column(sa.Text, nullable=True) + + form_code: Mapped[str] = mapped_column( + # A 32-character string can store a base64-encoded value with 192 bits of entropy + # or a base62-encoded value with over 180 bits of entropy, providing sufficient + # uniqueness for most use cases. + sa.String(32), + nullable=False, + unique=True, + default=_generate_suspension_form_code, + ) From 51774589994a6ef85110faa4b8884505d7938185 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 12:12:05 +0800 Subject: [PATCH 02/14] feat(api): move `max_execution_steps` and `max_execution_time` to GraphInitParams. --- .../entities/graph_init_params.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/graph_engine/entities/graph_init_params.py b/api/core/workflow/graph_engine/entities/graph_init_params.py index a0ecd824f4..0bb7ed8a1b 100644 --- a/api/core/workflow/graph_engine/entities/graph_init_params.py +++ b/api/core/workflow/graph_engine/entities/graph_init_params.py @@ -1,14 +1,25 @@ from collections.abc import Mapping from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PositiveInt +from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from models.enums import UserFrom from models.workflow import WorkflowType class GraphInitParams(BaseModel): + """GraphInitParams encapsulates the configurations and contextual information + that remain constant throughout a single execution of the graph engine. + + A single execution is defined as follows: as long as the execution has not reached + its conclusion, it is considered one execution. For instance, if a workflow is suspended + and later resumed, it is still regarded as a single execution, not two. + + For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`. + """ + # init params tenant_id: str = Field(..., description="tenant / workspace id") app_id: str = Field(..., description="app id") @@ -19,3 +30,10 @@ class GraphInitParams(BaseModel): user_from: UserFrom = Field(..., description="user from, account or end-user") invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger") call_depth: int = Field(..., description="call depth") + + # max_execution_steps records the maximum steps allowed during the execution of a workflow. + max_execution_steps: PositiveInt = Field( + default=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, description="max_execution_steps" + ) + # max_execution_time records the max execution time for the workflow, measured in seconds + max_execution_time: PositiveInt = Field(default=dify_config.WORKFLOW_MAX_EXECUTION_TIME, description="") From 9d6774c87bf30497529d1308e3de1fe61db4464d Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 12:14:09 +0800 Subject: [PATCH 03/14] feat(api): introduce SUSPENDED status for workflow, add correspond events. --- api/core/app/entities/queue_entities.py | 8 ++ .../workflow/entities/workflow_execution.py | 97 +++++++++++++++++++ .../workflow/graph_engine/entities/event.py | 4 + 3 files changed, 109 insertions(+) diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 42e6a1519c..c8377a6cd1 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -29,6 +29,7 @@ class QueueEvent(StrEnum): WORKFLOW_SUCCEEDED = "workflow_succeeded" WORKFLOW_FAILED = "workflow_failed" WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded" + WORKFLOW_SUSPENDED = "workflow_suspended" ITERATION_START = "iteration_start" ITERATION_NEXT = "iteration_next" ITERATION_COMPLETED = "iteration_completed" @@ -326,6 +327,13 @@ class QueueWorkflowStartedEvent(AppQueueEvent): graph_runtime_state: GraphRuntimeState +class QueueWorkflowSuspendedEvent(AppQueueEvent): + event: QueueEvent = QueueEvent.WORKFLOW_SUSPENDED + # next_node_id records the next node to execute after resuming + # workflow. + next_node_id: str + + class QueueWorkflowSucceededEvent(AppQueueEvent): """ QueueWorkflowSucceededEvent entity diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py index 781be4b3c6..a43d75bbed 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/core/workflow/entities/workflow_execution.py @@ -23,12 +23,109 @@ class WorkflowType(StrEnum): class WorkflowExecutionStatus(StrEnum): + # State diagram for the workflw status: + # (@) means start, (*) means end + # + # ┌------------------>------------------------->------------------->--------------┐ + # | | + # | ┌-----------------------<--------------------┐ | + # ^ | | | + # | | ^ | + # | V | | + # ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V + # | Scheduled |------->| Running |---------------------->| Suspended | | + # └-----------┘ └-----------------------┘ └-----------┘ | + # | | | | | | | + # | | | | | | | + # ^ | | | V V | + # | | | | | ┌---------┐ | + # (@) | | | └------------------------>| Stopped |<----┘ + # | | | └---------┘ + # | | | | + # | | V V + # | | ┌-----------┐ | + # | | | Succeeded |------------->--------------┤ + # | | └-----------┘ | + # | V V + # | +--------┐ | + # | | Failed |---------------------->----------------┤ + # | └--------┘ | + # V V + # ┌---------------------┐ | + # | Partially Succeeded |---------------------->-----------------┘--------> (*) + # └---------------------┘ + # + # Mermaid diagram: + # + # --- + # title: State diagram for Workflow run state + # --- + # stateDiagram-v2 + # scheduled: Scheduled + # running: Running + # succeeded: Succeeded + # failed: Failed + # partial_succeeded: Partial Succeeded + # suspended: Suspended + # stopped: Stopped + # + # [*] --> scheduled: + # scheduled --> running: Start Execution + # running --> suspended: Human input required + # suspended --> running: human input added + # suspended --> stopped: User stops execution + # running --> succeeded: Execution finishes without any error + # running --> failed: Execution finishes with errors + # running --> stopped: User stops execution + # running --> partial_succeeded: some execution occurred and handled during execution + # + # scheduled --> stopped: User stops execution + # + # succeeded --> [*] + # failed --> [*] + # partial_succeeded --> [*] + # stopped --> [*] + + # `SCHEDULED` means that the workflow is scheduled to run, but has not + # started running yet. (maybe due to possible worker saturation.) + SCHEDULED = "scheduled" + + # `RUNNING` means the workflow is exeuting. RUNNING = "running" + + # `SUCCEEDED` means the execution of workflow succeed without any error. SUCCEEDED = "succeeded" + + # `FAILED` means the execution of workflow failed without some errors. FAILED = "failed" + + # `STOPPED` means the execution of workflow was stopped, either manually + # by the user, or automatically by the Dify application (E.G. the moderation + # mechanism.) STOPPED = "stopped" + + # `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow + # execution, but they were successfully handled (e.g., by using an error + # strategy such as "fail branch" or "default value"). PARTIAL_SUCCEEDED = "partial-succeeded" + # `SUSPENDED` indicates that the workflow execution is temporarily paused + # (e.g., awaiting human input) and is expected to resume later. + SUSPENDED = "suspended" + + def is_ended(self) -> bool: + return self in _END_STATE + + +_END_STATE = frozenset( + [ + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + WorkflowExecutionStatus.STOPPED, + ] +) + class WorkflowExecution(BaseModel): """ diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index e57e9e4d64..d44b4e0418 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -43,6 +43,10 @@ class GraphRunPartialSucceededEvent(BaseGraphEvent): outputs: Optional[dict[str, Any]] = None +class GraphRunSuspendedEvent(BaseGraphEvent): + next_node_id: str = Field(..., description="the next node id to execute while resumed.") + + ########################################### # Node Events ########################################### From 55c2c4a6b6e054a279a66bde2078747fec1cc866 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 12:16:49 +0800 Subject: [PATCH 04/14] feat(api): track routing information in RouteNodeState --- .../entities/runtime_route_state.py | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index f2d9c98936..8dad09a085 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -1,7 +1,7 @@ import uuid from datetime import UTC, datetime from enum import Enum -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel, Field @@ -44,6 +44,8 @@ class RouteNodeState(BaseModel): paused_by: Optional[str] = None """paused by""" + # The `index` is used used to record the execution order for a given node. + # Nodes executed ealier get smaller `index` values. index: int = 1 def set_finished(self, run_result: NodeRunResult) -> None: @@ -79,10 +81,25 @@ class RuntimeRouteState(BaseModel): default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)" ) + # A mapping from node_id to its routing state. node_state_mapping: dict[str, RouteNodeState] = Field( default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)" ) + next_node_id: Optional[str] = Field( + default=None, description="The next node id to run when resumed from suspension." + ) + + # If `previous_node_id` is not `None`, then the correspond node has state in the dict + # `node_state_mapping`. + previous_node_state_id: Optional[str] = Field(None, description="The state of last executed node.") + + _state_by_id: dict[str, RouteNodeState] + + def model_post_init(self, context: Any) -> None: + super().model_post_init(context) + self._state_by_id = {v.id: v for v in self.node_state_mapping.values()} + def create_node_state(self, node_id: str) -> RouteNodeState: """ Create node state @@ -91,6 +108,7 @@ class RuntimeRouteState(BaseModel): """ state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) self.node_state_mapping[state.id] = state + self._state_by_id[state.id] = state return state def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None: @@ -115,3 +133,18 @@ class RuntimeRouteState(BaseModel): return [ self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, []) ] + + # def get_node_state(self, node_id: str) -> RouteNodeState | None: + # return self.node_state_mapping.get(node_id) + + def get_previous_route_node_state(self) -> RouteNodeState | None: + if self.previous_node_state_id is None: + return None + return self._state_by_id[self.previous_node_state_id] + + @property + def previous_node_id(self): + if self.previous_node_state_id is None: + return None + state = self._state_by_id[self.previous_node_state_id] + return state.node_id From d99ad77837f413244682786ef423593001212f2f Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 12:17:35 +0800 Subject: [PATCH 05/14] feat(api): tracking execution time in GraphRuntimeState use wall clock for time measurement (As the execution may be continued on another node) --- .../workflow/graph_engine/_engine_utils.py | 15 +++++++ .../entities/graph_runtime_state.py | 45 +++++++++++++++++-- .../entities/runtime_route_state.py | 8 ++-- .../nodes/iteration/iteration_node.py | 2 +- api/core/workflow/nodes/loop/loop_node.py | 2 +- api/core/workflow/workflow_entry.py | 41 ++++++++++------- .../workflow/nodes/test_code.py | 5 ++- .../workflow/nodes/test_http.py | 5 ++- .../workflow/nodes/test_llm.py | 5 ++- .../nodes/test_parameter_extractor.py | 5 ++- .../workflow/nodes/test_template_transform.py | 5 ++- .../workflow/nodes/test_tool.py | 5 ++- .../entities/test_graph_runtime_state.py | 2 - .../graph_engine/test_graph_engine.py | 20 +++++---- .../nodes/iteration/test_iteration.py | 21 ++++++--- .../core/workflow/nodes/test_answer.py | 5 ++- .../workflow/nodes/test_continue_on_error.py | 5 ++- .../core/workflow/nodes/test_if_else.py | 9 ++-- .../v1/test_variable_assigner_v1.py | 7 ++- .../v2/test_variable_assigner_v2.py | 9 ++-- 20 files changed, 154 insertions(+), 67 deletions(-) create mode 100644 api/core/workflow/graph_engine/_engine_utils.py diff --git a/api/core/workflow/graph_engine/_engine_utils.py b/api/core/workflow/graph_engine/_engine_utils.py new file mode 100644 index 0000000000..28898268fe --- /dev/null +++ b/api/core/workflow/graph_engine/_engine_utils.py @@ -0,0 +1,15 @@ +import time + + +def get_timestamp() -> float: + """Retrieve a timestamp as a float point numer representing the number of seconds + since the Unix epoch. + + This function is primarily used to measure the execution time of the workflow engine. + Since workflow execution may be paused and resumed on a different machine, + `time.perf_counter` cannot be used as it is inconsistent across machines. + + To address this, the function uses the wall clock as the time source. + However, it assumes that the clocks of all servers are properly synchronized. + """ + return round(time.time()) diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index a62ffe46c9..304a827b30 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -1,18 +1,32 @@ from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine._engine_utils import get_timestamp from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState +_SECOND_TO_US = 1_000_000 + class GraphRuntimeState(BaseModel): + """`GraphRuntimeState` encapsulates the runtime state of workflow execution, + including scheduling details, variable values, and timing information. + + Values that are initialized prior to workflow execution and remain constant + throughout the execution should be part of `GraphInitParams` instead. + """ + variable_pool: VariablePool = Field(..., description="variable pool") """variable pool""" - start_at: float = Field(..., description="start time") - """start time""" + # The `start_at` field records the execution start time of the workflow. + # + # This field is automatically generated, and its value or interpretation may evolve. + # Avoid manually setting this field to ensure compatibility with future updates. + start_at: float = Field(description="start time", default_factory=get_timestamp) + total_tokens: int = 0 """total tokens""" llm_usage: LLMUsage = LLMUsage.empty_usage() @@ -29,3 +43,28 @@ class GraphRuntimeState(BaseModel): node_run_state: RuntimeRouteState = RuntimeRouteState() """node run state""" + + # `execution_time_us` tracks the total execution time of the workflow in microseconds. + # Time spent in suspension is excluded from this calculation. + # + # This field is used to persist the time already spent while suspending a workflow. + execution_time_us: int = 0 + + # `_last_execution_started_at` records the timestamp of the most recent resume start. + # It is updated when the workflow resumes from a suspended state. + _last_execution_started_at: float = PrivateAttr(default_factory=get_timestamp) + + def is_timed_out(self, max_execution_time_seconds: int) -> bool: + """Checks if the workflow execution has exceeded the specified `max_execution_time_seconds`.""" + remaining_time_us = max_execution_time_seconds * _SECOND_TO_US - self.execution_time_us + if remaining_time_us <= 0: + return False + return int(get_timestamp() - self._last_execution_started_at) * _SECOND_TO_US > remaining_time_us + + def record_suspend_state(self, next_node_id: str): + """Record the time already spent in executing workflow. + + This function should be called when suspending the workflow. + """ + self.execution_time_us = int(get_timestamp() - self._last_execution_started_at) * _SECOND_TO_US + self.node_run_state.next_node_id = next_node_id diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index 8dad09a085..f8f1679110 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -3,7 +3,7 @@ from datetime import UTC, datetime from enum import Enum from typing import Any, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -92,9 +92,11 @@ class RuntimeRouteState(BaseModel): # If `previous_node_id` is not `None`, then the correspond node has state in the dict # `node_state_mapping`. - previous_node_state_id: Optional[str] = Field(None, description="The state of last executed node.") + previous_node_state_id: Optional[str] = Field(default=None, description="The state of last executed node.") - _state_by_id: dict[str, RouteNodeState] + # `_state_by_id` serves as a mapping from the unique identifier (`id`) of each `RouteNodeState` + # instance to the corresponding `RouteNodeState` object itself. + _state_by_id: dict[str, RouteNodeState] = PrivateAttr(default={}) def model_post_init(self, context: Any) -> None: super().model_post_init(context) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 5842c8d64b..84a5732fdd 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -160,7 +160,7 @@ class IterationNode(BaseNode): from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool) graph_engine = GraphEngine( tenant_id=self.tenant_id, diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 655de9362f..53ac04d6dc 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -124,7 +124,7 @@ class LoopNode(BaseNode): from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool) graph_engine = GraphEngine( tenant_id=self.tenant_id, diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index d2375da39c..224e8036f7 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,11 +1,10 @@ import logging -import time import uuid from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast from configs import dify_config -from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File from core.workflow.callbacks import WorkflowCallback @@ -70,7 +69,8 @@ class WorkflowEntry: raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) # init workflow run state - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool) + self.graph_engine = GraphEngine( tenant_id=tenant_id, app_id=app_id, @@ -146,7 +146,7 @@ class WorkflowEntry: graph = Graph.init(graph_config=workflow.graph_dict) # init workflow run state - node = node_cls( + node_instance = node_cls( id=str(uuid.uuid4()), config=node_config, graph_init_params=GraphInitParams( @@ -161,7 +161,7 @@ class WorkflowEntry: call_depth=0, ), graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool), ) try: @@ -190,11 +190,17 @@ class WorkflowEntry: try: # run node - generator = node.run() + generator = node_instance.run() except Exception as e: - logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}") - raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) - return node, generator + logger.exception( + "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", + workflow.id, + node_instance.id, + node_instance.node_type, + node_instance.version(), + ) + raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + return node_instance, generator @classmethod def run_free_node( @@ -256,7 +262,7 @@ class WorkflowEntry: node_cls = cast(type[BaseNode], node_cls) # init workflow run state - node: BaseNode = node_cls( + node_instance: BaseNode = node_cls( id=str(uuid.uuid4()), config=node_config, graph_init_params=GraphInitParams( @@ -271,7 +277,7 @@ class WorkflowEntry: call_depth=0, ), graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool), ) try: @@ -291,12 +297,17 @@ class WorkflowEntry: ) # run node - generator = node.run() + generator = node_instance.run() - return node, generator + return node_instance, generator except Exception as e: - logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}") - raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) + logger.exception( + "error while running node_instance, node_id=%s, type=%s, version=%s", + node_instance.id, + node_instance.node_type, + node_instance.version(), + ) + raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 707b28e6d8..682ac55352 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -1,4 +1,3 @@ -import time import uuid from os import getenv from typing import cast @@ -62,7 +61,9 @@ def init_code_node(code_config: dict): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + ), config=code_config, ) diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index d7856129a3..987f81bdd7 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -1,4 +1,3 @@ -import time import uuid from urllib.parse import urlencode @@ -56,7 +55,9 @@ def init_http_node(config: dict): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + ), config=config, ) diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index a14791bc67..02825470d9 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -1,5 +1,4 @@ import json -import time import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch @@ -73,7 +72,9 @@ def init_llm_node(config: dict) -> LLMNode: id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + ), config=config, ) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index edd70193a8..f2ea5759a3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -1,5 +1,4 @@ import os -import time import uuid from typing import Optional from unittest.mock import MagicMock @@ -78,7 +77,9 @@ def init_parameter_extractor_node(config: dict): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + ), config=config, ) node.init_node_data(config.get("data", {})) diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index f71a5ee140..f5c7397511 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,4 +1,3 @@ -import time import uuid import pytest @@ -73,7 +72,9 @@ def test_execute_code(setup_code_executor_mock): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + ), config=config, ) node.init_node_data(config.get("data", {})) diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 8476c1f874..d3eaa16acd 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -1,4 +1,3 @@ -import time import uuid from unittest.mock import MagicMock @@ -54,7 +53,9 @@ def init_tool_node(config: dict): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + ), config=config, ) node.init_node_data(config.get("data", {})) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py index cf7cee8710..cf2dbcceb3 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py @@ -49,7 +49,6 @@ def create_test_graph_runtime_state() -> GraphRuntimeState: return GraphRuntimeState( variable_pool=variable_pool, - start_at=time.perf_counter(), total_tokens=100, llm_usage=llm_usage, outputs={ @@ -106,7 +105,6 @@ def test_empty_outputs_round_trip(): variable_pool = VariablePool.empty() original_state = GraphRuntimeState( variable_pool=variable_pool, - start_at=time.perf_counter(), outputs={}, # Empty outputs ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index ed4e42425e..c8a55538aa 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -175,8 +175,9 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): user_inputs={"query": "hi"}, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - graph_engine = GraphEngine( + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + ) tenant_id="111", app_id="222", workflow_type=WorkflowType.WORKFLOW, @@ -303,8 +304,9 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): user_inputs={}, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - graph_engine = GraphEngine( + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + ) tenant_id="111", app_id="222", workflow_type=WorkflowType.CHAT, @@ -484,8 +486,9 @@ def test_run_branch(mock_close, mock_remove): user_inputs={"uid": "takato"}, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - graph_engine = GraphEngine( + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + ) tenant_id="111", app_id="222", workflow_type=WorkflowType.CHAT, @@ -823,8 +826,9 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app): user_inputs={"query": "hi"}, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - graph_engine = GraphEngine( + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + ) tenant_id="111", app_id="222", workflow_type=WorkflowType.CHAT, diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index f53f391433..8f782bab03 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -1,4 +1,3 @@ -import time import uuid from unittest.mock import patch @@ -179,7 +178,9 @@ def test_run(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + ), config=node_config, ) @@ -401,7 +402,9 @@ def test_run_parallel(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + ), config=node_config, ) @@ -623,7 +626,9 @@ def test_iteration_run_in_parallel_mode(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + ), config=parallel_node_config, ) @@ -647,7 +652,9 @@ def test_iteration_run_in_parallel_mode(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + ), config=sequential_node_config, ) @@ -857,7 +864,9 @@ def test_iteration_run_error_handle(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + ), config=error_node_config, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 466d7bad06..5ec32b55eb 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -1,4 +1,3 @@ -import time import uuid from unittest.mock import MagicMock @@ -74,7 +73,9 @@ def test_execute_answer(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + ), config=node_config, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index 3f83428834..499457f9cc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -1,4 +1,3 @@ -import time from unittest.mock import patch from core.app.entities.app_invoke_entities import InvokeFrom @@ -175,7 +174,9 @@ class ContinueOnErrorTestHelper: ), user_inputs=user_inputs or {"uid": "takato"}, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + ) return GraphEngine( tenant_id="111", diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 8383aee0e4..05b2176548 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -1,4 +1,3 @@ -import time import uuid from unittest.mock import MagicMock, Mock @@ -106,7 +105,9 @@ def test_execute_if_else_result_true(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + ), config=node_config, ) @@ -192,7 +193,9 @@ def test_execute_if_else_result_false(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + ), config=node_config, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index ee51339427..8b02360e48 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -1,4 +1,3 @@ -import time import uuid from unittest import mock from uuid import uuid4 @@ -96,7 +95,7 @@ def test_overwrite_string_variable(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool), config=node_config, conv_var_updater_factory=mock_conv_var_updater_factory, ) @@ -197,7 +196,7 @@ def test_append_variable_to_array(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool), config=node_config, conv_var_updater_factory=mock_conv_var_updater_factory, ) @@ -289,7 +288,7 @@ def test_clear_array(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool), config=node_config, conv_var_updater_factory=mock_conv_var_updater_factory, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 987eaf7534..dec0d6da2e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -1,4 +1,3 @@ -import time import uuid from uuid import uuid4 @@ -135,7 +134,7 @@ def test_remove_first_from_array(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool), config=node_config, ) @@ -227,7 +226,7 @@ def test_remove_last_from_array(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool), config=node_config, ) @@ -311,7 +310,7 @@ def test_remove_first_from_empty_array(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool), config=node_config, ) @@ -395,7 +394,7 @@ def test_remove_last_from_empty_array(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool), config=node_config, ) From f900a92ee7bc76a028fb989dcf02f638eb2d5b04 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 14:20:32 +0800 Subject: [PATCH 06/14] refactor(api): Simplify the constructor of `GraphEngine` Move most contextual arguments into `GraphInitParams`. --- .../workflow/graph_engine/graph_engine.py | 26 ++------------- api/core/workflow/nodes/base/node.py | 4 +-- .../nodes/iteration/iteration_node.py | 17 ++-------- api/core/workflow/nodes/loop/loop_node.py | 16 ++-------- api/core/workflow/workflow_entry.py | 9 ++++-- .../entities/test_graph_runtime_state.py | 1 - .../graph_engine/test_graph_engine.py | 32 ++++++++++++++----- .../core/workflow/nodes/answer/test_answer.py | 5 +-- 8 files changed, 41 insertions(+), 69 deletions(-) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b315129763..920d30b39b 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -3,7 +3,7 @@ import logging import queue import time import uuid -from collections.abc import Generator, Mapping +from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy from datetime import UTC, datetime @@ -91,19 +91,9 @@ class GraphEngine: def __init__( self, - tenant_id: str, - app_id: str, - workflow_type: WorkflowType, - workflow_id: str, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - call_depth: int, graph: Graph, - graph_config: Mapping[str, Any], + graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, - max_execution_steps: int, - max_execution_time: int, thread_pool_id: Optional[str] = None, ) -> None: thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT @@ -126,17 +116,7 @@ class GraphEngine: GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool self.graph = graph - self.init_params = GraphInitParams( - tenant_id=tenant_id, - app_id=app_id, - workflow_type=workflow_type, - workflow_id=workflow_id, - graph_config=graph_config, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, - call_depth=call_depth, - ) + self.init_params = graph_init_params self.graph_runtime_state = graph_runtime_state diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index fb5ec55453..b7a5e3eeec 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -32,9 +32,6 @@ class BaseNode: self.id = id self.tenant_id = graph_init_params.tenant_id self.app_id = graph_init_params.app_id - self.workflow_type = graph_init_params.workflow_type - self.workflow_id = graph_init_params.workflow_id - self.graph_config = graph_init_params.graph_config self.user_id = graph_init_params.user_id self.user_from = graph_init_params.user_from self.invoke_from = graph_init_params.invoke_from @@ -43,6 +40,7 @@ class BaseNode: self.graph_runtime_state = graph_runtime_state self.previous_node_id = previous_node_id self.thread_pool_id = thread_pool_id + self._init_params = graph_init_params node_id = config.get("id") if not node_id: diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 84a5732fdd..d07a9971b9 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,6 +1,5 @@ import contextvars import logging -import time import uuid from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, wait @@ -137,15 +136,13 @@ class IterationNode(BaseNode): inputs = {"iterator_selector": iterator_list_value} - graph_config = self.graph_config - if not self._node_data.start_node_id: raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found") root_node_id = self._node_data.start_node_id # init graph - iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) + iteration_graph = Graph.init(graph_config=self._init_params.graph_config, root_node_id=root_node_id) if not iteration_graph: raise IterationGraphNotFoundError("iteration graph not found") @@ -163,19 +160,9 @@ class IterationNode(BaseNode): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool) graph_engine = GraphEngine( - tenant_id=self.tenant_id, - app_id=self.app_id, - workflow_type=self.workflow_type, - workflow_id=self.workflow_id, - user_id=self.user_id, - user_from=self.user_from, - invoke_from=self.invoke_from, - call_depth=self.workflow_call_depth, graph=iteration_graph, - graph_config=graph_config, graph_runtime_state=graph_runtime_state, - max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, - max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + graph_init_params=self._init_params, thread_pool_id=self.thread_pool_id, ) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 53ac04d6dc..cd12e809d5 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,11 +1,9 @@ import json import logging -import time from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Literal, Optional, cast -from configs import dify_config from core.variables import ( IntegerSegment, Segment, @@ -91,7 +89,7 @@ class LoopNode(BaseNode): raise ValueError(f"field start_node_id in loop {self.node_id} not found") # Initialize graph - loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id) + loop_graph = Graph.init(graph_config=self._init_params.graph_config, root_node_id=self._node_data.start_node_id) if not loop_graph: raise ValueError("loop graph not found") @@ -127,19 +125,9 @@ class LoopNode(BaseNode): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool) graph_engine = GraphEngine( - tenant_id=self.tenant_id, - app_id=self.app_id, - workflow_type=self.workflow_type, - workflow_id=self.workflow_id, - user_id=self.user_id, - user_from=self.user_from, - invoke_from=self.invoke_from, - call_depth=self.workflow_call_depth, graph=loop_graph, - graph_config=self.graph_config, + graph_init_params=self._init_params, graph_runtime_state=graph_runtime_state, - max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, - max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, thread_pool_id=self.thread_pool_id, ) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 224e8036f7..a1b6f69289 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -71,7 +71,7 @@ class WorkflowEntry: # init workflow run state graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool) - self.graph_engine = GraphEngine( + graph_init_params = GraphInitParams( tenant_id=tenant_id, app_id=app_id, workflow_type=workflow_type, @@ -80,11 +80,14 @@ class WorkflowEntry: user_from=user_from, invoke_from=invoke_from, call_depth=call_depth, - graph=graph, graph_config=graph_config, - graph_runtime_state=graph_runtime_state, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + ) + self.graph_engine = GraphEngine( + graph=graph, + graph_runtime_state=graph_runtime_state, + graph_init_params=graph_init_params, thread_pool_id=thread_pool_id, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py index cf2dbcceb3..86c842b78d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py @@ -1,4 +1,3 @@ -import time from decimal import Decimal from core.model_runtime.entities.llm_entities import LLMUsage diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index c8a55538aa..408b09e30d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -178,6 +178,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, ) + init_params = GraphInitParams( tenant_id="111", app_id="222", workflow_type=WorkflowType.WORKFLOW, @@ -187,11 +188,14 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, call_depth=0, - graph=graph, - graph_runtime_state=graph_runtime_state, max_execution_steps=500, max_execution_time=1200, ) + graph_engine = GraphEngine( + graph=graph, + graph_runtime_state=graph_runtime_state, + graph_init_params=init_params, + ) def llm_generator(self): contents = ["hi", "bye", "good morning"] @@ -307,6 +311,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, ) + graph_init_params = GraphInitParams( tenant_id="111", app_id="222", workflow_type=WorkflowType.CHAT, @@ -316,11 +321,14 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, call_depth=0, - graph=graph, - graph_runtime_state=graph_runtime_state, max_execution_steps=500, max_execution_time=1200, ) + graph_engine = GraphEngine( + graph=graph, + graph_runtime_state=graph_runtime_state, + graph_init_params=graph_init_params, + ) # print("") @@ -489,6 +497,7 @@ def test_run_branch(mock_close, mock_remove): graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, ) + graph_init_params = GraphInitParams( tenant_id="111", app_id="222", workflow_type=WorkflowType.CHAT, @@ -498,11 +507,14 @@ def test_run_branch(mock_close, mock_remove): user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, call_depth=0, - graph=graph, - graph_runtime_state=graph_runtime_state, max_execution_steps=500, max_execution_time=1200, ) + graph_engine = GraphEngine( + graph=graph, + graph_runtime_state=graph_runtime_state, + graph_init_params=graph_init_params, + ) # print("") @@ -829,6 +841,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app): graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, ) + graph_init_params = GraphInitParams( tenant_id="111", app_id="222", workflow_type=WorkflowType.CHAT, @@ -838,11 +851,14 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app): user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, call_depth=0, - graph=graph, - graph_runtime_state=graph_runtime_state, max_execution_steps=500, max_execution_time=1200, ) + graph_engine = GraphEngine( + graph=graph, + graph_runtime_state=graph_runtime_state, + graph_init_params=graph_init_params, + ) def qc_generator(self): yield RunCompletedEvent( diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 1ef024f46b..d9f21e1460 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -1,4 +1,3 @@ -import time import uuid from unittest.mock import MagicMock @@ -71,7 +70,9 @@ def test_execute_answer(): id=str(uuid.uuid4()), graph_init_params=init_params, graph=graph, - graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + ), config=node_config, ) From e0343febde939f96c70f10422cd5d2b091ba3588 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 14:23:12 +0800 Subject: [PATCH 07/14] feat(api): support the suspension of graph engine Add a simple test case --- .../graph_engine/execution_decision.py | 25 +++ .../workflow/graph_engine/graph_engine.py | 104 +++++++-- .../graph_engine/test_graph_engine.py | 207 +++++++++++++++++- 3 files changed, 316 insertions(+), 20 deletions(-) create mode 100644 api/core/workflow/graph_engine/execution_decision.py diff --git a/api/core/workflow/graph_engine/execution_decision.py b/api/core/workflow/graph_engine/execution_decision.py new file mode 100644 index 0000000000..bd65819a40 --- /dev/null +++ b/api/core/workflow/graph_engine/execution_decision.py @@ -0,0 +1,25 @@ +from collections.abc import Callable +from dataclasses import dataclass +from enum import StrEnum +from typing import TypeAlias + +from core.workflow.nodes.base import BaseNode + + +class ExecutionDecision(StrEnum): + SUSPEND = "suspend" + STOP = "stop" + CONTINUE = "continue" + + +@dataclass(frozen=True) +class DecisionParams: + # `next_node_instance` is the instance of the next node to run. + next_node_instance: BaseNode + + +# `ExecutionDecisionHook` is a callable that takes a single argument of type `DecisionParams` and +# returns an `ExecutionDecision` indicating whether the graph engine should suspend, continue, or stop. +# +# It must not modify the data inside `DecisionParams`, including any attributes within its fields. +ExecutionDecisionHook: TypeAlias = Callable[[DecisionParams], ExecutionDecision] diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 920d30b39b..f3a11c9a98 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -6,14 +6,15 @@ import uuid from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy +from dataclasses import dataclass from datetime import UTC, datetime from typing import Any, Optional, cast from flask import Flask, current_app +from pydantic import BaseModel from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -27,6 +28,7 @@ from core.workflow.graph_engine.entities.event import ( GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, + GraphRunSuspendedEvent, NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, @@ -53,9 +55,10 @@ from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts -from models.enums import UserFrom from models.workflow import WorkflowType +from .execution_decision import DecisionParams, ExecutionDecision, ExecutionDecisionHook + logger = logging.getLogger(__name__) @@ -86,6 +89,10 @@ class GraphEngineThreadPool(ThreadPoolExecutor): raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") +def _default_hook(params: DecisionParams) -> ExecutionDecision: + return ExecutionDecision.CONTINUE + + class GraphEngine: workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {} @@ -95,7 +102,12 @@ class GraphEngine: graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, thread_pool_id: Optional[str] = None, + execution_decision_hook: ExecutionDecisionHook = _default_hook, ) -> None: + """Create a graph from the given state. + + The + """ thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT thread_pool_max_workers = 10 @@ -120,8 +132,7 @@ class GraphEngine: self.graph_runtime_state = graph_runtime_state - self.max_execution_steps = max_execution_steps - self.max_execution_time = max_execution_time + self._exec_decision_hook = execution_decision_hook def run(self) -> Generator[GraphEngineEvent, None, None]: # trigger graph run start event @@ -140,12 +151,18 @@ class GraphEngine: ) # run graph + + next_node_to_run = self.graph.root_node_id + if (next_node_id := self.graph_runtime_state.node_run_state.next_node_id) is not None: + next_node_to_run = next_node_id generator = stream_processor.process( - self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions) + self._run(start_node_id=next_node_to_run, handle_exceptions=handle_exceptions) ) for item in generator: try: yield item + if isinstance(item, GraphRunSuspendedEvent): + return if isinstance(item, NodeRunFailedEvent): yield GraphRunFailedEvent( error=item.route_node_state.failed_reason or "Unknown error.", @@ -209,22 +226,22 @@ class GraphEngine: parent_parallel_start_node_id: Optional[str] = None, handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent, None, None]: + # Hint: the `_run` method is used both when running a the main graph, + # and also running parallel branches. parallel_start_node_id = None if in_parallel_id: parallel_start_node_id = start_node_id next_node_id = start_node_id previous_route_node_state: Optional[RouteNodeState] = None + while True: # max steps reached - if self.graph_runtime_state.node_run_steps > self.max_execution_steps: - raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps)) + if self.graph_runtime_state.node_run_steps > self.init_params.max_execution_steps: + raise GraphRunFailedError("Max steps {} reached.".format(self.init_params.max_execution_steps)) - # or max execution time reached - if self._is_timed_out( - start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time - ): - raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time)) + if self.graph_runtime_state.is_timed_out(self.init_params.max_execution_time): + raise GraphRunFailedError("Max execution time {}s reached.".format(self.init_params.max_execution_time)) # init route node state route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id) @@ -257,6 +274,23 @@ class GraphEngine: thread_pool_id=self.thread_pool_id, ) node.init_node_data(node_config.get("data", {})) + # Determine if the execution should be suspended or stopped at this point. + # If so, yield the corresponding event. + # + # Note: Suspension is not allowed while the graph engine is running in parallel mode. + if in_parallel_id is None: + hook_result = self._exec_decision_hook(DecisionParams(next_node_instance=node)) + if hook_result == ExecutionDecision.SUSPEND: + self.graph_runtime_state.record_suspend_state(next_node_id) + yield GraphRunSuspendedEvent(next_node_id=next_node_id) + return + elif hook_result == ExecutionDecision.STOP: + # TODO: STOP the execution of worklow. + return + elif hook_result == ExecutionDecision.CONTINUE: + pass + else: + raise AssertionError("unreachable statement.") try: # run node generator = self._run_node( @@ -383,8 +417,8 @@ class GraphEngine: ) for parallel_result in parallel_generator: - if isinstance(parallel_result, str): - final_node_id = parallel_result + if isinstance(parallel_result, _ParallelBranchResult): + final_node_id = parallel_result.final_node_id else: yield parallel_result @@ -409,8 +443,8 @@ class GraphEngine: ) for generated_item in parallel_generator: - if isinstance(generated_item, str): - final_node_id = generated_item + if isinstance(generated_item, _ParallelBranchResult): + final_node_id = generated_item.final_node_id else: yield generated_item @@ -428,7 +462,7 @@ class GraphEngine: in_parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, handle_exceptions: list[str] = [], - ) -> Generator[GraphEngineEvent | str, None, None]: + ) -> Generator["GraphEngineEvent | _ParallelBranchResult", None, None]: # if nodes has no run conditions, parallel run all nodes parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) if not parallel_id: @@ -506,7 +540,7 @@ class GraphEngine: # get final node id final_node_id = parallel.end_to_node_id if final_node_id: - yield final_node_id + yield _ParallelBranchResult(final_node_id) def _run_parallel_node( self, @@ -908,7 +942,41 @@ class GraphEngine: ) return error_result + def save(self) -> str: + """save serializes the state inside this graph engine. + + This method should be called when suspension of the execution is necessary. + """ + state = _GraphEngineState(init_params=self.init_params, graph_runtime_state=self.graph_runtime_state) + return state.model_dump_json() + + @classmethod + def resume( + cls, + state: str, + graph: Graph, + execution_decision_hook: ExecutionDecisionHook = _default_hook, + ) -> "GraphEngine": + """`resume` continues a suspended execution.""" + state_ = _GraphEngineState.model_validate_json(state) + return cls( + graph=graph, + graph_init_params=state_.init_params, + graph_runtime_state=state_.graph_runtime_state, + execution_decision_hook=execution_decision_hook, + ) + class GraphRunFailedError(Exception): def __init__(self, error: str): self.error = error + + +@dataclass +class _ParallelBranchResult: + final_node_id: str + + +class _GraphEngineState(BaseModel): + init_params: GraphInitParams + graph_runtime_state: GraphRuntimeState diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 408b09e30d..ecbf53cf80 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -1,10 +1,10 @@ -import time from unittest.mock import patch import pytest from flask import Flask from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -13,15 +13,18 @@ from core.workflow.graph_engine.entities.event import ( GraphRunFailedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, + GraphRunSuspendedEvent, NodeRunFailedEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState -from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.graph_engine.execution_decision import DecisionParams +from core.workflow.graph_engine.graph_engine import ExecutionDecision, GraphEngine from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.llm.node import LLMNode @@ -904,3 +907,203 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app): assert item.outputs is not None answer = item.outputs["answer"] assert all(rc not in answer for rc in wrong_content) + + +def test_suspend_and_resume(): + graph_config = { + "edges": [ + { + "data": {"isInLoop": False, "sourceType": "start", "targetType": "if-else"}, + "id": "1753041723554-source-1753041730748-target", + "source": "1753041723554", + "sourceHandle": "source", + "target": "1753041730748", + "targetHandle": "target", + "type": "custom", + "zIndex": 0, + }, + { + "data": {"isInLoop": False, "sourceType": "if-else", "targetType": "answer"}, + "id": "1753041730748-true-answer-target", + "source": "1753041730748", + "sourceHandle": "true", + "target": "answer", + "targetHandle": "target", + "type": "custom", + "zIndex": 0, + }, + { + "data": { + "isInIteration": False, + "isInLoop": False, + "sourceType": "if-else", + "targetType": "answer", + }, + "id": "1753041730748-false-1753041952799-target", + "source": "1753041730748", + "sourceHandle": "false", + "target": "1753041952799", + "targetHandle": "target", + "type": "custom", + "zIndex": 0, + }, + ], + "nodes": [ + { + "data": {"desc": "", "selected": False, "title": "Start", "type": "start", "variables": []}, + "height": 54, + "id": "1753041723554", + "position": {"x": 32, "y": 282}, + "positionAbsolute": {"x": 32, "y": 282}, + "selected": False, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244, + }, + { + "data": { + "cases": [ + { + "case_id": "true", + "conditions": [ + { + "comparison_operator": "contains", + "id": "5db4103a-7e62-4e71-a0a6-c45ac11c0b3d", + "value": "a", + "varType": "string", + "variable_selector": ["sys", "query"], + } + ], + "id": "true", + "logical_operator": "and", + } + ], + "desc": "", + "selected": False, + "title": "IF/ELSE", + "type": "if-else", + }, + "height": 126, + "id": "1753041730748", + "position": {"x": 368, "y": 282}, + "positionAbsolute": {"x": 368, "y": 282}, + "selected": False, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244, + }, + { + "data": { + "answer": "A", + "desc": "", + "selected": False, + "title": "Answer A", + "type": "answer", + "variables": [], + }, + "height": 102, + "id": "answer", + "position": {"x": 746, "y": 282}, + "positionAbsolute": {"x": 746, "y": 282}, + "selected": False, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244, + }, + { + "data": { + "answer": "Else", + "desc": "", + "selected": False, + "title": "Answer Else", + "type": "answer", + "variables": [], + }, + "height": 102, + "id": "1753041952799", + "position": {"x": 746, "y": 426}, + "positionAbsolute": {"x": 746, "y": 426}, + "selected": True, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244, + }, + ], + "viewport": {"x": -420, "y": -76.5, "zoom": 1}, + } + graph = Graph.init(graph_config) + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="aaa", + files=[], + query="hello", + conversation_id="abababa", + ), + user_inputs={"uid": "takato"}, + ) + + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + ) + graph_init_params = GraphInitParams( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + max_execution_steps=500, + max_execution_time=1200, + ) + + _IF_ELSE_NODE_ID = "1753041730748" + + def exec_decision_hook(params: DecisionParams) -> ExecutionDecision: + # requires the engine to suspend before the execution + # of If-Else node. + if params.next_node_instance.node_id == _IF_ELSE_NODE_ID: + return ExecutionDecision.SUSPEND + else: + return ExecutionDecision.CONTINUE + + graph_engine = GraphEngine( + graph=graph, + graph_runtime_state=graph_runtime_state, + graph_init_params=graph_init_params, + execution_decision_hook=exec_decision_hook, + ) + events = list(graph_engine.run()) + last_event = events[-1] + assert isinstance(last_event, GraphRunSuspendedEvent) + assert last_event.next_node_id == _IF_ELSE_NODE_ID + state = graph_engine.save() + assert state != "" + + engine2 = GraphEngine.resume( + state=state, + graph=graph, + ) + events = list(engine2.run()) + assert isinstance(events[-1], GraphRunSucceededEvent) + node_run_succeeded_events = [i for i in events if isinstance(i, NodeRunSucceededEvent)] + assert node_run_succeeded_events + start_events = [i for i in node_run_succeeded_events if i.node_id == "1753041723554"] + assert not start_events + ifelse_succeeded_events = [i for i in node_run_succeeded_events if i.node_id == _IF_ELSE_NODE_ID] + assert ifelse_succeeded_events + answer_else_events = [i for i in node_run_succeeded_events if i.node_id == "1753041952799"] + assert answer_else_events + assert answer_else_events[0].route_node_state.node_run_result.outputs == { + "answer": "Else", + "files": ArrayFileSegment(value=[]), + } + + answer_a_events = [i for i in node_run_succeeded_events if i.node_id == "answer"] + assert not answer_a_events From b804f7179f50293b50d0a10d4f261c83478c50f0 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 14:31:38 +0800 Subject: [PATCH 08/14] docs(api): Update docs for `generate_string` and `BaseNode` --- api/core/workflow/nodes/base/node.py | 9 +++++++++ api/libs/helper.py | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index b7a5e3eeec..25751601b1 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -17,6 +17,15 @@ logger = logging.getLogger(__name__) class BaseNode: + """BaseNode serves as the foundational class for all node implementations. + + Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output` + attribute to track files generated by the LLM). However, these states are not persisted + when the workflow is suspended or resumed. If a node needs its state to be preserved + across workflow suspension and resumption, it should include the relevant state data + in its output. + """ + _node_type: ClassVar[NodeType] def __init__( diff --git a/api/libs/helper.py b/api/libs/helper.py index 00772d530a..c5178d2459 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -181,6 +181,26 @@ def timezone(timezone_string): def generate_string(n): + """ + Generates a cryptographically secure random string of the specified length. + + This function uses a cryptographically secure pseudorandom number generator (CSPRNG) + to create a string composed of ASCII letters (both uppercase and lowercase) and digits. + + Each character in the generated string provides approximately 5.95 bits of entropy + (log2(62)). To ensure a minimum of 128 bits of entropy for security purposes, the + length of the string (`n`) should be at least 22 characters. + + Args: + n (int): The length of the random string to generate. For secure usage, + `n` should be 22 or greater. + + Returns: + str: A random string of length `n` composed of ASCII letters and digits. + + Note: + This function is suitable for generating credentials or other secure tokens. + """ letters_digits = string.ascii_letters + string.digits result = "" for i in range(n): From ff7f1c1f173bc3bcd0ce1a0bd8b743b651bfc57b Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 14:35:29 +0800 Subject: [PATCH 09/14] refactor(api): Rename `continuation_node_id` to `next_node_id` in `WorkflowSuspension`. This aligns with the naming convension of the codebase. --- api/core/workflow/entities/workflow_suspension.py | 2 +- ...7_2020-1091956b9ee0_add_workflowsuspension_model_add_.py | 2 +- api/models/workflow.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/core/workflow/entities/workflow_suspension.py b/api/core/workflow/entities/workflow_suspension.py index 9304e6cadd..61da6a6a8e 100644 --- a/api/core/workflow/entities/workflow_suspension.py +++ b/api/core/workflow/entities/workflow_suspension.py @@ -19,7 +19,7 @@ class WorkflowSuspension(BaseModel): workflow_id: str - continuation_node_id: str + next_node_id: str state: str diff --git a/api/migrations/versions/2025_07_17_2020-1091956b9ee0_add_workflowsuspension_model_add_.py b/api/migrations/versions/2025_07_17_2020-1091956b9ee0_add_workflowsuspension_model_add_.py index 49182e6474..b5945b4706 100644 --- a/api/migrations/versions/2025_07_17_2020-1091956b9ee0_add_workflowsuspension_model_add_.py +++ b/api/migrations/versions/2025_07_17_2020-1091956b9ee0_add_workflowsuspension_model_add_.py @@ -28,7 +28,7 @@ def upgrade(): sa.Column('workflow_id', models.types.StringUUID(), nullable=False), sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False), sa.Column('resumed_at', sa.DateTime(), nullable=True), - sa.Column('continuation_node_id', sa.String(length=255), nullable=False), + sa.Column('next_node_id', sa.String(length=255), nullable=False), sa.Column('state_version', sa.String(length=20), nullable=False), sa.Column('state', sa.Text(), nullable=False), sa.Column('inputs', sa.Text(), nullable=True), diff --git a/api/models/workflow.py b/api/models/workflow.py index 8c153af741..3bdb6dbde6 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1340,12 +1340,12 @@ class WorkflowSuspension(Base): default=sa.null, ) - # `continuation_node_id` specifies the next node to execute when the workflow resumes. + # `next_node_id` specifies the next node to execute when the workflow resumes. # # Although this information is embedded within the `state` field, it is extracted # into a separate field to facilitate debugging and data analysis. - continuation_node_id: Mapped[str] = mapped_column( - sa.String(length=255), + next_node_id: Mapped[str] = mapped_column( + __name_pos=sa.String(length=255), nullable=False, ) From 5de663d52aefbdcf51c715f9a532e857f639d3f2 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 15:33:12 +0800 Subject: [PATCH 10/14] test(api): fix broken tests --- api/core/workflow/workflow_entry.py | 10 +++++----- .../core/workflow/nodes/test_continue_on_error.py | 10 +++++++--- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index a1b6f69289..605b6eaf6d 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast from configs import dify_config -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File from core.workflow.callbacks import WorkflowCallback @@ -199,10 +199,10 @@ class WorkflowEntry: "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", workflow.id, node_instance.id, - node_instance.node_type, + node_instance.type_, node_instance.version(), ) - raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + raise WorkflowNodeRunFailedError(node=node_instance, err_msg=str(e)) return node_instance, generator @classmethod @@ -307,10 +307,10 @@ class WorkflowEntry: logger.exception( "error while running node_instance, node_id=%s, type=%s, version=%s", node_instance.id, - node_instance.node_type, + node_instance.type_, node_instance.version(), ) - raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + raise WorkflowNodeRunFailedError(node=node_instance, err_msg=str(e)) @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index 499457f9cc..367c2c6596 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -11,6 +11,7 @@ from core.workflow.graph_engine.entities.event import ( NodeRunStreamChunkEvent, ) from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent @@ -178,7 +179,7 @@ class ContinueOnErrorTestHelper: variable_pool=variable_pool, ) - return GraphEngine( + graph_init_params = GraphInitParams( tenant_id="111", app_id="222", workflow_type=WorkflowType.CHAT, @@ -188,11 +189,14 @@ class ContinueOnErrorTestHelper: user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, call_depth=0, - graph=graph, - graph_runtime_state=graph_runtime_state, max_execution_steps=500, max_execution_time=1200, ) + return GraphEngine( + graph=graph, + graph_runtime_state=graph_runtime_state, + graph_init_params=graph_init_params, + ) DEFAULT_VALUE_EDGE = [ From 3af1a6d8c44f111d1ad7de7f0dd4f1dd8374ae94 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 15:55:49 +0800 Subject: [PATCH 11/14] refactor(api): rename ExecutionDecisionHook to CommandSource Use structured types for commands. --- .../workflow/graph_engine/command_source.py | 69 +++++++++++++++++++ .../graph_engine/execution_decision.py | 25 ------- .../workflow/graph_engine/graph_engine.py | 22 +++--- .../graph_engine/test_graph_engine.py | 14 ++-- 4 files changed, 87 insertions(+), 43 deletions(-) create mode 100644 api/core/workflow/graph_engine/command_source.py delete mode 100644 api/core/workflow/graph_engine/execution_decision.py diff --git a/api/core/workflow/graph_engine/command_source.py b/api/core/workflow/graph_engine/command_source.py new file mode 100644 index 0000000000..1cafcf20f9 --- /dev/null +++ b/api/core/workflow/graph_engine/command_source.py @@ -0,0 +1,69 @@ +import abc +from collections.abc import Callable +from dataclasses import dataclass +from enum import StrEnum +from typing import Annotated, TypeAlias, final + +from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator + +from core.workflow.nodes.base import BaseNode + + +@dataclass(frozen=True) +class CommandParams: + # `next_node_instance` is the instance of the next node to run. + next_node: BaseNode + + +class _CommandTag(StrEnum): + SUSPEND = "suspend" + STOP = "stop" + CONTINUE = "continue" + + +class Command(BaseModel, abc.ABC): + model_config = ConfigDict(frozen=True) + + tag: _CommandTag + + @field_validator("tag") + @classmethod + def validate_value_type(cls, value): + if value != cls.model_fields["tag"].default: + raise ValueError("Cannot modify 'tag'") + return value + + +@final +class StopCommand(Command): + tag: _CommandTag = _CommandTag.STOP + + +@final +class SuspendCommand(Command): + tag: _CommandTag = _CommandTag.SUSPEND + + +@final +class ContinueCommand(Command): + tag: _CommandTag = _CommandTag.CONTINUE + + +def _get_command_tag(command: Command): + return command.tag + + +CommandTypes: TypeAlias = Annotated[ + ( + Annotated[StopCommand, Tag(_CommandTag.STOP)] + | Annotated[SuspendCommand, Tag(_CommandTag.SUSPEND)] + | Annotated[ContinueCommand, Tag(_CommandTag.CONTINUE)] + ), + Discriminator(_get_command_tag), +] + +# `CommandSource` is a callable that takes a single argument of type `CommandParams` and +# returns a `Command` object to the engine, indicating whether the graph engine should suspend, continue, or stop. +# +# It must not modify the data inside `CommandParams`, including any attributes within its fields. +CommandSource: TypeAlias = Callable[[CommandParams], CommandTypes] diff --git a/api/core/workflow/graph_engine/execution_decision.py b/api/core/workflow/graph_engine/execution_decision.py deleted file mode 100644 index bd65819a40..0000000000 --- a/api/core/workflow/graph_engine/execution_decision.py +++ /dev/null @@ -1,25 +0,0 @@ -from collections.abc import Callable -from dataclasses import dataclass -from enum import StrEnum -from typing import TypeAlias - -from core.workflow.nodes.base import BaseNode - - -class ExecutionDecision(StrEnum): - SUSPEND = "suspend" - STOP = "stop" - CONTINUE = "continue" - - -@dataclass(frozen=True) -class DecisionParams: - # `next_node_instance` is the instance of the next node to run. - next_node_instance: BaseNode - - -# `ExecutionDecisionHook` is a callable that takes a single argument of type `DecisionParams` and -# returns an `ExecutionDecision` indicating whether the graph engine should suspend, continue, or stop. -# -# It must not modify the data inside `DecisionParams`, including any attributes within its fields. -ExecutionDecisionHook: TypeAlias = Callable[[DecisionParams], ExecutionDecision] diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index f3a11c9a98..b89aee7712 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -57,7 +57,7 @@ from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts from models.workflow import WorkflowType -from .execution_decision import DecisionParams, ExecutionDecision, ExecutionDecisionHook +from .command_source import Command, CommandParams, CommandSource, ContinueCommand, StopCommand, SuspendCommand logger = logging.getLogger(__name__) @@ -89,8 +89,8 @@ class GraphEngineThreadPool(ThreadPoolExecutor): raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") -def _default_hook(params: DecisionParams) -> ExecutionDecision: - return ExecutionDecision.CONTINUE +def _default_source(params: CommandParams) -> Command: + return ContinueCommand() class GraphEngine: @@ -102,7 +102,7 @@ class GraphEngine: graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, thread_pool_id: Optional[str] = None, - execution_decision_hook: ExecutionDecisionHook = _default_hook, + command_source: CommandSource = _default_source, ) -> None: """Create a graph from the given state. @@ -132,7 +132,7 @@ class GraphEngine: self.graph_runtime_state = graph_runtime_state - self._exec_decision_hook = execution_decision_hook + self._command_source = command_source def run(self) -> Generator[GraphEngineEvent, None, None]: # trigger graph run start event @@ -279,15 +279,15 @@ class GraphEngine: # # Note: Suspension is not allowed while the graph engine is running in parallel mode. if in_parallel_id is None: - hook_result = self._exec_decision_hook(DecisionParams(next_node_instance=node)) - if hook_result == ExecutionDecision.SUSPEND: + command = self._command_source(CommandParams(next_node=node)) + if isinstance(command, SuspendCommand): self.graph_runtime_state.record_suspend_state(next_node_id) yield GraphRunSuspendedEvent(next_node_id=next_node_id) return - elif hook_result == ExecutionDecision.STOP: + elif isinstance(command, StopCommand): # TODO: STOP the execution of worklow. return - elif hook_result == ExecutionDecision.CONTINUE: + elif isinstance(command, ContinueCommand): pass else: raise AssertionError("unreachable statement.") @@ -955,7 +955,7 @@ class GraphEngine: cls, state: str, graph: Graph, - execution_decision_hook: ExecutionDecisionHook = _default_hook, + command_source: CommandSource = _default_source, ) -> "GraphEngine": """`resume` continues a suspended execution.""" state_ = _GraphEngineState.model_validate_json(state) @@ -963,7 +963,7 @@ class GraphEngine: graph=graph, graph_init_params=state_.init_params, graph_runtime_state=state_.graph_runtime_state, - execution_decision_hook=execution_decision_hook, + command_source=command_source, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index ecbf53cf80..92e46d4abd 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -8,6 +8,7 @@ from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.graph_engine.command_source import CommandParams, CommandTypes, ContinueCommand, SuspendCommand from core.workflow.graph_engine.entities.event import ( BaseNodeEvent, GraphRunFailedEvent, @@ -23,8 +24,7 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState -from core.workflow.graph_engine.execution_decision import DecisionParams -from core.workflow.graph_engine.graph_engine import ExecutionDecision, GraphEngine +from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.llm.node import LLMNode @@ -1065,19 +1065,19 @@ def test_suspend_and_resume(): _IF_ELSE_NODE_ID = "1753041730748" - def exec_decision_hook(params: DecisionParams) -> ExecutionDecision: + def command_source(params: CommandParams) -> CommandTypes: # requires the engine to suspend before the execution # of If-Else node. - if params.next_node_instance.node_id == _IF_ELSE_NODE_ID: - return ExecutionDecision.SUSPEND + if params.next_node.node_id == _IF_ELSE_NODE_ID: + return SuspendCommand() else: - return ExecutionDecision.CONTINUE + return ContinueCommand() graph_engine = GraphEngine( graph=graph, graph_runtime_state=graph_runtime_state, graph_init_params=graph_init_params, - execution_decision_hook=exec_decision_hook, + command_source=command_source, ) events = list(graph_engine.run()) last_event = events[-1] From d50bdc1d70c73a32707d96f7f38cce4eea8bdb3e Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 16:07:25 +0800 Subject: [PATCH 12/14] chore(api): rebase migrations --- ..._07_21_0935-1a83934ad6d1_adjust_mcp_string_fields_length.py} | 2 +- ...7_21_1605-1091956b9ee0_add_workflowsuspension_model_add_.py} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename api/migrations/versions/{2025_07_21_0935-1a83934ad6d1_update_models.py => 2025_07_21_0935-1a83934ad6d1_adjust_mcp_string_fields_length.py} (96%) rename api/migrations/versions/{2025_07_17_2020-1091956b9ee0_add_workflowsuspension_model_add_.py => 2025_07_21_1605-1091956b9ee0_add_workflowsuspension_model_add_.py} (98%) diff --git a/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_adjust_mcp_string_fields_length.py similarity index 96% rename from api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py rename to api/migrations/versions/2025_07_21_0935-1a83934ad6d1_adjust_mcp_string_fields_length.py index 3bdbafda7c..4e2c2169ea 100644 --- a/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py +++ b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_adjust_mcp_string_fields_length.py @@ -1,4 +1,4 @@ -"""update models +"""adjust length for mcp tool name and server identifiers Revision ID: 1a83934ad6d1 Revises: 71f5020c6470 diff --git a/api/migrations/versions/2025_07_17_2020-1091956b9ee0_add_workflowsuspension_model_add_.py b/api/migrations/versions/2025_07_21_1605-1091956b9ee0_add_workflowsuspension_model_add_.py similarity index 98% rename from api/migrations/versions/2025_07_17_2020-1091956b9ee0_add_workflowsuspension_model_add_.py rename to api/migrations/versions/2025_07_21_1605-1091956b9ee0_add_workflowsuspension_model_add_.py index b5945b4706..b6140ec358 100644 --- a/api/migrations/versions/2025_07_17_2020-1091956b9ee0_add_workflowsuspension_model_add_.py +++ b/api/migrations/versions/2025_07_21_1605-1091956b9ee0_add_workflowsuspension_model_add_.py @@ -12,7 +12,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '1091956b9ee0' -down_revision = '1c9ba48be8e4' +down_revision = '1a83934ad6d1' branch_labels = None depends_on = None From 71aa34dcce385bf94aad0c56df2cb8d8210951fd Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 16:35:00 +0800 Subject: [PATCH 13/14] chore(api): fix mypy violations --- api/core/workflow/graph_engine/command_source.py | 12 +++++++----- api/core/workflow/graph_engine/graph_engine.py | 11 +++++++++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/api/core/workflow/graph_engine/command_source.py b/api/core/workflow/graph_engine/command_source.py index 1cafcf20f9..2d0d4b8211 100644 --- a/api/core/workflow/graph_engine/command_source.py +++ b/api/core/workflow/graph_engine/command_source.py @@ -21,7 +21,9 @@ class _CommandTag(StrEnum): CONTINUE = "continue" -class Command(BaseModel, abc.ABC): +# Note: Avoid using the `_Command` class directly. +# Instead, use `CommandTypes` for type annotations. +class _Command(BaseModel, abc.ABC): model_config = ConfigDict(frozen=True) tag: _CommandTag @@ -35,21 +37,21 @@ class Command(BaseModel, abc.ABC): @final -class StopCommand(Command): +class StopCommand(_Command): tag: _CommandTag = _CommandTag.STOP @final -class SuspendCommand(Command): +class SuspendCommand(_Command): tag: _CommandTag = _CommandTag.SUSPEND @final -class ContinueCommand(Command): +class ContinueCommand(_Command): tag: _CommandTag = _CommandTag.CONTINUE -def _get_command_tag(command: Command): +def _get_command_tag(command: _Command): return command.tag diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b89aee7712..d930b6a923 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -57,7 +57,14 @@ from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts from models.workflow import WorkflowType -from .command_source import Command, CommandParams, CommandSource, ContinueCommand, StopCommand, SuspendCommand +from .command_source import ( + CommandParams, + CommandSource, + CommandTypes, + ContinueCommand, + StopCommand, + SuspendCommand, +) logger = logging.getLogger(__name__) @@ -89,7 +96,7 @@ class GraphEngineThreadPool(ThreadPoolExecutor): raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") -def _default_source(params: CommandParams) -> Command: +def _default_source(_: CommandParams) -> CommandTypes: return ContinueCommand() From 96f749e16235641f23d8777def71d8ff9d8949fa Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 20:42:19 +0800 Subject: [PATCH 14/14] Update api/core/workflow/graph_engine/_engine_utils.py Co-authored-by: -LAN- --- api/core/workflow/graph_engine/_engine_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/graph_engine/_engine_utils.py b/api/core/workflow/graph_engine/_engine_utils.py index 28898268fe..6e915cf256 100644 --- a/api/core/workflow/graph_engine/_engine_utils.py +++ b/api/core/workflow/graph_engine/_engine_utils.py @@ -1,7 +1,7 @@ import time -def get_timestamp() -> float: +def get_current_timestamp() -> float: """Retrieve a timestamp as a float point numer representing the number of seconds since the Unix epoch.