refactor(ndoes): Refactors node data management in workflow engine

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent 3c631aa94c
commit fbfb7fa131
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -1,3 +1,4 @@
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
from .graph_engine import GraphEngine
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] __all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]

@ -265,7 +265,7 @@ class GraphEngine:
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
# init workflow run state # init workflow run state
node_instance = node_cls( # type: ignore node_instance = node_cls(
id=route_node_state.id, id=route_node_state.id,
config=node_config, config=node_config,
graph_init_params=self.init_params, graph_init_params=self.init_params,
@ -274,7 +274,7 @@ class GraphEngine:
previous_node_id=previous_node_id, previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,
) )
node_instance = cast(BaseNode[BaseNodeData], node_instance) node_instance.from_dict(node_config.get("data", {}))
try: try:
# run node # run node
generator = self._run_node( generator = self._run_node(

@ -16,10 +16,14 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerNode(BaseNode[AnswerNodeData]): class AnswerNode(BaseNode):
_node_data_cls = AnswerNodeData
_node_type = NodeType.ANSWER _node_type = NodeType.ANSWER
node_data: AnswerNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = AnswerNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -1,28 +1,21 @@
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence from collections.abc import Callable, Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from .entities import BaseNodeData
if TYPE_CHECKING: if TYPE_CHECKING:
from core.workflow.graph_engine import Graph, GraphEngine, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
class BaseNode(Generic[GenericNodeData]): class BaseNode:
_node_data_cls: type[GenericNodeData]
_node_type: ClassVar[NodeType] _node_type: ClassVar[NodeType]
def __init__( def __init__(
@ -57,7 +50,8 @@ class BaseNode(Generic[GenericNodeData]):
self.node_id = node_id self.node_id = node_id
node_data = self._node_data_cls.model_validate(config.get("data", {})) node_data = self._node_data_cls.model_validate(config.get("data", {}))
self.node_data = node_data @abstractmethod
def from_dict(self, data: Mapping[str, Any]) -> None: ...
@abstractmethod @abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]: def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:

@ -21,10 +21,14 @@ from .exc import (
) )
class CodeNode(BaseNode[CodeNodeData]): class CodeNode(BaseNode):
_node_data_cls = CodeNodeData
_node_type = NodeType.CODE _node_type = NodeType.CODE
node_data: CodeNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = CodeNodeData(**data)
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
""" """

@ -36,15 +36,19 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): class DocumentExtractorNode(BaseNode):
""" """
Extracts text content from various file types. Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files. Supports plain text, PDF, and DOC/DOCX files.
""" """
_node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR _node_type = NodeType.DOCUMENT_EXTRACTOR
node_data: DocumentExtractorNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = DocumentExtractorNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -5,10 +8,14 @@ from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
class EndNode(BaseNode[EndNodeData]): class EndNode(BaseNode):
_node_data_cls = EndNodeData
_node_type = NodeType.END _node_type = NodeType.END
node_data: EndNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = EndNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -32,10 +32,14 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HttpRequestNode(BaseNode[HttpRequestNodeData]): class HttpRequestNode(BaseNode):
_node_data_cls = HttpRequestNodeData
_node_type = NodeType.HTTP_REQUEST _node_type = NodeType.HTTP_REQUEST
node_data: HttpRequestNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = HttpRequestNodeData(**data)
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
return { return {

@ -13,10 +13,14 @@ from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor from core.workflow.utils.condition.processor import ConditionProcessor
class IfElseNode(BaseNode[IfElseNodeData]): class IfElseNode(BaseNode):
_node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE _node_type = NodeType.IF_ELSE
node_data: IfElseNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = IfElseNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -56,14 +56,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class IterationNode(BaseNode[IterationNodeData]): class IterationNode(BaseNode):
""" """
Iteration Node. Iteration Node.
""" """
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION _node_type = NodeType.ITERATION
node_data: IterationNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = IterationNodeData(**data)
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return { return {

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -5,14 +8,18 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.iteration.entities import IterationStartNodeData from core.workflow.nodes.iteration.entities import IterationStartNodeData
class IterationStartNode(BaseNode[IterationStartNodeData]): class IterationStartNode(BaseNode):
""" """
Iteration Start Node. Iteration Start Node.
""" """
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START _node_type = NodeType.ITERATION_START
node_data: IterationStartNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = IterationStartNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -1,4 +1,4 @@
from collections.abc import Callable, Sequence from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal, Union from typing import Any, Literal, Union
from core.file import File from core.file import File
@ -13,10 +13,14 @@ from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
class ListOperatorNode(BaseNode[ListOperatorNodeData]): class ListOperatorNode(BaseNode):
_node_data_cls = ListOperatorNodeData
_node_type = NodeType.LIST_OPERATOR _node_type = NodeType.LIST_OPERATOR
node_data: ListOperatorNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = ListOperatorNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -90,17 +90,16 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING: if TYPE_CHECKING:
from core.file.models import File from core.file.models import File
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LLMNode(BaseNode[LLMNodeData]): class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM _node_type = NodeType.LLM
node_data: LLMNodeData
# Instance attributes specific to LLMNode. # Instance attributes specific to LLMNode.
# Output variable for file # Output variable for file
_file_outputs: list["File"] _file_outputs: list["File"]
@ -138,6 +137,9 @@ class LLMNode(BaseNode[LLMNodeData]):
) )
self._llm_file_saver = llm_file_saver self._llm_file_saver = llm_file_saver
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = LLMNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -5,14 +8,18 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.loop.entities import LoopEndNodeData from core.workflow.nodes.loop.entities import LoopEndNodeData
class LoopEndNode(BaseNode[LoopEndNodeData]): class LoopEndNode(BaseNode):
""" """
Loop End Node. Loop End Node.
""" """
_node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END _node_type = NodeType.LOOP_END
node_data: LoopEndNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopEndNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -43,14 +43,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LoopNode(BaseNode[LoopNodeData]): class LoopNode(BaseNode):
""" """
Loop Node. Loop Node.
""" """
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP _node_type = NodeType.LOOP
node_data: LoopNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -5,14 +8,18 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.loop.entities import LoopStartNodeData from core.workflow.nodes.loop.entities import LoopStartNodeData
class LoopStartNode(BaseNode[LoopStartNodeData]): class LoopStartNode(BaseNode):
""" """
Loop Start Node. Loop Start Node.
""" """
_node_data_cls = LoopStartNodeData
_node_type = NodeType.LOOP_START _node_type = NodeType.LOOP_START
node_data: LoopStartNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopStartNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -91,10 +91,13 @@ class ParameterExtractorNode(BaseNode):
Parameter Extractor Node. Parameter Extractor Node.
""" """
# FIXME: figure out why here is different from super class
_node_data_cls = ParameterExtractorNodeData # type: ignore
_node_type = NodeType.PARAMETER_EXTRACTOR _node_type = NodeType.PARAMETER_EXTRACTOR
node_data: ParameterExtractorNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = ParameterExtractorNodeData(**data)
_model_instance: Optional[ModelInstance] = None _model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None _model_config: Optional[ModelConfigWithCredentialsEntity] = None

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -6,10 +9,14 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.start.entities import StartNodeData from core.workflow.nodes.start.entities import StartNodeData
class StartNode(BaseNode[StartNodeData]): class StartNode(BaseNode):
_node_data_cls = StartNodeData
_node_type = NodeType.START _node_type = NodeType.START
node_data: StartNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = StartNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -12,10 +12,14 @@ from core.workflow.nodes.template_transform.entities import TemplateTransformNod
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): class TemplateTransformNode(BaseNode):
_node_data_cls = TemplateTransformNodeData
_node_type = NodeType.TEMPLATE_TRANSFORM _node_type = NodeType.TEMPLATE_TRANSFORM
node_data: TemplateTransformNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = TemplateTransformNodeData(**data)
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
""" """

@ -37,14 +37,18 @@ from .exc import (
) )
class ToolNode(BaseNode[ToolNodeData]): class ToolNode(BaseNode):
""" """
Tool Node Tool Node
""" """
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL _node_type = NodeType.TOOL
node_data: ToolNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = ToolNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -124,7 +128,11 @@ class ToolNode(BaseNode[ToolNodeData]):
try: try:
# convert tool messages # convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log) yield from self._transform_message(
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
)
except (PluginDaemonClientSideError, ToolInvokeError) as e: except (PluginDaemonClientSideError, ToolInvokeError) as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(

@ -1,4 +1,5 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any
from core.variables.segments import Segment from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
@ -8,10 +9,14 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): class VariableAggregatorNode(BaseNode):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR _node_type = NodeType.VARIABLE_AGGREGATOR
node_data: VariableAssignerNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerNodeData(**data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -22,11 +22,15 @@ if TYPE_CHECKING:
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] _CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(BaseNode[VariableAssignerData]): class VariableAssignerNode(BaseNode):
_node_data_cls = VariableAssignerData
_node_type = NodeType.VARIABLE_ASSIGNER _node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
node_data: VariableAssignerData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerData(**data)
def __init__( def __init__(
self, self,
id: str, id: str,

@ -54,10 +54,14 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
mapping[key] = selector mapping[key] = selector
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): class VariableAssignerNode(BaseNode):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER _node_type = NodeType.VARIABLE_ASSIGNER
node_data: VariableAssignerNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerNodeData(**data)
def _conv_var_updater_factory(self) -> ConversationVariableUpdater: def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory() return conversation_variable_updater_factory()

Loading…
Cancel
Save