feat(api): support the suspension of graph engine

Add a simple test case
pull/22621/head
QuantumGhost 7 months ago
parent f900a92ee7
commit e0343febde

@ -0,0 +1,25 @@
from collections.abc import Callable
from dataclasses import dataclass
from enum import StrEnum
from typing import TypeAlias
from core.workflow.nodes.base import BaseNode
class ExecutionDecision(StrEnum):
SUSPEND = "suspend"
STOP = "stop"
CONTINUE = "continue"
@dataclass(frozen=True)
class DecisionParams:
# `next_node_instance` is the instance of the next node to run.
next_node_instance: BaseNode
# `ExecutionDecisionHook` is a callable that takes a single argument of type `DecisionParams` and
# returns an `ExecutionDecision` indicating whether the graph engine should suspend, continue, or stop.
#
# It must not modify the data inside `DecisionParams`, including any attributes within its fields.
ExecutionDecisionHook: TypeAlias = Callable[[DecisionParams], ExecutionDecision]

@ -6,14 +6,15 @@ import uuid
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any, Optional, cast
from flask import Flask, current_app
from pydantic import BaseModel
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -27,6 +28,7 @@ from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
GraphRunSuspendedEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
@ -53,9 +55,10 @@ from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom
from models.workflow import WorkflowType
from .execution_decision import DecisionParams, ExecutionDecision, ExecutionDecisionHook
logger = logging.getLogger(__name__)
@ -86,6 +89,10 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
def _default_hook(params: DecisionParams) -> ExecutionDecision:
return ExecutionDecision.CONTINUE
class GraphEngine:
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
@ -95,7 +102,12 @@ class GraphEngine:
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
thread_pool_id: Optional[str] = None,
execution_decision_hook: ExecutionDecisionHook = _default_hook,
) -> None:
"""Create a graph from the given state.
The
"""
thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT
thread_pool_max_workers = 10
@ -120,8 +132,7 @@ class GraphEngine:
self.graph_runtime_state = graph_runtime_state
self.max_execution_steps = max_execution_steps
self.max_execution_time = max_execution_time
self._exec_decision_hook = execution_decision_hook
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
@ -140,12 +151,18 @@ class GraphEngine:
)
# run graph
next_node_to_run = self.graph.root_node_id
if (next_node_id := self.graph_runtime_state.node_run_state.next_node_id) is not None:
next_node_to_run = next_node_id
generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions)
self._run(start_node_id=next_node_to_run, handle_exceptions=handle_exceptions)
)
for item in generator:
try:
yield item
if isinstance(item, GraphRunSuspendedEvent):
return
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(
error=item.route_node_state.failed_reason or "Unknown error.",
@ -209,22 +226,22 @@ class GraphEngine:
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent, None, None]:
# Hint: the `_run` method is used both when running a the main graph,
# and also running parallel branches.
parallel_start_node_id = None
if in_parallel_id:
parallel_start_node_id = start_node_id
next_node_id = start_node_id
previous_route_node_state: Optional[RouteNodeState] = None
while True:
# max steps reached
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps))
if self.graph_runtime_state.node_run_steps > self.init_params.max_execution_steps:
raise GraphRunFailedError("Max steps {} reached.".format(self.init_params.max_execution_steps))
# or max execution time reached
if self._is_timed_out(
start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time
):
raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time))
if self.graph_runtime_state.is_timed_out(self.init_params.max_execution_time):
raise GraphRunFailedError("Max execution time {}s reached.".format(self.init_params.max_execution_time))
# init route node state
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id)
@ -257,6 +274,23 @@ class GraphEngine:
thread_pool_id=self.thread_pool_id,
)
node.init_node_data(node_config.get("data", {}))
# Determine if the execution should be suspended or stopped at this point.
# If so, yield the corresponding event.
#
# Note: Suspension is not allowed while the graph engine is running in parallel mode.
if in_parallel_id is None:
hook_result = self._exec_decision_hook(DecisionParams(next_node_instance=node))
if hook_result == ExecutionDecision.SUSPEND:
self.graph_runtime_state.record_suspend_state(next_node_id)
yield GraphRunSuspendedEvent(next_node_id=next_node_id)
return
elif hook_result == ExecutionDecision.STOP:
# TODO: STOP the execution of worklow.
return
elif hook_result == ExecutionDecision.CONTINUE:
pass
else:
raise AssertionError("unreachable statement.")
try:
# run node
generator = self._run_node(
@ -383,8 +417,8 @@ class GraphEngine:
)
for parallel_result in parallel_generator:
if isinstance(parallel_result, str):
final_node_id = parallel_result
if isinstance(parallel_result, _ParallelBranchResult):
final_node_id = parallel_result.final_node_id
else:
yield parallel_result
@ -409,8 +443,8 @@ class GraphEngine:
)
for generated_item in parallel_generator:
if isinstance(generated_item, str):
final_node_id = generated_item
if isinstance(generated_item, _ParallelBranchResult):
final_node_id = generated_item.final_node_id
else:
yield generated_item
@ -428,7 +462,7 @@ class GraphEngine:
in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent | str, None, None]:
) -> Generator["GraphEngineEvent | _ParallelBranchResult", None, None]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id:
@ -506,7 +540,7 @@ class GraphEngine:
# get final node id
final_node_id = parallel.end_to_node_id
if final_node_id:
yield final_node_id
yield _ParallelBranchResult(final_node_id)
def _run_parallel_node(
self,
@ -908,7 +942,41 @@ class GraphEngine:
)
return error_result
def save(self) -> str:
"""save serializes the state inside this graph engine.
This method should be called when suspension of the execution is necessary.
"""
state = _GraphEngineState(init_params=self.init_params, graph_runtime_state=self.graph_runtime_state)
return state.model_dump_json()
@classmethod
def resume(
cls,
state: str,
graph: Graph,
execution_decision_hook: ExecutionDecisionHook = _default_hook,
) -> "GraphEngine":
"""`resume` continues a suspended execution."""
state_ = _GraphEngineState.model_validate_json(state)
return cls(
graph=graph,
graph_init_params=state_.init_params,
graph_runtime_state=state_.graph_runtime_state,
execution_decision_hook=execution_decision_hook,
)
class GraphRunFailedError(Exception):
def __init__(self, error: str):
self.error = error
@dataclass
class _ParallelBranchResult:
final_node_id: str
class _GraphEngineState(BaseModel):
init_params: GraphInitParams
graph_runtime_state: GraphRuntimeState

@ -1,10 +1,10 @@
import time
from unittest.mock import patch
import pytest
from flask import Flask
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -13,15 +13,18 @@ from core.workflow.graph_engine.entities.event import (
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
GraphRunSuspendedEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_engine.execution_decision import DecisionParams
from core.workflow.graph_engine.graph_engine import ExecutionDecision, GraphEngine
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
@ -904,3 +907,203 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
assert item.outputs is not None
answer = item.outputs["answer"]
assert all(rc not in answer for rc in wrong_content)
def test_suspend_and_resume():
graph_config = {
"edges": [
{
"data": {"isInLoop": False, "sourceType": "start", "targetType": "if-else"},
"id": "1753041723554-source-1753041730748-target",
"source": "1753041723554",
"sourceHandle": "source",
"target": "1753041730748",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
{
"data": {"isInLoop": False, "sourceType": "if-else", "targetType": "answer"},
"id": "1753041730748-true-answer-target",
"source": "1753041730748",
"sourceHandle": "true",
"target": "answer",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
{
"data": {
"isInIteration": False,
"isInLoop": False,
"sourceType": "if-else",
"targetType": "answer",
},
"id": "1753041730748-false-1753041952799-target",
"source": "1753041730748",
"sourceHandle": "false",
"target": "1753041952799",
"targetHandle": "target",
"type": "custom",
"zIndex": 0,
},
],
"nodes": [
{
"data": {"desc": "", "selected": False, "title": "Start", "type": "start", "variables": []},
"height": 54,
"id": "1753041723554",
"position": {"x": 32, "y": 282},
"positionAbsolute": {"x": 32, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"cases": [
{
"case_id": "true",
"conditions": [
{
"comparison_operator": "contains",
"id": "5db4103a-7e62-4e71-a0a6-c45ac11c0b3d",
"value": "a",
"varType": "string",
"variable_selector": ["sys", "query"],
}
],
"id": "true",
"logical_operator": "and",
}
],
"desc": "",
"selected": False,
"title": "IF/ELSE",
"type": "if-else",
},
"height": 126,
"id": "1753041730748",
"position": {"x": 368, "y": 282},
"positionAbsolute": {"x": 368, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"answer": "A",
"desc": "",
"selected": False,
"title": "Answer A",
"type": "answer",
"variables": [],
},
"height": 102,
"id": "answer",
"position": {"x": 746, "y": 282},
"positionAbsolute": {"x": 746, "y": 282},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
{
"data": {
"answer": "Else",
"desc": "",
"selected": False,
"title": "Answer Else",
"type": "answer",
"variables": [],
},
"height": 102,
"id": "1753041952799",
"position": {"x": 746, "y": 426},
"positionAbsolute": {"x": 746, "y": 426},
"selected": True,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
},
],
"viewport": {"x": -420, "y": -76.5, "zoom": 1},
}
graph = Graph.init(graph_config)
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="hello",
conversation_id="abababa",
),
user_inputs={"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
graph_init_params = GraphInitParams(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
max_execution_steps=500,
max_execution_time=1200,
)
_IF_ELSE_NODE_ID = "1753041730748"
def exec_decision_hook(params: DecisionParams) -> ExecutionDecision:
# requires the engine to suspend before the execution
# of If-Else node.
if params.next_node_instance.node_id == _IF_ELSE_NODE_ID:
return ExecutionDecision.SUSPEND
else:
return ExecutionDecision.CONTINUE
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
execution_decision_hook=exec_decision_hook,
)
events = list(graph_engine.run())
last_event = events[-1]
assert isinstance(last_event, GraphRunSuspendedEvent)
assert last_event.next_node_id == _IF_ELSE_NODE_ID
state = graph_engine.save()
assert state != ""
engine2 = GraphEngine.resume(
state=state,
graph=graph,
)
events = list(engine2.run())
assert isinstance(events[-1], GraphRunSucceededEvent)
node_run_succeeded_events = [i for i in events if isinstance(i, NodeRunSucceededEvent)]
assert node_run_succeeded_events
start_events = [i for i in node_run_succeeded_events if i.node_id == "1753041723554"]
assert not start_events
ifelse_succeeded_events = [i for i in node_run_succeeded_events if i.node_id == _IF_ELSE_NODE_ID]
assert ifelse_succeeded_events
answer_else_events = [i for i in node_run_succeeded_events if i.node_id == "1753041952799"]
assert answer_else_events
assert answer_else_events[0].route_node_state.node_run_result.outputs == {
"answer": "Else",
"files": ArrayFileSegment(value=[]),
}
answer_a_events = [i for i in node_run_succeeded_events if i.node_id == "answer"]
assert not answer_a_events

Loading…
Cancel
Save