diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index 28b6a5342e..d57a0b8183 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -2,7 +2,7 @@ from core.workflow.nodes.base import BaseNode class WorkflowNodeRunFailedError(Exception): - def __init__(self, node_instance: BaseNode, error: str): - self.node_instance = node_instance + def __init__(self, node: BaseNode, error: str): + self.node = node self.error = error - super().__init__(f"Node {node_instance.node_title} run failed: {error}") + super().__init__(f"Node {node.title} run failed: {error}") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 26e5e89d9a..b5bdd20236 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -267,7 +267,7 @@ class GraphEngine: previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None # init workflow run state - node_instance = node_cls( + node = node_cls( id=route_node_state.id, config=node_config, graph_init_params=self.init_params, @@ -276,11 +276,11 @@ class GraphEngine: previous_node_id=previous_node_id, thread_pool_id=self.thread_pool_id, ) - node_instance.init_node_data(node_config.get("data", {})) + node.init_node_data(node_config.get("data", {})) try: # run node generator = self._run_node( - node_instance=node_instance, + node=node, route_node_state=route_node_state, parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, @@ -308,16 +308,16 @@ class GraphEngine: route_node_state.failed_reason = str(e) yield NodeRunFailedEvent( error=str(e), - id=node_instance.id, + id=node.id, node_id=next_node_id, node_type=node_type, - node_data=node_instance.get_base_node_data(), + node_data=node._get_base_node_data(), route_node_state=route_node_state, parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) raise e @@ -339,7 +339,7 @@ class GraphEngine: edge = edge_mappings[0] if ( previous_route_node_state.status == RouteNodeState.Status.EXCEPTION - and node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH + and node.error_strategy == ErrorStrategy.FAIL_BRANCH and edge.run_condition is None ): break @@ -415,8 +415,8 @@ class GraphEngine: next_node_id = final_node_id elif ( - node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH - and node_instance.continue_on_error + node.continue_on_error + and node.error_strategy == ErrorStrategy.FAIL_BRANCH and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION ): break @@ -599,7 +599,7 @@ class GraphEngine: def _run_node( self, - node_instance: BaseNode, + node: BaseNode, route_node_state: RouteNodeState, parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, @@ -613,29 +613,29 @@ class GraphEngine: # trigger node run start event agent_strategy = ( AgentNodeStrategyInit( - name=cast(AgentNodeData, node_instance.get_base_node_data()).agent_strategy_name, - icon=cast(AgentNode, node_instance).agent_strategy_icon, + name=cast(AgentNodeData, node._get_base_node_data()).agent_strategy_name, + icon=cast(AgentNode, node).agent_strategy_icon, ) - if node_instance.node_type == NodeType.AGENT + if node.type_ == NodeType.AGENT else None ) yield NodeRunStartedEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.get_base_node_data(), + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node._get_base_node_data(), route_node_state=route_node_state, - predecessor_node_id=node_instance.previous_node_id, + predecessor_node_id=node.previous_node_id, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, agent_strategy=agent_strategy, - node_version=node_instance.version(), + node_version=node.version(), ) - max_retries = node_instance.node_retry_config.max_retries - retry_interval = node_instance.node_retry_config.retry_interval_seconds + max_retries = node.retry_config.max_retries + retry_interval = node.retry_config.retry_interval_seconds retries = 0 should_continue_retry = True while should_continue_retry and retries <= max_retries: @@ -644,7 +644,7 @@ class GraphEngine: retry_start_at = datetime.now(UTC).replace(tzinfo=None) # yield control to other threads time.sleep(0.001) - event_stream = node_instance.run() + event_stream = node.run() for event in event_stream: if isinstance(event, GraphEngineEvent): # add parallel info to iteration event @@ -660,21 +660,21 @@ class GraphEngine: if run_result.status == WorkflowNodeExecutionStatus.FAILED: if ( retries == max_retries - and node_instance.node_type == NodeType.HTTP_REQUEST + and node.type_ == NodeType.HTTP_REQUEST and run_result.outputs - and not node_instance.continue_on_error + and not node.continue_on_error ): run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED - if node_instance.retry and retries < max_retries: + if node.retry and retries < max_retries: retries += 1 route_node_state.node_run_result = run_result yield NodeRunRetryEvent( id=str(uuid.uuid4()), - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.get_base_node_data(), + node_id=node.node_id, + node_type=node.type_, + node_data=node._get_base_node_data(), route_node_state=route_node_state, - predecessor_node_id=node_instance.previous_node_id, + predecessor_node_id=node.previous_node_id, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, @@ -682,17 +682,17 @@ class GraphEngine: error=run_result.error or "Unknown error", retry_index=retries, start_at=retry_start_at, - node_version=node_instance.version(), + node_version=node.version(), ) time.sleep(retry_interval) break route_node_state.set_finished(run_result=run_result) if run_result.status == WorkflowNodeExecutionStatus.FAILED: - if node_instance.continue_on_error: + if node.continue_on_error: # if run failed, handle error run_result = self._handle_continue_on_error( - node_instance, + node, event.run_result, self.graph_runtime_state.variable_pool, handle_exceptions=handle_exceptions, @@ -703,44 +703,44 @@ class GraphEngine: for variable_key, variable_value in run_result.outputs.items(): # append variables to variable pool recursively self._append_variables_recursively( - node_id=node_instance.node_id, + node_id=node.node_id, variable_key_list=[variable_key], variable_value=variable_value, ) yield NodeRunExceptionEvent( error=run_result.error or "System Error", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.get_base_node_data(), + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node._get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) should_continue_retry = False else: yield NodeRunFailedEvent( error=route_node_state.failed_reason or "Unknown error.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.get_base_node_data(), + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node._get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) should_continue_retry = False elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: if ( - node_instance.continue_on_error - and self.graph.edge_mapping.get(node_instance.node_id) - and node_instance.error_strategy is ErrorStrategy.FAIL_BRANCH + node.continue_on_error + and self.graph.edge_mapping.get(node.node_id) + and node.error_strategy is ErrorStrategy.FAIL_BRANCH ): run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS if run_result.metadata and run_result.metadata.get( @@ -760,7 +760,7 @@ class GraphEngine: for variable_key, variable_value in run_result.outputs.items(): # append variables to variable pool recursively self._append_variables_recursively( - node_id=node_instance.node_id, + node_id=node.node_id, variable_key_list=[variable_key], variable_value=variable_value, ) @@ -785,26 +785,26 @@ class GraphEngine: run_result.metadata = metadata_dict yield NodeRunSucceededEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.get_base_node_data(), + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node._get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) should_continue_retry = False break elif isinstance(event, RunStreamChunkEvent): yield NodeRunStreamChunkEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.get_base_node_data(), + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node._get_base_node_data(), chunk_content=event.chunk_content, from_variable_selector=event.from_variable_selector, route_node_state=route_node_state, @@ -812,14 +812,14 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) elif isinstance(event, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.get_base_node_data(), + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node._get_base_node_data(), retriever_resources=event.retriever_resources, context=event.context, route_node_state=route_node_state, @@ -827,7 +827,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) except GenerateTaskStoppedError: # trigger node run failed event @@ -835,20 +835,20 @@ class GraphEngine: route_node_state.failed_reason = "Workflow stopped." yield NodeRunFailedEvent( error="Workflow stopped.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.get_base_node_data(), + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node._get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) return except Exception as e: - logger.exception(f"Node {node_instance.node_title} run failed") + logger.exception(f"Node {node.title} run failed") raise e def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): @@ -888,22 +888,14 @@ class GraphEngine: def _handle_continue_on_error( self, - node_instance: BaseNode, + node: BaseNode, error_result: NodeRunResult, variable_pool: VariablePool, handle_exceptions: list[str] = [], ) -> NodeRunResult: - """ - handle continue on error when self._should_continue_on_error is True - - - :param error_result (NodeRunResult): error run result - :param variable_pool (VariablePool): variable pool - :return: excption run result - """ # add error message and error type to variable pool - variable_pool.add([node_instance.node_id, "error_message"], error_result.error) - variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) + variable_pool.add([node.node_id, "error_message"], error_result.error) + variable_pool.add([node.node_id, "error_type"], error_result.error_type) # add error message to handle_exceptions handle_exceptions.append(error_result.error or "") node_error_args: dict[str, Any] = { @@ -911,21 +903,21 @@ class GraphEngine: "error": error_result.error, "inputs": error_result.inputs, "metadata": { - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.error_strategy, + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy, }, } - if node_instance.error_strategy is ErrorStrategy.DEFAULT_VALUE: + if node.error_strategy is ErrorStrategy.DEFAULT_VALUE: return NodeRunResult( **node_error_args, outputs={ - **node_instance.default_value_dict, + **node.default_value_dict, "error_message": error_result.error, "error_type": error_result.error_type, }, ) - elif node_instance.error_strategy is ErrorStrategy.FAIL_BRANCH: - if self.graph.edge_mapping.get(node_instance.node_id): + elif node.error_strategy is ErrorStrategy.FAIL_BRANCH: + if self.graph.edge_mapping.get(node.node_id): node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED return NodeRunResult( **node_error_args, diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 118cc4b657..c899736b63 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -68,22 +68,22 @@ class AgentNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = AgentNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod @@ -185,7 +185,7 @@ class AgentNode(BaseNode): parameters_for_log=parameters_for_log, user_id=self.user_id, tenant_id=self.tenant_id, - node_type=self.node_type, + node_type=self.type_, node_id=self.node_id, node_execution_id=self.id, ) diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 754b0121cb..72bb6d687f 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -25,22 +25,22 @@ class AnswerNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = AnswerNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 7d84e0e212..dcfed5eed2 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -122,9 +122,9 @@ class RetryConfig(BaseModel): class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + version: str = "1" error_strategy: Optional[ErrorStrategy] = None default_value: Optional[list[DefaultValue]] = None - version: str = "1" retry_config: RetryConfig = RetryConfig() @property diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 861b07d1f9..06c0ff2ca8 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -145,7 +145,7 @@ class BaseNode: return {} @property - def node_type(self) -> NodeType: + def type_(self) -> NodeType: return self._node_type @classmethod @@ -170,32 +170,32 @@ class BaseNode: # to BaseNodeData properties in a type-safe way @abstractmethod - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: """Get the error strategy for this node.""" ... @abstractmethod - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: """Get the retry configuration for this node.""" ... @abstractmethod - def get_title(self) -> str: + def _get_title(self) -> str: """Get the node title.""" ... @abstractmethod - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: """Get the node description.""" ... @abstractmethod - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: """Get the default values dictionary for this node.""" ... @abstractmethod - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: """Get the BaseNodeData object for this node.""" ... @@ -203,24 +203,24 @@ class BaseNode: @property def error_strategy(self) -> Optional[ErrorStrategy]: """Get the error strategy for this node.""" - return self.get_error_strategy() + return self._get_error_strategy() @property - def node_retry_config(self) -> RetryConfig: + def retry_config(self) -> RetryConfig: """Get the retry configuration for this node.""" - return self.get_retry_config() + return self._get_retry_config() @property - def node_title(self) -> str: + def title(self) -> str: """Get the node title.""" - return self.get_title() + return self._get_title() @property - def node_description(self) -> Optional[str]: + def description(self) -> Optional[str]: """Get the node description.""" - return self.get_description() + return self._get_description() @property def default_value_dict(self) -> dict[str, Any]: """Get the default values dictionary for this node.""" - return self.get_default_value_dict() + return self._get_default_value_dict() diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 41d6202d38..c20e019a1e 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -30,22 +30,22 @@ class CodeNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = CodeNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 53f7e1969e..4c7b7e52d5 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -50,22 +50,22 @@ class DocumentExtractorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = DocumentExtractorNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 06bf393cc5..17771811b2 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -17,22 +17,22 @@ class EndNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = EndNodeData(**data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 4bd4dde03f..3182240f6c 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -41,22 +41,22 @@ class HttpRequestNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = HttpRequestNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 03769c866e..7bb1565558 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -22,22 +22,22 @@ class IfElseNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = IfElseNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index cf6b3eaeb7..a71970af3f 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -69,22 +69,22 @@ class IterationNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = IterationNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod @@ -184,7 +184,7 @@ class IterationNode(BaseNode): yield IterationRunStartedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, + iteration_node_type=self.type_, iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, @@ -195,7 +195,7 @@ class IterationNode(BaseNode): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, + iteration_node_type=self.type_, iteration_node_data=self._node_data, index=0, pre_iteration_output=None, @@ -276,7 +276,7 @@ class IterationNode(BaseNode): yield IterationRunSucceededEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, + iteration_node_type=self.type_, iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, @@ -301,7 +301,7 @@ class IterationNode(BaseNode): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, + iteration_node_type=self.type_, iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, @@ -461,7 +461,7 @@ class IterationNode(BaseNode): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, + iteration_node_type=self.type_, iteration_node_data=self._node_data, parallel_mode_run_id=parallel_mode_run_id, start_at=start_at, @@ -475,7 +475,7 @@ class IterationNode(BaseNode): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, + iteration_node_type=self.type_, iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, @@ -510,7 +510,7 @@ class IterationNode(BaseNode): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, + iteration_node_type=self.type_, iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, @@ -531,7 +531,7 @@ class IterationNode(BaseNode): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, + iteration_node_type=self.type_, iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, @@ -554,7 +554,7 @@ class IterationNode(BaseNode): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, + iteration_node_type=self.type_, iteration_node_data=self._node_data, parallel_mode_run_id=parallel_mode_run_id, start_at=start_at, @@ -568,7 +568,7 @@ class IterationNode(BaseNode): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, + iteration_node_type=self.type_, iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, @@ -607,7 +607,7 @@ class IterationNode(BaseNode): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, + iteration_node_type=self.type_, iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, @@ -620,7 +620,7 @@ class IterationNode(BaseNode): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, + iteration_node_type=self.type_, iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 2fe07594c6..a83ecaf335 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -16,28 +16,28 @@ class IterationStartNode(BaseNode): _node_type = NodeType.ITERATION_START - node_data: IterationStartNodeData + _node_data: IterationStartNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = IterationStartNodeData(**data) + self._node_data = IterationStartNodeData(**data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config - def get_title(self) -> str: - return self.node_data.title + def _get_title(self) -> str: + return self._node_data.title - def get_description(self) -> Optional[str]: - return self.node_data.desc + def _get_description(self) -> Optional[str]: + return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: - return self.node_data + def _get_base_node_data(self) -> BaseNodeData: + return self._node_data @classmethod def version(cls) -> str: diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 323b47cd40..09f44be099 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -128,22 +128,22 @@ class KnowledgeRetrievalNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = KnowledgeRetrievalNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index e64f5dc047..4eacb3c6c9 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -22,22 +22,22 @@ class ListOperatorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = ListOperatorNodeData(**data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 3559e3a00b..e6a5796e1e 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -141,22 +141,22 @@ class LLMNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = LLMNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 3a93621aa4..3c6b9ec2ed 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -21,22 +21,22 @@ class LoopEndNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = LoopEndNodeData(**data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 2fa8924b07..35c7bd1c24 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -56,22 +56,22 @@ class LoopNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = LoopNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod @@ -150,7 +150,7 @@ class LoopNode(BaseNode): yield LoopRunStartedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, + loop_node_type=self.type_, loop_node_data=self._node_data, start_at=start_at, inputs=inputs, @@ -207,7 +207,7 @@ class LoopNode(BaseNode): yield LoopRunSucceededEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, + loop_node_type=self.type_, loop_node_data=self._node_data, start_at=start_at, inputs=inputs, @@ -240,7 +240,7 @@ class LoopNode(BaseNode): yield LoopRunFailedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, + loop_node_type=self.type_, loop_node_data=self._node_data, start_at=start_at, inputs=inputs, @@ -343,7 +343,7 @@ class LoopNode(BaseNode): yield LoopRunFailedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, + loop_node_type=self.type_, loop_node_data=self._node_data, start_at=start_at, inputs=inputs, @@ -374,7 +374,7 @@ class LoopNode(BaseNode): yield LoopRunFailedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, + loop_node_type=self.type_, loop_node_data=self._node_data, start_at=start_at, inputs=inputs, @@ -423,7 +423,7 @@ class LoopNode(BaseNode): yield LoopRunNextEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, + loop_node_type=self.type_, loop_node_data=self._node_data, index=next_index, pre_loop_output=self._node_data.outputs, diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index 52afff0c02..d570daaab5 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -21,22 +21,22 @@ class LoopStartNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = LoopStartNodeData(**data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index cbb9738833..d436c7e706 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -99,22 +99,22 @@ class ParameterExtractorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = ParameterExtractorNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data _model_instance: Optional[ModelInstance] = None diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index ec8e6c10fd..0388065c25 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -86,22 +86,22 @@ class QuestionClassifierNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = QuestionClassifierNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 7de764ae9a..1c9c6dcc0b 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -18,22 +18,22 @@ class StartNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = StartNodeData(**data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index c7fbef02c5..366153ca0c 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -21,22 +21,22 @@ class TemplateTransformNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = TemplateTransformNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 8942686510..1d82ebcbd2 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -403,22 +403,22 @@ class ToolNode(BaseNode): return result - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @property diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index a98dd329e1..23e849c4df 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -18,22 +18,22 @@ class VariableAggregatorNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = VariableAssignerNodeData(**data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data @classmethod diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index e6ecc3f936..3343c3d27c 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -32,22 +32,22 @@ class VariableAssignerNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = VariableAssignerData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data def __init__( diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 4c4cc8080d..7104eb874a 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -61,22 +61,22 @@ class VariableAssignerNode(BaseNode): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = VariableAssignerNodeData.model_validate(data) - def get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> Optional[ErrorStrategy]: return self._node_data.error_strategy - def get_retry_config(self) -> RetryConfig: + def _get_retry_config(self) -> RetryConfig: return self._node_data.retry_config - def get_title(self) -> str: + def _get_title(self) -> str: return self._node_data.title - def get_description(self) -> Optional[str]: + def _get_description(self) -> Optional[str]: return self._node_data.desc - def get_default_value_dict(self) -> dict[str, Any]: + def _get_default_value_dict(self) -> dict[str, Any]: return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: + def _get_base_node_data(self) -> BaseNodeData: return self._node_data def _conv_var_updater_factory(self) -> ConversationVariableUpdater: diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 89061c746b..9af5181626 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -146,7 +146,7 @@ class WorkflowEntry: graph = Graph.init(graph_config=workflow.graph_dict) # init workflow run state - node_instance = node_cls( + node = node_cls( id=str(uuid.uuid4()), config=node_config, graph_init_params=GraphInitParams( @@ -190,17 +190,11 @@ class WorkflowEntry: try: # run node - generator = node_instance.run() + generator = node.run() except Exception as e: - logger.exception( - "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", - workflow.id, - node_instance.id, - node_instance.node_type, - node_instance.version(), - ) - raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) - return node_instance, generator + logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}") + raise WorkflowNodeRunFailedError(node=node, error=str(e)) + return node, generator @classmethod def run_free_node( @@ -262,7 +256,7 @@ class WorkflowEntry: node_cls = cast(type[BaseNode], node_cls) # init workflow run state - node_instance: BaseNode = node_cls( + node: BaseNode = node_cls( id=str(uuid.uuid4()), config=node_config, graph_init_params=GraphInitParams( @@ -297,17 +291,12 @@ class WorkflowEntry: ) # run node - generator = node_instance.run() + generator = node.run() - return node_instance, generator + return node, generator except Exception as e: - logger.exception( - "error while running node_instance, node_id=%s, type=%s, version=%s", - node_instance.id, - node_instance.node_type, - node_instance.version(), - ) - raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}") + raise WorkflowNodeRunFailedError(node=node, error=str(e)) @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0d5e5fd6aa..a3de1c981f 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -465,10 +465,10 @@ class WorkflowService: node_id: str, ) -> WorkflowNodeExecution: try: - node_instance, generator = invoke_node_fn() + node, node_events = invoke_node_fn() node_run_result: NodeRunResult | None = None - for event in generator: + for event in node_events: if isinstance(event, RunCompletedEvent): node_run_result = event.run_result @@ -479,18 +479,18 @@ class WorkflowService: if not node_run_result: raise ValueError("Node run failed with no run result") # single step debug mode error handling return - if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.continue_on_error: + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.continue_on_error: node_error_args: dict[str, Any] = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": node_run_result.error, "inputs": node_run_result.inputs, - "metadata": {"error_strategy": node_instance.error_strategy}, + "metadata": {"error_strategy": node.error_strategy}, } - if node_instance.error_strategy is ErrorStrategy.DEFAULT_VALUE: + if node.error_strategy is ErrorStrategy.DEFAULT_VALUE: node_run_result = NodeRunResult( **node_error_args, outputs={ - **node_instance.default_value_dict, + **node.default_value_dict, "error_message": node_run_result.error, "error_type": node_run_result.error_type, }, @@ -509,7 +509,7 @@ class WorkflowService: ) error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: - node_instance = e.node_instance + node = e.node run_succeeded = False node_run_result = None error = e.error @@ -520,8 +520,8 @@ class WorkflowService: workflow_id="", # This is a single-step execution, so no workflow ID index=1, node_id=node_id, - node_type=node_instance.node_type, - title=node_instance.node_title, + node_type=node.type_, + title=node.title, elapsed_time=time.perf_counter() - start_at, created_at=datetime.now(UTC).replace(tzinfo=None), finished_at=datetime.now(UTC).replace(tzinfo=None),