refactor(nodes): rename `from_dict` to `init_node_data`.

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent 49c2714e47
commit 3f143aea59
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -276,7 +276,7 @@ class GraphEngine:
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id,
)
node_instance.from_dict(node_config.get("data", {}))
node_instance.init_node_data(node_config.get("data", {}))
try:
# run node
generator = self._run_node(

@ -65,7 +65,7 @@ class AgentNode(BaseNode):
_node_type = NodeType.AGENT
node_data: AgentNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = AgentNodeData(**data)
@classmethod

@ -21,7 +21,7 @@ class AnswerNode(BaseNode):
node_data: AnswerNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = AnswerNodeData(**data)
@classmethod

@ -50,7 +50,7 @@ class BaseNode:
self.node_id = node_id
@abstractmethod
def from_dict(self, data: Mapping[str, Any]) -> None: ...
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:

@ -26,7 +26,7 @@ class CodeNode(BaseNode):
node_data: CodeNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = CodeNodeData(**data)
@classmethod

@ -46,7 +46,7 @@ class DocumentExtractorNode(BaseNode):
node_data: DocumentExtractorNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = DocumentExtractorNodeData(**data)
@classmethod

@ -13,7 +13,7 @@ class EndNode(BaseNode):
node_data: EndNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = EndNodeData(**data)
@classmethod

@ -37,7 +37,7 @@ class HttpRequestNode(BaseNode):
node_data: HttpRequestNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = HttpRequestNodeData(**data)
@classmethod

@ -18,7 +18,7 @@ class IfElseNode(BaseNode):
node_data: IfElseNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = IfElseNodeData(**data)
@classmethod

@ -65,7 +65,7 @@ class IterationNode(BaseNode):
node_data: IterationNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = IterationNodeData(**data)
@classmethod

@ -17,7 +17,7 @@ class IterationStartNode(BaseNode):
node_data: IterationStartNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = IterationStartNodeData(**data)
@classmethod

@ -124,7 +124,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
self._llm_file_saver = llm_file_saver
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = KnowledgeRetrievalNodeData(**data)
@classmethod

@ -18,7 +18,7 @@ class ListOperatorNode(BaseNode):
node_data: ListOperatorNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = ListOperatorNodeData(**data)
@classmethod

@ -137,7 +137,7 @@ class LLMNode(BaseNode):
)
self._llm_file_saver = llm_file_saver
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LLMNodeData(**data)
@classmethod
@ -788,10 +788,12 @@ class LLMNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LLMNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
prompt_template = node_data.prompt_template
# Create typed NodeData from dict
typed_node_data = LLMNodeData(**node_data)
prompt_template = typed_node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list) and all(
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
@ -811,7 +813,7 @@ class LLMNode(BaseNode):
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
memory = node_data.memory
memory = typed_node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
@ -819,16 +821,16 @@ class LLMNode(BaseNode):
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled:
variable_mapping["#context#"] = node_data.context.variable_selector
if typed_node_data.context.enabled:
variable_mapping["#context#"] = typed_node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
if typed_node_data.vision.enabled:
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
if node_data.memory:
if typed_node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
if node_data.prompt_config:
if typed_node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, list):
@ -841,7 +843,7 @@ class LLMNode(BaseNode):
enable_jinja = True
if enable_jinja:
for variable_selector in node_data.prompt_config.jinja2_variables or []:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}

@ -17,7 +17,7 @@ class LoopEndNode(BaseNode):
node_data: LoopEndNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopEndNodeData(**data)
@classmethod

@ -52,7 +52,7 @@ class LoopNode(BaseNode):
node_data: LoopNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopNodeData(**data)
@classmethod

@ -17,7 +17,7 @@ class LoopStartNode(BaseNode):
node_data: LoopStartNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopStartNodeData(**data)
@classmethod

@ -95,7 +95,7 @@ class ParameterExtractorNode(BaseNode):
node_data: ParameterExtractorNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = ParameterExtractorNodeData(**data)
_model_instance: Optional[ModelInstance] = None

@ -81,7 +81,7 @@ class QuestionClassifierNode(BaseNode):
)
self._llm_file_saver = llm_file_saver
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = QuestionClassifierNodeData(**data)
@classmethod

@ -14,7 +14,7 @@ class StartNode(BaseNode):
node_data: StartNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = StartNodeData(**data)
@classmethod

@ -17,7 +17,7 @@ class TemplateTransformNode(BaseNode):
node_data: TemplateTransformNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = TemplateTransformNodeData(**data)
@classmethod

@ -44,7 +44,7 @@ class ToolNode(BaseNode):
node_data: ToolNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = ToolNodeData(**data)
@classmethod
@ -373,7 +373,7 @@ class ToolNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ToolNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -382,9 +382,12 @@ class ToolNode(BaseNode):
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = ToolNodeData(**node_data)
result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()

@ -14,7 +14,7 @@ class VariableAggregatorNode(BaseNode):
node_data: VariableAssignerNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerNodeData(**data)
@classmethod

@ -28,7 +28,7 @@ class VariableAssignerNode(BaseNode):
node_data: VariableAssignerData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerData(**data)
def __init__(

@ -59,7 +59,7 @@ class VariableAssignerNode(BaseNode):
node_data: VariableAssignerNodeData
def from_dict(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerNodeData(**data)
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:

Loading…
Cancel
Save