refactor(api): rename ExecutionDecisionHook to CommandSource
Use structured types for commands.pull/22621/head
parent
5de663d52a
commit
3af1a6d8c4
@ -0,0 +1,69 @@
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, TypeAlias, final
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
|
||||
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommandParams:
|
||||
# `next_node_instance` is the instance of the next node to run.
|
||||
next_node: BaseNode
|
||||
|
||||
|
||||
class _CommandTag(StrEnum):
|
||||
SUSPEND = "suspend"
|
||||
STOP = "stop"
|
||||
CONTINUE = "continue"
|
||||
|
||||
|
||||
class Command(BaseModel, abc.ABC):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
tag: _CommandTag
|
||||
|
||||
@field_validator("tag")
|
||||
@classmethod
|
||||
def validate_value_type(cls, value):
|
||||
if value != cls.model_fields["tag"].default:
|
||||
raise ValueError("Cannot modify 'tag'")
|
||||
return value
|
||||
|
||||
|
||||
@final
|
||||
class StopCommand(Command):
|
||||
tag: _CommandTag = _CommandTag.STOP
|
||||
|
||||
|
||||
@final
|
||||
class SuspendCommand(Command):
|
||||
tag: _CommandTag = _CommandTag.SUSPEND
|
||||
|
||||
|
||||
@final
|
||||
class ContinueCommand(Command):
|
||||
tag: _CommandTag = _CommandTag.CONTINUE
|
||||
|
||||
|
||||
def _get_command_tag(command: Command):
|
||||
return command.tag
|
||||
|
||||
|
||||
CommandTypes: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[StopCommand, Tag(_CommandTag.STOP)]
|
||||
| Annotated[SuspendCommand, Tag(_CommandTag.SUSPEND)]
|
||||
| Annotated[ContinueCommand, Tag(_CommandTag.CONTINUE)]
|
||||
),
|
||||
Discriminator(_get_command_tag),
|
||||
]
|
||||
|
||||
# `CommandSource` is a callable that takes a single argument of type `CommandParams` and
|
||||
# returns a `Command` object to the engine, indicating whether the graph engine should suspend, continue, or stop.
|
||||
#
|
||||
# It must not modify the data inside `CommandParams`, including any attributes within its fields.
|
||||
CommandSource: TypeAlias = Callable[[CommandParams], CommandTypes]
|
||||
@ -1,25 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import TypeAlias
|
||||
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
|
||||
|
||||
class ExecutionDecision(StrEnum):
|
||||
SUSPEND = "suspend"
|
||||
STOP = "stop"
|
||||
CONTINUE = "continue"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecisionParams:
|
||||
# `next_node_instance` is the instance of the next node to run.
|
||||
next_node_instance: BaseNode
|
||||
|
||||
|
||||
# `ExecutionDecisionHook` is a callable that takes a single argument of type `DecisionParams` and
|
||||
# returns an `ExecutionDecision` indicating whether the graph engine should suspend, continue, or stop.
|
||||
#
|
||||
# It must not modify the data inside `DecisionParams`, including any attributes within its fields.
|
||||
ExecutionDecisionHook: TypeAlias = Callable[[DecisionParams], ExecutionDecision]
|
||||
Loading…
Reference in New Issue