From ba40054f4476da180e77d164958bccb7bce19562 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 17 Jul 2025 23:29:11 +0800 Subject: [PATCH] refactor(nodes): rename `node_data` to `_node_data` Signed-off-by: -LAN- --- api/core/workflow/nodes/agent/agent_node.py | 30 ++++---- api/core/workflow/nodes/answer/answer_node.py | 18 ++--- api/core/workflow/nodes/code/code_node.py | 28 ++++---- .../workflow/nodes/document_extractor/node.py | 18 ++--- api/core/workflow/nodes/end/end_node.py | 18 ++--- api/core/workflow/nodes/http_request/node.py | 24 +++---- .../workflow/nodes/if_else/if_else_node.py | 24 +++---- .../nodes/iteration/iteration_node.py | 68 +++++++++---------- .../knowledge_retrieval_node.py | 20 +++--- api/core/workflow/nodes/list_operator/node.py | 42 ++++++------ api/core/workflow/nodes/llm/node.py | 58 ++++++++-------- api/core/workflow/nodes/loop/loop_end_node.py | 16 ++--- api/core/workflow/nodes/loop/loop_node.py | 50 +++++++------- .../workflow/nodes/loop/loop_start_node.py | 16 ++--- .../parameter_extractor_node.py | 18 ++--- .../question_classifier_node.py | 18 ++--- api/core/workflow/nodes/start/start_node.py | 16 ++--- .../template_transform_node.py | 20 +++--- api/core/workflow/nodes/tool/tool_node.py | 30 ++++---- .../variable_aggregator_node.py | 22 +++--- .../nodes/variable_assigner/v1/node.py | 26 +++---- .../nodes/variable_assigner/v2/node.py | 26 ++++--- .../workflow/nodes/test_code.py | 16 ++--- .../nodes/iteration/test_iteration.py | 6 +- 24 files changed, 313 insertions(+), 315 deletions(-) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5b0ff24df6..118cc4b657 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -63,28 +63,28 @@ class AgentNode(BaseNode): """ _node_type = NodeType.AGENT - node_data: AgentNodeData + _node_data: AgentNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = AgentNodeData.model_validate(data) + self._node_data = AgentNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls) -> str: @@ -94,7 +94,7 @@ class AgentNode(BaseNode): """ Run the agent node """ - node_data = cast(AgentNodeData, self.node_data) + node_data = cast(AgentNodeData, self._node_data) try: strategy = get_plugin_agent_strategy( @@ -160,18 +160,18 @@ class AgentNode(BaseNode): type=ToolInvokeMessage.MessageType.LOG, message=ToolInvokeMessage.LogMessage( id=str(uuid.uuid4()), - label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}", + label=f"Agent Strategy: {cast(AgentNodeData, self._node_data).agent_strategy_name}", parent_id=None, error=None, status=ToolInvokeMessage.LogMessage.LogStatus.START, data={ - "strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, + "strategy": cast(AgentNodeData, self._node_data).agent_strategy_name, "parameters": parameters_for_log, "thought_process": "Agent strategy execution started", }, metadata={ "icon": self.agent_strategy_icon, - "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, + "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name, }, ), ) @@ -180,7 +180,7 @@ class AgentNode(BaseNode): messages=message_stream, tool_info={ "icon": self.agent_strategy_icon, - "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, + "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name, }, parameters_for_log=parameters_for_log, user_id=self.user_id, @@ -299,7 +299,7 @@ class AgentNode(BaseNode): ) extra = tool.get("extra", {}) - runtime_variable_pool = variable_pool if self.node_data.version != "1" else None + runtime_variable_pool = variable_pool if self._node_data.version != "1" else None tool_runtime = ToolManager.get_agent_tool_runtime( self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool ) @@ -415,7 +415,7 @@ class AgentNode(BaseNode): plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" - == cast(AgentNodeData, self.node_data).agent_strategy_provider_name + == cast(AgentNodeData, self._node_data).agent_strategy_provider_name ) icon = current_plugin.declaration.icon except StopIteration: diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 2eb6d34495..754b0121cb 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -20,28 +20,28 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser class AnswerNode(BaseNode): _node_type = NodeType.ANSWER - node_data: AnswerNodeData + _node_data: AnswerNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = AnswerNodeData.model_validate(data) + self._node_data = AnswerNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls) -> str: @@ -53,7 +53,7 @@ class AnswerNode(BaseNode): :return: """ # generate routes - generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data) + generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data) answer = "" files = [] diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 4953017a84..41d6202d38 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -25,28 +25,28 @@ from .exc import ( class CodeNode(BaseNode): _node_type = NodeType.CODE - node_data: CodeNodeData + _node_data: CodeNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = CodeNodeData.model_validate(data) + self._node_data = CodeNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: @@ -70,12 +70,12 @@ class CodeNode(BaseNode): def _run(self) -> NodeRunResult: # Get code language - code_language = self.node_data.code_language - code = self.node_data.code + code_language = self._node_data.code_language + code = self._node_data.code # Get variables variables = {} - for variable_selector in self.node_data.variables: + for variable_selector in self._node_data.variables: variable_name = variable_selector.variable variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if isinstance(variable, ArrayFileSegment): @@ -91,7 +91,7 @@ class CodeNode(BaseNode): ) # Transform result - result = self._transform_result(result=result, output_schema=self.node_data.outputs) + result = self._transform_result(result=result, output_schema=self._node_data.outputs) except (CodeExecutionError, CodeNodeError) as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ @@ -369,8 +369,8 @@ class CodeNode(BaseNode): @property def continue_on_error(self) -> bool: - return self.node_data.error_strategy is not None + return self._node_data.error_strategy is not None @property def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled + return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index fe530dbb47..53f7e1969e 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -45,35 +45,35 @@ class DocumentExtractorNode(BaseNode): _node_type = NodeType.DOCUMENT_EXTRACTOR - node_data: DocumentExtractorNodeData + _node_data: DocumentExtractorNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = DocumentExtractorNodeData.model_validate(data) + self._node_data = DocumentExtractorNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls) -> str: return "1" def _run(self): - variable_selector = self.node_data.variable_selector + variable_selector = self._node_data.variable_selector variable = self.graph_runtime_state.variable_pool.get(variable_selector) if variable is None: diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 0a84716b3d..06bf393cc5 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -12,28 +12,28 @@ from core.workflow.nodes.enums import ErrorStrategy, NodeType class EndNode(BaseNode): _node_type = NodeType.END - node_data: EndNodeData + _node_data: EndNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = EndNodeData(**data) + self._node_data = EndNodeData(**data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls) -> str: @@ -44,7 +44,7 @@ class EndNode(BaseNode): Run node :return: """ - output_variables = self.node_data.outputs + output_variables = self._node_data.outputs outputs = {} for variable_selector in output_variables: diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 95f4ad60fc..4bd4dde03f 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -36,28 +36,28 @@ logger = logging.getLogger(__name__) class HttpRequestNode(BaseNode): _node_type = NodeType.HTTP_REQUEST - node_data: HttpRequestNodeData + _node_data: HttpRequestNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = HttpRequestNodeData.model_validate(data) + self._node_data = HttpRequestNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: @@ -92,8 +92,8 @@ class HttpRequestNode(BaseNode): process_data = {} try: http_executor = Executor( - node_data=self.node_data, - timeout=self._get_request_timeout(self.node_data), + node_data=self._node_data, + timeout=self._get_request_timeout(self._node_data), variable_pool=self.graph_runtime_state.variable_pool, max_retries=0, ) @@ -246,8 +246,8 @@ class HttpRequestNode(BaseNode): @property def continue_on_error(self) -> bool: - return self.node_data.error_strategy is not None + return self._node_data.error_strategy is not None @property def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled + return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index d8bde610f3..03769c866e 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -17,28 +17,28 @@ from core.workflow.utils.condition.processor import ConditionProcessor class IfElseNode(BaseNode): _node_type = NodeType.IF_ELSE - node_data: IfElseNodeData + _node_data: IfElseNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = IfElseNodeData.model_validate(data) + self._node_data = IfElseNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls) -> str: @@ -59,8 +59,8 @@ class IfElseNode(BaseNode): condition_processor = ConditionProcessor() try: # Check if the new cases structure is used - if self.node_data.cases: - for case in self.node_data.cases: + if self._node_data.cases: + for case in self._node_data.cases: input_conditions, group_result, final_result = condition_processor.process_conditions( variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions, @@ -86,8 +86,8 @@ class IfElseNode(BaseNode): input_conditions, group_result, final_result = _should_not_use_old_function( condition_processor=condition_processor, variable_pool=self.graph_runtime_state.variable_pool, - conditions=self.node_data.conditions or [], - operator=self.node_data.logical_operator or "and", + conditions=self._node_data.conditions or [], + operator=self._node_data.logical_operator or "and", ) selected_case_id = "true" if final_result else "false" diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 59e561de4f..cf6b3eaeb7 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -64,28 +64,28 @@ class IterationNode(BaseNode): _node_type = NodeType.ITERATION - node_data: IterationNodeData + _node_data: IterationNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = IterationNodeData.model_validate(data) + self._node_data = IterationNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: @@ -106,10 +106,10 @@ class IterationNode(BaseNode): """ Run the node. """ - variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) + variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector) if not variable: - raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") + raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found") if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable): raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") @@ -139,10 +139,10 @@ class IterationNode(BaseNode): graph_config = self.graph_config - if not self.node_data.start_node_id: + if not self._node_data.start_node_id: raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found") - root_node_id = self.node_data.start_node_id + root_node_id = self._node_data.start_node_id # init graph iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) @@ -185,7 +185,7 @@ class IterationNode(BaseNode): iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, metadata={"iterator_length": len(iterator_list_value)}, @@ -196,7 +196,7 @@ class IterationNode(BaseNode): iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_data=self._node_data, index=0, pre_iteration_output=None, duration=None, @@ -204,11 +204,11 @@ class IterationNode(BaseNode): iter_run_map: dict[str, float] = {} outputs: list[Any] = [None] * len(iterator_list_value) try: - if self.node_data.is_parallel: + if self._node_data.is_parallel: futures: list[Future] = [] q: Queue = Queue() thread_pool = GraphEngineThreadPool( - max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT + max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT ) for index, item in enumerate(iterator_list_value): future: Future = thread_pool.submit( @@ -265,7 +265,7 @@ class IterationNode(BaseNode): iteration_graph=iteration_graph, iter_run_map=iter_run_map, ) - if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: outputs = [output for output in outputs if output is not None] # Flatten the list of lists @@ -277,7 +277,7 @@ class IterationNode(BaseNode): iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -302,7 +302,7 @@ class IterationNode(BaseNode): iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -394,7 +394,7 @@ class IterationNode(BaseNode): """ if not isinstance(event, BaseNodeEvent): return event - if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): + if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent): event.parallel_mode_run_id = parallel_mode_run_id iter_metadata = { @@ -457,12 +457,12 @@ class IterationNode(BaseNode): elif isinstance(event, BaseGraphEvent): if isinstance(event, GraphRunFailedEvent): # iteration run failed - if self.node_data.is_parallel: + if self._node_data.is_parallel: yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_data=self._node_data, parallel_mode_run_id=parallel_mode_run_id, start_at=start_at, inputs=inputs, @@ -476,7 +476,7 @@ class IterationNode(BaseNode): iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -497,7 +497,7 @@ class IterationNode(BaseNode): event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id ) if isinstance(event, NodeRunFailedEvent): - if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: + if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: yield NodeInIterationFailedEvent( **metadata_event.model_dump(), ) @@ -511,14 +511,14 @@ class IterationNode(BaseNode): iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=None, duration=duration, ) return - elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: yield NodeInIterationFailedEvent( **metadata_event.model_dump(), ) @@ -532,14 +532,14 @@ class IterationNode(BaseNode): iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=None, duration=duration, ) return - elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: + elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED: yield NodeInIterationFailedEvent( **metadata_event.model_dump(), ) @@ -550,12 +550,12 @@ class IterationNode(BaseNode): variable_pool.remove([node_id]) # iteration run failed - if self.node_data.is_parallel: + if self._node_data.is_parallel: yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_data=self._node_data, parallel_mode_run_id=parallel_mode_run_id, start_at=start_at, inputs=inputs, @@ -569,7 +569,7 @@ class IterationNode(BaseNode): iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -588,7 +588,7 @@ class IterationNode(BaseNode): return yield metadata_event - current_output_segment = variable_pool.get(self.node_data.output_selector) + current_output_segment = variable_pool.get(self._node_data.output_selector) if current_output_segment is None: raise IterationNodeError("iteration output selector not found") current_iteration_output = current_output_segment.value @@ -608,7 +608,7 @@ class IterationNode(BaseNode): iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=current_iteration_output or None, @@ -621,7 +621,7 @@ class IterationNode(BaseNode): iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": None}, diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 5bb6804ef5..323b47cd40 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -86,7 +86,7 @@ default_retrieval_model = { class KnowledgeRetrievalNode(BaseNode): _node_type = NodeType.KNOWLEDGE_RETRIEVAL - node_data: KnowledgeRetrievalNodeData + _node_data: KnowledgeRetrievalNodeData # Instance attributes specific to LLMNode. # Output variable for file @@ -126,32 +126,32 @@ class KnowledgeRetrievalNode(BaseNode): self._llm_file_saver = llm_file_saver def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = KnowledgeRetrievalNodeData.model_validate(data) + self._node_data = KnowledgeRetrievalNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls): return "1" def _run(self) -> NodeRunResult: # type: ignore - node_data = cast(KnowledgeRetrievalNodeData, self.node_data) + node_data = cast(KnowledgeRetrievalNodeData, self._node_data) # extract variables variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector) if not isinstance(variable, StringSegment): @@ -545,7 +545,7 @@ class KnowledgeRetrievalNode(BaseNode): prompt_messages=prompt_messages, stop=stop, user_id=self.user_id, - structured_output_enabled=self.node_data.structured_output_enabled, + structured_output_enabled=self._node_data.structured_output_enabled, structured_output=None, file_saver=self._llm_file_saver, file_outputs=self._file_outputs, diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 1d28d01c2f..e64f5dc047 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -17,28 +17,28 @@ from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError class ListOperatorNode(BaseNode): _node_type = NodeType.LIST_OPERATOR - node_data: ListOperatorNodeData + _node_data: ListOperatorNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = ListOperatorNodeData(**data) + self._node_data = ListOperatorNodeData(**data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls) -> str: @@ -49,9 +49,9 @@ class ListOperatorNode(BaseNode): process_data: dict[str, list] = {} outputs: dict[str, Any] = {} - variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) + variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable) if variable is None: - error_message = f"Variable not found for selector: {self.node_data.variable}" + error_message = f"Variable not found for selector: {self._node_data.variable}" return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs ) @@ -71,7 +71,7 @@ class ListOperatorNode(BaseNode): ) if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): error_message = ( - f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " + f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " "or ArrayStringSegment" ) return NodeRunResult( @@ -87,19 +87,19 @@ class ListOperatorNode(BaseNode): try: # Filter - if self.node_data.filter_by.enabled: + if self._node_data.filter_by.enabled: variable = self._apply_filter(variable) # Extract - if self.node_data.extract_by.enabled: + if self._node_data.extract_by.enabled: variable = self._extract_slice(variable) # Order - if self.node_data.order_by.enabled: + if self._node_data.order_by.enabled: variable = self._apply_order(variable) # Slice - if self.node_data.limit.enabled: + if self._node_data.limit.enabled: variable = self._apply_slice(variable) outputs = { @@ -127,7 +127,7 @@ class ListOperatorNode(BaseNode): ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: filter_func: Callable[[Any], bool] result: list[Any] = [] - for condition in self.node_data.filter_by.conditions: + for condition in self._node_data.filter_by.conditions: if isinstance(variable, ArrayStringSegment): if not isinstance(condition.value, str): raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") @@ -160,14 +160,14 @@ class ListOperatorNode(BaseNode): self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: if isinstance(variable, ArrayStringSegment): - result = _order_string(order=self.node_data.order_by.value, array=variable.value) + result = _order_string(order=self._node_data.order_by.value, array=variable.value) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayNumberSegment): - result = _order_number(order=self.node_data.order_by.value, array=variable.value) + result = _order_number(order=self._node_data.order_by.value, array=variable.value) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayFileSegment): result = _order_file( - order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value ) variable = variable.model_copy(update={"value": result}) return variable @@ -175,13 +175,13 @@ class ListOperatorNode(BaseNode): def _apply_slice( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: - result = variable.value[: self.node_data.limit.size] + result = variable.value[: self._node_data.limit.size] return variable.model_copy(update={"value": result}) def _extract_slice( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: - value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) + value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text) if value < 1: raise ValueError(f"Invalid serial index: must be >= 1, got {value}") value -= 1 diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index f78328501c..3559e3a00b 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -99,7 +99,7 @@ logger = logging.getLogger(__name__) class LLMNode(BaseNode): _node_type = NodeType.LLM - node_data: LLMNodeData + _node_data: LLMNodeData # Instance attributes specific to LLMNode. # Output variable for file @@ -139,25 +139,25 @@ class LLMNode(BaseNode): self._llm_file_saver = llm_file_saver def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = LLMNodeData.model_validate(data) + self._node_data = LLMNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls) -> str: @@ -173,13 +173,13 @@ class LLMNode(BaseNode): try: # init messages template - self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) + self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template) # fetch variables and fetch values from variable pool - inputs = self._fetch_inputs(node_data=self.node_data) + inputs = self._fetch_inputs(node_data=self._node_data) # fetch jinja2 inputs - jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) + jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data) # merge inputs inputs.update(jinja_inputs) @@ -190,9 +190,9 @@ class LLMNode(BaseNode): files = ( llm_utils.fetch_files( variable_pool=variable_pool, - selector=self.node_data.vision.configs.variable_selector, + selector=self._node_data.vision.configs.variable_selector, ) - if self.node_data.vision.enabled + if self._node_data.vision.enabled else [] ) @@ -200,7 +200,7 @@ class LLMNode(BaseNode): node_inputs["#files#"] = [file.to_dict() for file in files] # fetch context value - generator = self._fetch_context(node_data=self.node_data) + generator = self._fetch_context(node_data=self._node_data) context = None for event in generator: if isinstance(event, RunRetrieverResourceEvent): @@ -211,7 +211,7 @@ class LLMNode(BaseNode): # fetch model config model_instance, model_config = LLMNode._fetch_model_config( - node_data_model=self.node_data.model, + node_data_model=self._node_data.model, tenant_id=self.tenant_id, ) @@ -219,13 +219,13 @@ class LLMNode(BaseNode): memory = llm_utils.fetch_memory( variable_pool=variable_pool, app_id=self.app_id, - node_data_memory=self.node_data.memory, + node_data_memory=self._node_data.memory, model_instance=model_instance, ) query = None - if self.node_data.memory: - query = self.node_data.memory.query_prompt_template + if self._node_data.memory: + query = self._node_data.memory.query_prompt_template if not query and ( query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) ): @@ -237,24 +237,24 @@ class LLMNode(BaseNode): context=context, memory=memory, model_config=model_config, - prompt_template=self.node_data.prompt_template, - memory_config=self.node_data.memory, - vision_enabled=self.node_data.vision.enabled, - vision_detail=self.node_data.vision.configs.detail, + prompt_template=self._node_data.prompt_template, + memory_config=self._node_data.memory, + vision_enabled=self._node_data.vision.enabled, + vision_detail=self._node_data.vision.configs.detail, variable_pool=variable_pool, - jinja2_variables=self.node_data.prompt_config.jinja2_variables, + jinja2_variables=self._node_data.prompt_config.jinja2_variables, tenant_id=self.tenant_id, ) # handle invoke result generator = LLMNode.invoke_llm( - node_data_model=self.node_data.model, + node_data_model=self._node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, user_id=self.user_id, - structured_output_enabled=self.node_data.structured_output_enabled, - structured_output=self.node_data.structured_output, + structured_output_enabled=self._node_data.structured_output_enabled, + structured_output=self._node_data.structured_output, file_saver=self._llm_file_saver, file_outputs=self._file_outputs, node_id=self.node_id, @@ -1010,7 +1010,7 @@ class LLMNode(BaseNode): """ Fetch model schema """ - model_name = self.node_data.model.name + model_name = self._node_data.model.name model_manager = ModelManager() model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name @@ -1089,11 +1089,11 @@ class LLMNode(BaseNode): @property def continue_on_error(self) -> bool: - return self.node_data.error_strategy is not None + return self._node_data.error_strategy is not None @property def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled + return self._node_data.retry_config.retry_enabled def _combine_message_content_with_role( diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 1190d3ec2d..3a93621aa4 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -16,28 +16,28 @@ class LoopEndNode(BaseNode): _node_type = NodeType.LOOP_END - node_data: LoopEndNodeData + _node_data: LoopEndNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = LoopEndNodeData(**data) + self._node_data = LoopEndNodeData(**data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls) -> str: diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index e6b1866fdf..2fa8924b07 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -51,28 +51,28 @@ class LoopNode(BaseNode): _node_type = NodeType.LOOP - node_data: LoopNodeData + _node_data: LoopNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = LoopNodeData.model_validate(data) + self._node_data = LoopNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls) -> str: @@ -81,17 +81,17 @@ class LoopNode(BaseNode): def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """Run the node.""" # Get inputs - loop_count = self.node_data.loop_count - break_conditions = self.node_data.break_conditions - logical_operator = self.node_data.logical_operator + loop_count = self._node_data.loop_count + break_conditions = self._node_data.break_conditions + logical_operator = self._node_data.logical_operator inputs = {"loop_count": loop_count} - if not self.node_data.start_node_id: + if not self._node_data.start_node_id: raise ValueError(f"field start_node_id in loop {self.node_id} not found") # Initialize graph - loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id) + loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id) if not loop_graph: raise ValueError("loop graph not found") @@ -101,8 +101,8 @@ class LoopNode(BaseNode): # Initialize loop variables loop_variable_selectors = {} - if self.node_data.loop_variables: - for loop_variable in self.node_data.loop_variables: + if self._node_data.loop_variables: + for loop_variable in self._node_data.loop_variables: value_processor = { "constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value), "variable": lambda var=loop_variable: variable_pool.get(var.value), @@ -151,7 +151,7 @@ class LoopNode(BaseNode): loop_id=self.id, loop_node_id=self.node_id, loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, metadata={"loop_length": loop_count}, @@ -208,10 +208,10 @@ class LoopNode(BaseNode): loop_id=self.id, loop_node_id=self.node_id, loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, - outputs=self.node_data.outputs, + outputs=self._node_data.outputs, steps=loop_count, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, @@ -229,7 +229,7 @@ class LoopNode(BaseNode): WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, - outputs=self.node_data.outputs, + outputs=self._node_data.outputs, inputs=inputs, ) ) @@ -241,7 +241,7 @@ class LoopNode(BaseNode): loop_id=self.id, loop_node_id=self.node_id, loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, steps=loop_count, @@ -344,7 +344,7 @@ class LoopNode(BaseNode): loop_id=self.id, loop_node_id=self.node_id, loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, steps=current_index, @@ -375,7 +375,7 @@ class LoopNode(BaseNode): loop_id=self.id, loop_node_id=self.node_id, loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, steps=current_index, @@ -411,7 +411,7 @@ class LoopNode(BaseNode): _outputs[loop_variable_key] = None _outputs["loop_round"] = current_index + 1 - self.node_data.outputs = _outputs + self._node_data.outputs = _outputs if check_break_result: return {"check_break_result": True} @@ -424,9 +424,9 @@ class LoopNode(BaseNode): loop_id=self.id, loop_node_id=self.node_id, loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_data=self._node_data, index=next_index, - pre_loop_output=self.node_data.outputs, + pre_loop_output=self._node_data.outputs, ) return {"check_break_result": False} diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index 0268dcf543..52afff0c02 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -16,28 +16,28 @@ class LoopStartNode(BaseNode): _node_type = NodeType.LOOP_START - node_data: LoopStartNodeData + _node_data: LoopStartNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = LoopStartNodeData(**data) + self._node_data = LoopStartNodeData(**data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls) -> str: diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 98041121fb..cbb9738833 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -94,28 +94,28 @@ class ParameterExtractorNode(BaseNode): _node_type = NodeType.PARAMETER_EXTRACTOR - node_data: ParameterExtractorNodeData + _node_data: ParameterExtractorNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = ParameterExtractorNodeData.model_validate(data) + self._node_data = ParameterExtractorNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data _model_instance: Optional[ModelInstance] = None _model_config: Optional[ModelConfigWithCredentialsEntity] = None @@ -141,7 +141,7 @@ class ParameterExtractorNode(BaseNode): """ Run the node. """ - node_data = cast(ParameterExtractorNodeData, self.node_data) + node_data = cast(ParameterExtractorNodeData, self._node_data) variable = self.graph_runtime_state.variable_pool.get(node_data.query) query = variable.text if variable else "" diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 54e10ba966..ec8e6c10fd 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -47,7 +47,7 @@ if TYPE_CHECKING: class QuestionClassifierNode(BaseNode): _node_type = NodeType.QUESTION_CLASSIFIER - node_data: QuestionClassifierNodeData + _node_data: QuestionClassifierNodeData _file_outputs: list["File"] _llm_file_saver: LLMFileSaver @@ -84,32 +84,32 @@ class QuestionClassifierNode(BaseNode): self._llm_file_saver = llm_file_saver def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = QuestionClassifierNodeData.model_validate(data) + self._node_data = QuestionClassifierNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls): return "1" def _run(self): - node_data = cast(QuestionClassifierNodeData, self.node_data) + node_data = cast(QuestionClassifierNodeData, self._node_data) variable_pool = self.graph_runtime_state.variable_pool # extract variables diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 068950c112..7de764ae9a 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -13,28 +13,28 @@ from core.workflow.nodes.start.entities import StartNodeData class StartNode(BaseNode): _node_type = NodeType.START - node_data: StartNodeData + _node_data: StartNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = StartNodeData(**data) + self._node_data = StartNodeData(**data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls) -> str: diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index a75b84ac01..c7fbef02c5 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -16,28 +16,28 @@ MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MA class TemplateTransformNode(BaseNode): _node_type = NodeType.TEMPLATE_TRANSFORM - node_data: TemplateTransformNodeData + _node_data: TemplateTransformNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = TemplateTransformNodeData.model_validate(data) + self._node_data = TemplateTransformNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: @@ -58,14 +58,14 @@ class TemplateTransformNode(BaseNode): def _run(self) -> NodeRunResult: # Get variables variables = {} - for variable_selector in self.node_data.variables: + for variable_selector in self._node_data.variables: variable_name = variable_selector.variable value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) variables[variable_name] = value.to_object() if value else None # Run code try: result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables + language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables ) except CodeExecutionError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index ad5df08b53..8942686510 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -43,10 +43,10 @@ class ToolNode(BaseNode): _node_type = NodeType.TOOL - node_data: ToolNodeData + _node_data: ToolNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = ToolNodeData.model_validate(data) + self._node_data = ToolNodeData.model_validate(data) @classmethod def version(cls) -> str: @@ -57,7 +57,7 @@ class ToolNode(BaseNode): Run the tool node """ - node_data = cast(ToolNodeData, self.node_data) + node_data = cast(ToolNodeData, self._node_data) # fetch tool icon tool_info = { @@ -70,9 +70,9 @@ class ToolNode(BaseNode): try: from core.tools.tool_manager import ToolManager - variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None + variable_pool = self.graph_runtime_state.variable_pool if self._node_data.version != "1" else None tool_runtime = ToolManager.get_workflow_tool_runtime( - self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool + self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool ) except ToolNodeError as e: yield RunCompletedEvent( @@ -91,12 +91,12 @@ class ToolNode(BaseNode): parameters = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, + node_data=self._node_data, ) parameters_for_log = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, + node_data=self._node_data, for_log=True, ) # get conversation id @@ -404,27 +404,27 @@ class ToolNode(BaseNode): return result def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @property def continue_on_error(self) -> bool: - return self.node_data.error_strategy is not None + return self._node_data.error_strategy is not None @property def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled + return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index a8ade81020..a98dd329e1 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -13,28 +13,28 @@ from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNod class VariableAggregatorNode(BaseNode): _node_type = NodeType.VARIABLE_AGGREGATOR - node_data: VariableAssignerNodeData + _node_data: VariableAssignerNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = VariableAssignerNodeData(**data) + self._node_data = VariableAssignerNodeData(**data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data @classmethod def version(cls) -> str: @@ -45,8 +45,8 @@ class VariableAggregatorNode(BaseNode): outputs: dict[str, Segment | Mapping[str, Segment]] = {} inputs = {} - if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: - for selector in self.node_data.variables: + if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled: + for selector in self._node_data.variables: variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: outputs = {"output": variable} @@ -54,7 +54,7 @@ class VariableAggregatorNode(BaseNode): inputs = {".".join(selector[1:]): variable.to_object()} break else: - for group in self.node_data.advanced_settings.groups: + for group in self._node_data.advanced_settings.groups: for selector in group.variables: variable = self.graph_runtime_state.variable_pool.get(selector) diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 74a3e6ec42..e6ecc3f936 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -27,28 +27,28 @@ class VariableAssignerNode(BaseNode): _node_type = NodeType.VARIABLE_ASSIGNER _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY - node_data: VariableAssignerData + _node_data: VariableAssignerData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = VariableAssignerData.model_validate(data) + self._node_data = VariableAssignerData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data def __init__( self, @@ -100,21 +100,21 @@ class VariableAssignerNode(BaseNode): return mapping def _run(self) -> NodeRunResult: - assigned_variable_selector = self.node_data.assigned_variable_selector + assigned_variable_selector = self._node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) if not isinstance(original_variable, Variable): raise VariableOperatorNodeError("assigned variable not found") - match self.node_data.write_mode: + match self._node_data.write_mode: case WriteMode.OVER_WRITE: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector) if not income_value: raise VariableOperatorNodeError("input value not found") updated_variable = original_variable.model_copy(update={"value": income_value.value}) case WriteMode.APPEND: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector) if not income_value: raise VariableOperatorNodeError("input value not found") updated_value = original_variable.value + [income_value.value] @@ -127,7 +127,7 @@ class VariableAssignerNode(BaseNode): updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: - raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") + raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}") # Over write the variable. self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 167079db28..4c4cc8080d 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,6 +1,6 @@ import json -from collections.abc import Callable, Mapping, MutableMapping, Sequence -from typing import Any, Optional, TypeAlias, cast +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable @@ -29,8 +29,6 @@ from .exc import ( VariableNotFoundError, ) -_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] - def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): selector_node_id = item.variable_selector[0] @@ -58,28 +56,28 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ class VariableAssignerNode(BaseNode): _node_type = NodeType.VARIABLE_ASSIGNER - node_data: VariableAssignerNodeData + _node_data: VariableAssignerNodeData def init_node_data(self, data: Mapping[str, Any]) -> None: - self.node_data = VariableAssignerNodeData.model_validate(data) + self._node_data = VariableAssignerNodeData.model_validate(data) def get_error_strategy(self) -> Optional[ErrorStrategy]: - return self.node_data.error_strategy + return self._node_data.error_strategy def get_retry_config(self) -> RetryConfig: - return self.node_data.retry_config + return self._node_data.retry_config def get_title(self) -> str: - return self.node_data.title + return self._node_data.title def get_description(self) -> Optional[str]: - return self.node_data.desc + return self._node_data.desc def get_default_value_dict(self) -> dict[str, Any]: - return self.node_data.default_value_dict + return self._node_data.default_value_dict def get_base_node_data(self) -> BaseNodeData: - return self.node_data + return self._node_data def _conv_var_updater_factory(self) -> ConversationVariableUpdater: return conversation_variable_updater_factory() @@ -106,13 +104,13 @@ class VariableAssignerNode(BaseNode): return var_mapping def _run(self) -> NodeRunResult: - inputs = self.node_data.model_dump() + inputs = self._node_data.model_dump() process_data: dict[str, Any] = {} # NOTE: This node has no outputs updated_variable_selectors: list[Sequence[str]] = [] try: - for item in self.node_data.items: + for item in self._node_data.items: variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) # ==================== Validation Part diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index daab974775..9d17307ecf 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -234,10 +234,10 @@ def test_execute_code_output_validator_depth(): "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } - node.node_data = cast(CodeNodeData, node.node_data) + node._node_data = cast(CodeNodeData, node._node_data) # validate - node._transform_result(result, node.node_data.outputs) + node._transform_result(result, node._node_data.outputs) # construct result result = { @@ -250,7 +250,7 @@ def test_execute_code_output_validator_depth(): # validate with pytest.raises(ValueError): - node._transform_result(result, node.node_data.outputs) + node._transform_result(result, node._node_data.outputs) # construct result result = { @@ -263,7 +263,7 @@ def test_execute_code_output_validator_depth(): # validate with pytest.raises(ValueError): - node._transform_result(result, node.node_data.outputs) + node._transform_result(result, node._node_data.outputs) # construct result result = { @@ -276,7 +276,7 @@ def test_execute_code_output_validator_depth(): # validate with pytest.raises(ValueError): - node._transform_result(result, node.node_data.outputs) + node._transform_result(result, node._node_data.outputs) def test_execute_code_output_object_list(): @@ -330,10 +330,10 @@ def test_execute_code_output_object_list(): ] } - node.node_data = cast(CodeNodeData, node.node_data) + node._node_data = cast(CodeNodeData, node._node_data) # validate - node._transform_result(result, node.node_data.outputs) + node._transform_result(result, node._node_data.outputs) # construct result result = { @@ -353,7 +353,7 @@ def test_execute_code_output_object_list(): # validate with pytest.raises(ValueError): - node._transform_result(result, node.node_data.outputs) + node._transform_result(result, node._node_data.outputs) def test_execute_code_scientific_notation(): diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index 787d4cb3ee..f53f391433 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -665,8 +665,8 @@ def test_iteration_run_in_parallel_mode(): # execute node parallel_result = parallel_iteration_node._run() sequential_result = sequential_iteration_node._run() - assert parallel_iteration_node.node_data.parallel_nums == 10 - assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED + assert parallel_iteration_node._node_data.parallel_nums == 10 + assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED count = 0 parallel_arr = [] sequential_arr = [] @@ -876,7 +876,7 @@ def test_iteration_run_error_handle(): assert count == 14 # execute remove abnormal output - iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT + iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT result = iteration_node._run() count = 0 for item in result: