refactor(nodes): Update `_extract_variable_selector_to_variable_mapping` to receive a mapping node data.

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent 3f143aea59
commit dc83c9822f
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -36,7 +36,6 @@ from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@ -363,19 +362,14 @@ class AgentNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: BaseNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
node_data = cast(AgentNodeData, node_data)
# Create typed NodeData from dict
typed_node_data = AgentNodeData(**node_data)
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
input = node_data.agent_parameters[parameter_name]
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
if input.type in ["mixed", "constant"]:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:

@ -64,8 +64,11 @@ class AnswerNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AnswerNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = AnswerNodeData(**node_data)
"""
Extract variable selector to variable mapping
:param graph_config: graph config
@ -73,7 +76,7 @@ class AnswerNode(BaseNode):
:param node_data: node data
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}

@ -1,7 +1,7 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -123,9 +123,9 @@ class BaseNode:
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {}))
# Pass raw dict data instead of creating NodeData instance
data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
)
return data
@ -135,7 +135,7 @@ class BaseNode:
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Any,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

@ -338,8 +338,11 @@ class CodeNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: CodeNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = CodeNodeData(**node_data)
"""
Extract variable selector to variable mapping
:param graph_config: graph config
@ -349,7 +352,7 @@ class CodeNode(BaseNode):
"""
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables
for variable_selector in typed_node_data.variables
}
@property

@ -101,16 +101,12 @@ class DocumentExtractorNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: DocumentExtractorNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {node_id + ".files": node_data.variable_selector}
# Create typed NodeData from dict
typed_node_data = DocumentExtractorNodeData(**node_data)
return {node_id + ".files": typed_node_data.variable_selector}
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:

@ -135,15 +135,18 @@ class HttpRequestNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: HttpRequestNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = HttpRequestNodeData(**node_data)
selectors: list[VariableSelector] = []
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
if node_data.body:
body_type = node_data.body.type
data = node_data.body.data
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
if typed_node_data.body:
body_type = typed_node_data.body.type
data = typed_node_data.body.data
match body_type:
case "binary":
if len(data) != 1:

@ -102,10 +102,13 @@ class IfElseNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IfElseNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IfElseNodeData(**node_data)
var_mapping: dict[str, list[str]] = {}
for case in node_data.cases or []:
for case in typed_node_data.cases or []:
for condition in case.conditions:
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
var_mapping[key] = condition.variable_selector

@ -309,21 +309,17 @@ class IterationNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = IterationNodeData(**node_data)
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": node_data.iterator_selector,
f"{node_id}.input_selector": typed_node_data.iterator_selector,
}
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")

@ -619,17 +619,13 @@ class KnowledgeRetrievalNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: KnowledgeRetrievalNodeData, # type: ignore
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = KnowledgeRetrievalNodeData(**node_data)
variable_mapping = {}
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
return variable_mapping
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:

@ -792,7 +792,7 @@ class LLMNode(BaseNode):
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = LLMNodeData(**node_data)
prompt_template = typed_node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list) and all(

@ -442,19 +442,15 @@ class LoopNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LoopNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = LoopNodeData(**node_data)
variable_mapping = {}
# init graph
loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
@ -490,7 +486,7 @@ class LoopNode(BaseNode):
variable_mapping.update(sub_node_variable_mapping)
for loop_variable in node_data.loop_variables or []:
for loop_variable in typed_node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping

@ -830,19 +830,15 @@ class ParameterExtractorNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ParameterExtractorNodeData, # type: ignore
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
# Create typed NodeData from dict
typed_node_data = ParameterExtractorNodeData(**node_data)
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
if node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
if typed_node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector

@ -236,20 +236,15 @@ class QuestionClassifierNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Any,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
node_data = cast(QuestionClassifierNodeData, node_data)
variable_mapping = {"query": node_data.query_variable_selector}
# Create typed NodeData from dict
typed_node_data = QuestionClassifierNodeData(**node_data)
variable_mapping = {"query": typed_node_data.query_variable_selector}
variable_selectors = []
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
if typed_node_data.instruction:
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector

@ -64,16 +64,12 @@ class TemplateTransformNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = TemplateTransformNodeData(**node_data)
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables
for variable_selector in typed_node_data.variables
}

@ -384,7 +384,7 @@ class ToolNode(BaseNode):
"""
# Create typed NodeData from dict
typed_node_data = ToolNodeData(**node_data)
result = {}
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]

@ -63,18 +63,21 @@ class VariableAssignerNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: VariableAssignerData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerData(**node_data)
mapping = {}
assigned_variable_node_id = node_data.assigned_variable_selector[0]
assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
selector_key = ".".join(node_data.assigned_variable_selector)
selector_key = ".".join(typed_node_data.assigned_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = node_data.assigned_variable_selector
mapping[key] = typed_node_data.assigned_variable_selector
selector_key = ".".join(node_data.input_variable_selector)
selector_key = ".".join(typed_node_data.input_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = node_data.input_variable_selector
mapping[key] = typed_node_data.input_variable_selector
return mapping
def _run(self) -> NodeRunResult:

@ -75,10 +75,13 @@ class VariableAssignerNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: VariableAssignerNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerNodeData(**node_data)
var_mapping: dict[str, Sequence[str]] = {}
for item in node_data.items:
for item in typed_node_data.items:
_target_mapping_from_item(var_mapping, node_id, item)
_source_mapping_from_item(var_mapping, node_id, item)
return var_mapping

Loading…
Cancel
Save