|
|
|
|
@ -1,11 +1,12 @@
|
|
|
|
|
from typing import Any
|
|
|
|
|
from typing import Any, Optional, cast
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
|
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
|
|
|
|
from core.rag.datasource.retrieval_service import RetrievalService
|
|
|
|
|
from core.rag.entities.context_entities import DocumentContext
|
|
|
|
|
from core.rag.entities.metadata_entities import MetadataCondition
|
|
|
|
|
from core.rag.models.document import Document as RetrievalDocument
|
|
|
|
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
|
|
|
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
|
|
|
|
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
|
|
|
|
from extensions.ext_database import db
|
|
|
|
|
@ -34,7 +35,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
|
|
|
|
description: str = "use this to retrieve a dataset. "
|
|
|
|
|
dataset_id: str
|
|
|
|
|
metadata_filtering_conditions: MetadataCondition
|
|
|
|
|
user_id: Optional[str] = None
|
|
|
|
|
retrieve_config: DatasetRetrieveConfigEntity
|
|
|
|
|
inputs: dict
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_dataset(cls, dataset: Dataset, **kwargs):
|
|
|
|
|
@ -48,7 +51,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
tenant_id=dataset.tenant_id,
|
|
|
|
|
dataset_id=dataset.id,
|
|
|
|
|
description=description,
|
|
|
|
|
metadata_filtering_conditions=MetadataCondition(),
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@ -61,6 +63,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
return ""
|
|
|
|
|
for hit_callback in self.hit_callbacks:
|
|
|
|
|
hit_callback.on_query(query, dataset.id)
|
|
|
|
|
dataset_retrieval = DatasetRetrieval()
|
|
|
|
|
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
|
|
|
|
[dataset.id],
|
|
|
|
|
query,
|
|
|
|
|
self.tenant_id,
|
|
|
|
|
self.user_id or "unknown",
|
|
|
|
|
cast(str, self.retrieve_config.metadata_filtering_mode),
|
|
|
|
|
cast(ModelConfig, self.retrieve_config.metadata_model_config),
|
|
|
|
|
self.retrieve_config.metadata_filtering_conditions,
|
|
|
|
|
self.inputs,
|
|
|
|
|
)
|
|
|
|
|
if metadata_filter_document_ids:
|
|
|
|
|
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
|
|
|
|
|
else:
|
|
|
|
|
document_ids_filter = None
|
|
|
|
|
if dataset.provider == "external":
|
|
|
|
|
results = []
|
|
|
|
|
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
|
|
|
|
@ -68,7 +85,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
dataset_id=dataset.id,
|
|
|
|
|
query=query,
|
|
|
|
|
external_retrieval_parameters=dataset.retrieval_model,
|
|
|
|
|
metadata_condition=self.metadata_filtering_conditions,
|
|
|
|
|
metadata_condition=metadata_condition,
|
|
|
|
|
)
|
|
|
|
|
for external_document in external_documents:
|
|
|
|
|
document = RetrievalDocument(
|
|
|
|
|
@ -104,12 +121,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
|
|
|
|
|
return str("\n".join([item.page_content for item in results]))
|
|
|
|
|
else:
|
|
|
|
|
if metadata_condition and not document_ids_filter:
|
|
|
|
|
return ""
|
|
|
|
|
# get retrieval model , if the model is not setting , using default
|
|
|
|
|
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
|
|
|
|
if dataset.indexing_technique == "economy":
|
|
|
|
|
# use keyword table query
|
|
|
|
|
documents = RetrievalService.retrieve(
|
|
|
|
|
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
|
|
|
|
|
retrieval_method="keyword_search",
|
|
|
|
|
dataset_id=dataset.id,
|
|
|
|
|
query=query,
|
|
|
|
|
top_k=self.top_k,
|
|
|
|
|
document_ids_filter=document_ids_filter,
|
|
|
|
|
)
|
|
|
|
|
return str("\n".join([document.page_content for document in documents]))
|
|
|
|
|
else:
|
|
|
|
|
@ -128,6 +151,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
else None,
|
|
|
|
|
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
|
|
|
|
weights=retrieval_model.get("weights"),
|
|
|
|
|
document_ids_filter=document_ids_filter,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
documents = []
|
|
|
|
|
|