style: fix style issues

pull/18136/head
JF.Hsiong 1 year ago
parent d6b95e0702
commit 96950fd8b5

@ -122,13 +122,16 @@ class MetadataFilteringComplexSubCondition(BaseModel):
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
sub_conditions: Optional[list["MetadataFilteringComplexSubCondition"]] = None sub_conditions: Optional[list["MetadataFilteringComplexSubCondition"]] = None
class MetadataFilteringComplexCondition(BaseModel): class MetadataFilteringComplexCondition(BaseModel):
""" """
Complex Metadata Filtering Condition. Complex Metadata Filtering Condition.
""" """
logical_operator: Optional[Literal["and", "or"]] = "and" logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[MetadataFilteringComplexSubCondition]] = Field(default=None, deprecated=True) conditions: Optional[list[MetadataFilteringComplexSubCondition]] = Field(default=None, deprecated=True)
class KnowledgeRetrievalNodeData(BaseNodeData): class KnowledgeRetrievalNodeData(BaseNodeData):
""" """
Knowledge retrieval Node Data. Knowledge retrieval Node Data.

@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from sqlalchemy import Integer, and_, func, or_, text from sqlalchemy import ColumnElement, Integer, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -322,7 +322,7 @@ class KnowledgeRetrievalNode(LLMNode):
return retrieval_resource_list return retrieval_resource_list
def _recursive_metadata_filter( def _recursive_metadata_filter(
self, metadata_filtering_complex_conditions: MetadataFilteringComplexSubCondition, filters self, metadata_filtering_complex_conditions: MetadataFilteringComplexSubCondition, filters
): ):
logical_operator = metadata_filtering_complex_conditions.logical_operator logical_operator = metadata_filtering_complex_conditions.logical_operator
conditions = metadata_filtering_complex_conditions.conditions conditions = metadata_filtering_complex_conditions.conditions
@ -334,16 +334,16 @@ class KnowledgeRetrievalNode(LLMNode):
sub_filter = self._recursive_metadata_filter(sub_condition, filters) sub_filter = self._recursive_metadata_filter(sub_condition, filters)
sub_filters.append(sub_filter) sub_filters.append(sub_filter)
temp_filters = [] temp_filters: list = []
if conditions: if conditions:
for sequence, condition in enumerate(conditions): for sequence, condition in enumerate(conditions):
metadata_name = condition.name metadata_name = condition.name
expected_value = condition.value expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
if isinstance(expected_value, str): if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template( expected_value = self.graph_runtime_state.variable_pool.convert_template(expected_value).value[
expected_value 0
).value[0] ]
if expected_value.value_type == "number": # type: ignore if expected_value.value_type == "number": # type: ignore
expected_value = expected_value.value # type: ignore expected_value = expected_value.value # type: ignore
elif expected_value.value_type == "string": # type: ignore elif expected_value.value_type == "string": # type: ignore
@ -357,20 +357,21 @@ class KnowledgeRetrievalNode(LLMNode):
expected_value, expected_value,
temp_filters, temp_filters,
) )
temp_filters_result: ColumnElement[bool]
if temp_filters: if temp_filters:
if logical_operator == "and": # type: ignore if logical_operator == "and": # type: ignore
temp_filters = and_(*temp_filters) temp_filters_result = and_(*temp_filters)
else: else:
temp_filters = or_(*temp_filters) temp_filters_result = or_(*temp_filters)
filters.append(temp_filters) filters.append(temp_filters_result)
sub_filters_result: ColumnElement[bool]
if sub_filters: if sub_filters:
if logical_operator == "and": # type: ignore if logical_operator == "and": # type: ignore
sub_filters = and_(*sub_filters) sub_filters_result = and_(*sub_filters)
else: else:
sub_filters = or_(*sub_filters) sub_filters_result = or_(*sub_filters)
filters.append(sub_filters) filters.append(sub_filters_result)
return filters return filters
@ -393,21 +394,21 @@ class KnowledgeRetrievalNode(LLMNode):
# Enable forward references # Enable forward references
MetadataFilteringComplexSubCondition.model_rebuild() MetadataFilteringComplexSubCondition.model_rebuild()
metadata_filtering_complex_conditions = MetadataFilteringComplexCondition( metadata_filtering_complex_conditions = MetadataFilteringComplexCondition(
**node_data.metadata_filtering_complex_conditions.model_dump()) **node_data.metadata_filtering_complex_conditions.model_dump()
for sequence, condition in enumerate(metadata_filtering_complex_conditions.conditions): # type: ignore )
for condition in metadata_filtering_complex_conditions.conditions: # type: ignore
filters = self._recursive_metadata_filter(condition, filters) filters = self._recursive_metadata_filter(condition, filters)
if filters: if filters:
if metadata_filtering_complex_conditions.logical_operator == "and": # type: ignore if metadata_filtering_complex_conditions.logical_operator == "and": # type: ignore
document_query = document_query.filter(and_(*filters)) document_query = document_query.filter(and_(*filters))
else: else:
document_query = document_query.filter(or_(*filters)) document_query = document_query.filter(or_(*filters))
metadata_condition = metadata_filtering_complex_conditions
documents = document_query.all() documents = document_query.all()
# group by dataset_id # group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents: for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition return metadata_filter_document_ids, MetadataCondition()
elif node_data.metadata_filtering_mode == "automatic": elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data) automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
if automatic_metadata_filters: if automatic_metadata_filters:
@ -435,7 +436,7 @@ class KnowledgeRetrievalNode(LLMNode):
if node_data.metadata_filtering_conditions: if node_data.metadata_filtering_conditions:
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump()) metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
if node_data.metadata_filtering_conditions: if node_data.metadata_filtering_conditions:
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore for sequence, condition in enumerate(metadata_condition.conditions): # type: ignore
metadata_name = condition.name metadata_name = condition.name
expected_value = condition.value expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):

Loading…
Cancel
Save