|
|
|
|
@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
|
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
|
|
|
|
from core.rag.datasource.retrieval_service import RetrievalService
|
|
|
|
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
|
|
|
|
from core.rag.entities.context_entities import DocumentContext
|
|
|
|
|
from core.rag.models.document import Document as RetrievalDocument
|
|
|
|
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
|
|
|
|
@ -14,7 +15,7 @@ from models.dataset import Dataset
|
|
|
|
|
from models.dataset import Document as DatasetDocument
|
|
|
|
|
from services.external_knowledge_service import ExternalDatasetService
|
|
|
|
|
|
|
|
|
|
default_retrieval_model = {
|
|
|
|
|
default_retrieval_model: dict[str, Any] = {
|
|
|
|
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
|
|
|
|
"reranking_enable": False,
|
|
|
|
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
|
|
|
|
@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
else:
|
|
|
|
|
document_ids_filter = None
|
|
|
|
|
if dataset.provider == "external":
|
|
|
|
|
results = []
|
|
|
|
|
results: list[RetrievalDocument] = []
|
|
|
|
|
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
|
|
|
|
tenant_id=dataset.tenant_id,
|
|
|
|
|
dataset_id=dataset.id,
|
|
|
|
|
@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
document.metadata["dataset_name"] = dataset.name
|
|
|
|
|
results.append(document)
|
|
|
|
|
# deal with external documents
|
|
|
|
|
context_list = []
|
|
|
|
|
context_list: list[RetrievalSourceMetadata] = []
|
|
|
|
|
for position, item in enumerate(results, start=1):
|
|
|
|
|
if item.metadata is not None:
|
|
|
|
|
source = {
|
|
|
|
|
"position": position,
|
|
|
|
|
"dataset_id": item.metadata.get("dataset_id"),
|
|
|
|
|
"dataset_name": item.metadata.get("dataset_name"),
|
|
|
|
|
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
|
|
|
|
"document_name": item.metadata.get("title"),
|
|
|
|
|
"data_source_type": "external",
|
|
|
|
|
"retriever_from": self.retriever_from,
|
|
|
|
|
"score": item.metadata.get("score"),
|
|
|
|
|
"title": item.metadata.get("title"),
|
|
|
|
|
"content": item.page_content,
|
|
|
|
|
}
|
|
|
|
|
source = RetrievalSourceMetadata(
|
|
|
|
|
position=position,
|
|
|
|
|
dataset_id=item.metadata.get("dataset_id"),
|
|
|
|
|
dataset_name=item.metadata.get("dataset_name"),
|
|
|
|
|
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
|
|
|
|
|
document_name=item.metadata.get("title"),
|
|
|
|
|
data_source_type="external",
|
|
|
|
|
retriever_from=self.retriever_from,
|
|
|
|
|
score=item.metadata.get("score"),
|
|
|
|
|
title=item.metadata.get("title"),
|
|
|
|
|
content=item.page_content,
|
|
|
|
|
)
|
|
|
|
|
context_list.append(source)
|
|
|
|
|
for hit_callback in self.hit_callbacks:
|
|
|
|
|
hit_callback.return_retriever_resource_info(context_list)
|
|
|
|
|
@ -125,7 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
return ""
|
|
|
|
|
# get retrieval model , if the model is not setting , using default
|
|
|
|
|
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
|
|
|
|
retrieval_resource_list = []
|
|
|
|
|
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
|
|
|
|
if dataset.indexing_technique == "economy":
|
|
|
|
|
# use keyword table query
|
|
|
|
|
documents = RetrievalService.retrieve(
|
|
|
|
|
@ -163,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
for item in documents:
|
|
|
|
|
if item.metadata is not None and item.metadata.get("score"):
|
|
|
|
|
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
|
|
|
|
document_context_list = []
|
|
|
|
|
document_context_list: list[DocumentContext] = []
|
|
|
|
|
records = RetrievalService.format_retrieval_documents(documents)
|
|
|
|
|
if records:
|
|
|
|
|
for record in records:
|
|
|
|
|
@ -197,37 +198,37 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
if dataset and document:
|
|
|
|
|
source = {
|
|
|
|
|
"dataset_id": dataset.id,
|
|
|
|
|
"dataset_name": dataset.name,
|
|
|
|
|
"document_id": document.id, # type: ignore
|
|
|
|
|
"document_name": document.name, # type: ignore
|
|
|
|
|
"data_source_type": document.data_source_type, # type: ignore
|
|
|
|
|
"segment_id": segment.id,
|
|
|
|
|
"retriever_from": self.retriever_from,
|
|
|
|
|
"score": record.score or 0.0,
|
|
|
|
|
"doc_metadata": document.doc_metadata, # type: ignore
|
|
|
|
|
}
|
|
|
|
|
source = RetrievalSourceMetadata(
|
|
|
|
|
dataset_id=dataset.id,
|
|
|
|
|
dataset_name=dataset.name,
|
|
|
|
|
document_id=document.id, # type: ignore
|
|
|
|
|
document_name=document.name, # type: ignore
|
|
|
|
|
data_source_type=document.data_source_type, # type: ignore
|
|
|
|
|
segment_id=segment.id,
|
|
|
|
|
retriever_from=self.retriever_from,
|
|
|
|
|
score=record.score or 0.0,
|
|
|
|
|
doc_metadata=document.doc_metadata, # type: ignore
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.retriever_from == "dev":
|
|
|
|
|
source["hit_count"] = segment.hit_count
|
|
|
|
|
source["word_count"] = segment.word_count
|
|
|
|
|
source["segment_position"] = segment.position
|
|
|
|
|
source["index_node_hash"] = segment.index_node_hash
|
|
|
|
|
source.hit_count = segment.hit_count
|
|
|
|
|
source.word_count = segment.word_count
|
|
|
|
|
source.segment_position = segment.position
|
|
|
|
|
source.index_node_hash = segment.index_node_hash
|
|
|
|
|
if segment.answer:
|
|
|
|
|
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
|
|
|
|
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
|
|
|
|
else:
|
|
|
|
|
source["content"] = segment.content
|
|
|
|
|
source.content = segment.content
|
|
|
|
|
retrieval_resource_list.append(source)
|
|
|
|
|
|
|
|
|
|
if self.return_resource and retrieval_resource_list:
|
|
|
|
|
retrieval_resource_list = sorted(
|
|
|
|
|
retrieval_resource_list,
|
|
|
|
|
key=lambda x: x.get("score") or 0.0,
|
|
|
|
|
key=lambda x: x.score or 0.0,
|
|
|
|
|
reverse=True,
|
|
|
|
|
)
|
|
|
|
|
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
|
|
|
|
|
item["position"] = position # type: ignore
|
|
|
|
|
item.position = position # type: ignore
|
|
|
|
|
for hit_callback in self.hit_callbacks:
|
|
|
|
|
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
|
|
|
|
if document_context_list:
|
|
|
|
|
|