diff --git a/api/core/workflow/graph_engine/execution_decision.py b/api/core/workflow/graph_engine/execution_decision.py new file mode 100644 index 0000000000..bd65819a40 --- /dev/null +++ b/api/core/workflow/graph_engine/execution_decision.py @@ -0,0 +1,25 @@ +from collections.abc import Callable +from dataclasses import dataclass +from enum import StrEnum +from typing import TypeAlias + +from core.workflow.nodes.base import BaseNode + + +class ExecutionDecision(StrEnum): + SUSPEND = "suspend" + STOP = "stop" + CONTINUE = "continue" + + +@dataclass(frozen=True) +class DecisionParams: + # `next_node_instance` is the instance of the next node to run. + next_node_instance: BaseNode + + +# `ExecutionDecisionHook` is a callable that takes a single argument of type `DecisionParams` and +# returns an `ExecutionDecision` indicating whether the graph engine should suspend, continue, or stop. +# +# It must not modify the data inside `DecisionParams`, including any attributes within its fields. +ExecutionDecisionHook: TypeAlias = Callable[[DecisionParams], ExecutionDecision] diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 920d30b39b..f3a11c9a98 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -6,14 +6,15 @@ import uuid from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy +from dataclasses import dataclass from datetime import UTC, datetime from typing import Any, Optional, cast from flask import Flask, current_app +from pydantic import BaseModel from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -27,6 +28,7 @@ from core.workflow.graph_engine.entities.event import ( GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, + GraphRunSuspendedEvent, NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, @@ -53,9 +55,10 @@ from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts -from models.enums import UserFrom from models.workflow import WorkflowType +from .execution_decision import DecisionParams, ExecutionDecision, ExecutionDecisionHook + logger = logging.getLogger(__name__) @@ -86,6 +89,10 @@ class GraphEngineThreadPool(ThreadPoolExecutor): raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") +def _default_hook(params: DecisionParams) -> ExecutionDecision: + return ExecutionDecision.CONTINUE + + class GraphEngine: workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {} @@ -95,7 +102,12 @@ class GraphEngine: graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, thread_pool_id: Optional[str] = None, + execution_decision_hook: ExecutionDecisionHook = _default_hook, ) -> None: + """Create a graph from the given state. + + The + """ thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT thread_pool_max_workers = 10 @@ -120,8 +132,7 @@ class GraphEngine: self.graph_runtime_state = graph_runtime_state - self.max_execution_steps = max_execution_steps - self.max_execution_time = max_execution_time + self._exec_decision_hook = execution_decision_hook def run(self) -> Generator[GraphEngineEvent, None, None]: # trigger graph run start event @@ -140,12 +151,18 @@ class GraphEngine: ) # run graph + + next_node_to_run = self.graph.root_node_id + if (next_node_id := self.graph_runtime_state.node_run_state.next_node_id) is not None: + next_node_to_run = next_node_id generator = stream_processor.process( - self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions) + self._run(start_node_id=next_node_to_run, handle_exceptions=handle_exceptions) ) for item in generator: try: yield item + if isinstance(item, GraphRunSuspendedEvent): + return if isinstance(item, NodeRunFailedEvent): yield GraphRunFailedEvent( error=item.route_node_state.failed_reason or "Unknown error.", @@ -209,22 +226,22 @@ class GraphEngine: parent_parallel_start_node_id: Optional[str] = None, handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent, None, None]: + # Hint: the `_run` method is used both when running a the main graph, + # and also running parallel branches. parallel_start_node_id = None if in_parallel_id: parallel_start_node_id = start_node_id next_node_id = start_node_id previous_route_node_state: Optional[RouteNodeState] = None + while True: # max steps reached - if self.graph_runtime_state.node_run_steps > self.max_execution_steps: - raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps)) + if self.graph_runtime_state.node_run_steps > self.init_params.max_execution_steps: + raise GraphRunFailedError("Max steps {} reached.".format(self.init_params.max_execution_steps)) - # or max execution time reached - if self._is_timed_out( - start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time - ): - raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time)) + if self.graph_runtime_state.is_timed_out(self.init_params.max_execution_time): + raise GraphRunFailedError("Max execution time {}s reached.".format(self.init_params.max_execution_time)) # init route node state route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id) @@ -257,6 +274,23 @@ class GraphEngine: thread_pool_id=self.thread_pool_id, ) node.init_node_data(node_config.get("data", {})) + # Determine if the execution should be suspended or stopped at this point. + # If so, yield the corresponding event. + # + # Note: Suspension is not allowed while the graph engine is running in parallel mode. + if in_parallel_id is None: + hook_result = self._exec_decision_hook(DecisionParams(next_node_instance=node)) + if hook_result == ExecutionDecision.SUSPEND: + self.graph_runtime_state.record_suspend_state(next_node_id) + yield GraphRunSuspendedEvent(next_node_id=next_node_id) + return + elif hook_result == ExecutionDecision.STOP: + # TODO: STOP the execution of worklow. + return + elif hook_result == ExecutionDecision.CONTINUE: + pass + else: + raise AssertionError("unreachable statement.") try: # run node generator = self._run_node( @@ -383,8 +417,8 @@ class GraphEngine: ) for parallel_result in parallel_generator: - if isinstance(parallel_result, str): - final_node_id = parallel_result + if isinstance(parallel_result, _ParallelBranchResult): + final_node_id = parallel_result.final_node_id else: yield parallel_result @@ -409,8 +443,8 @@ class GraphEngine: ) for generated_item in parallel_generator: - if isinstance(generated_item, str): - final_node_id = generated_item + if isinstance(generated_item, _ParallelBranchResult): + final_node_id = generated_item.final_node_id else: yield generated_item @@ -428,7 +462,7 @@ class GraphEngine: in_parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, handle_exceptions: list[str] = [], - ) -> Generator[GraphEngineEvent | str, None, None]: + ) -> Generator["GraphEngineEvent | _ParallelBranchResult", None, None]: # if nodes has no run conditions, parallel run all nodes parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) if not parallel_id: @@ -506,7 +540,7 @@ class GraphEngine: # get final node id final_node_id = parallel.end_to_node_id if final_node_id: - yield final_node_id + yield _ParallelBranchResult(final_node_id) def _run_parallel_node( self, @@ -908,7 +942,41 @@ class GraphEngine: ) return error_result + def save(self) -> str: + """save serializes the state inside this graph engine. + + This method should be called when suspension of the execution is necessary. + """ + state = _GraphEngineState(init_params=self.init_params, graph_runtime_state=self.graph_runtime_state) + return state.model_dump_json() + + @classmethod + def resume( + cls, + state: str, + graph: Graph, + execution_decision_hook: ExecutionDecisionHook = _default_hook, + ) -> "GraphEngine": + """`resume` continues a suspended execution.""" + state_ = _GraphEngineState.model_validate_json(state) + return cls( + graph=graph, + graph_init_params=state_.init_params, + graph_runtime_state=state_.graph_runtime_state, + execution_decision_hook=execution_decision_hook, + ) + class GraphRunFailedError(Exception): def __init__(self, error: str): self.error = error + + +@dataclass +class _ParallelBranchResult: + final_node_id: str + + +class _GraphEngineState(BaseModel): + init_params: GraphInitParams + graph_runtime_state: GraphRuntimeState diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 408b09e30d..ecbf53cf80 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -1,10 +1,10 @@ -import time from unittest.mock import patch import pytest from flask import Flask from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -13,15 +13,18 @@ from core.workflow.graph_engine.entities.event import ( GraphRunFailedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, + GraphRunSuspendedEvent, NodeRunFailedEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState -from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.graph_engine.execution_decision import DecisionParams +from core.workflow.graph_engine.graph_engine import ExecutionDecision, GraphEngine from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.llm.node import LLMNode @@ -904,3 +907,203 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app): assert item.outputs is not None answer = item.outputs["answer"] assert all(rc not in answer for rc in wrong_content) + + +def test_suspend_and_resume(): + graph_config = { + "edges": [ + { + "data": {"isInLoop": False, "sourceType": "start", "targetType": "if-else"}, + "id": "1753041723554-source-1753041730748-target", + "source": "1753041723554", + "sourceHandle": "source", + "target": "1753041730748", + "targetHandle": "target", + "type": "custom", + "zIndex": 0, + }, + { + "data": {"isInLoop": False, "sourceType": "if-else", "targetType": "answer"}, + "id": "1753041730748-true-answer-target", + "source": "1753041730748", + "sourceHandle": "true", + "target": "answer", + "targetHandle": "target", + "type": "custom", + "zIndex": 0, + }, + { + "data": { + "isInIteration": False, + "isInLoop": False, + "sourceType": "if-else", + "targetType": "answer", + }, + "id": "1753041730748-false-1753041952799-target", + "source": "1753041730748", + "sourceHandle": "false", + "target": "1753041952799", + "targetHandle": "target", + "type": "custom", + "zIndex": 0, + }, + ], + "nodes": [ + { + "data": {"desc": "", "selected": False, "title": "Start", "type": "start", "variables": []}, + "height": 54, + "id": "1753041723554", + "position": {"x": 32, "y": 282}, + "positionAbsolute": {"x": 32, "y": 282}, + "selected": False, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244, + }, + { + "data": { + "cases": [ + { + "case_id": "true", + "conditions": [ + { + "comparison_operator": "contains", + "id": "5db4103a-7e62-4e71-a0a6-c45ac11c0b3d", + "value": "a", + "varType": "string", + "variable_selector": ["sys", "query"], + } + ], + "id": "true", + "logical_operator": "and", + } + ], + "desc": "", + "selected": False, + "title": "IF/ELSE", + "type": "if-else", + }, + "height": 126, + "id": "1753041730748", + "position": {"x": 368, "y": 282}, + "positionAbsolute": {"x": 368, "y": 282}, + "selected": False, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244, + }, + { + "data": { + "answer": "A", + "desc": "", + "selected": False, + "title": "Answer A", + "type": "answer", + "variables": [], + }, + "height": 102, + "id": "answer", + "position": {"x": 746, "y": 282}, + "positionAbsolute": {"x": 746, "y": 282}, + "selected": False, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244, + }, + { + "data": { + "answer": "Else", + "desc": "", + "selected": False, + "title": "Answer Else", + "type": "answer", + "variables": [], + }, + "height": 102, + "id": "1753041952799", + "position": {"x": 746, "y": 426}, + "positionAbsolute": {"x": 746, "y": 426}, + "selected": True, + "sourcePosition": "right", + "targetPosition": "left", + "type": "custom", + "width": 244, + }, + ], + "viewport": {"x": -420, "y": -76.5, "zoom": 1}, + } + graph = Graph.init(graph_config) + variable_pool = VariablePool( + system_variables=SystemVariable( + user_id="aaa", + files=[], + query="hello", + conversation_id="abababa", + ), + user_inputs={"uid": "takato"}, + ) + + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + ) + graph_init_params = GraphInitParams( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + max_execution_steps=500, + max_execution_time=1200, + ) + + _IF_ELSE_NODE_ID = "1753041730748" + + def exec_decision_hook(params: DecisionParams) -> ExecutionDecision: + # requires the engine to suspend before the execution + # of If-Else node. + if params.next_node_instance.node_id == _IF_ELSE_NODE_ID: + return ExecutionDecision.SUSPEND + else: + return ExecutionDecision.CONTINUE + + graph_engine = GraphEngine( + graph=graph, + graph_runtime_state=graph_runtime_state, + graph_init_params=graph_init_params, + execution_decision_hook=exec_decision_hook, + ) + events = list(graph_engine.run()) + last_event = events[-1] + assert isinstance(last_event, GraphRunSuspendedEvent) + assert last_event.next_node_id == _IF_ELSE_NODE_ID + state = graph_engine.save() + assert state != "" + + engine2 = GraphEngine.resume( + state=state, + graph=graph, + ) + events = list(engine2.run()) + assert isinstance(events[-1], GraphRunSucceededEvent) + node_run_succeeded_events = [i for i in events if isinstance(i, NodeRunSucceededEvent)] + assert node_run_succeeded_events + start_events = [i for i in node_run_succeeded_events if i.node_id == "1753041723554"] + assert not start_events + ifelse_succeeded_events = [i for i in node_run_succeeded_events if i.node_id == _IF_ELSE_NODE_ID] + assert ifelse_succeeded_events + answer_else_events = [i for i in node_run_succeeded_events if i.node_id == "1753041952799"] + assert answer_else_events + assert answer_else_events[0].route_node_state.node_run_result.outputs == { + "answer": "Else", + "files": ArrayFileSegment(value=[]), + } + + answer_a_events = [i for i in node_run_succeeded_events if i.node_id == "answer"] + assert not answer_a_events