diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index bd4ccc1072..28b6a5342e 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -5,4 +5,4 @@ class WorkflowNodeRunFailedError(Exception): def __init__(self, node_instance: BaseNode, error: str): self.node_instance = node_instance self.error = error - super().__init__(f"Node {node_instance.node_data.title} run failed: {error}") + super().__init__(f"Node {node_instance.node_title} run failed: {error}") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index a97443bd6f..26e5e89d9a 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -311,7 +311,7 @@ class GraphEngine: id=node_instance.id, node_id=next_node_id, node_type=node_type, - node_data=node_instance.node_data, + node_data=node_instance.get_base_node_data(), route_node_state=route_node_state, parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, @@ -339,7 +339,7 @@ class GraphEngine: edge = edge_mappings[0] if ( previous_route_node_state.status == RouteNodeState.Status.EXCEPTION - and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH + and node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH and edge.run_condition is None ): break @@ -415,7 +415,7 @@ class GraphEngine: next_node_id = final_node_id elif ( - node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH + node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH and node_instance.continue_on_error and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION ): @@ -613,7 +613,7 @@ class GraphEngine: # trigger node run start event agent_strategy = ( AgentNodeStrategyInit( - name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name, + name=cast(AgentNodeData, node_instance.get_base_node_data()).agent_strategy_name, icon=cast(AgentNode, node_instance).agent_strategy_icon, ) if node_instance.node_type == NodeType.AGENT @@ -623,7 +623,7 @@ class GraphEngine: id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, - node_data=node_instance.node_data, + node_data=node_instance.get_base_node_data(), route_node_state=route_node_state, predecessor_node_id=node_instance.previous_node_id, parallel_id=parallel_id, @@ -634,8 +634,8 @@ class GraphEngine: node_version=node_instance.version(), ) - max_retries = node_instance.node_data.retry_config.max_retries - retry_interval = node_instance.node_data.retry_config.retry_interval_seconds + max_retries = node_instance.node_retry_config.max_retries + retry_interval = node_instance.node_retry_config.retry_interval_seconds retries = 0 should_continue_retry = True while should_continue_retry and retries <= max_retries: @@ -672,7 +672,7 @@ class GraphEngine: id=str(uuid.uuid4()), node_id=node_instance.node_id, node_type=node_instance.node_type, - node_data=node_instance.node_data, + node_data=node_instance.get_base_node_data(), route_node_state=route_node_state, predecessor_node_id=node_instance.previous_node_id, parallel_id=parallel_id, @@ -712,7 +712,7 @@ class GraphEngine: id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, - node_data=node_instance.node_data, + node_data=node_instance.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, @@ -727,7 +727,7 @@ class GraphEngine: id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, - node_data=node_instance.node_data, + node_data=node_instance.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, @@ -740,7 +740,7 @@ class GraphEngine: if ( node_instance.continue_on_error and self.graph.edge_mapping.get(node_instance.node_id) - and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH + and node_instance.error_strategy is ErrorStrategy.FAIL_BRANCH ): run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS if run_result.metadata and run_result.metadata.get( @@ -788,7 +788,7 @@ class GraphEngine: id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, - node_data=node_instance.node_data, + node_data=node_instance.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, @@ -804,7 +804,7 @@ class GraphEngine: id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, - node_data=node_instance.node_data, + node_data=node_instance.get_base_node_data(), chunk_content=event.chunk_content, from_variable_selector=event.from_variable_selector, route_node_state=route_node_state, @@ -819,7 +819,7 @@ class GraphEngine: id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, - node_data=node_instance.node_data, + node_data=node_instance.get_base_node_data(), retriever_resources=event.retriever_resources, context=event.context, route_node_state=route_node_state, @@ -838,7 +838,7 @@ class GraphEngine: id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, - node_data=node_instance.node_data, + node_data=node_instance.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, @@ -848,7 +848,7 @@ class GraphEngine: ) return except Exception as e: - logger.exception(f"Node {node_instance.node_data.title} run failed") + logger.exception(f"Node {node_instance.node_title} run failed") raise e def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): @@ -911,20 +911,20 @@ class GraphEngine: "error": error_result.error, "inputs": error_result.inputs, "metadata": { - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy, + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.error_strategy, }, } - if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + if node_instance.error_strategy is ErrorStrategy.DEFAULT_VALUE: return NodeRunResult( **node_error_args, outputs={ - **node_instance.node_data.default_value_dict, + **node_instance.default_value_dict, "error_message": error_result.error, "error_type": error_result.error_type, }, ) - elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH: + elif node_instance.error_strategy is ErrorStrategy.FAIL_BRANCH: if self.graph.edge_mapping.get(node_instance.node_id): node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED return NodeRunResult( diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 7cfcc30e7a..5b0ff24df6 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -36,7 +36,8 @@ from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db @@ -67,6 +68,24 @@ class AgentNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = AgentNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls) -> str: return "1" @@ -171,12 +190,14 @@ class AgentNode(BaseNode): node_execution_id=self.id, ) except PluginDaemonClientSideError as e: - error = AgentMessageTransformError(f"Failed to transform agent message: {str(e)}", original_error=e) + transform_error = AgentMessageTransformError( + f"Failed to transform agent message: {str(e)}", original_error=e + ) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - error=str(error), + error=str(transform_error), ) ) diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 1e510befa6..2eb6d34495 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any, Optional, cast from core.variables import ArrayFileSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult @@ -12,7 +12,8 @@ from core.workflow.nodes.answer.entities import ( VarGenerateRouteChunk, ) from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -24,6 +25,24 @@ class AnswerNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = AnswerNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index d853eb71be..7d84e0e212 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -128,7 +128,7 @@ class BaseNodeData(ABC, BaseModel): retry_config: RetryConfig = RetryConfig() @property - def default_value_dict(self): + def default_value_dict(self) -> dict[str, Any]: if self.default_value: return {item.key: item.value for item in self.default_value} return {} diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index b039accba4..7cb64f27cc 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -5,7 +5,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent if TYPE_CHECKING: @@ -17,6 +18,7 @@ logger = logging.getLogger(__name__) class BaseNode: _node_type: ClassVar[NodeType] + # Each subclass will declare: node_data: SpecificNodeData def __init__( self, @@ -185,3 +187,62 @@ class BaseNode: bool: if should retry """ return False + + # Abstract methods that subclasses must implement to provide access + # to BaseNodeData properties in a type-safe way + + @abstractmethod + def get_error_strategy(self) -> Optional[ErrorStrategy]: + """Get the error strategy for this node.""" + ... + + @abstractmethod + def get_retry_config(self) -> RetryConfig: + """Get the retry configuration for this node.""" + ... + + @abstractmethod + def get_title(self) -> str: + """Get the node title.""" + ... + + @abstractmethod + def get_description(self) -> Optional[str]: + """Get the node description.""" + ... + + @abstractmethod + def get_default_value_dict(self) -> dict[str, Any]: + """Get the default values dictionary for this node.""" + ... + + @abstractmethod + def get_base_node_data(self) -> BaseNodeData: + """Get the BaseNodeData object for this node.""" + ... + + # Public interface properties that delegate to abstract methods + @property + def error_strategy(self) -> Optional[ErrorStrategy]: + """Get the error strategy for this node.""" + return self.get_error_strategy() + + @property + def node_retry_config(self) -> RetryConfig: + """Get the retry configuration for this node.""" + return self.get_retry_config() + + @property + def node_title(self) -> str: + """Get the node title.""" + return self.get_title() + + @property + def node_description(self) -> Optional[str]: + """Get the node description.""" + return self.get_description() + + @property + def default_value_dict(self) -> dict[str, Any]: + """Get the default values dictionary for this node.""" + return self.get_default_value_dict() diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index c9eac98d39..4953017a84 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -11,8 +11,9 @@ from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.code.entities import CodeNodeData -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType from .exc import ( CodeNodeError, @@ -29,6 +30,24 @@ class CodeNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = CodeNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index bb79b6b914..fe530dbb47 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -5,7 +5,7 @@ import logging import os import tempfile from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any, Optional, cast import chardet import docx @@ -28,7 +28,8 @@ from core.variables.segments import ArrayStringSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from .entities import DocumentExtractorNodeData from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError @@ -49,6 +50,24 @@ class DocumentExtractorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = DocumentExtractorNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 745d8f898a..0a84716b3d 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,11 +1,12 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, Optional from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType class EndNode(BaseNode): @@ -16,6 +17,24 @@ class EndNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = EndNodeData(**data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 85320b58f2..95f4ad60fc 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -11,7 +11,8 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.http_request.executor import Executor from core.workflow.utils import variable_template_parser from factories import file_factory @@ -40,6 +41,24 @@ class HttpRequestNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = HttpRequestNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: return { diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index c417b88982..d8bde610f3 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Any, Literal, Optional from typing_extensions import deprecated @@ -7,7 +7,8 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.utils.condition.entities import Condition from core.workflow.utils.condition.processor import ConditionProcessor @@ -21,6 +22,24 @@ class IfElseNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = IfElseNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index f9479d559d..59e561de4f 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -36,7 +36,8 @@ from core.workflow.graph_engine.entities.event import ( ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from factories.variable_factory import build_segment @@ -68,6 +69,24 @@ class IterationNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = IterationNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: return { diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index ffafc6dcc2..2fe07594c6 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -1,10 +1,11 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, Optional from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.iteration.entities import IterationStartNodeData @@ -20,6 +21,24 @@ class IterationStartNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = IterationStartNodeData(**data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 197afd2252..c7c368e509 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -35,7 +35,8 @@ from core.variables.segments import ArrayObjectSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import ( ModelInvokeCompletedEvent, ) @@ -127,6 +128,24 @@ class KnowledgeRetrievalNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = KnowledgeRetrievalNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls): return "1" diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 00c9a070b9..1d28d01c2f 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Mapping, Sequence -from typing import Any, Literal, Union +from typing import Any, Literal, Optional, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -7,7 +7,8 @@ from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from .entities import ListOperatorNodeData from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError @@ -21,6 +22,24 @@ class ListOperatorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = ListOperatorNodeData(**data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 4e59c33a2e..1d10df5db6 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -59,7 +59,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import ( ModelInvokeCompletedEvent, NodeEvent, @@ -140,6 +141,24 @@ class LLMNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = LLMNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 53856c17c4..1190d3ec2d 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -1,10 +1,11 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, Optional from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.loop.entities import LoopEndNodeData @@ -20,6 +21,24 @@ class LoopEndNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = LoopEndNodeData(**data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 4d76eb0d66..e6b1866fdf 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -3,7 +3,7 @@ import logging import time from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, cast from configs import dify_config from core.variables import ( @@ -30,7 +30,8 @@ from core.workflow.graph_engine.entities.event import ( ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.loop.entities import LoopNodeData from core.workflow.utils.condition.processor import ConditionProcessor @@ -55,6 +56,24 @@ class LoopNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = LoopNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index 9362b4a38b..0268dcf543 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -1,10 +1,11 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, Optional from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.loop.entities import LoopStartNodeData @@ -20,6 +21,24 @@ class LoopStartNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = LoopStartNodeData(**data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 0b18e5e4f7..98041121fb 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -29,8 +29,9 @@ from core.variables.types import SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.utils import variable_template_parser from factories.variable_factory import build_segment_with_type @@ -98,6 +99,24 @@ class ParameterExtractorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = ParameterExtractorNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + _model_instance: Optional[ModelInstance] = None _model_config: Optional[ModelConfigWithCredentialsEntity] = None diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 318b3c0421..ed3323972b 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -11,9 +11,11 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import ModelInvokeCompletedEvent from core.workflow.nodes.llm import ( LLMNode, @@ -84,6 +86,24 @@ class QuestionClassifierNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = QuestionClassifierNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls): return "1" @@ -242,12 +262,12 @@ class QuestionClassifierNode(BaseNode): typed_node_data = QuestionClassifierNodeData.model_validate(node_data) variable_mapping = {"query": typed_node_data.query_variable_selector} - variable_selectors = [] + variable_selectors: list[VariableSelector] = [] if typed_node_data.instruction: variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping[variable_selector.variable] = list(variable_selector.value_selector) variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 2fc3dcb363..068950c112 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,11 +1,12 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, Optional from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.start.entities import StartNodeData @@ -17,6 +18,24 @@ class StartNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = StartNodeData(**data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index f2ebf33d9a..a75b84ac01 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -6,7 +6,8 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) @@ -20,6 +21,24 @@ class TemplateTransformNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = TemplateTransformNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 3dc70d909b..ad5df08b53 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import Any, cast +from typing import Any, Optional, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -19,8 +19,9 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from factories import file_factory @@ -402,6 +403,24 @@ class ToolNode(BaseNode): return result + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @property def continue_on_error(self) -> bool: return self.node_data.error_strategy is not None diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index b8448d2333..a8ade81020 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,11 +1,12 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, Optional from core.variables.segments import Segment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData @@ -17,6 +18,24 @@ class VariableAggregatorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = VariableAssignerNodeData(**data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 5433ff939e..74a3e6ec42 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -7,7 +7,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from factories import variable_factory @@ -31,6 +32,24 @@ class VariableAssignerNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = VariableAssignerData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + def __init__( self, id: str, diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 82c39d13dc..167079db28 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,6 +1,6 @@ import json from collections.abc import Callable, Mapping, MutableMapping, Sequence -from typing import Any, TypeAlias, cast +from typing import Any, Optional, TypeAlias, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable @@ -10,7 +10,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory @@ -62,6 +63,24 @@ class VariableAssignerNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self.node_data = VariableAssignerNodeData.model_validate(data) + def get_error_strategy(self) -> Optional[ErrorStrategy]: + return self.node_data.error_strategy + + def get_retry_config(self) -> RetryConfig: + return self.node_data.retry_config + + def get_title(self) -> str: + return self.node_data.title + + def get_description(self) -> Optional[str]: + return self.node_data.desc + + def get_default_value_dict(self) -> dict[str, Any]: + return self.node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self.node_data + def _conv_var_updater_factory(self) -> ConversationVariableUpdater: return conversation_variable_updater_factory() diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 934453f87d..0d5e5fd6aa 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -484,13 +484,13 @@ class WorkflowService: "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": node_run_result.error, "inputs": node_run_result.inputs, - "metadata": {"error_strategy": node_instance.node_data.error_strategy}, + "metadata": {"error_strategy": node_instance.error_strategy}, } - if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + if node_instance.error_strategy is ErrorStrategy.DEFAULT_VALUE: node_run_result = NodeRunResult( **node_error_args, outputs={ - **node_instance.node_data.default_value_dict, + **node_instance.default_value_dict, "error_message": node_run_result.error, "error_type": node_run_result.error_type, }, @@ -521,7 +521,7 @@ class WorkflowService: index=1, node_id=node_id, node_type=node_instance.node_type, - title=node_instance.node_data.title, + title=node_instance.node_title, elapsed_time=time.perf_counter() - start_at, created_at=datetime.now(UTC).replace(tzinfo=None), finished_at=datetime.now(UTC).replace(tzinfo=None),