From 3af1a6d8c44f111d1ad7de7f0dd4f1dd8374ae94 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 21 Jul 2025 15:55:49 +0800 Subject: [PATCH] refactor(api): rename ExecutionDecisionHook to CommandSource Use structured types for commands. --- .../workflow/graph_engine/command_source.py | 69 +++++++++++++++++++ .../graph_engine/execution_decision.py | 25 ------- .../workflow/graph_engine/graph_engine.py | 22 +++--- .../graph_engine/test_graph_engine.py | 14 ++-- 4 files changed, 87 insertions(+), 43 deletions(-) create mode 100644 api/core/workflow/graph_engine/command_source.py delete mode 100644 api/core/workflow/graph_engine/execution_decision.py diff --git a/api/core/workflow/graph_engine/command_source.py b/api/core/workflow/graph_engine/command_source.py new file mode 100644 index 0000000000..1cafcf20f9 --- /dev/null +++ b/api/core/workflow/graph_engine/command_source.py @@ -0,0 +1,69 @@ +import abc +from collections.abc import Callable +from dataclasses import dataclass +from enum import StrEnum +from typing import Annotated, TypeAlias, final + +from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator + +from core.workflow.nodes.base import BaseNode + + +@dataclass(frozen=True) +class CommandParams: + # `next_node_instance` is the instance of the next node to run. + next_node: BaseNode + + +class _CommandTag(StrEnum): + SUSPEND = "suspend" + STOP = "stop" + CONTINUE = "continue" + + +class Command(BaseModel, abc.ABC): + model_config = ConfigDict(frozen=True) + + tag: _CommandTag + + @field_validator("tag") + @classmethod + def validate_value_type(cls, value): + if value != cls.model_fields["tag"].default: + raise ValueError("Cannot modify 'tag'") + return value + + +@final +class StopCommand(Command): + tag: _CommandTag = _CommandTag.STOP + + +@final +class SuspendCommand(Command): + tag: _CommandTag = _CommandTag.SUSPEND + + +@final +class ContinueCommand(Command): + tag: _CommandTag = _CommandTag.CONTINUE + + +def _get_command_tag(command: Command): + return command.tag + + +CommandTypes: TypeAlias = Annotated[ + ( + Annotated[StopCommand, Tag(_CommandTag.STOP)] + | Annotated[SuspendCommand, Tag(_CommandTag.SUSPEND)] + | Annotated[ContinueCommand, Tag(_CommandTag.CONTINUE)] + ), + Discriminator(_get_command_tag), +] + +# `CommandSource` is a callable that takes a single argument of type `CommandParams` and +# returns a `Command` object to the engine, indicating whether the graph engine should suspend, continue, or stop. +# +# It must not modify the data inside `CommandParams`, including any attributes within its fields. +CommandSource: TypeAlias = Callable[[CommandParams], CommandTypes] diff --git a/api/core/workflow/graph_engine/execution_decision.py b/api/core/workflow/graph_engine/execution_decision.py deleted file mode 100644 index bd65819a40..0000000000 --- a/api/core/workflow/graph_engine/execution_decision.py +++ /dev/null @@ -1,25 +0,0 @@ -from collections.abc import Callable -from dataclasses import dataclass -from enum import StrEnum -from typing import TypeAlias - -from core.workflow.nodes.base import BaseNode - - -class ExecutionDecision(StrEnum): - SUSPEND = "suspend" - STOP = "stop" - CONTINUE = "continue" - - -@dataclass(frozen=True) -class DecisionParams: - # `next_node_instance` is the instance of the next node to run. - next_node_instance: BaseNode - - -# `ExecutionDecisionHook` is a callable that takes a single argument of type `DecisionParams` and -# returns an `ExecutionDecision` indicating whether the graph engine should suspend, continue, or stop. -# -# It must not modify the data inside `DecisionParams`, including any attributes within its fields. -ExecutionDecisionHook: TypeAlias = Callable[[DecisionParams], ExecutionDecision] diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index f3a11c9a98..b89aee7712 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -57,7 +57,7 @@ from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts from models.workflow import WorkflowType -from .execution_decision import DecisionParams, ExecutionDecision, ExecutionDecisionHook +from .command_source import Command, CommandParams, CommandSource, ContinueCommand, StopCommand, SuspendCommand logger = logging.getLogger(__name__) @@ -89,8 +89,8 @@ class GraphEngineThreadPool(ThreadPoolExecutor): raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") -def _default_hook(params: DecisionParams) -> ExecutionDecision: - return ExecutionDecision.CONTINUE +def _default_source(params: CommandParams) -> Command: + return ContinueCommand() class GraphEngine: @@ -102,7 +102,7 @@ class GraphEngine: graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, thread_pool_id: Optional[str] = None, - execution_decision_hook: ExecutionDecisionHook = _default_hook, + command_source: CommandSource = _default_source, ) -> None: """Create a graph from the given state. @@ -132,7 +132,7 @@ class GraphEngine: self.graph_runtime_state = graph_runtime_state - self._exec_decision_hook = execution_decision_hook + self._command_source = command_source def run(self) -> Generator[GraphEngineEvent, None, None]: # trigger graph run start event @@ -279,15 +279,15 @@ class GraphEngine: # # Note: Suspension is not allowed while the graph engine is running in parallel mode. if in_parallel_id is None: - hook_result = self._exec_decision_hook(DecisionParams(next_node_instance=node)) - if hook_result == ExecutionDecision.SUSPEND: + command = self._command_source(CommandParams(next_node=node)) + if isinstance(command, SuspendCommand): self.graph_runtime_state.record_suspend_state(next_node_id) yield GraphRunSuspendedEvent(next_node_id=next_node_id) return - elif hook_result == ExecutionDecision.STOP: + elif isinstance(command, StopCommand): # TODO: STOP the execution of worklow. return - elif hook_result == ExecutionDecision.CONTINUE: + elif isinstance(command, ContinueCommand): pass else: raise AssertionError("unreachable statement.") @@ -955,7 +955,7 @@ class GraphEngine: cls, state: str, graph: Graph, - execution_decision_hook: ExecutionDecisionHook = _default_hook, + command_source: CommandSource = _default_source, ) -> "GraphEngine": """`resume` continues a suspended execution.""" state_ = _GraphEngineState.model_validate_json(state) @@ -963,7 +963,7 @@ class GraphEngine: graph=graph, graph_init_params=state_.init_params, graph_runtime_state=state_.graph_runtime_state, - execution_decision_hook=execution_decision_hook, + command_source=command_source, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index ecbf53cf80..92e46d4abd 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -8,6 +8,7 @@ from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.graph_engine.command_source import CommandParams, CommandTypes, ContinueCommand, SuspendCommand from core.workflow.graph_engine.entities.event import ( BaseNodeEvent, GraphRunFailedEvent, @@ -23,8 +24,7 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState -from core.workflow.graph_engine.execution_decision import DecisionParams -from core.workflow.graph_engine.graph_engine import ExecutionDecision, GraphEngine +from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.llm.node import LLMNode @@ -1065,19 +1065,19 @@ def test_suspend_and_resume(): _IF_ELSE_NODE_ID = "1753041730748" - def exec_decision_hook(params: DecisionParams) -> ExecutionDecision: + def command_source(params: CommandParams) -> CommandTypes: # requires the engine to suspend before the execution # of If-Else node. - if params.next_node_instance.node_id == _IF_ELSE_NODE_ID: - return ExecutionDecision.SUSPEND + if params.next_node.node_id == _IF_ELSE_NODE_ID: + return SuspendCommand() else: - return ExecutionDecision.CONTINUE + return ContinueCommand() graph_engine = GraphEngine( graph=graph, graph_runtime_state=graph_runtime_state, graph_init_params=graph_init_params, - execution_decision_hook=exec_decision_hook, + command_source=command_source, ) events = list(graph_engine.run()) last_event = events[-1]