refactor(nodes): Use `model_validate` to create NodeData.

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent dc83c9822f
commit 8cce31548d
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -65,7 +65,7 @@ class AgentNode(BaseNode):
node_data: AgentNodeData node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = AgentNodeData(**data) self.node_data = AgentNodeData.model_validate(data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -365,7 +365,7 @@ class AgentNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = AgentNodeData(**node_data) typed_node_data = AgentNodeData.model_validate(node_data)
result: dict[str, Any] = {} result: dict[str, Any] = {}
for parameter_name in typed_node_data.agent_parameters: for parameter_name in typed_node_data.agent_parameters:

@ -22,7 +22,7 @@ class AnswerNode(BaseNode):
node_data: AnswerNodeData node_data: AnswerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = AnswerNodeData(**data) self.node_data = AnswerNodeData.model_validate(data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -67,15 +67,8 @@ class AnswerNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = AnswerNodeData(**node_data) typed_node_data = AnswerNodeData.model_validate(node_data)
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer) variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors() variable_selectors = variable_template_parser.extract_variable_selectors()

@ -27,7 +27,7 @@ class CodeNode(BaseNode):
node_data: CodeNodeData node_data: CodeNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = CodeNodeData(**data) self.node_data = CodeNodeData.model_validate(data)
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
@ -341,15 +341,8 @@ class CodeNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = CodeNodeData(**node_data) typed_node_data = CodeNodeData.model_validate(node_data)
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return { return {
node_id + "." + variable_selector.variable: variable_selector.value_selector node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in typed_node_data.variables for variable_selector in typed_node_data.variables

@ -47,7 +47,7 @@ class DocumentExtractorNode(BaseNode):
node_data: DocumentExtractorNodeData node_data: DocumentExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = DocumentExtractorNodeData(**data) self.node_data = DocumentExtractorNodeData.model_validate(data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -104,7 +104,7 @@ class DocumentExtractorNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = DocumentExtractorNodeData(**node_data) typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
return {node_id + ".files": typed_node_data.variable_selector} return {node_id + ".files": typed_node_data.variable_selector}

@ -38,7 +38,7 @@ class HttpRequestNode(BaseNode):
node_data: HttpRequestNodeData node_data: HttpRequestNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = HttpRequestNodeData(**data) self.node_data = HttpRequestNodeData.model_validate(data)
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
@ -138,7 +138,7 @@ class HttpRequestNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = HttpRequestNodeData(**node_data) typed_node_data = HttpRequestNodeData.model_validate(node_data)
selectors: list[VariableSelector] = [] selectors: list[VariableSelector] = []
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url) selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)

@ -19,7 +19,7 @@ class IfElseNode(BaseNode):
node_data: IfElseNodeData node_data: IfElseNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = IfElseNodeData(**data) self.node_data = IfElseNodeData.model_validate(data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -105,7 +105,7 @@ class IfElseNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = IfElseNodeData(**node_data) typed_node_data = IfElseNodeData.model_validate(node_data)
var_mapping: dict[str, list[str]] = {} var_mapping: dict[str, list[str]] = {}
for case in typed_node_data.cases or []: for case in typed_node_data.cases or []:

@ -66,7 +66,7 @@ class IterationNode(BaseNode):
node_data: IterationNodeData node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = IterationNodeData(**data) self.node_data = IterationNodeData.model_validate(data)
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
@ -312,7 +312,7 @@ class IterationNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = IterationNodeData(**node_data) typed_node_data = IterationNodeData.model_validate(node_data)
variable_mapping: dict[str, Sequence[str]] = { variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": typed_node_data.iterator_selector, f"{node_id}.input_selector": typed_node_data.iterator_selector,

@ -125,7 +125,7 @@ class KnowledgeRetrievalNode(BaseNode):
self._llm_file_saver = llm_file_saver self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = KnowledgeRetrievalNodeData(**data) self.node_data = KnowledgeRetrievalNodeData.model_validate(data)
@classmethod @classmethod
def version(cls): def version(cls):
@ -622,7 +622,7 @@ class KnowledgeRetrievalNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = KnowledgeRetrievalNodeData(**node_data) typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
variable_mapping = {} variable_mapping = {}
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector

@ -138,7 +138,7 @@ class LLMNode(BaseNode):
self._llm_file_saver = llm_file_saver self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LLMNodeData(**data) self.node_data = LLMNodeData.model_validate(data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -791,7 +791,7 @@ class LLMNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = LLMNodeData(**node_data) typed_node_data = LLMNodeData.model_validate(node_data)
prompt_template = typed_node_data.prompt_template prompt_template = typed_node_data.prompt_template
variable_selectors = [] variable_selectors = []

@ -53,7 +53,7 @@ class LoopNode(BaseNode):
node_data: LoopNodeData node_data: LoopNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = LoopNodeData(**data) self.node_data = LoopNodeData.model_validate(data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -445,7 +445,7 @@ class LoopNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = LoopNodeData(**node_data) typed_node_data = LoopNodeData.model_validate(node_data)
variable_mapping = {} variable_mapping = {}

@ -96,7 +96,7 @@ class ParameterExtractorNode(BaseNode):
node_data: ParameterExtractorNodeData node_data: ParameterExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = ParameterExtractorNodeData(**data) self.node_data = ParameterExtractorNodeData.model_validate(data)
_model_instance: Optional[ModelInstance] = None _model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None _model_config: Optional[ModelConfigWithCredentialsEntity] = None
@ -833,7 +833,7 @@ class ParameterExtractorNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = ParameterExtractorNodeData(**node_data) typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query} variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}

@ -82,7 +82,7 @@ class QuestionClassifierNode(BaseNode):
self._llm_file_saver = llm_file_saver self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = QuestionClassifierNodeData(**data) self.node_data = QuestionClassifierNodeData.model_validate(data)
@classmethod @classmethod
def version(cls): def version(cls):
@ -239,7 +239,7 @@ class QuestionClassifierNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = QuestionClassifierNodeData(**node_data) typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
variable_mapping = {"query": typed_node_data.query_variable_selector} variable_mapping = {"query": typed_node_data.query_variable_selector}
variable_selectors = [] variable_selectors = []

@ -18,7 +18,7 @@ class TemplateTransformNode(BaseNode):
node_data: TemplateTransformNodeData node_data: TemplateTransformNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = TemplateTransformNodeData(**data) self.node_data = TemplateTransformNodeData.model_validate(data)
@classmethod @classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict: def get_default_config(cls, filters: Optional[dict] = None) -> dict:
@ -67,7 +67,7 @@ class TemplateTransformNode(BaseNode):
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any] cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = TemplateTransformNodeData(**node_data) typed_node_data = TemplateTransformNodeData.model_validate(node_data)
return { return {
node_id + "." + variable_selector.variable: variable_selector.value_selector node_id + "." + variable_selector.variable: variable_selector.value_selector

@ -45,7 +45,7 @@ class ToolNode(BaseNode):
node_data: ToolNodeData node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = ToolNodeData(**data) self.node_data = ToolNodeData.model_validate(data)
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -383,7 +383,7 @@ class ToolNode(BaseNode):
:return: :return:
""" """
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = ToolNodeData(**node_data) typed_node_data = ToolNodeData.model_validate(node_data)
result = {} result = {}
for parameter_name in typed_node_data.tool_parameters: for parameter_name in typed_node_data.tool_parameters:

@ -29,7 +29,7 @@ class VariableAssignerNode(BaseNode):
node_data: VariableAssignerData node_data: VariableAssignerData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerData(**data) self.node_data = VariableAssignerData.model_validate(data)
def __init__( def __init__(
self, self,
@ -66,7 +66,7 @@ class VariableAssignerNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = VariableAssignerData(**node_data) typed_node_data = VariableAssignerData.model_validate(node_data)
mapping = {} mapping = {}
assigned_variable_node_id = typed_node_data.assigned_variable_selector[0] assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]

@ -60,7 +60,7 @@ class VariableAssignerNode(BaseNode):
node_data: VariableAssignerNodeData node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None: def init_node_data(self, data: Mapping[str, Any]) -> None:
self.node_data = VariableAssignerNodeData(**data) self.node_data = VariableAssignerNodeData.model_validate(data)
def _conv_var_updater_factory(self) -> ConversationVariableUpdater: def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory() return conversation_variable_updater_factory()
@ -78,7 +78,7 @@ class VariableAssignerNode(BaseNode):
node_data: Mapping[str, Any], node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict # Create typed NodeData from dict
typed_node_data = VariableAssignerNodeData(**node_data) typed_node_data = VariableAssignerNodeData.model_validate(node_data)
var_mapping: dict[str, Sequence[str]] = {} var_mapping: dict[str, Sequence[str]] = {}
for item in typed_node_data.items: for item in typed_node_data.items:

Loading…
Cancel
Save