feat(api): track routing information in RouteNodeState

pull/22621/head
QuantumGhost 10 months ago
parent 9d6774c87b
commit 55c2c4a6b6

@ -1,7 +1,7 @@
import uuid import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import Enum from enum import Enum
from typing import Optional from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -44,6 +44,8 @@ class RouteNodeState(BaseModel):
paused_by: Optional[str] = None paused_by: Optional[str] = None
"""paused by""" """paused by"""
# The `index` is used used to record the execution order for a given node.
# Nodes executed ealier get smaller `index` values.
index: int = 1 index: int = 1
def set_finished(self, run_result: NodeRunResult) -> None: def set_finished(self, run_result: NodeRunResult) -> None:
@ -79,10 +81,25 @@ class RuntimeRouteState(BaseModel):
default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)" default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)"
) )
# A mapping from node_id to its routing state.
node_state_mapping: dict[str, RouteNodeState] = Field( node_state_mapping: dict[str, RouteNodeState] = Field(
default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)" default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)"
) )
next_node_id: Optional[str] = Field(
default=None, description="The next node id to run when resumed from suspension."
)
# If `previous_node_id` is not `None`, then the correspond node has state in the dict
# `node_state_mapping`.
previous_node_state_id: Optional[str] = Field(None, description="The state of last executed node.")
_state_by_id: dict[str, RouteNodeState]
def model_post_init(self, context: Any) -> None:
super().model_post_init(context)
self._state_by_id = {v.id: v for v in self.node_state_mapping.values()}
def create_node_state(self, node_id: str) -> RouteNodeState: def create_node_state(self, node_id: str) -> RouteNodeState:
""" """
Create node state Create node state
@ -91,6 +108,7 @@ class RuntimeRouteState(BaseModel):
""" """
state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None))
self.node_state_mapping[state.id] = state self.node_state_mapping[state.id] = state
self._state_by_id[state.id] = state
return state return state
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None: def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
@ -115,3 +133,18 @@ class RuntimeRouteState(BaseModel):
return [ return [
self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, []) self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, [])
] ]
# def get_node_state(self, node_id: str) -> RouteNodeState | None:
# return self.node_state_mapping.get(node_id)
def get_previous_route_node_state(self) -> RouteNodeState | None:
if self.previous_node_state_id is None:
return None
return self._state_by_id[self.previous_node_state_id]
@property
def previous_node_id(self):
if self.previous_node_state_id is None:
return None
state = self._state_by_id[self.previous_node_state_id]
return state.node_id

Loading…
Cancel
Save