refactor(api): rename ExecutionDecisionHook to CommandSource

Use structured types for commands.
pull/22621/head
QuantumGhost 7 months ago
parent 5de663d52a
commit 3af1a6d8c4

@ -0,0 +1,69 @@
import abc
from collections.abc import Callable
from dataclasses import dataclass
from enum import StrEnum
from typing import Annotated, TypeAlias, final
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.workflow.nodes.base import BaseNode
@dataclass(frozen=True)
class CommandParams:
# `next_node_instance` is the instance of the next node to run.
next_node: BaseNode
class _CommandTag(StrEnum):
SUSPEND = "suspend"
STOP = "stop"
CONTINUE = "continue"
class Command(BaseModel, abc.ABC):
model_config = ConfigDict(frozen=True)
tag: _CommandTag
@field_validator("tag")
@classmethod
def validate_value_type(cls, value):
if value != cls.model_fields["tag"].default:
raise ValueError("Cannot modify 'tag'")
return value
@final
class StopCommand(Command):
tag: _CommandTag = _CommandTag.STOP
@final
class SuspendCommand(Command):
tag: _CommandTag = _CommandTag.SUSPEND
@final
class ContinueCommand(Command):
tag: _CommandTag = _CommandTag.CONTINUE
def _get_command_tag(command: Command):
return command.tag
CommandTypes: TypeAlias = Annotated[
(
Annotated[StopCommand, Tag(_CommandTag.STOP)]
| Annotated[SuspendCommand, Tag(_CommandTag.SUSPEND)]
| Annotated[ContinueCommand, Tag(_CommandTag.CONTINUE)]
),
Discriminator(_get_command_tag),
]
# `CommandSource` is a callable that takes a single argument of type `CommandParams` and
# returns a `Command` object to the engine, indicating whether the graph engine should suspend, continue, or stop.
#
# It must not modify the data inside `CommandParams`, including any attributes within its fields.
CommandSource: TypeAlias = Callable[[CommandParams], CommandTypes]

@ -1,25 +0,0 @@
from collections.abc import Callable
from dataclasses import dataclass
from enum import StrEnum
from typing import TypeAlias
from core.workflow.nodes.base import BaseNode
class ExecutionDecision(StrEnum):
SUSPEND = "suspend"
STOP = "stop"
CONTINUE = "continue"
@dataclass(frozen=True)
class DecisionParams:
# `next_node_instance` is the instance of the next node to run.
next_node_instance: BaseNode
# `ExecutionDecisionHook` is a callable that takes a single argument of type `DecisionParams` and
# returns an `ExecutionDecision` indicating whether the graph engine should suspend, continue, or stop.
#
# It must not modify the data inside `DecisionParams`, including any attributes within its fields.
ExecutionDecisionHook: TypeAlias = Callable[[DecisionParams], ExecutionDecision]

@ -57,7 +57,7 @@ from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts
from models.workflow import WorkflowType
from .execution_decision import DecisionParams, ExecutionDecision, ExecutionDecisionHook
from .command_source import Command, CommandParams, CommandSource, ContinueCommand, StopCommand, SuspendCommand
logger = logging.getLogger(__name__)
@ -89,8 +89,8 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
def _default_hook(params: DecisionParams) -> ExecutionDecision:
return ExecutionDecision.CONTINUE
def _default_source(params: CommandParams) -> Command:
return ContinueCommand()
class GraphEngine:
@ -102,7 +102,7 @@ class GraphEngine:
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
thread_pool_id: Optional[str] = None,
execution_decision_hook: ExecutionDecisionHook = _default_hook,
command_source: CommandSource = _default_source,
) -> None:
"""Create a graph from the given state.
@ -132,7 +132,7 @@ class GraphEngine:
self.graph_runtime_state = graph_runtime_state
self._exec_decision_hook = execution_decision_hook
self._command_source = command_source
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
@ -279,15 +279,15 @@ class GraphEngine:
#
# Note: Suspension is not allowed while the graph engine is running in parallel mode.
if in_parallel_id is None:
hook_result = self._exec_decision_hook(DecisionParams(next_node_instance=node))
if hook_result == ExecutionDecision.SUSPEND:
command = self._command_source(CommandParams(next_node=node))
if isinstance(command, SuspendCommand):
self.graph_runtime_state.record_suspend_state(next_node_id)
yield GraphRunSuspendedEvent(next_node_id=next_node_id)
return
elif hook_result == ExecutionDecision.STOP:
elif isinstance(command, StopCommand):
# TODO: STOP the execution of worklow.
return
elif hook_result == ExecutionDecision.CONTINUE:
elif isinstance(command, ContinueCommand):
pass
else:
raise AssertionError("unreachable statement.")
@ -955,7 +955,7 @@ class GraphEngine:
cls,
state: str,
graph: Graph,
execution_decision_hook: ExecutionDecisionHook = _default_hook,
command_source: CommandSource = _default_source,
) -> "GraphEngine":
"""`resume` continues a suspended execution."""
state_ = _GraphEngineState.model_validate_json(state)
@ -963,7 +963,7 @@ class GraphEngine:
graph=graph,
graph_init_params=state_.init_params,
graph_runtime_state=state_.graph_runtime_state,
execution_decision_hook=execution_decision_hook,
command_source=command_source,
)

@ -8,6 +8,7 @@ from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.command_source import CommandParams, CommandTypes, ContinueCommand, SuspendCommand
from core.workflow.graph_engine.entities.event import (
BaseNodeEvent,
GraphRunFailedEvent,
@ -23,8 +24,7 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.execution_decision import DecisionParams
from core.workflow.graph_engine.graph_engine import ExecutionDecision, GraphEngine
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
@ -1065,19 +1065,19 @@ def test_suspend_and_resume():
_IF_ELSE_NODE_ID = "1753041730748"
def exec_decision_hook(params: DecisionParams) -> ExecutionDecision:
def command_source(params: CommandParams) -> CommandTypes:
# requires the engine to suspend before the execution
# of If-Else node.
if params.next_node_instance.node_id == _IF_ELSE_NODE_ID:
return ExecutionDecision.SUSPEND
if params.next_node.node_id == _IF_ELSE_NODE_ID:
return SuspendCommand()
else:
return ExecutionDecision.CONTINUE
return ContinueCommand()
graph_engine = GraphEngine(
graph=graph,
graph_runtime_state=graph_runtime_state,
graph_init_params=graph_init_params,
execution_decision_hook=exec_decision_hook,
command_source=command_source,
)
events = list(graph_engine.run())
last_event = events[-1]

Loading…
Cancel
Save