refactor(ndoes): Refactors node data management in workflow engine

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent 3c631aa94c
commit fbfb7fa131
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -1,3 +1,4 @@
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
from .graph_engine import GraphEngine
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]

@ -265,7 +265,7 @@ class GraphEngine:
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
# init workflow run state
node_instance = node_cls( # type: ignore
node_instance = node_cls(
id=route_node_state.id,
config=node_config,
graph_init_params=self.init_params,
@ -274,7 +274,7 @@ class GraphEngine:
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id,
)
node_instance = cast(BaseNode[BaseNodeData], node_instance)
node_instance.from_dict(node_config.get("data", {}))
try:
# run node
generator = self._run_node(

@ -16,10 +16,14 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerNode(BaseNode[AnswerNodeData]):
_node_data_cls = AnswerNodeData
class AnswerNode(BaseNode):
_node_type = NodeType.ANSWER
node_data: AnswerNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = AnswerNodeData(**data)
@classmethod
def version(cls) -> str:
return "1"

@ -1,28 +1,21 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from .entities import BaseNodeData
if TYPE_CHECKING:
from core.workflow.graph_engine import Graph, GraphEngine, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__)
GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[GenericNodeData]
class BaseNode:
_node_type: ClassVar[NodeType]
def __init__(
@ -57,7 +50,8 @@ class BaseNode(Generic[GenericNodeData]):
self.node_id = node_id
node_data = self._node_data_cls.model_validate(config.get("data", {}))
self.node_data = node_data
@abstractmethod
def from_dict(self, data: Mapping[str, Any]) -> None: ...
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:

@ -21,10 +21,14 @@ from .exc import (
)
class CodeNode(BaseNode[CodeNodeData]):
_node_data_cls = CodeNodeData
class CodeNode(BaseNode):
_node_type = NodeType.CODE
node_data: CodeNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = CodeNodeData(**data)
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""

@ -36,15 +36,19 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
logger = logging.getLogger(__name__)
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
class DocumentExtractorNode(BaseNode):
"""
Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files.
"""
_node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR
node_data: DocumentExtractorNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = DocumentExtractorNodeData(**data)
@classmethod
def version(cls) -> str:
return "1"

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@ -5,10 +8,14 @@ from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.enums import NodeType
class EndNode(BaseNode[EndNodeData]):
_node_data_cls = EndNodeData
class EndNode(BaseNode):
_node_type = NodeType.END
node_data: EndNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = EndNodeData(**data)
@classmethod
def version(cls) -> str:
return "1"

@ -32,10 +32,14 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
class HttpRequestNode(BaseNode[HttpRequestNodeData]):
_node_data_cls = HttpRequestNodeData
class HttpRequestNode(BaseNode):
_node_type = NodeType.HTTP_REQUEST
node_data: HttpRequestNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = HttpRequestNodeData(**data)
@classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
return {

@ -13,10 +13,14 @@ from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor
class IfElseNode(BaseNode[IfElseNodeData]):
_node_data_cls = IfElseNodeData
class IfElseNode(BaseNode):
_node_type = NodeType.IF_ELSE
node_data: IfElseNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = IfElseNodeData(**data)
@classmethod
def version(cls) -> str:
return "1"

@ -56,14 +56,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class IterationNode(BaseNode[IterationNodeData]):
class IterationNode(BaseNode):
"""
Iteration Node.
"""
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
node_data: IterationNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = IterationNodeData(**data)
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return {

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@ -5,14 +8,18 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.iteration.entities import IterationStartNodeData
class IterationStartNode(BaseNode[IterationStartNodeData]):
class IterationStartNode(BaseNode):
"""
Iteration Start Node.
"""
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START
node_data: IterationStartNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = IterationStartNodeData(**data)
@classmethod
def version(cls) -> str:
return "1"

@ -1,4 +1,4 @@
from collections.abc import Callable, Sequence
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal, Union
from core.file import File
@ -13,10 +13,14 @@ from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_data_cls = ListOperatorNodeData
class ListOperatorNode(BaseNode):
_node_type = NodeType.LIST_OPERATOR
node_data: ListOperatorNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = ListOperatorNodeData(**data)
@classmethod
def version(cls) -> str:
return "1"

@ -90,17 +90,16 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
logger = logging.getLogger(__name__)
class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
class LLMNode(BaseNode):
_node_type = NodeType.LLM
node_data: LLMNodeData
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
@ -138,6 +137,9 @@ class LLMNode(BaseNode[LLMNodeData]):
)
self._llm_file_saver = llm_file_saver
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = LLMNodeData(**data)
@classmethod
def version(cls) -> str:
return "1"

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@ -5,14 +8,18 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.loop.entities import LoopEndNodeData
class LoopEndNode(BaseNode[LoopEndNodeData]):
class LoopEndNode(BaseNode):
"""
Loop End Node.
"""
_node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END
node_data: LoopEndNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopEndNodeData(**data)
@classmethod
def version(cls) -> str:
return "1"

@ -43,14 +43,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class LoopNode(BaseNode[LoopNodeData]):
class LoopNode(BaseNode):
"""
Loop Node.
"""
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
node_data: LoopNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopNodeData(**data)
@classmethod
def version(cls) -> str:
return "1"

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@ -5,14 +8,18 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.loop.entities import LoopStartNodeData
class LoopStartNode(BaseNode[LoopStartNodeData]):
class LoopStartNode(BaseNode):
"""
Loop Start Node.
"""
_node_data_cls = LoopStartNodeData
_node_type = NodeType.LOOP_START
node_data: LoopStartNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopStartNodeData(**data)
@classmethod
def version(cls) -> str:
return "1"

@ -91,10 +91,13 @@ class ParameterExtractorNode(BaseNode):
Parameter Extractor Node.
"""
# FIXME: figure out why here is different from super class
_node_data_cls = ParameterExtractorNodeData # type: ignore
_node_type = NodeType.PARAMETER_EXTRACTOR
node_data: ParameterExtractorNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = ParameterExtractorNodeData(**data)
_model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -6,10 +9,14 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.start.entities import StartNodeData
class StartNode(BaseNode[StartNodeData]):
_node_data_cls = StartNodeData
class StartNode(BaseNode):
_node_type = NodeType.START
node_data: StartNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = StartNodeData(**data)
@classmethod
def version(cls) -> str:
return "1"

@ -12,10 +12,14 @@ from core.workflow.nodes.template_transform.entities import TemplateTransformNod
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
_node_data_cls = TemplateTransformNodeData
class TemplateTransformNode(BaseNode):
_node_type = NodeType.TEMPLATE_TRANSFORM
node_data: TemplateTransformNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = TemplateTransformNodeData(**data)
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""

@ -37,14 +37,18 @@ from .exc import (
)
class ToolNode(BaseNode[ToolNodeData]):
class ToolNode(BaseNode):
"""
Tool Node
"""
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
node_data: ToolNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = ToolNodeData(**data)
@classmethod
def version(cls) -> str:
return "1"
@ -124,7 +128,11 @@ class ToolNode(BaseNode[ToolNodeData]):
try:
# convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
yield from self._transform_message(
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
)
except (PluginDaemonClientSideError, ToolInvokeError) as e:
yield RunCompletedEvent(
run_result=NodeRunResult(

@ -1,4 +1,5 @@
from collections.abc import Mapping
from typing import Any
from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult
@ -8,10 +9,14 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
class VariableAggregatorNode(BaseNode):
_node_type = NodeType.VARIABLE_AGGREGATOR
node_data: VariableAssignerNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerNodeData(**data)
@classmethod
def version(cls) -> str:
return "1"

@ -22,11 +22,15 @@ if TYPE_CHECKING:
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls = VariableAssignerData
class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
node_data: VariableAssignerData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerData(**data)
def __init__(
self,
id: str,

@ -54,10 +54,14 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
mapping[key] = selector
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
node_data: VariableAssignerNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerNodeData(**data)
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory()

Loading…
Cancel
Save