【Dify】 增加全文检索的接口

pull/22121/head
liuchangsheng@wisdomidata.com 11 months ago
parent d988bf7ec8
commit 76753ce0d1

@ -6,6 +6,7 @@ from unstructured.utils import first
from controllers.console import api from controllers.console import api
from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
from services.ext.account_ext_service import AccountExtService, TenantExtService from services.ext.account_ext_service import AccountExtService, TenantExtService
from services.ext.dataset_ext_service import DocumentExtService
from models.account import ( from models.account import (
Account, Account,
Tenant, Tenant,
@ -81,7 +82,35 @@ class TenantInitApi(Resource):
tenant_data = TenantExtService.init_tenant(target_tenant_id=target_tenant_id,target_tenant_name=target_tenant_name) tenant_data = TenantExtService.init_tenant(target_tenant_id=target_tenant_id,target_tenant_name=target_tenant_name)
return tenant_data.to_dict(),200 return tenant_data.to_dict(),200
class FullSearchTextApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
'dataset_names',
action='append',
help='List of names'
)
parser.add_argument("query_text", type=str, required=True, location="json")
args = parser.parse_args()
dataset_names = args.dataset_names
query_text = args.query_text
current_user = flask_login.current_user
tenant = current_user.current_tenant
search_datas = DocumentExtService.get_full_search_data(
dataset_names=dataset_names,
tenant_id=tenant.id,
query_text=query_text
)
return search_datas
api.add_resource(AccountsApi, "/accounts/update") api.add_resource(AccountsApi, "/accounts/update")
api.add_resource(TenantEnableApi, "/tenant/enable") api.add_resource(TenantEnableApi, "/tenant/enable")
api.add_resource(TenantInitApi, "/tenant/init") api.add_resource(TenantInitApi, "/tenant/init")
api.add_resource(LoginAccountsApi, "/login/account/info") api.add_resource(LoginAccountsApi, "/login/account/info")
api.add_resource(FullSearchTextApi, "/full/search")

@ -1,6 +1,6 @@
from models import ApiToken, Account, Tenant from models import ApiToken, Account, Tenant
from models.dataset import ( from models.dataset import (
Dataset,DocumentSegment Dataset,DocumentSegment,Document
) )
from core.rag.models.document import Document as DocumentModel from core.rag.models.document import Document as DocumentModel
from core.errors.error import ( from core.errors.error import (
@ -18,6 +18,9 @@ from extensions.ext_database import db
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from configs.ext_config import get_init_knowledge_config,get_init_full_text_knowledge_config from configs.ext_config import get_init_knowledge_config,get_init_full_text_knowledge_config
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from sqlalchemy import text, bindparam,select,func
from collections import defaultdict
from sqlalchemy.dialects.postgresql import ARRAY
class DatasetExtService: class DatasetExtService:
resource_type = "dataset" resource_type = "dataset"
@ -198,3 +201,71 @@ class DocumentExtService:
break break
return next_segment return next_segment
def get_full_search_data(dataset_names: list[str], tenant_id : str, query_text: str):
datasets = db.session.query(Dataset).filter(Dataset.name.in_(dataset_names),Dataset.tenant_id == tenant_id).all()
dataset_ids = [dataset.id for dataset in datasets]
# 精准查询的向量片段
fetch_segments = DocumentExtService.get_full_search_segments(dataset_ids=dataset_ids,query_text=query_text)
search_datas = []
for segment in fetch_segments:
search_data = {
"title": segment.document_name,
"content": segment.segment_content,
"query": query_text
}
search_datas.append(search_data)
return search_datas
def get_full_search_segments(dataset_ids: list[str], query_text: str):
sql = text("""
SELECT s.id segment_id, s.document_id, s.content segment_content, d.name document_name
FROM document_segments s
left join documents d on d.id = s.document_id
WHERE content ILIKE :keyword and d.dataset_id::text = ANY(:dataset_ids)
""")
# import uuid
# dataset_ids_uuid = [uuid.UUID(id_str) for id_str in dataset_ids]
segments_rows = db.session.execute(sql, {"keyword": f"%{query_text}%", "dataset_ids" : dataset_ids}).fetchall()
sql = text("""
SELECT d.id AS document_id,
d.name AS document_name,
s.id AS segment_id,
s.content AS segment_content
FROM documents d
JOIN (
SELECT s1.*
FROM document_segments s1
INNER JOIN (
SELECT document_id, MIN(position) AS first_position
FROM document_segments
GROUP BY document_id
) s2 ON s1.document_id = s2.document_id AND s1.position = s2.first_position
) s ON d.id = s.document_id
WHERE d.name ILIKE :keyword and d.dataset_id::text = ANY(:dataset_ids_)
""")
document_rows = db.session.execute(sql, {"keyword": f"%{query_text}%", "dataset_ids_" : dataset_ids}).fetchall()
grouped = defaultdict(list)
for row in document_rows:
grouped[row.document_id].append(row)
for row in segments_rows:
grouped[row.document_id].append(row)
fetch_segments = []
# 遍历 grouped
for document_id, segment_list in grouped.items():
# document_name = segment_list[0].document_name
if len(segment_list) == 1:
fetch_segments.append(segment_list[0])
else:
fetch_segments.append(segment_list[1])
return fetch_segments

Loading…
Cancel
Save