You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gcgj-dify-1.7.0/api/core/rag/datasource/retrieval_service.py

524 lines
22 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from flask import Flask, current_app
from sqlalchemy.orm import load_only
from configs import dify_config
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
class RetrievalService:
# Cache precompiled regular expressions to avoid repeated compilation
@classmethod
def retrieve(
cls,
retrieval_method: str,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float] = 0.0,
reranking_model: Optional[dict] = None,
reranking_mode: str = "reranking_model",
weights: Optional[dict] = None,
document_ids_filter: Optional[list[str]] = None,
):
if not query:
return []
dataset = cls._get_dataset(dataset_id)
if not dataset:
return []
all_documents: list[Document] = []
exceptions: list[str] = []
# Optimize multithreading with thread pools
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
futures = []
if retrieval_method == "keyword_search":
futures.append(
executor.submit(
cls.keyword_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
all_documents=all_documents,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_semantic_search(retrieval_method):
futures.append(
executor.submit(
cls.embedding_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
futures.append(
executor.submit(
cls.full_text_index_search,
flask_app=current_app._get_current_object(), # type: ignore
dataset_id=dataset_id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
all_documents=all_documents,
retrieval_method=retrieval_method,
exceptions=exceptions,
document_ids_filter=document_ids_filter,
)
)
concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED)
if exceptions:
raise ValueError(";\n".join(exceptions))
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k,
)
return all_documents
@classmethod
def external_retrieve(
cls,
dataset_id: str,
query: str,
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []
metadata_condition = (
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
)
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id,
dataset_id,
query,
external_retrieval_model or {},
metadata_condition=metadata_condition,
)
return all_documents
@classmethod
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
return db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
@classmethod
def keyword_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
exceptions: list,
document_ids_filter: Optional[list[str]] = None,
):
with flask_app.app_context():
try:
dataset = cls._get_dataset(dataset_id)
if not dataset:
raise ValueError("dataset not found")
keyword = Keyword(dataset=dataset)
documents = keyword.search(
cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter
)
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
@classmethod
def embedding_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
document_ids_filter: Optional[list[str]] = None,
):
with flask_app.app_context():
try:
dataset = cls._get_dataset(dataset_id)
if not dataset:
raise ValueError("dataset not found")
vector = Vector(dataset=dataset)
documents = vector.search_by_vector(
query,
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
)
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
)
)
else:
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
@classmethod
def full_text_index_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
document_ids_filter: Optional[list[str]] = None,
):
with flask_app.app_context():
try:
dataset = cls._get_dataset(dataset_id)
if not dataset:
raise ValueError("dataset not found")
vector_processor = Vector(dataset=dataset)
documents = vector_processor.search_by_full_text(
cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter
)
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query,
documents=documents,
score_threshold=score_threshold,
top_n=len(documents),
)
)
else:
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
@staticmethod
def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"')
@classmethod
def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]:
"""Format retrieval documents with optimized batch processing"""
if not documents:
return []
try:
# Collect document IDs
document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata}
if not document_ids:
return []
# Batch query dataset documents
dataset_documents = {
doc.id: doc
for doc in db.session.query(DatasetDocument)
.filter(DatasetDocument.id.in_(document_ids))
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
.all()
}
records = []
include_segment_ids = set()
segment_child_map = {}
# Process documents
for document in documents:
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id")
child_chunk = (
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first()
)
if not child_chunk:
continue
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == child_chunk.segment_id,
)
.options(
load_only(
DocumentSegment.id,
DocumentSegment.content,
DocumentSegment.answer,
)
)
.first()
)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
# Handle normal documents
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
if not segment:
continue
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
}
records.append(record)
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"]
result = []
for record in records:
# Extract segment
segment = record["segment"]
# Extract child_chunks, ensuring it's a list or None
child_chunks = record.get("child_chunks")
if not isinstance(child_chunks, list):
child_chunks = None
# Extract score, ensuring it's a float or None
score_value = record.get("score")
score = (
float(score_value)
if score_value is not None and isinstance(score_value, int | float | str)
else None
)
cls.append_next_segments(records=records,dataset_documents=dataset_documents)
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
result.append(retrieval_segment)
return result
except Exception as e:
db.session.rollback()
raise e
@classmethod
def append_next_segments(cls, records: list[dict], dataset_documents : dict):
# import pdb; pdb.set_trace()
def filter_record(record):
document_id = record["segment"].document_id
if document_id in dataset_documents:
dataset_document = dataset_documents[document_id]
if dataset_document and dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
return True
return False
filtered_data = list(filter(filter_record, records))
cls.set_next_segments(records=filtered_data)
# 为文档
@classmethod
def set_next_segments(cls,records: list[dict]) :
# 判断文档是否为空
document_ids = []
doc_segment_ids = []
for record in records:
document_id = record["segment"].document_id
doc_segment_id = record["segment"].id
doc_segment_ids.append(doc_segment_id)
document_ids.append(document_id)
# 找到文档的所有的
if len(document_ids) > 0:
document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all()
document_segment_data = {}
for document_segment in document_segments:
key = document_segment.document_id
if key not in document_segment_data:
document_segment_data[key] = []
document_segment_data[key].append(document_segment)
cls.merged_next_segment_content(records=records, document_segment_data=document_segment_data,doc_segment_ids=doc_segment_ids)
@classmethod
def merged_next_segment_content(cls,records: list[dict],document_segment_data: dict,doc_segment_ids: list) :
# 按照分数倒叙排序
sorted_records = sorted(records, key=lambda r: r["score"], reverse=True)
# 只处理最大分数的前三个如果已存在顺延处理下一片直到满3个
index = 3
for record in sorted_records:
if index == 0:
break
document_id = record["segment"].document_id
doc_segment_id = record["segment"].id
content = record["segment"].content
document_segments = document_segment_data[document_id]
# 获取下一个分片
next_segment = cls.get_next_segment(doc_segment_id=doc_segment_id,document_segments=document_segments)
if next_segment and next_segment.id not in doc_segment_ids:
merged_string, merged = cls.merged_text(content, next_segment.content)
doc_segment_ids.append(next_segment.id)
if merged:
record["segment"].content = merged_string
index -= 1
@classmethod
def merged_text(cls, text, target_text) -> (str,bool):
# 初始化最大重叠长度为0
max_overlap_length = 0 # 初始化变量max_overlap_length用于存储最大重叠长度
# 检查A的结尾与B的开头是否有大于10个字符的重叠
for overlap_length in range(1, min(len(text), len(target_text)) + 1): # 遍历可能的重叠长度从1到最小字符串长度
if text[-overlap_length:] == target_text[:overlap_length]: # 检查A的后缀和B的前缀是否相同
max_overlap_length = overlap_length # 更新最大重叠长度
merged_string = text
merged = False
# 如果有大于10个字符的重叠则合并字符串
if max_overlap_length > 10: # 判断最大重叠长度是否大于10
merged_string = text + target_text[max_overlap_length:] # 合并字符串,去掉重复部分
merged = True
return merged_string,merged
@classmethod
def get_next_segment(cls,doc_segment_id, document_segments: list[DocumentSegment]) -> DocumentSegment:
# import pdb; pdb.set_trace()
next_segment = None
if document_segments is not None and len(document_segments) > 0:
this_positions = -1
for index, document_segment in enumerate(document_segments):
if document_segment.id == doc_segment_id:
this_positions = document_segment.position
for document_segment in document_segments:
if document_segment.position == this_positions + 1:
next_segment = document_segment
break
return next_segment