feat(api): add a `version` class method to BaseNode and subclasses

This ensures that we can get the version of node while executing.

Add `node_version` to `BaseNodeEvent` to ensure that all node
related events includes node version information.
pull/20699/head
QuantumGhost 1 year ago
parent 1fbeb8d9bf
commit 655d55f290

@ -65,6 +65,8 @@ class BaseNodeEvent(GraphEngineEvent):
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
in_loop_id: Optional[str] = None in_loop_id: Optional[str] = None
"""loop id if node is in loop""" """loop id if node is in loop"""
# The version of the node, or "1" if not specified.
node_version: str = "1"
class NodeRunStartedEvent(BaseNodeEvent): class NodeRunStartedEvent(BaseNodeEvent):

@ -313,6 +313,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
raise e raise e
@ -643,6 +644,7 @@ class GraphEngine:
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
agent_strategy=agent_strategy, agent_strategy=agent_strategy,
node_version=node_instance.version(),
) )
db.session.close() db.session.close()
@ -701,6 +703,7 @@ class GraphEngine:
error=run_result.error or "Unknown error", error=run_result.error or "Unknown error",
retry_index=retries, retry_index=retries,
start_at=retry_start_at, start_at=retry_start_at,
node_version=node_instance.version(),
) )
time.sleep(retry_interval) time.sleep(retry_interval)
break break
@ -736,6 +739,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
should_continue_retry = False should_continue_retry = False
else: else:
@ -750,6 +754,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
should_continue_retry = False should_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
@ -804,6 +809,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
should_continue_retry = False should_continue_retry = False
@ -821,6 +827,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
elif isinstance(item, RunRetrieverResourceEvent): elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent( yield NodeRunRetrieverResourceEvent(
@ -835,6 +842,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
except GenerateTaskStoppedError: except GenerateTaskStoppedError:
# trigger node run failed event # trigger node run failed event
@ -851,6 +859,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
return return
except Exception as e: except Exception as e:

@ -18,7 +18,11 @@ from models.workflow import WorkflowNodeExecutionStatus
class AnswerNode(BaseNode[AnswerNodeData]): class AnswerNode(BaseNode[AnswerNodeData]):
_node_data_cls = AnswerNodeData _node_data_cls = AnswerNodeData
_node_type: NodeType = NodeType.ANSWER _node_type = NodeType.ANSWER
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """

@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor):
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"], from_variable_selector=[answer_node_id, "answer"],
node_version=event.node_version,
) )
else: else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk) route_chunk = cast(VarGenerateRouteChunk, route_chunk)
@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state, route_node_state=event.route_node_state,
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
) )
self.route_position[answer_node_id] += 1 self.route_position[answer_node_id] += 1

@ -1,7 +1,7 @@
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
class BaseNode(Generic[GenericNodeData]): class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[GenericNodeData] _node_data_cls: type[GenericNodeData]
_node_type: NodeType _node_type: ClassVar[NodeType]
def __init__( def __init__(
self, self,
@ -101,9 +101,10 @@ class BaseNode(Generic[GenericNodeData]):
raise ValueError("Node ID is required when extracting variable selector to variable mapping.") raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {})) node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping( data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
) )
return data
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
@ -139,6 +140,16 @@ class BaseNode(Generic[GenericNodeData]):
""" """
return self._node_type return self._node_type
@classmethod
@abstractmethod
def version(cls) -> str:
"""`node_version` returns the version of current node type."""
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
#
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
# in `api/core/workflow/nodes/__init__.py`.
pass
@property @property
def should_continue_on_error(self) -> bool: def should_continue_on_error(self) -> bool:
"""judge if should continue on error """judge if should continue on error

@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]):
return code_provider.get_default_config() return code_provider.get_default_config()
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get code language # Get code language
code_language = self.node_data.code_language code_language = self.node_data.code_language

@ -45,6 +45,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
_node_data_cls = DocumentExtractorNodeData _node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR _node_type = NodeType.DOCUMENT_EXTRACTOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self): def _run(self):
variable_selector = self.node_data.variable_selector variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector) variable = self.graph_runtime_state.variable_pool.get(variable_selector)

@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]):
_node_data_cls = EndNodeData _node_data_cls = EndNodeData
_node_type = NodeType.END _node_type = NodeType.END
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
Run node Run node

@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state, route_node_state=event.route_node_state,
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
) )
self.route_position[end_node_id] += 1 self.route_position[end_node_id] += 1

@ -60,6 +60,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
}, },
} }
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
process_data = {} process_data = {}
try: try:

@ -16,6 +16,10 @@ class IfElseNode(BaseNode[IfElseNodeData]):
_node_data_cls = IfElseNodeData _node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE _node_type = NodeType.IF_ELSE
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
Run node Run node

@ -72,6 +72,10 @@ class IterationNode(BaseNode[IterationNodeData]):
}, },
} }
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
""" """
Run the node. Run the node.

@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]):
_node_data_cls = IterationStartNodeData _node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START _node_type = NodeType.ITERATION_START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
Run the node. Run the node.

@ -16,6 +16,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_data_cls = ListOperatorNodeData _node_data_cls = ListOperatorNodeData
_node_type = NodeType.LIST_OPERATOR _node_type = NodeType.LIST_OPERATOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self): def _run(self):
inputs: dict[str, list] = {} inputs: dict[str, list] = {}
process_data: dict[str, list] = {} process_data: dict[str, list] = {}

@ -148,6 +148,10 @@ class LLMNode(BaseNode[LLMNodeData]):
) )
self._llm_file_saver = llm_file_saver self._llm_file_saver = llm_file_saver
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any]]: def process_structured_output(text: str) -> Optional[dict[str, Any]]:
"""Process structured output if enabled""" """Process structured output if enabled"""

@ -13,6 +13,10 @@ class LoopEndNode(BaseNode[LoopEndNodeData]):
_node_data_cls = LoopEndNodeData _node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END _node_type = NodeType.LOOP_END
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
Run the node. Run the node.

@ -54,6 +54,10 @@ class LoopNode(BaseNode[LoopNodeData]):
_node_data_cls = LoopNodeData _node_data_cls = LoopNodeData
_node_type = NodeType.LOOP _node_type = NodeType.LOOP
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""Run the node.""" """Run the node."""
# Get inputs # Get inputs

@ -13,6 +13,10 @@ class LoopStartNode(BaseNode[LoopStartNodeData]):
_node_data_cls = LoopStartNodeData _node_data_cls = LoopStartNodeData
_node_type = NodeType.LOOP_START _node_type = NodeType.LOOP_START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
Run the node. Run the node.

@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var
LATEST_VERSION = "latest" LATEST_VERSION = "latest"
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
# Specifically, if you have introduced new node types, you should add them here.
#
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
NodeType.START: { NodeType.START: {
LATEST_VERSION: StartNode, LATEST_VERSION: StartNode,

@ -1,3 +1,4 @@
from core.file.constants import add_dummy_output
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -10,6 +11,10 @@ class StartNode(BaseNode[StartNodeData]):
_node_data_cls = StartNodeData _node_data_cls = StartNodeData
_node_type = NodeType.START _node_type = NodeType.START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables system_inputs = self.graph_runtime_state.variable_pool.system_variables
@ -18,5 +23,9 @@ class StartNode(BaseNode[StartNodeData]):
# Set system variables as node outputs. # Set system variables as node outputs.
for var in system_inputs: for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
outputs = dict(node_inputs)
# Need special handling for `Start` node, as all other output variables
# are treated as systemd variables.
add_dummy_output(outputs)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)

@ -28,6 +28,10 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
} }
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get variables # Get variables
variables = {} variables = {}

@ -44,6 +44,10 @@ class ToolNode(BaseNode[ToolNodeData]):
_node_data_cls = ToolNodeData _node_data_cls = ToolNodeData
_node_type = NodeType.TOOL _node_type = NodeType.TOOL
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator: def _run(self) -> Generator:
""" """
Run the tool node Run the tool node

@ -9,6 +9,10 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData _node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR _node_type = NodeType.VARIABLE_AGGREGATOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get variables # Get variables
outputs = {} outputs = {}

@ -1,7 +1,11 @@
from collections.abc import Sequence
from typing import Any, TypedDict
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.variables import Variable from core.variables import Segment, SegmentType, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from extensions.ext_database import db from extensions.ext_database import db
from models import ConversationVariable from models import ConversationVariable
@ -17,3 +21,22 @@ def update_conversation_variable(conversation_id: str, variable: Variable):
raise VariableOperatorNodeError("conversation variable not found in the database") raise VariableOperatorNodeError("conversation variable not found in the database")
row.data = variable.model_dump_json() row.data = variable.model_dump_json()
session.commit() session.commit()
class VariableOutput(TypedDict):
name: str
selector: Sequence[str]
new_value: Any
type: SegmentType
def variable_to_output_mapping(selector: Sequence[str], seg: Segment) -> VariableOutput:
if len(selector) < MIN_SELECTORS_LENGTH:
raise Exception("selector too short")
node_id, var_name = selector[:2]
return {
"name": var_name,
"selector": selector[:2],
"new_value": seg.value,
"type": seg.value_type,
}

@ -14,9 +14,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls = VariableAssignerData _node_data_cls = VariableAssignerData
_node_type = NodeType.VARIABLE_ASSIGNER _node_type = NodeType.VARIABLE_ASSIGNER
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector) original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable): if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found") raise VariableOperatorNodeError("assigned variable not found")
@ -44,7 +49,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable. # Over write the variable.
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable) self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline. # TODO: Move database operation to the pipeline.
# Update conversation variable. # Update conversation variable.
@ -58,6 +63,14 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
inputs={ inputs={
"value": income_value.to_object(), "value": income_value.to_object(),
}, },
outputs={
# NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`,
# we still set `output_variables` as a list to ensure the schema of output is
# compatible with `v2.VariableAssignerNode`.
"updated_variables": [
common_helpers.variable_to_output_mapping(assigned_variable_selector, updated_variable)
]
},
) )

@ -29,6 +29,10 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData _node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER _node_type = NodeType.VARIABLE_ASSIGNER
@classmethod
def version(cls) -> str:
return "2"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump() inputs = self.node_data.model_dump()
process_data: dict[str, Any] = {} process_data: dict[str, Any] = {}
@ -137,6 +141,13 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs, inputs=inputs,
process_data=process_data, process_data=process_data,
outputs={
"updated_variables": [
common_helpers.variable_to_output_mapping(selector, seg)
for selector in updated_variable_selectors
if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None
],
},
) )
def _handle_item( def _handle_item(

Loading…
Cancel
Save