refactor(nodes): Update `_extract_variable_selector_to_variable_mapping` to receive a mapping node data.

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent 3f143aea59
commit dc83c9822f
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -36,7 +36,6 @@ from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
@ -363,19 +362,14 @@ class AgentNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: BaseNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = AgentNodeData(**node_data)
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
node_data = cast(AgentNodeData, node_data)
result: dict[str, Any] = {} result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters: for parameter_name in typed_node_data.agent_parameters:
input = node_data.agent_parameters[parameter_name] input = typed_node_data.agent_parameters[parameter_name]
if input.type in ["mixed", "constant"]: if input.type in ["mixed", "constant"]:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors: for selector in selectors:

@ -64,8 +64,11 @@ class AnswerNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: AnswerNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = AnswerNodeData(**node_data)
""" """
Extract variable selector to variable mapping Extract variable selector to variable mapping
:param graph_config: graph config :param graph_config: graph config
@ -73,7 +76,7 @@ class AnswerNode(BaseNode):
:param node_data: node data :param node_data: node data
:return: :return:
""" """
variable_template_parser = VariableTemplateParser(template=node_data.answer) variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors() variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {} variable_mapping = {}

@ -1,7 +1,7 @@
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -123,9 +123,9 @@ class BaseNode:
if not node_id: if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.") raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {})) # Pass raw dict data instead of creating NodeData instance
data = cls._extract_variable_selector_to_variable_mapping( data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
) )
return data return data
@ -135,7 +135,7 @@ class BaseNode:
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: Any, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" """
Extract variable selector to variable mapping Extract variable selector to variable mapping

@ -338,8 +338,11 @@ class CodeNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: CodeNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = CodeNodeData(**node_data)
""" """
Extract variable selector to variable mapping Extract variable selector to variable mapping
:param graph_config: graph config :param graph_config: graph config
@ -349,7 +352,7 @@ class CodeNode(BaseNode):
""" """
return { return {
node_id + "." + variable_selector.variable: variable_selector.value_selector node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables for variable_selector in typed_node_data.variables
} }
@property @property

@ -101,16 +101,12 @@ class DocumentExtractorNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: DocumentExtractorNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = DocumentExtractorNodeData(**node_data)
:param graph_config: graph config
:param node_id: node id return {node_id + ".files": typed_node_data.variable_selector}
:param node_data: node data
:return:
"""
return {node_id + ".files": node_data.variable_selector}
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:

@ -135,15 +135,18 @@ class HttpRequestNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: HttpRequestNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = HttpRequestNodeData(**node_data)
selectors: list[VariableSelector] = [] selectors: list[VariableSelector] = []
selectors += variable_template_parser.extract_selectors_from_template(node_data.url) selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(node_data.params) selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
if node_data.body: if typed_node_data.body:
body_type = node_data.body.type body_type = typed_node_data.body.type
data = node_data.body.data data = typed_node_data.body.data
match body_type: match body_type:
case "binary": case "binary":
if len(data) != 1: if len(data) != 1:

@ -102,10 +102,13 @@ class IfElseNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: IfElseNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IfElseNodeData(**node_data)
var_mapping: dict[str, list[str]] = {} var_mapping: dict[str, list[str]] = {}
for case in node_data.cases or []: for case in typed_node_data.cases or []:
for condition in case.conditions: for condition in case.conditions:
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector)) key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
var_mapping[key] = condition.variable_selector var_mapping[key] = condition.variable_selector

@ -309,21 +309,17 @@ class IterationNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: IterationNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = IterationNodeData(**node_data)
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_mapping: dict[str, Sequence[str]] = { variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": node_data.iterator_selector, f"{node_id}.input_selector": typed_node_data.iterator_selector,
} }
# init graph # init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
if not iteration_graph: if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found") raise IterationGraphNotFoundError("iteration graph not found")

@ -619,17 +619,13 @@ class KnowledgeRetrievalNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: KnowledgeRetrievalNodeData, # type: ignore node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = KnowledgeRetrievalNodeData(**node_data)
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_mapping = {} variable_mapping = {}
variable_mapping[node_id + ".query"] = node_data.query_variable_selector variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
return variable_mapping return variable_mapping
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:

@ -792,7 +792,7 @@ class LLMNode(BaseNode):
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = LLMNodeData(**node_data) typed_node_data = LLMNodeData(**node_data)
prompt_template = typed_node_data.prompt_template prompt_template = typed_node_data.prompt_template
variable_selectors = [] variable_selectors = []
if isinstance(prompt_template, list) and all( if isinstance(prompt_template, list) and all(

@ -442,19 +442,15 @@ class LoopNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: LoopNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = LoopNodeData(**node_data)
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_mapping = {} variable_mapping = {}
# init graph # init graph
loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
if not loop_graph: if not loop_graph:
raise ValueError("loop graph not found") raise ValueError("loop graph not found")
@ -490,7 +486,7 @@ class LoopNode(BaseNode):
variable_mapping.update(sub_node_variable_mapping) variable_mapping.update(sub_node_variable_mapping)
for loop_variable in node_data.loop_variables or []: for loop_variable in typed_node_data.loop_variables or []:
if loop_variable.value_type == "variable": if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type" assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping # add loop variable to variable mapping

@ -830,19 +830,15 @@ class ParameterExtractorNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: ParameterExtractorNodeData, # type: ignore node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = ParameterExtractorNodeData(**node_data)
:param graph_config: graph config
:param node_id: node id variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
:param node_data: node data
:return:
"""
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
if node_data.instruction: if typed_node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
for selector in selectors: for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector variable_mapping[selector.variable] = selector.value_selector

@ -236,20 +236,15 @@ class QuestionClassifierNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: Any, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = QuestionClassifierNodeData(**node_data)
:param graph_config: graph config
:param node_id: node id variable_mapping = {"query": typed_node_data.query_variable_selector}
:param node_data: node data
:return:
"""
node_data = cast(QuestionClassifierNodeData, node_data)
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors = [] variable_selectors = []
if node_data.instruction: if typed_node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction) variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors()) variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors: for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping[variable_selector.variable] = variable_selector.value_selector

@ -64,16 +64,12 @@ class TemplateTransformNode(BaseNode):
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" # Create typed NodeData from dict
Extract variable selector to variable mapping typed_node_data = TemplateTransformNodeData(**node_data)
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return { return {
node_id + "." + variable_selector.variable: variable_selector.value_selector node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables for variable_selector in typed_node_data.variables
} }

@ -384,7 +384,7 @@ class ToolNode(BaseNode):
""" """
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = ToolNodeData(**node_data) typed_node_data = ToolNodeData(**node_data)
result = {} result = {}
for parameter_name in typed_node_data.tool_parameters: for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name] input = typed_node_data.tool_parameters[parameter_name]

@ -63,18 +63,21 @@ class VariableAssignerNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: VariableAssignerData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerData(**node_data)
mapping = {} mapping = {}
assigned_variable_node_id = node_data.assigned_variable_selector[0] assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
selector_key = ".".join(node_data.assigned_variable_selector) selector_key = ".".join(typed_node_data.assigned_variable_selector)
key = f"{node_id}.#{selector_key}#" key = f"{node_id}.#{selector_key}#"
mapping[key] = node_data.assigned_variable_selector mapping[key] = typed_node_data.assigned_variable_selector
selector_key = ".".join(node_data.input_variable_selector) selector_key = ".".join(typed_node_data.input_variable_selector)
key = f"{node_id}.#{selector_key}#" key = f"{node_id}.#{selector_key}#"
mapping[key] = node_data.input_variable_selector mapping[key] = typed_node_data.input_variable_selector
return mapping return mapping
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:

@ -75,10 +75,13 @@ class VariableAssignerNode(BaseNode):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: VariableAssignerNodeData, node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerNodeData(**node_data)
var_mapping: dict[str, Sequence[str]] = {} var_mapping: dict[str, Sequence[str]] = {}
for item in node_data.items: for item in typed_node_data.items:
_target_mapping_from_item(var_mapping, node_id, item) _target_mapping_from_item(var_mapping, node_id, item)
_source_mapping_from_item(var_mapping, node_id, item) _source_mapping_from_item(var_mapping, node_id, item)
return var_mapping return var_mapping

Loading…
Cancel
Save