feat: optimize metadata recursive strategy

pull/18136/head
JF.Hsiong 1 year ago
parent 1110c2ca2f
commit 96bbc1b682

@ -116,20 +116,10 @@ class MetadataFilteringCondition(BaseModel):
logical_operator: Optional[Literal["and", "or"]] = "and" logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
class MetadataFilteringComplexSubCondition(BaseModel):
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
sub_conditions: Optional[list["MetadataFilteringComplexSubCondition"]] = None
class MetadataFilteringComplexCondition(BaseModel): class MetadataFilteringComplexCondition(BaseModel):
"""
Complex Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and" logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[MetadataFilteringComplexSubCondition]] = Field(default=None, deprecated=True) conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
sub_conditions: Optional[list["MetadataFilteringComplexCondition"]] = None
class KnowledgeRetrievalNodeData(BaseNodeData): class KnowledgeRetrievalNodeData(BaseNodeData):

@ -47,7 +47,6 @@ from services.feature_service import FeatureService
from .entities import ( from .entities import (
KnowledgeRetrievalNodeData, KnowledgeRetrievalNodeData,
MetadataFilteringComplexCondition, MetadataFilteringComplexCondition,
MetadataFilteringComplexSubCondition,
ModelConfig, ModelConfig,
) )
from .exc import ( from .exc import (
@ -322,7 +321,7 @@ class KnowledgeRetrievalNode(LLMNode):
return retrieval_resource_list return retrieval_resource_list
def _recursive_metadata_filter( def _recursive_metadata_filter(
self, metadata_filtering_complex_conditions: MetadataFilteringComplexSubCondition, filters self, metadata_filtering_complex_conditions: MetadataFilteringComplexCondition, filters
): ):
logical_operator = metadata_filtering_complex_conditions.logical_operator logical_operator = metadata_filtering_complex_conditions.logical_operator
conditions = metadata_filtering_complex_conditions.conditions conditions = metadata_filtering_complex_conditions.conditions
@ -331,8 +330,8 @@ class KnowledgeRetrievalNode(LLMNode):
sub_filters = [] sub_filters = []
if sub_conditions: if sub_conditions:
for sub_condition in sub_conditions: for sub_condition in sub_conditions:
sub_filter = self._recursive_metadata_filter(sub_condition, filters) sub_filter = self._recursive_metadata_filter(sub_condition, [])
sub_filters.append(sub_filter) sub_filters.extend(sub_filter)
temp_filters: list = [] temp_filters: list = []
if conditions: if conditions:
@ -357,16 +356,27 @@ class KnowledgeRetrievalNode(LLMNode):
expected_value, expected_value,
temp_filters, temp_filters,
) )
temp_filters_result: ColumnElement[bool]
if temp_filters: sub_filters_result: ColumnElement
temp_filters_result: ColumnElement
if temp_filters and sub_filters:
temp_all_filters = sub_filters +temp_filters
if logical_operator == "and": # type: ignore
sub_filters_result = and_(*temp_all_filters)
else:
sub_filters_result = or_(*temp_all_filters)
filters.append(sub_filters_result)
return filters
if temp_filters: # text
if logical_operator == "and": # type: ignore if logical_operator == "and": # type: ignore
temp_filters_result = and_(*temp_filters) temp_filters_result = and_(*temp_filters)
else: else:
temp_filters_result = or_(*temp_filters) temp_filters_result = or_(*temp_filters)
filters.append(temp_filters_result) filters.append(temp_filters_result)
return filters
sub_filters_result: ColumnElement[bool] if sub_filters: # Boolean
if sub_filters:
if logical_operator == "and": # type: ignore if logical_operator == "and": # type: ignore
sub_filters_result = and_(*sub_filters) sub_filters_result = and_(*sub_filters)
else: else:
@ -375,6 +385,7 @@ class KnowledgeRetrievalNode(LLMNode):
return filters return filters
def _get_metadata_filter_condition( def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
@ -392,17 +403,13 @@ class KnowledgeRetrievalNode(LLMNode):
# todo: do not support external_knowledge_retrieval # todo: do not support external_knowledge_retrieval
if node_data.metadata_filtering_complex_conditions: if node_data.metadata_filtering_complex_conditions:
# Enable forward references # Enable forward references
MetadataFilteringComplexSubCondition.model_rebuild() MetadataFilteringComplexCondition.model_rebuild()
metadata_filtering_complex_conditions = MetadataFilteringComplexCondition( metadata_filtering_complex_conditions = MetadataFilteringComplexCondition(
**node_data.metadata_filtering_complex_conditions.model_dump() **node_data.metadata_filtering_complex_conditions.model_dump()
) )
for condition in metadata_filtering_complex_conditions.conditions: # type: ignore filters = self._recursive_metadata_filter(metadata_filtering_complex_conditions, filters)
filters = self._recursive_metadata_filter(condition, filters)
if filters: if filters:
if metadata_filtering_complex_conditions.logical_operator == "and": # type: ignore document_query = document_query.filter(*filters)
document_query = document_query.filter(and_(*filters))
else:
document_query = document_query.filter(or_(*filters))
documents = document_query.all() documents = document_query.all()
# group by dataset_id # group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore

Loading…
Cancel
Save