feat(api): add a `version` class method to BaseNode and subclasses

This ensures that we can get the version of node while executing.

Add `node_version` to `BaseNodeEvent` to ensure that all node
related events includes node version information.
pull/20699/head
QuantumGhost 1 year ago
parent 1fbeb8d9bf
commit 655d55f290

@ -65,6 +65,8 @@ class BaseNodeEvent(GraphEngineEvent):
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
# The version of the node, or "1" if not specified.
node_version: str = "1"
class NodeRunStartedEvent(BaseNodeEvent):

@ -313,6 +313,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
raise e
@ -643,6 +644,7 @@ class GraphEngine:
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
agent_strategy=agent_strategy,
node_version=node_instance.version(),
)
db.session.close()
@ -701,6 +703,7 @@ class GraphEngine:
error=run_result.error or "Unknown error",
retry_index=retries,
start_at=retry_start_at,
node_version=node_instance.version(),
)
time.sleep(retry_interval)
break
@ -736,6 +739,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
should_continue_retry = False
else:
@ -750,6 +754,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
should_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
@ -804,6 +809,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
should_continue_retry = False
@ -821,6 +827,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
@ -835,6 +842,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
except GenerateTaskStoppedError:
# trigger node run failed event
@ -851,6 +859,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
)
return
except Exception as e:

@ -18,7 +18,11 @@ from models.workflow import WorkflowNodeExecutionStatus
class AnswerNode(BaseNode[AnswerNodeData]):
_node_data_cls = AnswerNodeData
_node_type: NodeType = NodeType.ANSWER
_node_type = NodeType.ANSWER
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""

@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor):
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"],
node_version=event.node_version,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)
self.route_position[answer_node_id] += 1

@ -1,7 +1,7 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[GenericNodeData]
_node_type: NodeType
_node_type: ClassVar[NodeType]
def __init__(
self,
@ -101,9 +101,10 @@ class BaseNode(Generic[GenericNodeData]):
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(
data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
)
return data
@classmethod
def _extract_variable_selector_to_variable_mapping(
@ -139,6 +140,16 @@ class BaseNode(Generic[GenericNodeData]):
"""
return self._node_type
@classmethod
@abstractmethod
def version(cls) -> str:
"""`node_version` returns the version of current node type."""
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
#
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
# in `api/core/workflow/nodes/__init__.py`.
pass
@property
def should_continue_on_error(self) -> bool:
"""judge if should continue on error

@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]):
return code_provider.get_default_config()
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
# Get code language
code_language = self.node_data.code_language

@ -45,6 +45,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
_node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)

@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]):
_node_data_cls = EndNodeData
_node_type = NodeType.END
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run node

@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
)
self.route_position[end_node_id] += 1

@ -60,6 +60,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
process_data = {}
try:

@ -16,6 +16,10 @@ class IfElseNode(BaseNode[IfElseNodeData]):
_node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run node

@ -72,6 +72,10 @@ class IterationNode(BaseNode[IterationNodeData]):
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.

@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]):
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the node.

@ -16,6 +16,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_data_cls = ListOperatorNodeData
_node_type = NodeType.LIST_OPERATOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
inputs: dict[str, list] = {}
process_data: dict[str, list] = {}

@ -148,6 +148,10 @@ class LLMNode(BaseNode[LLMNodeData]):
)
self._llm_file_saver = llm_file_saver
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any]]:
"""Process structured output if enabled"""

@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]):
_node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the node.

@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]):
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""Run the node."""
# Get inputs

@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]):
_node_data_cls = LoopStartNodeData
_node_type = NodeType.LOOP_START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
"""
Run the node.

@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var
LATEST_VERSION = "latest"
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
# Specifically, if you have introduced new node types, you should add them here.
#
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,

@ -1,3 +1,4 @@
from core.file.constants import add_dummy_output
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
@ -10,6 +11,10 @@ class StartNode(BaseNode[StartNodeData]):
_node_data_cls = StartNodeData
_node_type = NodeType.START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables
@ -18,5 +23,9 @@ class StartNode(BaseNode[StartNodeData]):
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
# Need special handling for `Start` node, as all other output variables
# are treated as systemd variables.
add_dummy_output(outputs)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)

@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
# Get variables
variables = {}

@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]):
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
"""
Run the tool node

@ -9,6 +9,10 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
# Get variables
outputs = {}

@ -1,7 +1,11 @@
from collections.abc import Sequence
from typing import Any, TypedDict
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.variables import Variable
from core.variables import Segment, SegmentType, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from extensions.ext_database import db
from models import ConversationVariable
@ -17,3 +21,22 @@ def update_conversation_variable(conversation_id: str, variable: Variable):
raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json()
session.commit()
class VariableOutput(TypedDict):
name: str
selector: Sequence[str]
new_value: Any
type: SegmentType
def variable_to_output_mapping(selector: Sequence[str], seg: Segment) -> VariableOutput:
if len(selector) < MIN_SELECTORS_LENGTH:
raise Exception("selector too short")
node_id, var_name = selector[:2]
return {
"name": var_name,
"selector": selector[:2],
"new_value": seg.value,
"type": seg.value_type,
}

@ -14,9 +14,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls = VariableAssignerData
_node_type = NodeType.VARIABLE_ASSIGNER
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found")
@ -44,7 +49,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
@ -58,6 +63,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
inputs={
"value": income_value.to_object(),
},
outputs={
# NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`,
# we still set `output_variables` as a list to ensure the schema of output is
# compatible with `v2.VariableAssignerNode`.
"updated_variables": [
common_helpers.variable_to_output_mapping(assigned_variable_selector, updated_variable)
]
},
)

@ -29,6 +29,10 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER
@classmethod
def version(cls) -> str:
return "2"
def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
process_data: dict[str, Any] = {}
@ -137,6 +141,13 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={
"updated_variables": [
common_helpers.variable_to_output_mapping(selector, seg)
for selector in updated_variable_selectors
if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None
],
},
)
def _handle_item(

Loading…
Cancel
Save