pull/18136/merge
hsiong 1 year ago committed by GitHub
commit 9b8104235b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -117,6 +117,12 @@ class MetadataFilteringCondition(BaseModel):
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
class MetadataFilteringComplexCondition(BaseModel):
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
sub_conditions: Optional[list["MetadataFilteringComplexCondition"]] = None
class KnowledgeRetrievalNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
@ -128,7 +134,8 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual", "complex_conditions"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
metadata_filtering_complex_conditions: Optional[MetadataFilteringComplexCondition] = None
vision: VisionConfig = Field(default_factory=VisionConfig)

@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from sqlalchemy import Integer, and_, func, or_, text
from sqlalchemy import ColumnElement, Integer, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -44,7 +44,11 @@ from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from models.workflow import WorkflowNodeExecutionStatus
from services.feature_service import FeatureService
from .entities import KnowledgeRetrievalNodeData, ModelConfig
from .entities import (
KnowledgeRetrievalNodeData,
MetadataFilteringComplexCondition,
ModelConfig,
)
from .exc import (
InvalidModelTypeError,
KnowledgeRetrievalNodeError,
@ -316,6 +320,74 @@ class KnowledgeRetrievalNode(LLMNode):
item["metadata"]["position"] = position
return retrieval_resource_list
def _recursive_metadata_filter(
self, metadata_filtering_complex_conditions: MetadataFilteringComplexCondition, filters
):
logical_operator = metadata_filtering_complex_conditions.logical_operator
conditions = metadata_filtering_complex_conditions.conditions
sub_conditions = metadata_filtering_complex_conditions.sub_conditions
sub_filters = []
if sub_conditions:
for sub_condition in sub_conditions:
sub_filter = self._recursive_metadata_filter(sub_condition, [])
sub_filters.extend(sub_filter)
temp_filters: list = []
if conditions:
for sequence, condition in enumerate(conditions):
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(expected_value).value[
0
]
if expected_value.value_type == "number": # type: ignore
expected_value = expected_value.value # type: ignore
elif expected_value.value_type == "string": # type: ignore
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
else:
raise ValueError("Invalid expected metadata value type")
temp_filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
temp_filters,
)
sub_filters_result: ColumnElement
temp_filters_result: ColumnElement
if temp_filters and sub_filters:
if logical_operator == "and": # type: ignore
all_sub_filters = and_(*sub_filters)
all_temp_filters = and_(*temp_filters)
sub_filters_result = and_(all_temp_filters, all_sub_filters)
else:
all_sub_filters = or_(*sub_filters)
all_temp_filters = or_(*temp_filters)
sub_filters_result = or_(all_sub_filters, all_temp_filters)
filters.append(sub_filters_result)
return filters
if temp_filters: # text
if logical_operator == "and": # type: ignore
temp_filters_result = and_(*temp_filters)
else:
temp_filters_result = or_(*temp_filters)
filters.append(temp_filters_result)
return filters
if sub_filters: # Boolean
if logical_operator == "and": # type: ignore
sub_filters_result = and_(*sub_filters)
else:
sub_filters_result = or_(*sub_filters)
filters.append(sub_filters_result)
return filters
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
@ -329,6 +401,23 @@ class KnowledgeRetrievalNode(LLMNode):
metadata_condition = None
if node_data.metadata_filtering_mode == "disabled":
return None, None
elif node_data.metadata_filtering_mode == "complex_conditions":
# todo: do not support external_knowledge_retrieval
if node_data.metadata_filtering_complex_conditions:
# Enable forward references
MetadataFilteringComplexCondition.model_rebuild()
metadata_filtering_complex_conditions = MetadataFilteringComplexCondition(
**node_data.metadata_filtering_complex_conditions.model_dump()
)
filters = self._recursive_metadata_filter(metadata_filtering_complex_conditions, filters)
if filters:
document_query = document_query.filter(*filters)
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, MetadataCondition()
elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
if automatic_metadata_filters:
@ -358,7 +447,7 @@ class KnowledgeRetrievalNode(LLMNode):
if node_data.metadata_filtering_conditions:
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
if node_data.metadata_filtering_conditions:
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
for sequence, condition in enumerate(metadata_condition.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):

@ -38,6 +38,11 @@ const MetadataFilterSelector = ({
value: t('workflow.nodes.knowledgeRetrieval.metadata.options.manual.title'),
desc: t('workflow.nodes.knowledgeRetrieval.metadata.options.manual.subTitle'),
},
{
key: MetadataFilteringModeEnum.complexConditions,
value: t('workflow.nodes.knowledgeRetrieval.metadata.options.complexConditions.title'),
desc: t('workflow.nodes.knowledgeRetrieval.metadata.options.complexConditions.subTitle'),
},
]
const selectedOption = options.find(option => option.key === value)!

@ -74,6 +74,7 @@ export enum MetadataFilteringModeEnum {
disabled = 'disabled',
automatic = 'automatic',
manual = 'manual',
complexConditions = 'complex_conditions',
}
export enum MetadataFilteringVariableType {

@ -480,6 +480,10 @@ const translation = {
title: 'Manual',
subTitle: 'Manually add metadata filtering conditions',
},
complexConditions: {
title: 'complexConditions',
subTitle: 'Manually add metadata filtering complex conditions',
},
},
panel: {
title: 'Metadata Filter Conditions',

@ -481,6 +481,10 @@ const translation = {
title: '手动',
subTitle: '手动添加元数据过滤条件',
},
complexConditions: {
title: '手动多重条件',
subTitle: '手动添加元数据多重过滤条件',
},
},
panel: {
title: '元数据过滤条件',

Loading…
Cancel
Save