pull/22621/merge
QuantumGhost 7 months ago committed by GitHub
commit 9e0fc64e41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -29,6 +29,7 @@ class QueueEvent(StrEnum):
WORKFLOW_SUCCEEDED = "workflow_succeeded" WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed" WORKFLOW_FAILED = "workflow_failed"
WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded" WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded"
WORKFLOW_SUSPENDED = "workflow_suspended"
ITERATION_START = "iteration_start" ITERATION_START = "iteration_start"
ITERATION_NEXT = "iteration_next" ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed" ITERATION_COMPLETED = "iteration_completed"
@ -326,6 +327,13 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
graph_runtime_state: GraphRuntimeState graph_runtime_state: GraphRuntimeState
class QueueWorkflowSuspendedEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.WORKFLOW_SUSPENDED
# next_node_id records the next node to execute after resuming
# workflow.
next_node_id: str
class QueueWorkflowSucceededEvent(AppQueueEvent): class QueueWorkflowSucceededEvent(AppQueueEvent):
""" """
QueueWorkflowSucceededEvent entity QueueWorkflowSucceededEvent entity

@ -23,12 +23,109 @@ class WorkflowType(StrEnum):
class WorkflowExecutionStatus(StrEnum): class WorkflowExecutionStatus(StrEnum):
# State diagram for the workflw status:
# (@) means start, (*) means end
#
# ┌------------------>------------------------->------------------->--------------┐
# | |
# | ┌-----------------------<--------------------┐ |
# ^ | | |
# | | ^ |
# | V | |
# ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V
# | Scheduled |------->| Running |---------------------->| Suspended | |
# └-----------┘ └-----------------------┘ └-----------┘ |
# | | | | | | |
# | | | | | | |
# ^ | | | V V |
# | | | | | ┌---------┐ |
# (@) | | | └------------------------>| Stopped |<----┘
# | | | └---------┘
# | | | |
# | | V V
# | | ┌-----------┐ |
# | | | Succeeded |------------->--------------┤
# | | └-----------┘ |
# | V V
# | +--------┐ |
# | | Failed |---------------------->----------------┤
# | └--------┘ |
# V V
# ┌---------------------┐ |
# | Partially Succeeded |---------------------->-----------------┘--------> (*)
# └---------------------┘
#
# Mermaid diagram:
#
# ---
# title: State diagram for Workflow run state
# ---
# stateDiagram-v2
# scheduled: Scheduled
# running: Running
# succeeded: Succeeded
# failed: Failed
# partial_succeeded: Partial Succeeded
# suspended: Suspended
# stopped: Stopped
#
# [*] --> scheduled:
# scheduled --> running: Start Execution
# running --> suspended: Human input required
# suspended --> running: human input added
# suspended --> stopped: User stops execution
# running --> succeeded: Execution finishes without any error
# running --> failed: Execution finishes with errors
# running --> stopped: User stops execution
# running --> partial_succeeded: some execution occurred and handled during execution
#
# scheduled --> stopped: User stops execution
#
# succeeded --> [*]
# failed --> [*]
# partial_succeeded --> [*]
# stopped --> [*]
# `SCHEDULED` means that the workflow is scheduled to run, but has not
# started running yet. (maybe due to possible worker saturation.)
SCHEDULED = "scheduled"
# `RUNNING` means the workflow is exeuting.
RUNNING = "running" RUNNING = "running"
# `SUCCEEDED` means the execution of workflow succeed without any error.
SUCCEEDED = "succeeded" SUCCEEDED = "succeeded"
# `FAILED` means the execution of workflow failed without some errors.
FAILED = "failed" FAILED = "failed"
# `STOPPED` means the execution of workflow was stopped, either manually
# by the user, or automatically by the Dify application (E.G. the moderation
# mechanism.)
STOPPED = "stopped" STOPPED = "stopped"
# `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow
# execution, but they were successfully handled (e.g., by using an error
# strategy such as "fail branch" or "default value").
PARTIAL_SUCCEEDED = "partial-succeeded" PARTIAL_SUCCEEDED = "partial-succeeded"
# `SUSPENDED` indicates that the workflow execution is temporarily paused
# (e.g., awaiting human input) and is expected to resume later.
SUSPENDED = "suspended"
def is_ended(self) -> bool:
return self in _END_STATE
_END_STATE = frozenset(
[
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
WorkflowExecutionStatus.STOPPED,
]
)
class WorkflowExecution(BaseModel): class WorkflowExecution(BaseModel):
""" """

@ -0,0 +1,28 @@
from enum import StrEnum
from uuid import UUID
from pydantic import BaseModel, Field
from libs.uuid_utils import uuidv7
class StateVersion(StrEnum):
# `V1` is `GraphRuntimeState` serialized as JSON by dumping with Pydantic.
V1 = "v1"
class WorkflowSuspension(BaseModel):
id: UUID = Field(default_factory=uuidv7)
# Correspond to WorkflowExecution.id_
execution_id: str
workflow_id: str
next_node_id: str
state: str
state_version: StateVersion = StateVersion.V1
inputs: str

@ -0,0 +1,15 @@
import time
def get_current_timestamp() -> float:
"""Retrieve a timestamp as a float point numer representing the number of seconds
since the Unix epoch.
This function is primarily used to measure the execution time of the workflow engine.
Since workflow execution may be paused and resumed on a different machine,
`time.perf_counter` cannot be used as it is inconsistent across machines.
To address this, the function uses the wall clock as the time source.
However, it assumes that the clocks of all servers are properly synchronized.
"""
return round(time.time())

@ -0,0 +1,71 @@
import abc
from collections.abc import Callable
from dataclasses import dataclass
from enum import StrEnum
from typing import Annotated, TypeAlias, final
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.workflow.nodes.base import BaseNode
@dataclass(frozen=True)
class CommandParams:
# `next_node_instance` is the instance of the next node to run.
next_node: BaseNode
class _CommandTag(StrEnum):
SUSPEND = "suspend"
STOP = "stop"
CONTINUE = "continue"
# Note: Avoid using the `_Command` class directly.
# Instead, use `CommandTypes` for type annotations.
class _Command(BaseModel, abc.ABC):
model_config = ConfigDict(frozen=True)
tag: _CommandTag
@field_validator("tag")
@classmethod
def validate_value_type(cls, value):
if value != cls.model_fields["tag"].default:
raise ValueError("Cannot modify 'tag'")
return value
@final
class StopCommand(_Command):
tag: _CommandTag = _CommandTag.STOP
@final
class SuspendCommand(_Command):
tag: _CommandTag = _CommandTag.SUSPEND
@final
class ContinueCommand(_Command):
tag: _CommandTag = _CommandTag.CONTINUE
def _get_command_tag(command: _Command):
return command.tag
CommandTypes: TypeAlias = Annotated[
(
Annotated[StopCommand, Tag(_CommandTag.STOP)]
| Annotated[SuspendCommand, Tag(_CommandTag.SUSPEND)]
| Annotated[ContinueCommand, Tag(_CommandTag.CONTINUE)]
),
Discriminator(_get_command_tag),
]
# `CommandSource` is a callable that takes a single argument of type `CommandParams` and
# returns a `Command` object to the engine, indicating whether the graph engine should suspend, continue, or stop.
#
# It must not modify the data inside `CommandParams`, including any attributes within its fields.
CommandSource: TypeAlias = Callable[[CommandParams], CommandTypes]

@ -43,6 +43,10 @@ class GraphRunPartialSucceededEvent(BaseGraphEvent):
outputs: Optional[dict[str, Any]] = None outputs: Optional[dict[str, Any]] = None
class GraphRunSuspendedEvent(BaseGraphEvent):
next_node_id: str = Field(..., description="the next node id to execute while resumed.")
########################################### ###########################################
# Node Events # Node Events
########################################### ###########################################

@ -1,14 +1,25 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, PositiveInt
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType
class GraphInitParams(BaseModel): class GraphInitParams(BaseModel):
"""GraphInitParams encapsulates the configurations and contextual information
that remain constant throughout a single execution of the graph engine.
A single execution is defined as follows: as long as the execution has not reached
its conclusion, it is considered one execution. For instance, if a workflow is suspended
and later resumed, it is still regarded as a single execution, not two.
For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`.
"""
# init params # init params
tenant_id: str = Field(..., description="tenant / workspace id") tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id") app_id: str = Field(..., description="app id")
@ -19,3 +30,10 @@ class GraphInitParams(BaseModel):
user_from: UserFrom = Field(..., description="user from, account or end-user") user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger") invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
call_depth: int = Field(..., description="call depth") call_depth: int = Field(..., description="call depth")
# max_execution_steps records the maximum steps allowed during the execution of a workflow.
max_execution_steps: PositiveInt = Field(
default=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, description="max_execution_steps"
)
# max_execution_time records the max execution time for the workflow, measured in seconds
max_execution_time: PositiveInt = Field(default=dify_config.WORKFLOW_MAX_EXECUTION_TIME, description="")

@ -1,18 +1,32 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, PrivateAttr
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine._engine_utils import get_timestamp
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
_SECOND_TO_US = 1_000_000
class GraphRuntimeState(BaseModel): class GraphRuntimeState(BaseModel):
"""`GraphRuntimeState` encapsulates the runtime state of workflow execution,
including scheduling details, variable values, and timing information.
Values that are initialized prior to workflow execution and remain constant
throughout the execution should be part of `GraphInitParams` instead.
"""
variable_pool: VariablePool = Field(..., description="variable pool") variable_pool: VariablePool = Field(..., description="variable pool")
"""variable pool""" """variable pool"""
start_at: float = Field(..., description="start time") # The `start_at` field records the execution start time of the workflow.
"""start time""" #
# This field is automatically generated, and its value or interpretation may evolve.
# Avoid manually setting this field to ensure compatibility with future updates.
start_at: float = Field(description="start time", default_factory=get_timestamp)
total_tokens: int = 0 total_tokens: int = 0
"""total tokens""" """total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage() llm_usage: LLMUsage = LLMUsage.empty_usage()
@ -29,3 +43,28 @@ class GraphRuntimeState(BaseModel):
node_run_state: RuntimeRouteState = RuntimeRouteState() node_run_state: RuntimeRouteState = RuntimeRouteState()
"""node run state""" """node run state"""
# `execution_time_us` tracks the total execution time of the workflow in microseconds.
# Time spent in suspension is excluded from this calculation.
#
# This field is used to persist the time already spent while suspending a workflow.
execution_time_us: int = 0
# `_last_execution_started_at` records the timestamp of the most recent resume start.
# It is updated when the workflow resumes from a suspended state.
_last_execution_started_at: float = PrivateAttr(default_factory=get_timestamp)
def is_timed_out(self, max_execution_time_seconds: int) -> bool:
"""Checks if the workflow execution has exceeded the specified `max_execution_time_seconds`."""
remaining_time_us = max_execution_time_seconds * _SECOND_TO_US - self.execution_time_us
if remaining_time_us <= 0:
return False
return int(get_timestamp() - self._last_execution_started_at) * _SECOND_TO_US > remaining_time_us
def record_suspend_state(self, next_node_id: str):
"""Record the time already spent in executing workflow.
This function should be called when suspending the workflow.
"""
self.execution_time_us = int(get_timestamp() - self._last_execution_started_at) * _SECOND_TO_US
self.node_run_state.next_node_id = next_node_id

@ -1,9 +1,9 @@
import uuid import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import Enum from enum import Enum
from typing import Optional from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, PrivateAttr
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -44,6 +44,8 @@ class RouteNodeState(BaseModel):
paused_by: Optional[str] = None paused_by: Optional[str] = None
"""paused by""" """paused by"""
# The `index` is used used to record the execution order for a given node.
# Nodes executed ealier get smaller `index` values.
index: int = 1 index: int = 1
def set_finished(self, run_result: NodeRunResult) -> None: def set_finished(self, run_result: NodeRunResult) -> None:
@ -79,10 +81,27 @@ class RuntimeRouteState(BaseModel):
default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)" default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)"
) )
# A mapping from node_id to its routing state.
node_state_mapping: dict[str, RouteNodeState] = Field( node_state_mapping: dict[str, RouteNodeState] = Field(
default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)" default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)"
) )
next_node_id: Optional[str] = Field(
default=None, description="The next node id to run when resumed from suspension."
)
# If `previous_node_id` is not `None`, then the correspond node has state in the dict
# `node_state_mapping`.
previous_node_state_id: Optional[str] = Field(default=None, description="The state of last executed node.")
# `_state_by_id` serves as a mapping from the unique identifier (`id`) of each `RouteNodeState`
# instance to the corresponding `RouteNodeState` object itself.
_state_by_id: dict[str, RouteNodeState] = PrivateAttr(default={})
def model_post_init(self, context: Any) -> None:
super().model_post_init(context)
self._state_by_id = {v.id: v for v in self.node_state_mapping.values()}
def create_node_state(self, node_id: str) -> RouteNodeState: def create_node_state(self, node_id: str) -> RouteNodeState:
""" """
Create node state Create node state
@ -91,6 +110,7 @@ class RuntimeRouteState(BaseModel):
""" """
state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None))
self.node_state_mapping[state.id] = state self.node_state_mapping[state.id] = state
self._state_by_id[state.id] = state
return state return state
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None: def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
@ -115,3 +135,18 @@ class RuntimeRouteState(BaseModel):
return [ return [
self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, []) self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, [])
] ]
# def get_node_state(self, node_id: str) -> RouteNodeState | None:
# return self.node_state_mapping.get(node_id)
def get_previous_route_node_state(self) -> RouteNodeState | None:
if self.previous_node_state_id is None:
return None
return self._state_by_id[self.previous_node_state_id]
@property
def previous_node_id(self):
if self.previous_node_state_id is None:
return None
state = self._state_by_id[self.previous_node_state_id]
return state.node_id

@ -3,17 +3,18 @@ import logging
import queue import queue
import time import time
import uuid import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor, wait from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any, Optional, cast from typing import Any, Optional, cast
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import BaseModel
from configs import dify_config from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -27,6 +28,7 @@ from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent, GraphRunPartialSucceededEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
GraphRunSuspendedEvent,
NodeRunExceptionEvent, NodeRunExceptionEvent,
NodeRunFailedEvent, NodeRunFailedEvent,
NodeRunRetrieverResourceEvent, NodeRunRetrieverResourceEvent,
@ -53,9 +55,17 @@ from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.utils import variable_utils from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType
from .command_source import (
CommandParams,
CommandSource,
CommandTypes,
ContinueCommand,
StopCommand,
SuspendCommand,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -86,26 +96,25 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
def _default_source(_: CommandParams) -> CommandTypes:
return ContinueCommand()
class GraphEngine: class GraphEngine:
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {} workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
def __init__( def __init__(
self, self,
tenant_id: str,
app_id: str,
workflow_type: WorkflowType,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph, graph: Graph,
graph_config: Mapping[str, Any], graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
max_execution_steps: int,
max_execution_time: int,
thread_pool_id: Optional[str] = None, thread_pool_id: Optional[str] = None,
command_source: CommandSource = _default_source,
) -> None: ) -> None:
"""Create a graph from the given state.
The
"""
thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT
thread_pool_max_workers = 10 thread_pool_max_workers = 10
@ -126,22 +135,11 @@ class GraphEngine:
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
self.graph = graph self.graph = graph
self.init_params = GraphInitParams( self.init_params = graph_init_params
tenant_id=tenant_id,
app_id=app_id,
workflow_type=workflow_type,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
)
self.graph_runtime_state = graph_runtime_state self.graph_runtime_state = graph_runtime_state
self.max_execution_steps = max_execution_steps self._command_source = command_source
self.max_execution_time = max_execution_time
def run(self) -> Generator[GraphEngineEvent, None, None]: def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event # trigger graph run start event
@ -160,12 +158,18 @@ class GraphEngine:
) )
# run graph # run graph
next_node_to_run = self.graph.root_node_id
if (next_node_id := self.graph_runtime_state.node_run_state.next_node_id) is not None:
next_node_to_run = next_node_id
generator = stream_processor.process( generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions) self._run(start_node_id=next_node_to_run, handle_exceptions=handle_exceptions)
) )
for item in generator: for item in generator:
try: try:
yield item yield item
if isinstance(item, GraphRunSuspendedEvent):
return
if isinstance(item, NodeRunFailedEvent): if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent( yield GraphRunFailedEvent(
error=item.route_node_state.failed_reason or "Unknown error.", error=item.route_node_state.failed_reason or "Unknown error.",
@ -229,22 +233,22 @@ class GraphEngine:
parent_parallel_start_node_id: Optional[str] = None, parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [], handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent, None, None]: ) -> Generator[GraphEngineEvent, None, None]:
# Hint: the `_run` method is used both when running a the main graph,
# and also running parallel branches.
parallel_start_node_id = None parallel_start_node_id = None
if in_parallel_id: if in_parallel_id:
parallel_start_node_id = start_node_id parallel_start_node_id = start_node_id
next_node_id = start_node_id next_node_id = start_node_id
previous_route_node_state: Optional[RouteNodeState] = None previous_route_node_state: Optional[RouteNodeState] = None
while True: while True:
# max steps reached # max steps reached
if self.graph_runtime_state.node_run_steps > self.max_execution_steps: if self.graph_runtime_state.node_run_steps > self.init_params.max_execution_steps:
raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps)) raise GraphRunFailedError("Max steps {} reached.".format(self.init_params.max_execution_steps))
# or max execution time reached if self.graph_runtime_state.is_timed_out(self.init_params.max_execution_time):
if self._is_timed_out( raise GraphRunFailedError("Max execution time {}s reached.".format(self.init_params.max_execution_time))
start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time
):
raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time))
# init route node state # init route node state
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id) route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id)
@ -277,6 +281,23 @@ class GraphEngine:
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,
) )
node.init_node_data(node_config.get("data", {})) node.init_node_data(node_config.get("data", {}))
# Determine if the execution should be suspended or stopped at this point.
# If so, yield the corresponding event.
#
# Note: Suspension is not allowed while the graph engine is running in parallel mode.
if in_parallel_id is None:
command = self._command_source(CommandParams(next_node=node))
if isinstance(command, SuspendCommand):
self.graph_runtime_state.record_suspend_state(next_node_id)
yield GraphRunSuspendedEvent(next_node_id=next_node_id)
return
elif isinstance(command, StopCommand):
# TODO: STOP the execution of worklow.
return
elif isinstance(command, ContinueCommand):
pass
else:
raise AssertionError("unreachable statement.")
try: try:
# run node # run node
generator = self._run_node( generator = self._run_node(
@ -403,8 +424,8 @@ class GraphEngine:
) )
for parallel_result in parallel_generator: for parallel_result in parallel_generator:
if isinstance(parallel_result, str): if isinstance(parallel_result, _ParallelBranchResult):
final_node_id = parallel_result final_node_id = parallel_result.final_node_id
else: else:
yield parallel_result yield parallel_result
@ -429,8 +450,8 @@ class GraphEngine:
) )
for generated_item in parallel_generator: for generated_item in parallel_generator:
if isinstance(generated_item, str): if isinstance(generated_item, _ParallelBranchResult):
final_node_id = generated_item final_node_id = generated_item.final_node_id
else: else:
yield generated_item yield generated_item
@ -448,7 +469,7 @@ class GraphEngine:
in_parallel_id: Optional[str] = None, in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [], handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent | str, None, None]: ) -> Generator["GraphEngineEvent | _ParallelBranchResult", None, None]:
# if nodes has no run conditions, parallel run all nodes # if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id: if not parallel_id:
@ -526,7 +547,7 @@ class GraphEngine:
# get final node id # get final node id
final_node_id = parallel.end_to_node_id final_node_id = parallel.end_to_node_id
if final_node_id: if final_node_id:
yield final_node_id yield _ParallelBranchResult(final_node_id)
def _run_parallel_node( def _run_parallel_node(
self, self,
@ -928,7 +949,41 @@ class GraphEngine:
) )
return error_result return error_result
def save(self) -> str:
"""save serializes the state inside this graph engine.
This method should be called when suspension of the execution is necessary.
"""
state = _GraphEngineState(init_params=self.init_params, graph_runtime_state=self.graph_runtime_state)
return state.model_dump_json()
@classmethod
def resume(
cls,
state: str,
graph: Graph,
command_source: CommandSource = _default_source,
) -> "GraphEngine":
"""`resume` continues a suspended execution."""
state_ = _GraphEngineState.model_validate_json(state)
return cls(
graph=graph,
graph_init_params=state_.init_params,
graph_runtime_state=state_.graph_runtime_state,
command_source=command_source,
)
class GraphRunFailedError(Exception): class GraphRunFailedError(Exception):
def __init__(self, error: str): def __init__(self, error: str):
self.error = error self.error = error
@dataclass
class _ParallelBranchResult:
final_node_id: str
class _GraphEngineState(BaseModel):
init_params: GraphInitParams
graph_runtime_state: GraphRuntimeState

@ -17,6 +17,15 @@ logger = logging.getLogger(__name__)
class BaseNode: class BaseNode:
"""BaseNode serves as the foundational class for all node implementations.
Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output`
attribute to track files generated by the LLM). However, these states are not persisted
when the workflow is suspended or resumed. If a node needs its state to be preserved
across workflow suspension and resumption, it should include the relevant state data
in its output.
"""
_node_type: ClassVar[NodeType] _node_type: ClassVar[NodeType]
def __init__( def __init__(
@ -32,9 +41,6 @@ class BaseNode:
self.id = id self.id = id
self.tenant_id = graph_init_params.tenant_id self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type
self.workflow_id = graph_init_params.workflow_id
self.graph_config = graph_init_params.graph_config
self.user_id = graph_init_params.user_id self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from self.invoke_from = graph_init_params.invoke_from
@ -43,6 +49,7 @@ class BaseNode:
self.graph_runtime_state = graph_runtime_state self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id self.previous_node_id = previous_node_id
self.thread_pool_id = thread_pool_id self.thread_pool_id = thread_pool_id
self._init_params = graph_init_params
node_id = config.get("id") node_id = config.get("id")
if not node_id: if not node_id:

@ -1,6 +1,5 @@
import contextvars import contextvars
import logging import logging
import time
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait from concurrent.futures import Future, wait
@ -137,15 +136,13 @@ class IterationNode(BaseNode):
inputs = {"iterator_selector": iterator_list_value} inputs = {"iterator_selector": iterator_list_value}
graph_config = self.graph_config
if not self._node_data.start_node_id: if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found") raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
root_node_id = self._node_data.start_node_id root_node_id = self._node_data.start_node_id
# init graph # init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) iteration_graph = Graph.init(graph_config=self._init_params.graph_config, root_node_id=root_node_id)
if not iteration_graph: if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found") raise IterationGraphNotFoundError("iteration graph not found")
@ -160,22 +157,12 @@ class IterationNode(BaseNode):
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=iteration_graph, graph=iteration_graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, graph_init_params=self._init_params,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,
) )

@ -1,11 +1,9 @@
import json import json
import logging import logging
import time
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Literal, Optional, cast from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from configs import dify_config
from core.variables import ( from core.variables import (
IntegerSegment, IntegerSegment,
Segment, Segment,
@ -91,7 +89,7 @@ class LoopNode(BaseNode):
raise ValueError(f"field start_node_id in loop {self.node_id} not found") raise ValueError(f"field start_node_id in loop {self.node_id} not found")
# Initialize graph # Initialize graph
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id) loop_graph = Graph.init(graph_config=self._init_params.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph: if not loop_graph:
raise ValueError("loop graph not found") raise ValueError("loop graph not found")
@ -124,22 +122,12 @@ class LoopNode(BaseNode):
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=loop_graph, graph=loop_graph,
graph_config=self.graph_config, graph_init_params=self._init_params,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,
) )

@ -1,5 +1,4 @@
import logging import logging
import time
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
@ -70,8 +69,9 @@ class WorkflowEntry:
raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
# init workflow run state # init workflow run state
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
self.graph_engine = GraphEngine(
graph_init_params = GraphInitParams(
tenant_id=tenant_id, tenant_id=tenant_id,
app_id=app_id, app_id=app_id,
workflow_type=workflow_type, workflow_type=workflow_type,
@ -80,11 +80,14 @@ class WorkflowEntry:
user_from=user_from, user_from=user_from,
invoke_from=invoke_from, invoke_from=invoke_from,
call_depth=call_depth, call_depth=call_depth,
graph=graph,
graph_config=graph_config, graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
)
self.graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
thread_pool_id=thread_pool_id, thread_pool_id=thread_pool_id,
) )
@ -146,7 +149,7 @@ class WorkflowEntry:
graph = Graph.init(graph_config=workflow.graph_dict) graph = Graph.init(graph_config=workflow.graph_dict)
# init workflow run state # init workflow run state
node = node_cls( node_instance = node_cls(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
config=node_config, config=node_config,
graph_init_params=GraphInitParams( graph_init_params=GraphInitParams(
@ -161,7 +164,7 @@ class WorkflowEntry:
call_depth=0, call_depth=0,
), ),
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
) )
node.init_node_data(node_config_data) node.init_node_data(node_config_data)
@ -191,11 +194,17 @@ class WorkflowEntry:
try: try:
# run node # run node
generator = node.run() generator = node_instance.run()
except Exception as e: except Exception as e:
logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}") logger.exception(
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
return node, generator workflow.id,
node_instance.id,
node_instance.type_,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node=node_instance, err_msg=str(e))
return node_instance, generator
@classmethod @classmethod
def run_free_node( def run_free_node(
@ -257,7 +266,7 @@ class WorkflowEntry:
node_cls = cast(type[BaseNode], node_cls) node_cls = cast(type[BaseNode], node_cls)
# init workflow run state # init workflow run state
node: BaseNode = node_cls( node_instance: BaseNode = node_cls(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
config=node_config, config=node_config,
graph_init_params=GraphInitParams( graph_init_params=GraphInitParams(
@ -272,7 +281,7 @@ class WorkflowEntry:
call_depth=0, call_depth=0,
), ),
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
) )
node.init_node_data(node_data) node.init_node_data(node_data)
@ -293,12 +302,17 @@ class WorkflowEntry:
) )
# run node # run node
generator = node.run() generator = node_instance.run()
return node, generator return node_instance, generator
except Exception as e: except Exception as e:
logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}") logger.exception(
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) "error while running node_instance, node_id=%s, type=%s, version=%s",
node_instance.id,
node_instance.type_,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node=node_instance, err_msg=str(e))
@staticmethod @staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:

@ -181,6 +181,26 @@ def timezone(timezone_string):
def generate_string(n): def generate_string(n):
"""
Generates a cryptographically secure random string of the specified length.
This function uses a cryptographically secure pseudorandom number generator (CSPRNG)
to create a string composed of ASCII letters (both uppercase and lowercase) and digits.
Each character in the generated string provides approximately 5.95 bits of entropy
(log2(62)). To ensure a minimum of 128 bits of entropy for security purposes, the
length of the string (`n`) should be at least 22 characters.
Args:
n (int): The length of the random string to generate. For secure usage,
`n` should be 22 or greater.
Returns:
str: A random string of length `n` composed of ASCII letters and digits.
Note:
This function is suitable for generating credentials or other secure tokens.
"""
letters_digits = string.ascii_letters + string.digits letters_digits = string.ascii_letters + string.digits
result = "" result = ""
for i in range(n): for i in range(n):

@ -1,4 +1,4 @@
"""update models """adjust length for mcp tool name and server identifiers
Revision ID: 1a83934ad6d1 Revision ID: 1a83934ad6d1
Revises: 71f5020c6470 Revises: 71f5020c6470

@ -0,0 +1,51 @@
"""Add WorkflowSuspension model, add suspension_id to WorkflowRun
Revision ID: 1091956b9ee0
Revises: 1c9ba48be8e4
Create Date: 2025-07-17 20:20:43.710683
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '1091956b9ee0'
down_revision = '1a83934ad6d1'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('workflow_suspensions',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('workflow_id', models.types.StringUUID(), nullable=False),
sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False),
sa.Column('resumed_at', sa.DateTime(), nullable=True),
sa.Column('next_node_id', sa.String(length=255), nullable=False),
sa.Column('state_version', sa.String(length=20), nullable=False),
sa.Column('state', sa.Text(), nullable=False),
sa.Column('inputs', sa.Text(), nullable=True),
sa.Column('form_code', sa.String(length=32), nullable=False),
sa.PrimaryKeyConstraint('id', name=op.f('workflow_suspensions_pkey')),
sa.UniqueConstraint('form_code', name=op.f('workflow_suspensions_form_code_key'))
)
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.add_column(sa.Column('suspension_id', models.types.StringUUID(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.drop_column('suspension_id')
op.drop_table('workflow_suspensions')
# ### end Alembic commands ###

@ -14,10 +14,12 @@ from core.file.models import File
from core.variables import utils as variable_utils from core.variables import utils as variable_utils
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from core.workflow.entities.workflow_suspension import StateVersion
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from factories.variable_factory import TypeMismatchError, build_segment_with_type from factories.variable_factory import TypeMismatchError, build_segment_with_type
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id from libs.helper import extract_tenant_id, generate_string
from ._workflow_exc import NodeNotFoundError, WorkflowDataError from ._workflow_exc import NodeNotFoundError, WorkflowDataError
@ -508,7 +510,10 @@ class WorkflowRun(Base):
version: Mapped[str] = mapped_column(db.String(255)) version: Mapped[str] = mapped_column(db.String(255))
graph: Mapped[Optional[str]] = mapped_column(db.Text) graph: Mapped[Optional[str]] = mapped_column(db.Text)
inputs: Mapped[Optional[str]] = mapped_column(db.Text) inputs: Mapped[Optional[str]] = mapped_column(db.Text)
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded status: Mapped[str] = mapped_column(
EnumText(WorkflowExecutionStatus, length=255),
nullable=False,
) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(db.Text) error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0")) elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0"))
@ -520,6 +525,10 @@ class WorkflowRun(Base):
finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)
# Represents the suspension details of a suspended workflow.
# This field is non-null when `status == SUSPENDED` and null otherwise.
suspension_id: Mapped[StringUUID] = mapped_column(StringUUID, nullable=True)
@property @property
def created_by_account(self): def created_by_account(self):
created_by_role = CreatorUserRole(self.created_by_role) created_by_role = CreatorUserRole(self.created_by_role)
@ -907,10 +916,6 @@ class ConversationVariable(Base):
_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"]) _EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
def _naive_utc_datetime():
return naive_utc_now()
class WorkflowDraftVariable(Base): class WorkflowDraftVariable(Base):
"""`WorkflowDraftVariable` record variables and outputs generated during """`WorkflowDraftVariable` record variables and outputs generated during
debugging worfklow or chatflow. debugging worfklow or chatflow.
@ -941,14 +946,14 @@ class WorkflowDraftVariable(Base):
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
db.DateTime, db.DateTime,
nullable=False, nullable=False,
default=_naive_utc_datetime, default=naive_utc_now,
server_default=func.current_timestamp(), server_default=func.current_timestamp(),
) )
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
db.DateTime, db.DateTime,
nullable=False, nullable=False,
default=_naive_utc_datetime, default=naive_utc_now,
server_default=func.current_timestamp(), server_default=func.current_timestamp(),
onupdate=func.current_timestamp(), onupdate=func.current_timestamp(),
) )
@ -1173,8 +1178,8 @@ class WorkflowDraftVariable(Base):
description: str = "", description: str = "",
) -> "WorkflowDraftVariable": ) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable() variable = WorkflowDraftVariable()
variable.created_at = _naive_utc_datetime() variable.created_at = naive_utc_now()
variable.updated_at = _naive_utc_datetime() variable.updated_at = naive_utc_now()
variable.description = description variable.description = description
variable.app_id = app_id variable.app_id = app_id
variable.node_id = node_id variable.node_id = node_id
@ -1254,3 +1259,120 @@ class WorkflowDraftVariable(Base):
def is_system_variable_editable(name: str) -> bool: def is_system_variable_editable(name: str) -> bool:
return name in _EDITABLE_SYSTEM_VARIABLE return name in _EDITABLE_SYSTEM_VARIABLE
_SUSPENSION_FORM_CODE_LENGTH = 22
def _generate_suspension_form_code():
return generate_string(_SUSPENSION_FORM_CODE_LENGTH)
class WorkflowSuspension(Base):
__tablename__ = "workflow_suspensions"
# id is the unique identifier of a suspension
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
# NOTE: The server default acts as a fallback mechanism.
# The application generates the ID for new `WorkflowSuspension` records
# to streamline the insertion process and minimize database roundtrips.
server_default=db.text("uuidv7()"),
)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
default=naive_utc_now,
server_default=func.current_timestamp(),
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
default=naive_utc_now,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
)
# `tenant_id` identifies the tenant associated with this suspension,
# corresponding to the `id` field in the `Tenant` model.
tenant_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
)
# `app_id` represents the application identifier associated with this state.
# It corresponds to the `id` field in the `App` model.
#
# While this field is technically redundant (as the corresponding app can be
# determined by querying the `Workflow`), it is retained to simplify data
# cleanup and management processes.
app_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
)
# `workflow_id` represents the unique identifier of the workflow associated with this suspension.
# It corresponds to the `id` field in the `Workflow` model.
#
# Since an application can have multiple versions of a workflow, each with its own unique ID,
# the `app_id` alone is insufficient to determine which workflow version should be loaded
# when resuming a suspended workflow.
workflow_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
)
# `workflow_run_id` represents the identifier of the execution of workflow,
# correspond to the `id` field of `WorkflowNodeExecutionModel`.
workflow_run_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
)
# `resumed_at` records the timestamp when the suspended workflow was resumed.
# It is set to `NULL` if the workflow has not been resumed.
resumed_at: Mapped[Optional[datetime]] = mapped_column(
sa.DateTime,
nullable=True,
default=sa.null,
)
# `next_node_id` specifies the next node to execute when the workflow resumes.
#
# Although this information is embedded within the `state` field, it is extracted
# into a separate field to facilitate debugging and data analysis.
next_node_id: Mapped[str] = mapped_column(
__name_pos=sa.String(length=255),
nullable=False,
)
# The version of the serialized execution state data. Currently, the only supported value is `v1`.
state_version: Mapped[StateVersion] = mapped_column(
EnumText(StateVersion),
nullable=False,
)
# `state` contains the serialized runtime state of the `GraphEngine`,
# capturing the workflow's execution context at the time of suspension.
#
# The value of `state` is a JSON-formatted string representing a JSON object (e.g., `{}`).
state: Mapped[str] = mapped_column(sa.Text, nullable=False)
# The inputs provided by the user when resuming the suspended workflow.
# These inputs are serialized as a JSON-formatted string (e.g., `{}`).
#
# This field is `NULL` if no inputs were submitted by the user.
inputs: Mapped[str] = mapped_column(sa.Text, nullable=True)
form_code: Mapped[str] = mapped_column(
# A 32-character string can store a base64-encoded value with 192 bits of entropy
# or a base62-encoded value with over 180 bits of entropy, providing sufficient
# uniqueness for most use cases.
sa.String(32),
nullable=False,
unique=True,
default=_generate_suspension_form_code,
)

@ -1,4 +1,3 @@
import time
import uuid import uuid
from os import getenv from os import getenv
from typing import cast from typing import cast
@ -62,7 +61,9 @@ def init_code_node(code_config: dict):
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=code_config, config=code_config,
) )

@ -1,4 +1,3 @@
import time
import uuid import uuid
from urllib.parse import urlencode from urllib.parse import urlencode
@ -56,7 +55,9 @@ def init_http_node(config: dict):
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config, config=config,
) )

@ -1,5 +1,4 @@
import json import json
import time
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -73,7 +72,9 @@ def init_llm_node(config: dict) -> LLMNode:
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config, config=config,
) )

@ -1,5 +1,4 @@
import os import os
import time
import uuid import uuid
from typing import Optional from typing import Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -78,7 +77,9 @@ def init_parameter_extractor_node(config: dict):
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config, config=config,
) )
node.init_node_data(config.get("data", {})) node.init_node_data(config.get("data", {}))

@ -1,4 +1,3 @@
import time
import uuid import uuid
import pytest import pytest
@ -73,7 +72,9 @@ def test_execute_code(setup_code_executor_mock):
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config, config=config,
) )
node.init_node_data(config.get("data", {})) node.init_node_data(config.get("data", {}))

@ -1,4 +1,3 @@
import time
import uuid import uuid
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -54,7 +53,9 @@ def init_tool_node(config: dict):
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config, config=config,
) )
node.init_node_data(config.get("data", {})) node.init_node_data(config.get("data", {}))

@ -1,4 +1,3 @@
import time
from decimal import Decimal from decimal import Decimal
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
@ -49,7 +48,6 @@ def create_test_graph_runtime_state() -> GraphRuntimeState:
return GraphRuntimeState( return GraphRuntimeState(
variable_pool=variable_pool, variable_pool=variable_pool,
start_at=time.perf_counter(),
total_tokens=100, total_tokens=100,
llm_usage=llm_usage, llm_usage=llm_usage,
outputs={ outputs={
@ -106,7 +104,6 @@ def test_empty_outputs_round_trip():
variable_pool = VariablePool.empty() variable_pool = VariablePool.empty()
original_state = GraphRuntimeState( original_state = GraphRuntimeState(
variable_pool=variable_pool, variable_pool=variable_pool,
start_at=time.perf_counter(),
outputs={}, # Empty outputs outputs={}, # Empty outputs
) )

@ -1,24 +1,27 @@
import time
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from flask import Flask from flask import Flask
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.command_source import CommandParams, CommandTypes, ContinueCommand, SuspendCommand
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
BaseNodeEvent, BaseNodeEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
GraphRunSuspendedEvent,
NodeRunFailedEvent, NodeRunFailedEvent,
NodeRunStartedEvent, NodeRunStartedEvent,
NodeRunStreamChunkEvent, NodeRunStreamChunkEvent,
NodeRunSucceededEvent, NodeRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
@ -175,8 +178,10 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
user_inputs={"query": "hi"}, user_inputs={"query": "hi"},
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_runtime_state = GraphRuntimeState(
graph_engine = GraphEngine( variable_pool=variable_pool,
)
init_params = GraphInitParams(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
workflow_type=WorkflowType.WORKFLOW, workflow_type=WorkflowType.WORKFLOW,
@ -186,11 +191,14 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=init_params,
)
def llm_generator(self): def llm_generator(self):
contents = ["hi", "bye", "good morning"] contents = ["hi", "bye", "good morning"]
@ -303,8 +311,10 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
user_inputs={}, user_inputs={},
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_runtime_state = GraphRuntimeState(
graph_engine = GraphEngine( variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
workflow_type=WorkflowType.CHAT, workflow_type=WorkflowType.CHAT,
@ -314,11 +324,14 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
# print("") # print("")
@ -484,8 +497,10 @@ def test_run_branch(mock_close, mock_remove):
user_inputs={"uid": "takato"}, user_inputs={"uid": "takato"},
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_runtime_state = GraphRuntimeState(
graph_engine = GraphEngine( variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
workflow_type=WorkflowType.CHAT, workflow_type=WorkflowType.CHAT,
@ -495,11 +510,14 @@ def test_run_branch(mock_close, mock_remove):
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
# print("") # print("")
@ -823,8 +841,10 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
user_inputs={"query": "hi"}, user_inputs={"query": "hi"},
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_runtime_state = GraphRuntimeState(
graph_engine = GraphEngine( variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
workflow_type=WorkflowType.CHAT, workflow_type=WorkflowType.CHAT,
@ -834,11 +854,14 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
def qc_generator(self): def qc_generator(self):
yield RunCompletedEvent( yield RunCompletedEvent(
@ -884,3 +907,203 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
assert item.outputs is not None assert item.outputs is not None
answer = item.outputs["answer"] answer = item.outputs["answer"]
assert all(rc not in answer for rc in wrong_content) assert all(rc not in answer for rc in wrong_content)
def test_suspend_and_resume():
graph_config = {
"edges": [
{
"data": {"isInLoop": False, "sourceType": "start", "targetType": "if-else"},
"id": "1753041723554-source-1753041730748-target",
"source": "1753041723554",
"sourceHandle": "source",
"target": "1753041730748",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
{
"data": {"isInLoop": False, "sourceType": "if-else", "targetType": "answer"},
"id": "1753041730748-true-answer-target",
"source": "1753041730748",
"sourceHandle": "true",
"target": "answer",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
{
"data": {
"isInIteration": False,
"isInLoop": False,
"sourceType": "if-else",
"targetType": "answer",
},
"id": "1753041730748-false-1753041952799-target",
"source": "1753041730748",
"sourceHandle": "false",
"target": "1753041952799",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
],
"nodes": [
{
"data": {"desc": "", "selected": False, "title": "Start", "type": "start", "variables": []},
"height": 54,
"id": "1753041723554",
"position": {"x": 32, "y": 282},
"positionAbsolute": {"x": 32, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"cases": [
{
"case_id": "true",
"conditions": [
{
"comparison_operator": "contains",
"id": "5db4103a-7e62-4e71-a0a6-c45ac11c0b3d",
"value": "a",
"varType": "string",
"variable_selector": ["sys", "query"],
}
],
"id": "true",
"logical_operator": "and",
}
],
"desc": "",
"selected": False,
"title": "IF/ELSE",
"type": "if-else",
},
"height": 126,
"id": "1753041730748",
"position": {"x": 368, "y": 282},
"positionAbsolute": {"x": 368, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"answer": "A",
"desc": "",
"selected": False,
"title": "Answer A",
"type": "answer",
"variables": [],
},
"height": 102,
"id": "answer",
"position": {"x": 746, "y": 282},
"positionAbsolute": {"x": 746, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"answer": "Else",
"desc": "",
"selected": False,
"title": "Answer Else",
"type": "answer",
"variables": [],
},
"height": 102,
"id": "1753041952799",
"position": {"x": 746, "y": 426},
"positionAbsolute": {"x": 746, "y": 426},
"selected": True,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
],
"viewport": {"x": -420, "y": -76.5, "zoom": 1},
}
graph = Graph.init(graph_config)
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="hello",
conversation_id="abababa",
),
user_inputs={"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
max_execution_steps=500,
max_execution_time=1200,
)
_IF_ELSE_NODE_ID = "1753041730748"
def command_source(params: CommandParams) -> CommandTypes:
# requires the engine to suspend before the execution
# of If-Else node.
if params.next_node.node_id == _IF_ELSE_NODE_ID:
return SuspendCommand()
else:
return ContinueCommand()
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
command_source=command_source,
)
events = list(graph_engine.run())
last_event = events[-1]
assert isinstance(last_event, GraphRunSuspendedEvent)
assert last_event.next_node_id == _IF_ELSE_NODE_ID
state = graph_engine.save()
assert state != ""
engine2 = GraphEngine.resume(
state=state,
graph=graph,
)
events = list(engine2.run())
assert isinstance(events[-1], GraphRunSucceededEvent)
node_run_succeeded_events = [i for i in events if isinstance(i, NodeRunSucceededEvent)]
assert node_run_succeeded_events
start_events = [i for i in node_run_succeeded_events if i.node_id == "1753041723554"]
assert not start_events
ifelse_succeeded_events = [i for i in node_run_succeeded_events if i.node_id == _IF_ELSE_NODE_ID]
assert ifelse_succeeded_events
answer_else_events = [i for i in node_run_succeeded_events if i.node_id == "1753041952799"]
assert answer_else_events
assert answer_else_events[0].route_node_state.node_run_result.outputs == {
"answer": "Else",
"files": ArrayFileSegment(value=[]),
}
answer_a_events = [i for i in node_run_succeeded_events if i.node_id == "answer"]
assert not answer_a_events

@ -1,4 +1,3 @@
import time
import uuid import uuid
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -71,7 +70,9 @@ def test_execute_answer():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config, config=node_config,
) )

@ -1,4 +1,3 @@
import time
import uuid import uuid
from unittest.mock import patch from unittest.mock import patch
@ -179,7 +178,9 @@ def test_run():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config, config=node_config,
) )
@ -401,7 +402,9 @@ def test_run_parallel():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config, config=node_config,
) )
@ -623,7 +626,9 @@ def test_iteration_run_in_parallel_mode():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=parallel_node_config, config=parallel_node_config,
) )
@ -647,7 +652,9 @@ def test_iteration_run_in_parallel_mode():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=sequential_node_config, config=sequential_node_config,
) )
@ -857,7 +864,9 @@ def test_iteration_run_error_handle():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=error_node_config, config=error_node_config,
) )

@ -1,4 +1,3 @@
import time
import uuid import uuid
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -74,7 +73,9 @@ def test_execute_answer():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=node_config, config=node_config,
) )

@ -1,4 +1,3 @@
import time
from unittest.mock import patch from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -12,6 +11,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunStreamChunkEvent, NodeRunStreamChunkEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
@ -175,9 +175,11 @@ class ContinueOnErrorTestHelper:
), ),
user_inputs=user_inputs or {"uid": "takato"}, user_inputs=user_inputs or {"uid": "takato"},
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
return GraphEngine( graph_init_params = GraphInitParams(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
workflow_type=WorkflowType.CHAT, workflow_type=WorkflowType.CHAT,
@ -187,11 +189,14 @@ class ContinueOnErrorTestHelper:
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
return GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
DEFAULT_VALUE_EDGE = [ DEFAULT_VALUE_EDGE = [

@ -1,4 +1,3 @@
import time
import uuid import uuid
from unittest.mock import MagicMock, Mock from unittest.mock import MagicMock, Mock
@ -106,7 +105,9 @@ def test_execute_if_else_result_true():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config, config=node_config,
) )
@ -192,7 +193,9 @@ def test_execute_if_else_result_false():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config, config=node_config,
) )

@ -1,4 +1,3 @@
import time
import uuid import uuid
from unittest import mock from unittest import mock
from uuid import uuid4 from uuid import uuid4
@ -96,7 +95,7 @@ def test_overwrite_string_variable():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config, config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory, conv_var_updater_factory=mock_conv_var_updater_factory,
) )
@ -197,7 +196,7 @@ def test_append_variable_to_array():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config, config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory, conv_var_updater_factory=mock_conv_var_updater_factory,
) )
@ -289,7 +288,7 @@ def test_clear_array():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config, config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory, conv_var_updater_factory=mock_conv_var_updater_factory,
) )

@ -1,4 +1,3 @@
import time
import uuid import uuid
from uuid import uuid4 from uuid import uuid4
@ -135,7 +134,7 @@ def test_remove_first_from_array():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config, config=node_config,
) )
@ -227,7 +226,7 @@ def test_remove_last_from_array():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config, config=node_config,
) )
@ -311,7 +310,7 @@ def test_remove_first_from_empty_array():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config, config=node_config,
) )
@ -395,7 +394,7 @@ def test_remove_last_from_empty_array():
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
graph_init_params=init_params, graph_init_params=init_params,
graph=graph, graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config, config=node_config,
) )

Loading…
Cancel
Save