feat(api): support the suspension of graph engine

Add a simple test case
pull/22621/head
QuantumGhost 10 months ago
parent f900a92ee7
commit e0343febde

@ -0,0 +1,25 @@
from collections.abc import Callable
from dataclasses import dataclass
from enum import StrEnum
from typing import TypeAlias
from core.workflow.nodes.base import BaseNode
class ExecutionDecision(StrEnum):
SUSPEND = "suspend"
STOP = "stop"
CONTINUE = "continue"
@dataclass(frozen=True)
class DecisionParams:
# `next_node_instance` is the instance of the next node to run.
next_node_instance: BaseNode
# `ExecutionDecisionHook` is a callable that takes a single argument of type `DecisionParams` and
# returns an `ExecutionDecision` indicating whether the graph engine should suspend, continue, or stop.
#
# It must not modify the data inside `DecisionParams`, including any attributes within its fields.
ExecutionDecisionHook: TypeAlias = Callable[[DecisionParams], ExecutionDecision]

@ -6,14 +6,15 @@ import uuid
from collections.abc import Generator from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor, wait from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any, Optional, cast from typing import Any, Optional, cast
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import BaseModel
from configs import dify_config from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -27,6 +28,7 @@ from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent, GraphRunPartialSucceededEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
GraphRunSuspendedEvent,
NodeRunExceptionEvent, NodeRunExceptionEvent,
NodeRunFailedEvent, NodeRunFailedEvent,
NodeRunRetrieverResourceEvent, NodeRunRetrieverResourceEvent,
@ -53,9 +55,10 @@ from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.utils import variable_utils from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType
from .execution_decision import DecisionParams, ExecutionDecision, ExecutionDecisionHook
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -86,6 +89,10 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
def _default_hook(params: DecisionParams) -> ExecutionDecision:
return ExecutionDecision.CONTINUE
class GraphEngine: class GraphEngine:
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {} workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
@ -95,7 +102,12 @@ class GraphEngine:
graph_init_params: GraphInitParams, graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
thread_pool_id: Optional[str] = None, thread_pool_id: Optional[str] = None,
execution_decision_hook: ExecutionDecisionHook = _default_hook,
) -> None: ) -> None:
"""Create a graph from the given state.
The
"""
thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT
thread_pool_max_workers = 10 thread_pool_max_workers = 10
@ -120,8 +132,7 @@ class GraphEngine:
self.graph_runtime_state = graph_runtime_state self.graph_runtime_state = graph_runtime_state
self.max_execution_steps = max_execution_steps self._exec_decision_hook = execution_decision_hook
self.max_execution_time = max_execution_time
def run(self) -> Generator[GraphEngineEvent, None, None]: def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event # trigger graph run start event
@ -140,12 +151,18 @@ class GraphEngine:
) )
# run graph # run graph
next_node_to_run = self.graph.root_node_id
if (next_node_id := self.graph_runtime_state.node_run_state.next_node_id) is not None:
next_node_to_run = next_node_id
generator = stream_processor.process( generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions) self._run(start_node_id=next_node_to_run, handle_exceptions=handle_exceptions)
) )
for item in generator: for item in generator:
try: try:
yield item yield item
if isinstance(item, GraphRunSuspendedEvent):
return
if isinstance(item, NodeRunFailedEvent): if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent( yield GraphRunFailedEvent(
error=item.route_node_state.failed_reason or "Unknown error.", error=item.route_node_state.failed_reason or "Unknown error.",
@ -209,22 +226,22 @@ class GraphEngine:
parent_parallel_start_node_id: Optional[str] = None, parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [], handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent, None, None]: ) -> Generator[GraphEngineEvent, None, None]:
# Hint: the `_run` method is used both when running a the main graph,
# and also running parallel branches.
parallel_start_node_id = None parallel_start_node_id = None
if in_parallel_id: if in_parallel_id:
parallel_start_node_id = start_node_id parallel_start_node_id = start_node_id
next_node_id = start_node_id next_node_id = start_node_id
previous_route_node_state: Optional[RouteNodeState] = None previous_route_node_state: Optional[RouteNodeState] = None
while True: while True:
# max steps reached # max steps reached
if self.graph_runtime_state.node_run_steps > self.max_execution_steps: if self.graph_runtime_state.node_run_steps > self.init_params.max_execution_steps:
raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps)) raise GraphRunFailedError("Max steps {} reached.".format(self.init_params.max_execution_steps))
# or max execution time reached if self.graph_runtime_state.is_timed_out(self.init_params.max_execution_time):
if self._is_timed_out( raise GraphRunFailedError("Max execution time {}s reached.".format(self.init_params.max_execution_time))
start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time
):
raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time))
# init route node state # init route node state
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id) route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id)
@ -257,6 +274,23 @@ class GraphEngine:
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,
) )
node.init_node_data(node_config.get("data", {})) node.init_node_data(node_config.get("data", {}))
# Determine if the execution should be suspended or stopped at this point.
# If so, yield the corresponding event.
#
# Note: Suspension is not allowed while the graph engine is running in parallel mode.
if in_parallel_id is None:
hook_result = self._exec_decision_hook(DecisionParams(next_node_instance=node))
if hook_result == ExecutionDecision.SUSPEND:
self.graph_runtime_state.record_suspend_state(next_node_id)
yield GraphRunSuspendedEvent(next_node_id=next_node_id)
return
elif hook_result == ExecutionDecision.STOP:
# TODO: STOP the execution of worklow.
return
elif hook_result == ExecutionDecision.CONTINUE:
pass
else:
raise AssertionError("unreachable statement.")
try: try:
# run node # run node
generator = self._run_node( generator = self._run_node(
@ -383,8 +417,8 @@ class GraphEngine:
) )
for parallel_result in parallel_generator: for parallel_result in parallel_generator:
if isinstance(parallel_result, str): if isinstance(parallel_result, _ParallelBranchResult):
final_node_id = parallel_result final_node_id = parallel_result.final_node_id
else: else:
yield parallel_result yield parallel_result
@ -409,8 +443,8 @@ class GraphEngine:
) )
for generated_item in parallel_generator: for generated_item in parallel_generator:
if isinstance(generated_item, str): if isinstance(generated_item, _ParallelBranchResult):
final_node_id = generated_item final_node_id = generated_item.final_node_id
else: else:
yield generated_item yield generated_item
@ -428,7 +462,7 @@ class GraphEngine:
in_parallel_id: Optional[str] = None, in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [], handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent | str, None, None]: ) -> Generator["GraphEngineEvent | _ParallelBranchResult", None, None]:
# if nodes has no run conditions, parallel run all nodes # if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id: if not parallel_id:
@ -506,7 +540,7 @@ class GraphEngine:
# get final node id # get final node id
final_node_id = parallel.end_to_node_id final_node_id = parallel.end_to_node_id
if final_node_id: if final_node_id:
yield final_node_id yield _ParallelBranchResult(final_node_id)
def _run_parallel_node( def _run_parallel_node(
self, self,
@ -908,7 +942,41 @@ class GraphEngine:
) )
return error_result return error_result
def save(self) -> str:
"""save serializes the state inside this graph engine.
This method should be called when suspension of the execution is necessary.
"""
state = _GraphEngineState(init_params=self.init_params, graph_runtime_state=self.graph_runtime_state)
return state.model_dump_json()
@classmethod
def resume(
cls,
state: str,
graph: Graph,
execution_decision_hook: ExecutionDecisionHook = _default_hook,
) -> "GraphEngine":
"""`resume` continues a suspended execution."""
state_ = _GraphEngineState.model_validate_json(state)
return cls(
graph=graph,
graph_init_params=state_.init_params,
graph_runtime_state=state_.graph_runtime_state,
execution_decision_hook=execution_decision_hook,
)
class GraphRunFailedError(Exception): class GraphRunFailedError(Exception):
def __init__(self, error: str): def __init__(self, error: str):
self.error = error self.error = error
@dataclass
class _ParallelBranchResult:
final_node_id: str
class _GraphEngineState(BaseModel):
init_params: GraphInitParams
graph_runtime_state: GraphRuntimeState

@ -1,10 +1,10 @@
import time
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from flask import Flask from flask import Flask
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -13,15 +13,18 @@ from core.workflow.graph_engine.entities.event import (
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
GraphRunSuspendedEvent,
NodeRunFailedEvent, NodeRunFailedEvent,
NodeRunStartedEvent, NodeRunStartedEvent,
NodeRunStreamChunkEvent, NodeRunStreamChunkEvent,
NodeRunSucceededEvent, NodeRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.execution_decision import DecisionParams
from core.workflow.graph_engine.graph_engine import ExecutionDecision, GraphEngine
from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.llm.node import LLMNode
@ -904,3 +907,203 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
assert item.outputs is not None assert item.outputs is not None
answer = item.outputs["answer"] answer = item.outputs["answer"]
assert all(rc not in answer for rc in wrong_content) assert all(rc not in answer for rc in wrong_content)
def test_suspend_and_resume():
graph_config = {
"edges": [
{
"data": {"isInLoop": False, "sourceType": "start", "targetType": "if-else"},
"id": "1753041723554-source-1753041730748-target",
"source": "1753041723554",
"sourceHandle": "source",
"target": "1753041730748",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
{
"data": {"isInLoop": False, "sourceType": "if-else", "targetType": "answer"},
"id": "1753041730748-true-answer-target",
"source": "1753041730748",
"sourceHandle": "true",
"target": "answer",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
{
"data": {
"isInIteration": False,
"isInLoop": False,
"sourceType": "if-else",
"targetType": "answer",
},
"id": "1753041730748-false-1753041952799-target",
"source": "1753041730748",
"sourceHandle": "false",
"target": "1753041952799",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
],
"nodes": [
{
"data": {"desc": "", "selected": False, "title": "Start", "type": "start", "variables": []},
"height": 54,
"id": "1753041723554",
"position": {"x": 32, "y": 282},
"positionAbsolute": {"x": 32, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"cases": [
{
"case_id": "true",
"conditions": [
{
"comparison_operator": "contains",
"id": "5db4103a-7e62-4e71-a0a6-c45ac11c0b3d",
"value": "a",
"varType": "string",
"variable_selector": ["sys", "query"],
}
],
"id": "true",
"logical_operator": "and",
}
],
"desc": "",
"selected": False,
"title": "IF/ELSE",
"type": "if-else",
},
"height": 126,
"id": "1753041730748",
"position": {"x": 368, "y": 282},
"positionAbsolute": {"x": 368, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"answer": "A",
"desc": "",
"selected": False,
"title": "Answer A",
"type": "answer",
"variables": [],
},
"height": 102,
"id": "answer",
"position": {"x": 746, "y": 282},
"positionAbsolute": {"x": 746, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"answer": "Else",
"desc": "",
"selected": False,
"title": "Answer Else",
"type": "answer",
"variables": [],
},
"height": 102,
"id": "1753041952799",
"position": {"x": 746, "y": 426},
"positionAbsolute": {"x": 746, "y": 426},
"selected": True,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
],
"viewport": {"x": -420, "y": -76.5, "zoom": 1},
}
graph = Graph.init(graph_config)
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="hello",
conversation_id="abababa",
),
user_inputs={"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
max_execution_steps=500,
max_execution_time=1200,
)
_IF_ELSE_NODE_ID = "1753041730748"
def exec_decision_hook(params: DecisionParams) -> ExecutionDecision:
# requires the engine to suspend before the execution
# of If-Else node.
if params.next_node_instance.node_id == _IF_ELSE_NODE_ID:
return ExecutionDecision.SUSPEND
else:
return ExecutionDecision.CONTINUE
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
execution_decision_hook=exec_decision_hook,
)
events = list(graph_engine.run())
last_event = events[-1]
assert isinstance(last_event, GraphRunSuspendedEvent)
assert last_event.next_node_id == _IF_ELSE_NODE_ID
state = graph_engine.save()
assert state != ""
engine2 = GraphEngine.resume(
state=state,
graph=graph,
)
events = list(engine2.run())
assert isinstance(events[-1], GraphRunSucceededEvent)
node_run_succeeded_events = [i for i in events if isinstance(i, NodeRunSucceededEvent)]
assert node_run_succeeded_events
start_events = [i for i in node_run_succeeded_events if i.node_id == "1753041723554"]
assert not start_events
ifelse_succeeded_events = [i for i in node_run_succeeded_events if i.node_id == _IF_ELSE_NODE_ID]
assert ifelse_succeeded_events
answer_else_events = [i for i in node_run_succeeded_events if i.node_id == "1753041952799"]
assert answer_else_events
assert answer_else_events[0].route_node_state.node_run_result.outputs == {
"answer": "Else",
"files": ArrayFileSegment(value=[]),
}
answer_a_events = [i for i in node_run_succeeded_events if i.node_id == "answer"]
assert not answer_a_events

Loading…
Cancel
Save