diff --git a/api/core/workflow/graph_engine/command_source.py b/api/core/workflow/graph_engine/command_source.py index 1cafcf20f9..2d0d4b8211 100644 --- a/api/core/workflow/graph_engine/command_source.py +++ b/api/core/workflow/graph_engine/command_source.py @@ -21,7 +21,9 @@ class _CommandTag(StrEnum): CONTINUE = "continue" -class Command(BaseModel, abc.ABC): +# Note: Avoid using the `_Command` class directly. +# Instead, use `CommandTypes` for type annotations. +class _Command(BaseModel, abc.ABC): model_config = ConfigDict(frozen=True) tag: _CommandTag @@ -35,21 +37,21 @@ class Command(BaseModel, abc.ABC): @final -class StopCommand(Command): +class StopCommand(_Command): tag: _CommandTag = _CommandTag.STOP @final -class SuspendCommand(Command): +class SuspendCommand(_Command): tag: _CommandTag = _CommandTag.SUSPEND @final -class ContinueCommand(Command): +class ContinueCommand(_Command): tag: _CommandTag = _CommandTag.CONTINUE -def _get_command_tag(command: Command): +def _get_command_tag(command: _Command): return command.tag diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b89aee7712..d930b6a923 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -57,7 +57,14 @@ from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts from models.workflow import WorkflowType -from .command_source import Command, CommandParams, CommandSource, ContinueCommand, StopCommand, SuspendCommand +from .command_source import ( + CommandParams, + CommandSource, + CommandTypes, + ContinueCommand, + StopCommand, + SuspendCommand, +) logger = logging.getLogger(__name__) @@ -89,7 +96,7 @@ class GraphEngineThreadPool(ThreadPoolExecutor): raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") -def _default_source(params: CommandParams) -> Command: +def _default_source(_: CommandParams) -> CommandTypes: return ContinueCommand()