style: fix style issues

pull/18136/head
JF.Hsiong 1 year ago
parent d6b95e0702
commit 96950fd8b5

@ -122,13 +122,16 @@ class MetadataFilteringComplexSubCondition(BaseModel):
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
sub_conditions: Optional[list["MetadataFilteringComplexSubCondition"]] = None
class MetadataFilteringComplexCondition(BaseModel):
"""
Complex Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[MetadataFilteringComplexSubCondition]] = Field(default=None, deprecated=True)
class KnowledgeRetrievalNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.

@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from sqlalchemy import Integer, and_, func, or_, text
from sqlalchemy import ColumnElement, Integer, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -322,7 +322,7 @@ class KnowledgeRetrievalNode(LLMNode):
return retrieval_resource_list
def _recursive_metadata_filter(
self, metadata_filtering_complex_conditions: MetadataFilteringComplexSubCondition, filters
self, metadata_filtering_complex_conditions: MetadataFilteringComplexSubCondition, filters
):
logical_operator = metadata_filtering_complex_conditions.logical_operator
conditions = metadata_filtering_complex_conditions.conditions
@ -334,16 +334,16 @@ class KnowledgeRetrievalNode(LLMNode):
sub_filter = self._recursive_metadata_filter(sub_condition, filters)
sub_filters.append(sub_filter)
temp_filters = []
temp_filters: list = []
if conditions:
for sequence, condition in enumerate(conditions):
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
expected_value = self.graph_runtime_state.variable_pool.convert_template(expected_value).value[
0
]
if expected_value.value_type == "number": # type: ignore
expected_value = expected_value.value # type: ignore
elif expected_value.value_type == "string": # type: ignore
@ -357,20 +357,21 @@ class KnowledgeRetrievalNode(LLMNode):
expected_value,
temp_filters,
)
temp_filters_result: ColumnElement[bool]
if temp_filters:
if logical_operator == "and": # type: ignore
temp_filters = and_(*temp_filters)
temp_filters_result = and_(*temp_filters)
else:
temp_filters = or_(*temp_filters)
filters.append(temp_filters)
temp_filters_result = or_(*temp_filters)
filters.append(temp_filters_result)
sub_filters_result: ColumnElement[bool]
if sub_filters:
if logical_operator == "and": # type: ignore
sub_filters = and_(*sub_filters)
sub_filters_result = and_(*sub_filters)
else:
sub_filters = or_(*sub_filters)
filters.append(sub_filters)
sub_filters_result = or_(*sub_filters)
filters.append(sub_filters_result)
return filters
@ -393,21 +394,21 @@ class KnowledgeRetrievalNode(LLMNode):
# Enable forward references
MetadataFilteringComplexSubCondition.model_rebuild()
metadata_filtering_complex_conditions = MetadataFilteringComplexCondition(
**node_data.metadata_filtering_complex_conditions.model_dump())
for sequence, condition in enumerate(metadata_filtering_complex_conditions.conditions): # type: ignore
**node_data.metadata_filtering_complex_conditions.model_dump()
)
for condition in metadata_filtering_complex_conditions.conditions: # type: ignore
filters = self._recursive_metadata_filter(condition, filters)
if filters:
if metadata_filtering_complex_conditions.logical_operator == "and": # type: ignore
document_query = document_query.filter(and_(*filters))
else:
document_query = document_query.filter(or_(*filters))
metadata_condition = metadata_filtering_complex_conditions
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition
return metadata_filter_document_ids, MetadataCondition()
elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
if automatic_metadata_filters:
@ -435,7 +436,7 @@ class KnowledgeRetrievalNode(LLMNode):
if node_data.metadata_filtering_conditions:
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
if node_data.metadata_filtering_conditions:
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
for sequence, condition in enumerate(metadata_condition.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):

Loading…
Cancel
Save