From dc83c9822fb99e9d2d2a2a0cfa758992b4a414f3 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 17 Jul 2025 19:59:07 +0800 Subject: [PATCH] refactor(nodes): Update `_extract_variable_selector_to_variable_mapping` to receive a mapping node data. Signed-off-by: -LAN- --- api/core/workflow/nodes/agent/agent_node.py | 18 ++++++------------ api/core/workflow/nodes/answer/answer_node.py | 7 +++++-- api/core/workflow/nodes/base/node.py | 8 ++++---- api/core/workflow/nodes/code/code_node.py | 7 +++++-- .../workflow/nodes/document_extractor/node.py | 14 +++++--------- api/core/workflow/nodes/http_request/node.py | 17 ++++++++++------- .../workflow/nodes/if_else/if_else_node.py | 7 +++++-- .../nodes/iteration/iteration_node.py | 16 ++++++---------- .../knowledge_retrieval_node.py | 14 +++++--------- api/core/workflow/nodes/llm/node.py | 2 +- api/core/workflow/nodes/loop/loop_node.py | 16 ++++++---------- .../parameter_extractor_node.py | 18 +++++++----------- .../question_classifier_node.py | 19 +++++++------------ .../template_transform_node.py | 14 +++++--------- api/core/workflow/nodes/tool/tool_node.py | 2 +- .../nodes/variable_assigner/v1/node.py | 15 +++++++++------ .../nodes/variable_assigner/v2/node.py | 7 +++++-- 17 files changed, 92 insertions(+), 109 deletions(-) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index fc8e69f9f1..e5527b7294 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -36,7 +36,6 @@ from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -363,19 +362,14 @@ class AgentNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: BaseNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - node_data = cast(AgentNodeData, node_data) + # Create typed NodeData from dict + typed_node_data = AgentNodeData(**node_data) + result: dict[str, Any] = {} - for parameter_name in node_data.agent_parameters: - input = node_data.agent_parameters[parameter_name] + for parameter_name in typed_node_data.agent_parameters: + input = typed_node_data.agent_parameters[parameter_name] if input.type in ["mixed", "constant"]: selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() for selector in selectors: diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 1ae5fddbc7..63dfb2d5db 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -64,8 +64,11 @@ class AnswerNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: AnswerNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = AnswerNodeData(**node_data) + """ Extract variable selector to variable mapping :param graph_config: graph config @@ -73,7 +76,7 @@ class AnswerNode(BaseNode): :param node_data: node data :return: """ - variable_template_parser = VariableTemplateParser(template=node_data.answer) + variable_template_parser = VariableTemplateParser(template=typed_node_data.answer) variable_selectors = variable_template_parser.extract_variable_selectors() variable_mapping = {} diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 8bb67f9840..b039accba4 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,7 +1,7 @@ import logging from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -123,9 +123,9 @@ class BaseNode: if not node_id: raise ValueError("Node ID is required when extracting variable selector to variable mapping.") - node_data = cls._node_data_cls(**config.get("data", {})) + # Pass raw dict data instead of creating NodeData instance data = cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) + graph_config=graph_config, node_id=node_id, node_data=config.get("data", {}) ) return data @@ -135,7 +135,7 @@ class BaseNode: *, graph_config: Mapping[str, Any], node_id: str, - node_data: Any, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 17458d21f3..ad57a13bf8 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -338,8 +338,11 @@ class CodeNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: CodeNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = CodeNodeData(**node_data) + """ Extract variable selector to variable mapping :param graph_config: graph config @@ -349,7 +352,7 @@ class CodeNode(BaseNode): """ return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables + for variable_selector in typed_node_data.variables } @property diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index f02ea5c033..f0e188c528 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -101,16 +101,12 @@ class DocumentExtractorNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: DocumentExtractorNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - return {node_id + ".files": node_data.variable_selector} + # Create typed NodeData from dict + typed_node_data = DocumentExtractorNodeData(**node_data) + + return {node_id + ".files": typed_node_data.variable_selector} def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index a5e1ef22d7..d4917b79c1 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -135,15 +135,18 @@ class HttpRequestNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: HttpRequestNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = HttpRequestNodeData(**node_data) + selectors: list[VariableSelector] = [] - selectors += variable_template_parser.extract_selectors_from_template(node_data.url) - selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) - selectors += variable_template_parser.extract_selectors_from_template(node_data.params) - if node_data.body: - body_type = node_data.body.type - data = node_data.body.data + selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url) + selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers) + selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params) + if typed_node_data.body: + body_type = typed_node_data.body.type + data = typed_node_data.body.data match body_type: case "binary": if len(data) != 1: diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 6cbe7aeb84..91d98188ab 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -102,10 +102,13 @@ class IfElseNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: IfElseNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = IfElseNodeData(**node_data) + var_mapping: dict[str, list[str]] = {} - for case in node_data.cases or []: + for case in typed_node_data.cases or []: for condition in case.conditions: key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector)) var_mapping[key] = condition.variable_selector diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 7942316816..b41b6c5de3 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -309,21 +309,17 @@ class IterationNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: IterationNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = IterationNodeData(**node_data) + variable_mapping: dict[str, Sequence[str]] = { - f"{node_id}.input_selector": node_data.iterator_selector, + f"{node_id}.input_selector": typed_node_data.iterator_selector, } # init graph - iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id) if not iteration_graph: raise IterationGraphNotFoundError("iteration graph not found") diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index bb19005f57..e977314b34 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -619,17 +619,13 @@ class KnowledgeRetrievalNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: KnowledgeRetrievalNodeData, # type: ignore + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = KnowledgeRetrievalNodeData(**node_data) + variable_mapping = {} - variable_mapping[node_id + ".query"] = node_data.query_variable_selector + variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector return variable_mapping def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index e08a207a61..8175e73525 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -792,7 +792,7 @@ class LLMNode(BaseNode): ) -> Mapping[str, Sequence[str]]: # Create typed NodeData from dict typed_node_data = LLMNodeData(**node_data) - + prompt_template = typed_node_data.prompt_template variable_selectors = [] if isinstance(prompt_template, list) and all( diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 269f402b5f..e12e4a3a3e 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -442,19 +442,15 @@ class LoopNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: LoopNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = LoopNodeData(**node_data) + variable_mapping = {} # init graph - loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) + loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id) if not loop_graph: raise ValueError("loop graph not found") @@ -490,7 +486,7 @@ class LoopNode(BaseNode): variable_mapping.update(sub_node_variable_mapping) - for loop_variable in node_data.loop_variables or []: + for loop_variable in typed_node_data.loop_variables or []: if loop_variable.value_type == "variable": assert loop_variable.value is not None, "Loop variable value must be provided for variable type" # add loop variable to variable mapping diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 6229b1bad0..240c7c4d76 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -830,19 +830,15 @@ class ParameterExtractorNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: ParameterExtractorNodeData, # type: ignore + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} + # Create typed NodeData from dict + typed_node_data = ParameterExtractorNodeData(**node_data) + + variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query} - if node_data.instruction: - selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) + if typed_node_data.instruction: + selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction) for selector in selectors: variable_mapping[selector.variable] = selector.value_selector diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index b59187f279..318bb5969d 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -236,20 +236,15 @@ class QuestionClassifierNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Any, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - node_data = cast(QuestionClassifierNodeData, node_data) - variable_mapping = {"query": node_data.query_variable_selector} + # Create typed NodeData from dict + typed_node_data = QuestionClassifierNodeData(**node_data) + + variable_mapping = {"query": typed_node_data.query_variable_selector} variable_selectors = [] - if node_data.instruction: - variable_template_parser = VariableTemplateParser(template=node_data.instruction) + if typed_node_data.instruction: + variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index df0692ca3f..7a25c8a588 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -64,16 +64,12 @@ class TemplateTransformNode(BaseNode): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData + cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any] ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = TemplateTransformNodeData(**node_data) + return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables + for variable_selector in typed_node_data.variables } diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index a43267b0fd..ce10ec6452 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -384,7 +384,7 @@ class ToolNode(BaseNode): """ # Create typed NodeData from dict typed_node_data = ToolNodeData(**node_data) - + result = {} for parameter_name in typed_node_data.tool_parameters: input = typed_node_data.tool_parameters[parameter_name] diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 31db41497b..809018519e 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -63,18 +63,21 @@ class VariableAssignerNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: VariableAssignerData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = VariableAssignerData(**node_data) + mapping = {} - assigned_variable_node_id = node_data.assigned_variable_selector[0] + assigned_variable_node_id = typed_node_data.assigned_variable_selector[0] if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: - selector_key = ".".join(node_data.assigned_variable_selector) + selector_key = ".".join(typed_node_data.assigned_variable_selector) key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.assigned_variable_selector + mapping[key] = typed_node_data.assigned_variable_selector - selector_key = ".".join(node_data.input_variable_selector) + selector_key = ".".join(typed_node_data.input_variable_selector) key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.input_variable_selector + mapping[key] = typed_node_data.input_variable_selector return mapping def _run(self) -> NodeRunResult: diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index f367ae664b..8f18924f88 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -75,10 +75,13 @@ class VariableAssignerNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: VariableAssignerNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = VariableAssignerNodeData(**node_data) + var_mapping: dict[str, Sequence[str]] = {} - for item in node_data.items: + for item in typed_node_data.items: _target_mapping_from_item(var_mapping, node_id, item) _source_mapping_from_item(var_mapping, node_id, item) return var_mapping