diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 6f45af1a14..7fde733eb5 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -116,20 +116,10 @@ class MetadataFilteringCondition(BaseModel): logical_operator: Optional[Literal["and", "or"]] = "and" conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) - -class MetadataFilteringComplexSubCondition(BaseModel): - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) - sub_conditions: Optional[list["MetadataFilteringComplexSubCondition"]] = None - - class MetadataFilteringComplexCondition(BaseModel): - """ - Complex Metadata Filtering Condition. - """ - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[MetadataFilteringComplexSubCondition]] = Field(default=None, deprecated=True) + conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + sub_conditions: Optional[list["MetadataFilteringComplexCondition"]] = None class KnowledgeRetrievalNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 4e5cc56ea6..2af78c72e0 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -47,7 +47,6 @@ from services.feature_service import FeatureService from .entities import ( KnowledgeRetrievalNodeData, MetadataFilteringComplexCondition, - MetadataFilteringComplexSubCondition, ModelConfig, ) from .exc import ( @@ -322,7 +321,7 @@ class KnowledgeRetrievalNode(LLMNode): return retrieval_resource_list def _recursive_metadata_filter( - self, metadata_filtering_complex_conditions: MetadataFilteringComplexSubCondition, filters + self, metadata_filtering_complex_conditions: MetadataFilteringComplexCondition, filters ): logical_operator = metadata_filtering_complex_conditions.logical_operator conditions = metadata_filtering_complex_conditions.conditions @@ -331,8 +330,8 @@ class KnowledgeRetrievalNode(LLMNode): sub_filters = [] if sub_conditions: for sub_condition in sub_conditions: - sub_filter = self._recursive_metadata_filter(sub_condition, filters) - sub_filters.append(sub_filter) + sub_filter = self._recursive_metadata_filter(sub_condition, []) + sub_filters.extend(sub_filter) temp_filters: list = [] if conditions: @@ -357,16 +356,27 @@ class KnowledgeRetrievalNode(LLMNode): expected_value, temp_filters, ) - temp_filters_result: ColumnElement[bool] - if temp_filters: + + sub_filters_result: ColumnElement + temp_filters_result: ColumnElement + if temp_filters and sub_filters: + temp_all_filters = sub_filters +temp_filters + if logical_operator == "and": # type: ignore + sub_filters_result = and_(*temp_all_filters) + else: + sub_filters_result = or_(*temp_all_filters) + filters.append(sub_filters_result) + return filters + + if temp_filters: # text if logical_operator == "and": # type: ignore temp_filters_result = and_(*temp_filters) else: temp_filters_result = or_(*temp_filters) filters.append(temp_filters_result) + return filters - sub_filters_result: ColumnElement[bool] - if sub_filters: + if sub_filters: # Boolean if logical_operator == "and": # type: ignore sub_filters_result = and_(*sub_filters) else: @@ -375,6 +385,7 @@ class KnowledgeRetrievalNode(LLMNode): return filters + def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: @@ -392,17 +403,13 @@ class KnowledgeRetrievalNode(LLMNode): # todo: do not support external_knowledge_retrieval if node_data.metadata_filtering_complex_conditions: # Enable forward references - MetadataFilteringComplexSubCondition.model_rebuild() + MetadataFilteringComplexCondition.model_rebuild() metadata_filtering_complex_conditions = MetadataFilteringComplexCondition( **node_data.metadata_filtering_complex_conditions.model_dump() ) - for condition in metadata_filtering_complex_conditions.conditions: # type: ignore - filters = self._recursive_metadata_filter(condition, filters) + filters = self._recursive_metadata_filter(metadata_filtering_complex_conditions, filters) if filters: - if metadata_filtering_complex_conditions.logical_operator == "and": # type: ignore - document_query = document_query.filter(and_(*filters)) - else: - document_query = document_query.filter(or_(*filters)) + document_query = document_query.filter(*filters) documents = document_query.all() # group by dataset_id metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore