You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gcgj-dify-1.7.0/api/services/ext/dataset_ext_service.py

178 lines
7.8 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from models import ApiToken, Account, Tenant
from models.dataset import (
Dataset,DocumentSegment
)
from core.rag.models.document import Document as DocumentModel
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from extensions.ext_database import db
from services.dataset_service import DatasetService, DocumentService
from configs.ext_config import get_init_knowledge_config
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
class DatasetExtService:
resource_type = "dataset"
token_prefix = "dataset-"
max_keys = 10
@staticmethod
def get_datasets(tenant_id=None, target_tenant_id=None) -> list[Dataset]:
datasets = (Dataset.query
.filter(Dataset.tenant_id == tenant_id,Dataset.target_tenant_id == target_tenant_id)
.all())
return datasets
@staticmethod
def init_dataset(tenant:Tenant=None, target_tenant_id:str=None,target_tenant_name:str=None, account:Account=None) -> list[Dataset]:
# 判断是否有知识库,如果没有,创建知识库
datasets = (DatasetExtService.get_datasets(tenant_id=tenant.id, target_tenant_id=target_tenant_id))
if not datasets:
public_name = f"PUBLIC_KNOWLEDGE"
public_description = f"{target_tenant_name}的公共知识库"
public_dataset = DatasetService.create_empty_dataset(tenant_id=tenant.id,
name=public_name,
description=public_description,
indexing_technique="",
account=account)
public_dataset.target_tenant_id = target_tenant_id
company_name = f"COMPANY_KNOWLEDGE"
company_description = f"{target_tenant_name}的企业知识库"
company_dataset = DatasetService.create_empty_dataset(tenant_id=tenant.id,
name=company_name,
description=company_description,
indexing_technique="",
account=account)
company_dataset.target_tenant_id = target_tenant_id
db.session.commit()
datasets = [public_dataset,company_dataset]
return datasets
@staticmethod
def set_dataset_config(dataset=None,current_user=None):
# 取默认的值
args = get_init_knowledge_config({})
# validate args
knowledge_config = KnowledgeConfig(**args)
print("knowledge_config")
try:
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
# @staticmethod
def get_or_add_datasets_api_token(self,tenant_id: str):
api_tokens = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == tenant_id)
.all()
)
if not api_tokens:
key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken()
api_token.tenant_id = tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token
else:
return api_tokens[-1]
class DocumentExtService:
# 为文档
@staticmethod
def set_next_segments(all_documents: list[DocumentModel]) :
# 判断文档是否为空
if all_documents:
document_ids = []
doc_segment_ids = []
for document in all_documents:
if document.children is None:
doc_segment_id = document.metadata["doc_id"]
document_id = document.metadata["document_id"]
doc_segment_ids.append(doc_segment_id)
document_ids.append(document_id)
# 找到文档的所有的
if len(document_ids) > 0:
document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all()
document_segment_data = {}
for document_segment in document_segments:
key = document_segment.document_id
if key not in document_segment_data:
document_segment_data[key] = []
document_segment_data[key].append(document_segment)
DocumentExtService.merged_next_segment_content(all_documents=all_documents,document_segment_data=document_segment_data,doc_segment_ids=doc_segment_ids)
# 为文档
@staticmethod
def merged_next_segment_content(all_documents: list[DocumentModel],document_segment_data: dict,doc_segment_ids: list) :
# 判断文档是否为空
if all_documents:
for document in all_documents:
if document.children is None:
doc_segment_id = document.metadata["doc_id"]
document_id = document.metadata["document_id"]
document_segments = document_segment_data[document_id]
next_segment = DocumentExtService.get_next_segment(doc_segment_id=doc_segment_id,document_segments=document_segments)
if next_segment and next_segment.index_node_id not in doc_segment_ids:
unin_content = DocumentExtService.merged_text(document.page_content, next_segment.content)
doc_segment_ids.append(next_segment.index_node_id)
document.page_content = unin_content
@staticmethod
def merged_text(text, target_text) -> str:
# 初始化最大重叠长度为0
max_overlap_length = 0 # 初始化变量max_overlap_length用于存储最大重叠长度
# 检查A的结尾与B的开头是否有大于10个字符的重叠
for overlap_length in range(1, min(len(text), len(target_text)) + 1): # 遍历可能的重叠长度从1到最小字符串长度
if text[-overlap_length:] == target_text[:overlap_length]: # 检查A的后缀和B的前缀是否相同
max_overlap_length = overlap_length # 更新最大重叠长度
# 如果有大于10个字符的重叠则合并字符串
if max_overlap_length > 10: # 判断最大重叠长度是否大于10
merged_string = text + target_text[max_overlap_length:] # 合并字符串,去掉重复部分
else:
merged_string = text
return merged_string
@staticmethod
def get_next_segment(doc_segment_id, document_segments: list[DocumentSegment]) -> DocumentSegment:
next_segment = None
if document_segments is not None and len(document_segments) > 0:
this_positions = -1
for index, document_segment in enumerate(document_segments):
if document_segment.index_node_id == doc_segment_id:
this_positions = document_segment.position
for document_segment in document_segments:
if document_segment.position == this_positions + 1:
next_segment = document_segment
break
return next_segment