|
|
|
|
@ -6,7 +6,7 @@ from collections import defaultdict
|
|
|
|
|
from collections.abc import Mapping, Sequence
|
|
|
|
|
from typing import Any, Optional, cast
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import Integer, and_, func, or_, text
|
|
|
|
|
from sqlalchemy import ColumnElement, Integer, and_, func, or_, text
|
|
|
|
|
from sqlalchemy import cast as sqlalchemy_cast
|
|
|
|
|
|
|
|
|
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
|
|
|
|
@ -334,16 +334,16 @@ class KnowledgeRetrievalNode(LLMNode):
|
|
|
|
|
sub_filter = self._recursive_metadata_filter(sub_condition, filters)
|
|
|
|
|
sub_filters.append(sub_filter)
|
|
|
|
|
|
|
|
|
|
temp_filters = []
|
|
|
|
|
temp_filters: list = []
|
|
|
|
|
if conditions:
|
|
|
|
|
for sequence, condition in enumerate(conditions):
|
|
|
|
|
metadata_name = condition.name
|
|
|
|
|
expected_value = condition.value
|
|
|
|
|
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
|
|
|
|
|
if isinstance(expected_value, str):
|
|
|
|
|
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
|
|
|
|
expected_value
|
|
|
|
|
).value[0]
|
|
|
|
|
expected_value = self.graph_runtime_state.variable_pool.convert_template(expected_value).value[
|
|
|
|
|
0
|
|
|
|
|
]
|
|
|
|
|
if expected_value.value_type == "number": # type: ignore
|
|
|
|
|
expected_value = expected_value.value # type: ignore
|
|
|
|
|
elif expected_value.value_type == "string": # type: ignore
|
|
|
|
|
@ -357,20 +357,21 @@ class KnowledgeRetrievalNode(LLMNode):
|
|
|
|
|
expected_value,
|
|
|
|
|
temp_filters,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
temp_filters_result: ColumnElement[bool]
|
|
|
|
|
if temp_filters:
|
|
|
|
|
if logical_operator == "and": # type: ignore
|
|
|
|
|
temp_filters = and_(*temp_filters)
|
|
|
|
|
temp_filters_result = and_(*temp_filters)
|
|
|
|
|
else:
|
|
|
|
|
temp_filters = or_(*temp_filters)
|
|
|
|
|
filters.append(temp_filters)
|
|
|
|
|
temp_filters_result = or_(*temp_filters)
|
|
|
|
|
filters.append(temp_filters_result)
|
|
|
|
|
|
|
|
|
|
sub_filters_result: ColumnElement[bool]
|
|
|
|
|
if sub_filters:
|
|
|
|
|
if logical_operator == "and": # type: ignore
|
|
|
|
|
sub_filters = and_(*sub_filters)
|
|
|
|
|
sub_filters_result = and_(*sub_filters)
|
|
|
|
|
else:
|
|
|
|
|
sub_filters = or_(*sub_filters)
|
|
|
|
|
filters.append(sub_filters)
|
|
|
|
|
sub_filters_result = or_(*sub_filters)
|
|
|
|
|
filters.append(sub_filters_result)
|
|
|
|
|
|
|
|
|
|
return filters
|
|
|
|
|
|
|
|
|
|
@ -393,21 +394,21 @@ class KnowledgeRetrievalNode(LLMNode):
|
|
|
|
|
# Enable forward references
|
|
|
|
|
MetadataFilteringComplexSubCondition.model_rebuild()
|
|
|
|
|
metadata_filtering_complex_conditions = MetadataFilteringComplexCondition(
|
|
|
|
|
**node_data.metadata_filtering_complex_conditions.model_dump())
|
|
|
|
|
for sequence, condition in enumerate(metadata_filtering_complex_conditions.conditions): # type: ignore
|
|
|
|
|
**node_data.metadata_filtering_complex_conditions.model_dump()
|
|
|
|
|
)
|
|
|
|
|
for condition in metadata_filtering_complex_conditions.conditions: # type: ignore
|
|
|
|
|
filters = self._recursive_metadata_filter(condition, filters)
|
|
|
|
|
if filters:
|
|
|
|
|
if metadata_filtering_complex_conditions.logical_operator == "and": # type: ignore
|
|
|
|
|
document_query = document_query.filter(and_(*filters))
|
|
|
|
|
else:
|
|
|
|
|
document_query = document_query.filter(or_(*filters))
|
|
|
|
|
metadata_condition = metadata_filtering_complex_conditions
|
|
|
|
|
documents = document_query.all()
|
|
|
|
|
# group by dataset_id
|
|
|
|
|
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
|
|
|
|
for document in documents:
|
|
|
|
|
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
|
|
|
|
return metadata_filter_document_ids, metadata_condition
|
|
|
|
|
return metadata_filter_document_ids, MetadataCondition()
|
|
|
|
|
elif node_data.metadata_filtering_mode == "automatic":
|
|
|
|
|
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
|
|
|
|
|
if automatic_metadata_filters:
|
|
|
|
|
@ -435,7 +436,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
|
|
|
|
if node_data.metadata_filtering_conditions:
|
|
|
|
|
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
|
|
|
|
|
if node_data.metadata_filtering_conditions:
|
|
|
|
|
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
|
|
|
|
for sequence, condition in enumerate(metadata_condition.conditions): # type: ignore
|
|
|
|
|
metadata_name = condition.name
|
|
|
|
|
expected_value = condition.value
|
|
|
|
|
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
|
|
|
|
|
|