feat: optimize metadata recursive strategy

pull/18136/head
JF.Hsiong 1 year ago
parent 1110c2ca2f
commit 96bbc1b682

@ -116,20 +116,10 @@ class MetadataFilteringCondition(BaseModel):
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
class MetadataFilteringComplexSubCondition(BaseModel):
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
sub_conditions: Optional[list["MetadataFilteringComplexSubCondition"]] = None
class MetadataFilteringComplexCondition(BaseModel):
"""
Complex Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[MetadataFilteringComplexSubCondition]] = Field(default=None, deprecated=True)
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
sub_conditions: Optional[list["MetadataFilteringComplexCondition"]] = None
class KnowledgeRetrievalNodeData(BaseNodeData):

@ -47,7 +47,6 @@ from services.feature_service import FeatureService
from .entities import (
KnowledgeRetrievalNodeData,
MetadataFilteringComplexCondition,
MetadataFilteringComplexSubCondition,
ModelConfig,
)
from .exc import (
@ -322,7 +321,7 @@ class KnowledgeRetrievalNode(LLMNode):
return retrieval_resource_list
def _recursive_metadata_filter(
self, metadata_filtering_complex_conditions: MetadataFilteringComplexSubCondition, filters
self, metadata_filtering_complex_conditions: MetadataFilteringComplexCondition, filters
):
logical_operator = metadata_filtering_complex_conditions.logical_operator
conditions = metadata_filtering_complex_conditions.conditions
@ -331,8 +330,8 @@ class KnowledgeRetrievalNode(LLMNode):
sub_filters = []
if sub_conditions:
for sub_condition in sub_conditions:
sub_filter = self._recursive_metadata_filter(sub_condition, filters)
sub_filters.append(sub_filter)
sub_filter = self._recursive_metadata_filter(sub_condition, [])
sub_filters.extend(sub_filter)
temp_filters: list = []
if conditions:
@ -357,16 +356,27 @@ class KnowledgeRetrievalNode(LLMNode):
expected_value,
temp_filters,
)
temp_filters_result: ColumnElement[bool]
if temp_filters:
sub_filters_result: ColumnElement
temp_filters_result: ColumnElement
if temp_filters and sub_filters:
temp_all_filters = sub_filters +temp_filters
if logical_operator == "and": # type: ignore
sub_filters_result = and_(*temp_all_filters)
else:
sub_filters_result = or_(*temp_all_filters)
filters.append(sub_filters_result)
return filters
if temp_filters: # text
if logical_operator == "and": # type: ignore
temp_filters_result = and_(*temp_filters)
else:
temp_filters_result = or_(*temp_filters)
filters.append(temp_filters_result)
return filters
sub_filters_result: ColumnElement[bool]
if sub_filters:
if sub_filters: # Boolean
if logical_operator == "and": # type: ignore
sub_filters_result = and_(*sub_filters)
else:
@ -375,6 +385,7 @@ class KnowledgeRetrievalNode(LLMNode):
return filters
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
@ -392,17 +403,13 @@ class KnowledgeRetrievalNode(LLMNode):
# todo: do not support external_knowledge_retrieval
if node_data.metadata_filtering_complex_conditions:
# Enable forward references
MetadataFilteringComplexSubCondition.model_rebuild()
MetadataFilteringComplexCondition.model_rebuild()
metadata_filtering_complex_conditions = MetadataFilteringComplexCondition(
**node_data.metadata_filtering_complex_conditions.model_dump()
)
for condition in metadata_filtering_complex_conditions.conditions: # type: ignore
filters = self._recursive_metadata_filter(condition, filters)
filters = self._recursive_metadata_filter(metadata_filtering_complex_conditions, filters)
if filters:
if metadata_filtering_complex_conditions.logical_operator == "and": # type: ignore
document_query = document_query.filter(and_(*filters))
else:
document_query = document_query.filter(or_(*filters))
document_query = document_query.filter(*filters)
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore

Loading…
Cancel
Save