pull/22621/merge
QuantumGhost 7 months ago committed by GitHub
commit 9e0fc64e41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -29,6 +29,7 @@ class QueueEvent(StrEnum):
WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed"
WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded"
WORKFLOW_SUSPENDED = "workflow_suspended"
ITERATION_START = "iteration_start"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
@ -326,6 +327,13 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
graph_runtime_state: GraphRuntimeState
class QueueWorkflowSuspendedEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.WORKFLOW_SUSPENDED
# next_node_id records the next node to execute after resuming
# workflow.
next_node_id: str
class QueueWorkflowSucceededEvent(AppQueueEvent):
"""
QueueWorkflowSucceededEvent entity

@ -23,12 +23,109 @@ class WorkflowType(StrEnum):
class WorkflowExecutionStatus(StrEnum):
# State diagram for the workflw status:
# (@) means start, (*) means end
#
# ┌------------------>------------------------->------------------->--------------┐
# | |
# | ┌-----------------------<--------------------┐ |
# ^ | | |
# | | ^ |
# | V | |
# ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V
# | Scheduled |------->| Running |---------------------->| Suspended | |
# └-----------┘ └-----------------------┘ └-----------┘ |
# | | | | | | |
# | | | | | | |
# ^ | | | V V |
# | | | | | ┌---------┐ |
# (@) | | | └------------------------>| Stopped |<----┘
# | | | └---------┘
# | | | |
# | | V V
# | | ┌-----------┐ |
# | | | Succeeded |------------->--------------┤
# | | └-----------┘ |
# | V V
# | +--------┐ |
# | | Failed |---------------------->----------------┤
# | └--------┘ |
# V V
# ┌---------------------┐ |
# | Partially Succeeded |---------------------->-----------------┘--------> (*)
# └---------------------┘
#
# Mermaid diagram:
#
# ---
# title: State diagram for Workflow run state
# ---
# stateDiagram-v2
# scheduled: Scheduled
# running: Running
# succeeded: Succeeded
# failed: Failed
# partial_succeeded: Partial Succeeded
# suspended: Suspended
# stopped: Stopped
#
# [*] --> scheduled:
# scheduled --> running: Start Execution
# running --> suspended: Human input required
# suspended --> running: human input added
# suspended --> stopped: User stops execution
# running --> succeeded: Execution finishes without any error
# running --> failed: Execution finishes with errors
# running --> stopped: User stops execution
# running --> partial_succeeded: some execution occurred and handled during execution
#
# scheduled --> stopped: User stops execution
#
# succeeded --> [*]
# failed --> [*]
# partial_succeeded --> [*]
# stopped --> [*]
# `SCHEDULED` means that the workflow is scheduled to run, but has not
# started running yet. (maybe due to possible worker saturation.)
SCHEDULED = "scheduled"
# `RUNNING` means the workflow is exeuting.
RUNNING = "running"
# `SUCCEEDED` means the execution of workflow succeed without any error.
SUCCEEDED = "succeeded"
# `FAILED` means the execution of workflow failed without some errors.
FAILED = "failed"
# `STOPPED` means the execution of workflow was stopped, either manually
# by the user, or automatically by the Dify application (E.G. the moderation
# mechanism.)
STOPPED = "stopped"
# `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow
# execution, but they were successfully handled (e.g., by using an error
# strategy such as "fail branch" or "default value").
PARTIAL_SUCCEEDED = "partial-succeeded"
# `SUSPENDED` indicates that the workflow execution is temporarily paused
# (e.g., awaiting human input) and is expected to resume later.
SUSPENDED = "suspended"
def is_ended(self) -> bool:
return self in _END_STATE
_END_STATE = frozenset(
[
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
WorkflowExecutionStatus.STOPPED,
]
)
class WorkflowExecution(BaseModel):
"""

@ -0,0 +1,28 @@
from enum import StrEnum
from uuid import UUID
from pydantic import BaseModel, Field
from libs.uuid_utils import uuidv7
class StateVersion(StrEnum):
# `V1` is `GraphRuntimeState` serialized as JSON by dumping with Pydantic.
V1 = "v1"
class WorkflowSuspension(BaseModel):
id: UUID = Field(default_factory=uuidv7)
# Correspond to WorkflowExecution.id_
execution_id: str
workflow_id: str
next_node_id: str
state: str
state_version: StateVersion = StateVersion.V1
inputs: str

@ -0,0 +1,15 @@
import time
def get_current_timestamp() -> float:
"""Retrieve a timestamp as a float point numer representing the number of seconds
since the Unix epoch.
This function is primarily used to measure the execution time of the workflow engine.
Since workflow execution may be paused and resumed on a different machine,
`time.perf_counter` cannot be used as it is inconsistent across machines.
To address this, the function uses the wall clock as the time source.
However, it assumes that the clocks of all servers are properly synchronized.
"""
return round(time.time())

@ -0,0 +1,71 @@
import abc
from collections.abc import Callable
from dataclasses import dataclass
from enum import StrEnum
from typing import Annotated, TypeAlias, final
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.workflow.nodes.base import BaseNode
@dataclass(frozen=True)
class CommandParams:
# `next_node_instance` is the instance of the next node to run.
next_node: BaseNode
class _CommandTag(StrEnum):
SUSPEND = "suspend"
STOP = "stop"
CONTINUE = "continue"
# Note: Avoid using the `_Command` class directly.
# Instead, use `CommandTypes` for type annotations.
class _Command(BaseModel, abc.ABC):
model_config = ConfigDict(frozen=True)
tag: _CommandTag
@field_validator("tag")
@classmethod
def validate_value_type(cls, value):
if value != cls.model_fields["tag"].default:
raise ValueError("Cannot modify 'tag'")
return value
@final
class StopCommand(_Command):
tag: _CommandTag = _CommandTag.STOP
@final
class SuspendCommand(_Command):
tag: _CommandTag = _CommandTag.SUSPEND
@final
class ContinueCommand(_Command):
tag: _CommandTag = _CommandTag.CONTINUE
def _get_command_tag(command: _Command):
return command.tag
CommandTypes: TypeAlias = Annotated[
(
Annotated[StopCommand, Tag(_CommandTag.STOP)]
| Annotated[SuspendCommand, Tag(_CommandTag.SUSPEND)]
| Annotated[ContinueCommand, Tag(_CommandTag.CONTINUE)]
),
Discriminator(_get_command_tag),
]
# `CommandSource` is a callable that takes a single argument of type `CommandParams` and
# returns a `Command` object to the engine, indicating whether the graph engine should suspend, continue, or stop.
#
# It must not modify the data inside `CommandParams`, including any attributes within its fields.
CommandSource: TypeAlias = Callable[[CommandParams], CommandTypes]

@ -43,6 +43,10 @@ class GraphRunPartialSucceededEvent(BaseGraphEvent):
outputs: Optional[dict[str, Any]] = None
class GraphRunSuspendedEvent(BaseGraphEvent):
next_node_id: str = Field(..., description="the next node id to execute while resumed.")
###########################################
# Node Events
###########################################

@ -1,14 +1,25 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PositiveInt
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from models.enums import UserFrom
from models.workflow import WorkflowType
class GraphInitParams(BaseModel):
"""GraphInitParams encapsulates the configurations and contextual information
that remain constant throughout a single execution of the graph engine.
A single execution is defined as follows: as long as the execution has not reached
its conclusion, it is considered one execution. For instance, if a workflow is suspended
and later resumed, it is still regarded as a single execution, not two.
For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`.
"""
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
@ -19,3 +30,10 @@ class GraphInitParams(BaseModel):
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
call_depth: int = Field(..., description="call depth")
# max_execution_steps records the maximum steps allowed during the execution of a workflow.
max_execution_steps: PositiveInt = Field(
default=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, description="max_execution_steps"
)
# max_execution_time records the max execution time for the workflow, measured in seconds
max_execution_time: PositiveInt = Field(default=dify_config.WORKFLOW_MAX_EXECUTION_TIME, description="")

@ -1,18 +1,32 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine._engine_utils import get_timestamp
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
_SECOND_TO_US = 1_000_000
class GraphRuntimeState(BaseModel):
"""`GraphRuntimeState` encapsulates the runtime state of workflow execution,
including scheduling details, variable values, and timing information.
Values that are initialized prior to workflow execution and remain constant
throughout the execution should be part of `GraphInitParams` instead.
"""
variable_pool: VariablePool = Field(..., description="variable pool")
"""variable pool"""
start_at: float = Field(..., description="start time")
"""start time"""
# The `start_at` field records the execution start time of the workflow.
#
# This field is automatically generated, and its value or interpretation may evolve.
# Avoid manually setting this field to ensure compatibility with future updates.
start_at: float = Field(description="start time", default_factory=get_timestamp)
total_tokens: int = 0
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
@ -29,3 +43,28 @@ class GraphRuntimeState(BaseModel):
node_run_state: RuntimeRouteState = RuntimeRouteState()
"""node run state"""
# `execution_time_us` tracks the total execution time of the workflow in microseconds.
# Time spent in suspension is excluded from this calculation.
#
# This field is used to persist the time already spent while suspending a workflow.
execution_time_us: int = 0
# `_last_execution_started_at` records the timestamp of the most recent resume start.
# It is updated when the workflow resumes from a suspended state.
_last_execution_started_at: float = PrivateAttr(default_factory=get_timestamp)
def is_timed_out(self, max_execution_time_seconds: int) -> bool:
"""Checks if the workflow execution has exceeded the specified `max_execution_time_seconds`."""
remaining_time_us = max_execution_time_seconds * _SECOND_TO_US - self.execution_time_us
if remaining_time_us <= 0:
return False
return int(get_timestamp() - self._last_execution_started_at) * _SECOND_TO_US > remaining_time_us
def record_suspend_state(self, next_node_id: str):
"""Record the time already spent in executing workflow.
This function should be called when suspending the workflow.
"""
self.execution_time_us = int(get_timestamp() - self._last_execution_started_at) * _SECOND_TO_US
self.node_run_state.next_node_id = next_node_id

@ -1,9 +1,9 @@
import uuid
from datetime import UTC, datetime
from enum import Enum
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -44,6 +44,8 @@ class RouteNodeState(BaseModel):
paused_by: Optional[str] = None
"""paused by"""
# The `index` is used used to record the execution order for a given node.
# Nodes executed ealier get smaller `index` values.
index: int = 1
def set_finished(self, run_result: NodeRunResult) -> None:
@ -79,10 +81,27 @@ class RuntimeRouteState(BaseModel):
default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)"
)
# A mapping from node_id to its routing state.
node_state_mapping: dict[str, RouteNodeState] = Field(
default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)"
)
next_node_id: Optional[str] = Field(
default=None, description="The next node id to run when resumed from suspension."
)
# If `previous_node_id` is not `None`, then the correspond node has state in the dict
# `node_state_mapping`.
previous_node_state_id: Optional[str] = Field(default=None, description="The state of last executed node.")
# `_state_by_id` serves as a mapping from the unique identifier (`id`) of each `RouteNodeState`
# instance to the corresponding `RouteNodeState` object itself.
_state_by_id: dict[str, RouteNodeState] = PrivateAttr(default={})
def model_post_init(self, context: Any) -> None:
super().model_post_init(context)
self._state_by_id = {v.id: v for v in self.node_state_mapping.values()}
def create_node_state(self, node_id: str) -> RouteNodeState:
"""
Create node state
@ -91,6 +110,7 @@ class RuntimeRouteState(BaseModel):
"""
state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None))
self.node_state_mapping[state.id] = state
self._state_by_id[state.id] = state
return state
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
@ -115,3 +135,18 @@ class RuntimeRouteState(BaseModel):
return [
self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, [])
]
# def get_node_state(self, node_id: str) -> RouteNodeState | None:
# return self.node_state_mapping.get(node_id)
def get_previous_route_node_state(self) -> RouteNodeState | None:
if self.previous_node_state_id is None:
return None
return self._state_by_id[self.previous_node_state_id]
@property
def previous_node_id(self):
if self.previous_node_state_id is None:
return None
state = self._state_by_id[self.previous_node_state_id]
return state.node_id

@ -3,17 +3,18 @@ import logging
import queue
import time
import uuid
from collections.abc import Generator, Mapping
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any, Optional, cast
from flask import Flask, current_app
from pydantic import BaseModel
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -27,6 +28,7 @@ from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
GraphRunSuspendedEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
@ -53,9 +55,17 @@ from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom
from models.workflow import WorkflowType
from .command_source import (
CommandParams,
CommandSource,
CommandTypes,
ContinueCommand,
StopCommand,
SuspendCommand,
)
logger = logging.getLogger(__name__)
@ -86,26 +96,25 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
def _default_source(_: CommandParams) -> CommandTypes:
return ContinueCommand()
class GraphEngine:
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
def __init__(
self,
tenant_id: str,
app_id: str,
workflow_type: WorkflowType,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph,
graph_config: Mapping[str, Any],
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
max_execution_steps: int,
max_execution_time: int,
thread_pool_id: Optional[str] = None,
command_source: CommandSource = _default_source,
) -> None:
"""Create a graph from the given state.
The
"""
thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT
thread_pool_max_workers = 10
@ -126,22 +135,11 @@ class GraphEngine:
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
self.graph = graph
self.init_params = GraphInitParams(
tenant_id=tenant_id,
app_id=app_id,
workflow_type=workflow_type,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
)
self.init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self.max_execution_steps = max_execution_steps
self.max_execution_time = max_execution_time
self._command_source = command_source
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
@ -160,12 +158,18 @@ class GraphEngine:
)
# run graph
next_node_to_run = self.graph.root_node_id
if (next_node_id := self.graph_runtime_state.node_run_state.next_node_id) is not None:
next_node_to_run = next_node_id
generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions)
self._run(start_node_id=next_node_to_run, handle_exceptions=handle_exceptions)
)
for item in generator:
try:
yield item
if isinstance(item, GraphRunSuspendedEvent):
return
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(
error=item.route_node_state.failed_reason or "Unknown error.",
@ -229,22 +233,22 @@ class GraphEngine:
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent, None, None]:
# Hint: the `_run` method is used both when running a the main graph,
# and also running parallel branches.
parallel_start_node_id = None
if in_parallel_id:
parallel_start_node_id = start_node_id
next_node_id = start_node_id
previous_route_node_state: Optional[RouteNodeState] = None
while True:
# max steps reached
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps))
if self.graph_runtime_state.node_run_steps > self.init_params.max_execution_steps:
raise GraphRunFailedError("Max steps {} reached.".format(self.init_params.max_execution_steps))
# or max execution time reached
if self._is_timed_out(
start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time
):
raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time))
if self.graph_runtime_state.is_timed_out(self.init_params.max_execution_time):
raise GraphRunFailedError("Max execution time {}s reached.".format(self.init_params.max_execution_time))
# init route node state
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id)
@ -277,6 +281,23 @@ class GraphEngine:
thread_pool_id=self.thread_pool_id,
)
node.init_node_data(node_config.get("data", {}))
# Determine if the execution should be suspended or stopped at this point.
# If so, yield the corresponding event.
#
# Note: Suspension is not allowed while the graph engine is running in parallel mode.
if in_parallel_id is None:
command = self._command_source(CommandParams(next_node=node))
if isinstance(command, SuspendCommand):
self.graph_runtime_state.record_suspend_state(next_node_id)
yield GraphRunSuspendedEvent(next_node_id=next_node_id)
return
elif isinstance(command, StopCommand):
# TODO: STOP the execution of worklow.
return
elif isinstance(command, ContinueCommand):
pass
else:
raise AssertionError("unreachable statement.")
try:
# run node
generator = self._run_node(
@ -403,8 +424,8 @@ class GraphEngine:
)
for parallel_result in parallel_generator:
if isinstance(parallel_result, str):
final_node_id = parallel_result
if isinstance(parallel_result, _ParallelBranchResult):
final_node_id = parallel_result.final_node_id
else:
yield parallel_result
@ -429,8 +450,8 @@ class GraphEngine:
)
for generated_item in parallel_generator:
if isinstance(generated_item, str):
final_node_id = generated_item
if isinstance(generated_item, _ParallelBranchResult):
final_node_id = generated_item.final_node_id
else:
yield generated_item
@ -448,7 +469,7 @@ class GraphEngine:
in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent | str, None, None]:
) -> Generator["GraphEngineEvent | _ParallelBranchResult", None, None]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id:
@ -526,7 +547,7 @@ class GraphEngine:
# get final node id
final_node_id = parallel.end_to_node_id
if final_node_id:
yield final_node_id
yield _ParallelBranchResult(final_node_id)
def _run_parallel_node(
self,
@ -928,7 +949,41 @@ class GraphEngine:
)
return error_result
def save(self) -> str:
"""save serializes the state inside this graph engine.
This method should be called when suspension of the execution is necessary.
"""
state = _GraphEngineState(init_params=self.init_params, graph_runtime_state=self.graph_runtime_state)
return state.model_dump_json()
@classmethod
def resume(
cls,
state: str,
graph: Graph,
command_source: CommandSource = _default_source,
) -> "GraphEngine":
"""`resume` continues a suspended execution."""
state_ = _GraphEngineState.model_validate_json(state)
return cls(
graph=graph,
graph_init_params=state_.init_params,
graph_runtime_state=state_.graph_runtime_state,
command_source=command_source,
)
class GraphRunFailedError(Exception):
def __init__(self, error: str):
self.error = error
@dataclass
class _ParallelBranchResult:
final_node_id: str
class _GraphEngineState(BaseModel):
init_params: GraphInitParams
graph_runtime_state: GraphRuntimeState

@ -17,6 +17,15 @@ logger = logging.getLogger(__name__)
class BaseNode:
"""BaseNode serves as the foundational class for all node implementations.
Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output`
attribute to track files generated by the LLM). However, these states are not persisted
when the workflow is suspended or resumed. If a node needs its state to be preserved
across workflow suspension and resumption, it should include the relevant state data
in its output.
"""
_node_type: ClassVar[NodeType]
def __init__(
@ -32,9 +41,6 @@ class BaseNode:
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type
self.workflow_id = graph_init_params.workflow_id
self.graph_config = graph_init_params.graph_config
self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from
@ -43,6 +49,7 @@ class BaseNode:
self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id
self.thread_pool_id = thread_pool_id
self._init_params = graph_init_params
node_id = config.get("id")
if not node_id:

@ -1,6 +1,5 @@
import contextvars
import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait
@ -137,15 +136,13 @@ class IterationNode(BaseNode):
inputs = {"iterator_selector": iterator_list_value}
graph_config = self.graph_config
if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
root_node_id = self._node_data.start_node_id
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
iteration_graph = Graph.init(graph_config=self._init_params.graph_config, root_node_id=root_node_id)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
@ -160,22 +157,12 @@ class IterationNode(BaseNode):
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
graph_init_params=self._init_params,
thread_pool_id=self.thread_pool_id,
)

@ -1,11 +1,9 @@
import json
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from configs import dify_config
from core.variables import (
IntegerSegment,
Segment,
@ -91,7 +89,7 @@ class LoopNode(BaseNode):
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
# Initialize graph
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
loop_graph = Graph.init(graph_config=self._init_params.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
@ -124,22 +122,12 @@ class LoopNode(BaseNode):
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=loop_graph,
graph_config=self.graph_config,
graph_init_params=self._init_params,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,
)

@ -1,5 +1,4 @@
import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
@ -70,8 +69,9 @@ class WorkflowEntry:
raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
# init workflow run state
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
self.graph_engine = GraphEngine(
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
graph_init_params = GraphInitParams(
tenant_id=tenant_id,
app_id=app_id,
workflow_type=workflow_type,
@ -80,11 +80,14 @@ class WorkflowEntry:
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
graph=graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
)
self.graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
thread_pool_id=thread_pool_id,
)
@ -146,7 +149,7 @@ class WorkflowEntry:
graph = Graph.init(graph_config=workflow.graph_dict)
# init workflow run state
node = node_cls(
node_instance = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
@ -161,7 +164,7 @@ class WorkflowEntry:
call_depth=0,
),
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
)
node.init_node_data(node_config_data)
@ -191,11 +194,17 @@ class WorkflowEntry:
try:
# run node
generator = node.run()
generator = node_instance.run()
except Exception as e:
logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}")
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
return node, generator
logger.exception(
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
workflow.id,
node_instance.id,
node_instance.type_,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node=node_instance, err_msg=str(e))
return node_instance, generator
@classmethod
def run_free_node(
@ -257,7 +266,7 @@ class WorkflowEntry:
node_cls = cast(type[BaseNode], node_cls)
# init workflow run state
node: BaseNode = node_cls(
node_instance: BaseNode = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
@ -272,7 +281,7 @@ class WorkflowEntry:
call_depth=0,
),
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
)
node.init_node_data(node_data)
@ -293,12 +302,17 @@ class WorkflowEntry:
)
# run node
generator = node.run()
generator = node_instance.run()
return node, generator
return node_instance, generator
except Exception as e:
logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}")
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
logger.exception(
"error while running node_instance, node_id=%s, type=%s, version=%s",
node_instance.id,
node_instance.type_,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node=node_instance, err_msg=str(e))
@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:

@ -181,6 +181,26 @@ def timezone(timezone_string):
def generate_string(n):
"""
Generates a cryptographically secure random string of the specified length.
This function uses a cryptographically secure pseudorandom number generator (CSPRNG)
to create a string composed of ASCII letters (both uppercase and lowercase) and digits.
Each character in the generated string provides approximately 5.95 bits of entropy
(log2(62)). To ensure a minimum of 128 bits of entropy for security purposes, the
length of the string (`n`) should be at least 22 characters.
Args:
n (int): The length of the random string to generate. For secure usage,
`n` should be 22 or greater.
Returns:
str: A random string of length `n` composed of ASCII letters and digits.
Note:
This function is suitable for generating credentials or other secure tokens.
"""
letters_digits = string.ascii_letters + string.digits
result = ""
for i in range(n):

@ -1,4 +1,4 @@
"""update models
"""adjust length for mcp tool name and server identifiers
Revision ID: 1a83934ad6d1
Revises: 71f5020c6470

@ -0,0 +1,51 @@
"""Add WorkflowSuspension model, add suspension_id to WorkflowRun
Revision ID: 1091956b9ee0
Revises: 1c9ba48be8e4
Create Date: 2025-07-17 20:20:43.710683
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '1091956b9ee0'
down_revision = '1a83934ad6d1'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('workflow_suspensions',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('workflow_id', models.types.StringUUID(), nullable=False),
sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False),
sa.Column('resumed_at', sa.DateTime(), nullable=True),
sa.Column('next_node_id', sa.String(length=255), nullable=False),
sa.Column('state_version', sa.String(length=20), nullable=False),
sa.Column('state', sa.Text(), nullable=False),
sa.Column('inputs', sa.Text(), nullable=True),
sa.Column('form_code', sa.String(length=32), nullable=False),
sa.PrimaryKeyConstraint('id', name=op.f('workflow_suspensions_pkey')),
sa.UniqueConstraint('form_code', name=op.f('workflow_suspensions_form_code_key'))
)
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.add_column(sa.Column('suspension_id', models.types.StringUUID(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.drop_column('suspension_id')
op.drop_table('workflow_suspensions')
# ### end Alembic commands ###

@ -14,10 +14,12 @@ from core.file.models import File
from core.variables import utils as variable_utils
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from core.workflow.entities.workflow_suspension import StateVersion
from core.workflow.nodes.enums import NodeType
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
from libs.helper import extract_tenant_id, generate_string
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
@ -508,7 +510,10 @@ class WorkflowRun(Base):
version: Mapped[str] = mapped_column(db.String(255))
graph: Mapped[Optional[str]] = mapped_column(db.Text)
inputs: Mapped[Optional[str]] = mapped_column(db.Text)
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
status: Mapped[str] = mapped_column(
EnumText(WorkflowExecutionStatus, length=255),
nullable=False,
) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0"))
@ -520,6 +525,10 @@ class WorkflowRun(Base):
finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)
# Represents the suspension details of a suspended workflow.
# This field is non-null when `status == SUSPENDED` and null otherwise.
suspension_id: Mapped[StringUUID] = mapped_column(StringUUID, nullable=True)
@property
def created_by_account(self):
created_by_role = CreatorUserRole(self.created_by_role)
@ -907,10 +916,6 @@ class ConversationVariable(Base):
_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
def _naive_utc_datetime():
return naive_utc_now()
class WorkflowDraftVariable(Base):
"""`WorkflowDraftVariable` record variables and outputs generated during
debugging worfklow or chatflow.
@ -941,14 +946,14 @@ class WorkflowDraftVariable(Base):
created_at: Mapped[datetime] = mapped_column(
db.DateTime,
nullable=False,
default=_naive_utc_datetime,
default=naive_utc_now,
server_default=func.current_timestamp(),
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime,
nullable=False,
default=_naive_utc_datetime,
default=naive_utc_now,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
)
@ -1173,8 +1178,8 @@ class WorkflowDraftVariable(Base):
description: str = "",
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
variable.created_at = _naive_utc_datetime()
variable.updated_at = _naive_utc_datetime()
variable.created_at = naive_utc_now()
variable.updated_at = naive_utc_now()
variable.description = description
variable.app_id = app_id
variable.node_id = node_id
@ -1254,3 +1259,120 @@ class WorkflowDraftVariable(Base):
def is_system_variable_editable(name: str) -> bool:
return name in _EDITABLE_SYSTEM_VARIABLE
_SUSPENSION_FORM_CODE_LENGTH = 22
def _generate_suspension_form_code():
return generate_string(_SUSPENSION_FORM_CODE_LENGTH)
class WorkflowSuspension(Base):
__tablename__ = "workflow_suspensions"
# id is the unique identifier of a suspension
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
# NOTE: The server default acts as a fallback mechanism.
# The application generates the ID for new `WorkflowSuspension` records
# to streamline the insertion process and minimize database roundtrips.
server_default=db.text("uuidv7()"),
)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
default=naive_utc_now,
server_default=func.current_timestamp(),
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
default=naive_utc_now,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
)
# `tenant_id` identifies the tenant associated with this suspension,
# corresponding to the `id` field in the `Tenant` model.
tenant_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
)
# `app_id` represents the application identifier associated with this state.
# It corresponds to the `id` field in the `App` model.
#
# While this field is technically redundant (as the corresponding app can be
# determined by querying the `Workflow`), it is retained to simplify data
# cleanup and management processes.
app_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
)
# `workflow_id` represents the unique identifier of the workflow associated with this suspension.
# It corresponds to the `id` field in the `Workflow` model.
#
# Since an application can have multiple versions of a workflow, each with its own unique ID,
# the `app_id` alone is insufficient to determine which workflow version should be loaded
# when resuming a suspended workflow.
workflow_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
)
# `workflow_run_id` represents the identifier of the execution of workflow,
# correspond to the `id` field of `WorkflowNodeExecutionModel`.
workflow_run_id: Mapped[str] = mapped_column(
StringUUID,
nullable=False,
)
# `resumed_at` records the timestamp when the suspended workflow was resumed.
# It is set to `NULL` if the workflow has not been resumed.
resumed_at: Mapped[Optional[datetime]] = mapped_column(
sa.DateTime,
nullable=True,
default=sa.null,
)
# `next_node_id` specifies the next node to execute when the workflow resumes.
#
# Although this information is embedded within the `state` field, it is extracted
# into a separate field to facilitate debugging and data analysis.
next_node_id: Mapped[str] = mapped_column(
__name_pos=sa.String(length=255),
nullable=False,
)
# The version of the serialized execution state data. Currently, the only supported value is `v1`.
state_version: Mapped[StateVersion] = mapped_column(
EnumText(StateVersion),
nullable=False,
)
# `state` contains the serialized runtime state of the `GraphEngine`,
# capturing the workflow's execution context at the time of suspension.
#
# The value of `state` is a JSON-formatted string representing a JSON object (e.g., `{}`).
state: Mapped[str] = mapped_column(sa.Text, nullable=False)
# The inputs provided by the user when resuming the suspended workflow.
# These inputs are serialized as a JSON-formatted string (e.g., `{}`).
#
# This field is `NULL` if no inputs were submitted by the user.
inputs: Mapped[str] = mapped_column(sa.Text, nullable=True)
form_code: Mapped[str] = mapped_column(
# A 32-character string can store a base64-encoded value with 192 bits of entropy
# or a base62-encoded value with over 180 bits of entropy, providing sufficient
# uniqueness for most use cases.
sa.String(32),
nullable=False,
unique=True,
default=_generate_suspension_form_code,
)

@ -1,4 +1,3 @@
import time
import uuid
from os import getenv
from typing import cast
@ -62,7 +61,9 @@ def init_code_node(code_config: dict):
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=code_config,
)

@ -1,4 +1,3 @@
import time
import uuid
from urllib.parse import urlencode
@ -56,7 +55,9 @@ def init_http_node(config: dict):
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config,
)

@ -1,5 +1,4 @@
import json
import time
import uuid
from collections.abc import Generator
from unittest.mock import MagicMock, patch
@ -73,7 +72,9 @@ def init_llm_node(config: dict) -> LLMNode:
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config,
)

@ -1,5 +1,4 @@
import os
import time
import uuid
from typing import Optional
from unittest.mock import MagicMock
@ -78,7 +77,9 @@ def init_parameter_extractor_node(config: dict):
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config,
)
node.init_node_data(config.get("data", {}))

@ -1,4 +1,3 @@
import time
import uuid
import pytest
@ -73,7 +72,9 @@ def test_execute_code(setup_code_executor_mock):
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config,
)
node.init_node_data(config.get("data", {}))

@ -1,4 +1,3 @@
import time
import uuid
from unittest.mock import MagicMock
@ -54,7 +53,9 @@ def init_tool_node(config: dict):
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config,
)
node.init_node_data(config.get("data", {}))

@ -1,4 +1,3 @@
import time
from decimal import Decimal
from core.model_runtime.entities.llm_entities import LLMUsage
@ -49,7 +48,6 @@ def create_test_graph_runtime_state() -> GraphRuntimeState:
return GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
total_tokens=100,
llm_usage=llm_usage,
outputs={
@ -106,7 +104,6 @@ def test_empty_outputs_round_trip():
variable_pool = VariablePool.empty()
original_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
outputs={}, # Empty outputs
)

@ -1,24 +1,27 @@
import time
from unittest.mock import patch
import pytest
from flask import Flask
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.command_source import CommandParams, CommandTypes, ContinueCommand, SuspendCommand
from core.workflow.graph_engine.entities.event import (
BaseNodeEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
GraphRunSuspendedEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine
@ -175,8 +178,10 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
user_inputs={"query": "hi"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.WORKFLOW,
@ -186,11 +191,14 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=init_params,
)
def llm_generator(self):
contents = ["hi", "bye", "good morning"]
@ -303,8 +311,10 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
user_inputs={},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
@ -314,11 +324,14 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
# print("")
@ -484,8 +497,10 @@ def test_run_branch(mock_close, mock_remove):
user_inputs={"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
@ -495,11 +510,14 @@ def test_run_branch(mock_close, mock_remove):
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
# print("")
@ -823,8 +841,10 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
user_inputs={"query": "hi"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
@ -834,11 +854,14 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
def qc_generator(self):
yield RunCompletedEvent(
@ -884,3 +907,203 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
assert item.outputs is not None
answer = item.outputs["answer"]
assert all(rc not in answer for rc in wrong_content)
def test_suspend_and_resume():
graph_config = {
"edges": [
{
"data": {"isInLoop": False, "sourceType": "start", "targetType": "if-else"},
"id": "1753041723554-source-1753041730748-target",
"source": "1753041723554",
"sourceHandle": "source",
"target": "1753041730748",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
{
"data": {"isInLoop": False, "sourceType": "if-else", "targetType": "answer"},
"id": "1753041730748-true-answer-target",
"source": "1753041730748",
"sourceHandle": "true",
"target": "answer",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
{
"data": {
"isInIteration": False,
"isInLoop": False,
"sourceType": "if-else",
"targetType": "answer",
},
"id": "1753041730748-false-1753041952799-target",
"source": "1753041730748",
"sourceHandle": "false",
"target": "1753041952799",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
],
"nodes": [
{
"data": {"desc": "", "selected": False, "title": "Start", "type": "start", "variables": []},
"height": 54,
"id": "1753041723554",
"position": {"x": 32, "y": 282},
"positionAbsolute": {"x": 32, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"cases": [
{
"case_id": "true",
"conditions": [
{
"comparison_operator": "contains",
"id": "5db4103a-7e62-4e71-a0a6-c45ac11c0b3d",
"value": "a",
"varType": "string",
"variable_selector": ["sys", "query"],
}
],
"id": "true",
"logical_operator": "and",
}
],
"desc": "",
"selected": False,
"title": "IF/ELSE",
"type": "if-else",
},
"height": 126,
"id": "1753041730748",
"position": {"x": 368, "y": 282},
"positionAbsolute": {"x": 368, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"answer": "A",
"desc": "",
"selected": False,
"title": "Answer A",
"type": "answer",
"variables": [],
},
"height": 102,
"id": "answer",
"position": {"x": 746, "y": 282},
"positionAbsolute": {"x": 746, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"answer": "Else",
"desc": "",
"selected": False,
"title": "Answer Else",
"type": "answer",
"variables": [],
},
"height": 102,
"id": "1753041952799",
"position": {"x": 746, "y": 426},
"positionAbsolute": {"x": 746, "y": 426},
"selected": True,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
],
"viewport": {"x": -420, "y": -76.5, "zoom": 1},
}
graph = Graph.init(graph_config)
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="hello",
conversation_id="abababa",
),
user_inputs={"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
max_execution_steps=500,
max_execution_time=1200,
)
_IF_ELSE_NODE_ID = "1753041730748"
def command_source(params: CommandParams) -> CommandTypes:
# requires the engine to suspend before the execution
# of If-Else node.
if params.next_node.node_id == _IF_ELSE_NODE_ID:
return SuspendCommand()
else:
return ContinueCommand()
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
command_source=command_source,
)
events = list(graph_engine.run())
last_event = events[-1]
assert isinstance(last_event, GraphRunSuspendedEvent)
assert last_event.next_node_id == _IF_ELSE_NODE_ID
state = graph_engine.save()
assert state != ""
engine2 = GraphEngine.resume(
state=state,
graph=graph,
)
events = list(engine2.run())
assert isinstance(events[-1], GraphRunSucceededEvent)
node_run_succeeded_events = [i for i in events if isinstance(i, NodeRunSucceededEvent)]
assert node_run_succeeded_events
start_events = [i for i in node_run_succeeded_events if i.node_id == "1753041723554"]
assert not start_events
ifelse_succeeded_events = [i for i in node_run_succeeded_events if i.node_id == _IF_ELSE_NODE_ID]
assert ifelse_succeeded_events
answer_else_events = [i for i in node_run_succeeded_events if i.node_id == "1753041952799"]
assert answer_else_events
assert answer_else_events[0].route_node_state.node_run_result.outputs == {
"answer": "Else",
"files": ArrayFileSegment(value=[]),
}
answer_a_events = [i for i in node_run_succeeded_events if i.node_id == "answer"]
assert not answer_a_events

@ -1,4 +1,3 @@
import time
import uuid
from unittest.mock import MagicMock
@ -71,7 +70,9 @@ def test_execute_answer():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config,
)

@ -1,4 +1,3 @@
import time
import uuid
from unittest.mock import patch
@ -179,7 +178,9 @@ def test_run():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config,
)
@ -401,7 +402,9 @@ def test_run_parallel():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config,
)
@ -623,7 +626,9 @@ def test_iteration_run_in_parallel_mode():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=parallel_node_config,
)
@ -647,7 +652,9 @@ def test_iteration_run_in_parallel_mode():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=sequential_node_config,
)
@ -857,7 +864,9 @@ def test_iteration_run_error_handle():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=error_node_config,
)

@ -1,4 +1,3 @@
import time
import uuid
from unittest.mock import MagicMock
@ -74,7 +73,9 @@ def test_execute_answer():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=node_config,
)

@ -1,4 +1,3 @@
import time
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
@ -12,6 +11,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunStreamChunkEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
@ -175,9 +175,11 @@ class ContinueOnErrorTestHelper:
),
user_inputs=user_inputs or {"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
return GraphEngine(
graph_init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
@ -187,11 +189,14 @@ class ContinueOnErrorTestHelper:
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
return GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
)
DEFAULT_VALUE_EDGE = [

@ -1,4 +1,3 @@
import time
import uuid
from unittest.mock import MagicMock, Mock
@ -106,7 +105,9 @@ def test_execute_if_else_result_true():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config,
)
@ -192,7 +193,9 @@ def test_execute_if_else_result_false():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config,
)

@ -1,4 +1,3 @@
import time
import uuid
from unittest import mock
from uuid import uuid4
@ -96,7 +95,7 @@ def test_overwrite_string_variable():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
@ -197,7 +196,7 @@ def test_append_variable_to_array():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
@ -289,7 +288,7 @@ def test_clear_array():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)

@ -1,4 +1,3 @@
import time
import uuid
from uuid import uuid4
@ -135,7 +134,7 @@ def test_remove_first_from_array():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
)
@ -227,7 +226,7 @@ def test_remove_last_from_array():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
)
@ -311,7 +310,7 @@ def test_remove_first_from_empty_array():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
)
@ -395,7 +394,7 @@ def test_remove_last_from_empty_array():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
)

Loading…
Cancel
Save