From 96950fd8b50cfd61d2388a7ea5e9f2932de99666 Mon Sep 17 00:00:00 2001 From: "JF.Hsiong" Date: Mon, 14 Apr 2025 12:14:25 +0800 Subject: [PATCH] style: fix style issues --- .../nodes/knowledge_retrieval/entities.py | 3 ++ .../knowledge_retrieval_node.py | 49 ++++++++++--------- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 53aaadc7f2..6f45af1a14 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -122,13 +122,16 @@ class MetadataFilteringComplexSubCondition(BaseModel): conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) sub_conditions: Optional[list["MetadataFilteringComplexSubCondition"]] = None + class MetadataFilteringComplexCondition(BaseModel): """ Complex Metadata Filtering Condition. """ + logical_operator: Optional[Literal["and", "or"]] = "and" conditions: Optional[list[MetadataFilteringComplexSubCondition]] = Field(default=None, deprecated=True) + class KnowledgeRetrievalNodeData(BaseNodeData): """ Knowledge retrieval Node Data. diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index ab5e414668..4e5cc56ea6 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -6,7 +6,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Any, Optional, cast -from sqlalchemy import Integer, and_, func, or_, text +from sqlalchemy import ColumnElement, Integer, and_, func, or_, text from sqlalchemy import cast as sqlalchemy_cast from core.app.app_config.entities import DatasetRetrieveConfigEntity @@ -320,30 +320,30 @@ class KnowledgeRetrievalNode(LLMNode): for position, item in enumerate(retrieval_resource_list, start=1): item["metadata"]["position"] = position return retrieval_resource_list - + def _recursive_metadata_filter( - self, metadata_filtering_complex_conditions: MetadataFilteringComplexSubCondition, filters + self, metadata_filtering_complex_conditions: MetadataFilteringComplexSubCondition, filters ): logical_operator = metadata_filtering_complex_conditions.logical_operator conditions = metadata_filtering_complex_conditions.conditions sub_conditions = metadata_filtering_complex_conditions.sub_conditions - + sub_filters = [] if sub_conditions: for sub_condition in sub_conditions: sub_filter = self._recursive_metadata_filter(sub_condition, filters) sub_filters.append(sub_filter) - - temp_filters = [] + + temp_filters: list = [] if conditions: for sequence, condition in enumerate(conditions): metadata_name = condition.name expected_value = condition.value if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): if isinstance(expected_value, str): - expected_value = self.graph_runtime_state.variable_pool.convert_template( - expected_value - ).value[0] + expected_value = self.graph_runtime_state.variable_pool.convert_template(expected_value).value[ + 0 + ] if expected_value.value_type == "number": # type: ignore expected_value = expected_value.value # type: ignore elif expected_value.value_type == "string": # type: ignore @@ -357,23 +357,24 @@ class KnowledgeRetrievalNode(LLMNode): expected_value, temp_filters, ) - + temp_filters_result: ColumnElement[bool] if temp_filters: if logical_operator == "and": # type: ignore - temp_filters = and_(*temp_filters) + temp_filters_result = and_(*temp_filters) else: - temp_filters = or_(*temp_filters) - filters.append(temp_filters) - + temp_filters_result = or_(*temp_filters) + filters.append(temp_filters_result) + + sub_filters_result: ColumnElement[bool] if sub_filters: if logical_operator == "and": # type: ignore - sub_filters = and_(*sub_filters) + sub_filters_result = and_(*sub_filters) else: - sub_filters = or_(*sub_filters) - filters.append(sub_filters) - + sub_filters_result = or_(*sub_filters) + filters.append(sub_filters_result) + return filters - + def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: @@ -393,21 +394,21 @@ class KnowledgeRetrievalNode(LLMNode): # Enable forward references MetadataFilteringComplexSubCondition.model_rebuild() metadata_filtering_complex_conditions = MetadataFilteringComplexCondition( - **node_data.metadata_filtering_complex_conditions.model_dump()) - for sequence, condition in enumerate(metadata_filtering_complex_conditions.conditions): # type: ignore + **node_data.metadata_filtering_complex_conditions.model_dump() + ) + for condition in metadata_filtering_complex_conditions.conditions: # type: ignore filters = self._recursive_metadata_filter(condition, filters) if filters: if metadata_filtering_complex_conditions.logical_operator == "and": # type: ignore document_query = document_query.filter(and_(*filters)) else: document_query = document_query.filter(or_(*filters)) - metadata_condition = metadata_filtering_complex_conditions documents = document_query.all() # group by dataset_id metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore for document in documents: metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore - return metadata_filter_document_ids, metadata_condition + return metadata_filter_document_ids, MetadataCondition() elif node_data.metadata_filtering_mode == "automatic": automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data) if automatic_metadata_filters: @@ -435,7 +436,7 @@ class KnowledgeRetrievalNode(LLMNode): if node_data.metadata_filtering_conditions: metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump()) if node_data.metadata_filtering_conditions: - for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore + for sequence, condition in enumerate(metadata_condition.conditions): # type: ignore metadata_name = condition.name expected_value = condition.value if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):