refactor(api): rename ExecutionDecisionHook to CommandSource
Use structured types for commands.pull/22621/head
parent
5de663d52a
commit
3af1a6d8c4
@ -0,0 +1,69 @@
|
|||||||
|
import abc
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Annotated, TypeAlias, final
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
|
||||||
|
|
||||||
|
from core.workflow.nodes.base import BaseNode
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CommandParams:
|
||||||
|
# `next_node_instance` is the instance of the next node to run.
|
||||||
|
next_node: BaseNode
|
||||||
|
|
||||||
|
|
||||||
|
class _CommandTag(StrEnum):
|
||||||
|
SUSPEND = "suspend"
|
||||||
|
STOP = "stop"
|
||||||
|
CONTINUE = "continue"
|
||||||
|
|
||||||
|
|
||||||
|
class Command(BaseModel, abc.ABC):
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
tag: _CommandTag
|
||||||
|
|
||||||
|
@field_validator("tag")
|
||||||
|
@classmethod
|
||||||
|
def validate_value_type(cls, value):
|
||||||
|
if value != cls.model_fields["tag"].default:
|
||||||
|
raise ValueError("Cannot modify 'tag'")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
class StopCommand(Command):
|
||||||
|
tag: _CommandTag = _CommandTag.STOP
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
class SuspendCommand(Command):
|
||||||
|
tag: _CommandTag = _CommandTag.SUSPEND
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
class ContinueCommand(Command):
|
||||||
|
tag: _CommandTag = _CommandTag.CONTINUE
|
||||||
|
|
||||||
|
|
||||||
|
def _get_command_tag(command: Command):
|
||||||
|
return command.tag
|
||||||
|
|
||||||
|
|
||||||
|
CommandTypes: TypeAlias = Annotated[
|
||||||
|
(
|
||||||
|
Annotated[StopCommand, Tag(_CommandTag.STOP)]
|
||||||
|
| Annotated[SuspendCommand, Tag(_CommandTag.SUSPEND)]
|
||||||
|
| Annotated[ContinueCommand, Tag(_CommandTag.CONTINUE)]
|
||||||
|
),
|
||||||
|
Discriminator(_get_command_tag),
|
||||||
|
]
|
||||||
|
|
||||||
|
# `CommandSource` is a callable that takes a single argument of type `CommandParams` and
|
||||||
|
# returns a `Command` object to the engine, indicating whether the graph engine should suspend, continue, or stop.
|
||||||
|
#
|
||||||
|
# It must not modify the data inside `CommandParams`, including any attributes within its fields.
|
||||||
|
CommandSource: TypeAlias = Callable[[CommandParams], CommandTypes]
|
||||||
@ -1,25 +0,0 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import TypeAlias
|
|
||||||
|
|
||||||
from core.workflow.nodes.base import BaseNode
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionDecision(StrEnum):
|
|
||||||
SUSPEND = "suspend"
|
|
||||||
STOP = "stop"
|
|
||||||
CONTINUE = "continue"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class DecisionParams:
|
|
||||||
# `next_node_instance` is the instance of the next node to run.
|
|
||||||
next_node_instance: BaseNode
|
|
||||||
|
|
||||||
|
|
||||||
# `ExecutionDecisionHook` is a callable that takes a single argument of type `DecisionParams` and
|
|
||||||
# returns an `ExecutionDecision` indicating whether the graph engine should suspend, continue, or stop.
|
|
||||||
#
|
|
||||||
# It must not modify the data inside `DecisionParams`, including any attributes within its fields.
|
|
||||||
ExecutionDecisionHook: TypeAlias = Callable[[DecisionParams], ExecutionDecision]
|
|
||||||
Loading…
Reference in New Issue