【Dify】 全文检索的分片,1:保持每片的长度都达到阈值,最后一段自动补齐并达到阈值。 2,将标题加入到第一个分片内

pull/22121/head
liuchangsheng@wisdomidata.com 11 months ago
parent 4d14e5d2bd
commit e51dad7639

@ -4,6 +4,8 @@ from __future__ import annotations
from typing import Any, Optional
from sqlalchemy.dialects.postgresql import JSONB
from core.model_manager import ModelInstance
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.rag.splitter.text_splitter import (
@ -63,7 +65,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
self._fixed_separator = fixed_separator
self._separators = separators or ["\n\n", "\n", " ", ""]
def split_text(self, text: str) -> list[str]:
def split_text(self, text: str, metadata:Optional[dict] = None) -> list[str]:
"""Split incoming text and return chunks."""
if self._fixed_separator:
chunks = text.split(self._fixed_separator)
@ -75,7 +77,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
for chunk, chunk_length in zip(chunks, chunks_lengths):
if chunk_length > self._chunk_size:
if self._keep_separator :
final_chunks.extend(self.recursive_split_text_keep_separator_(chunk)) # 调用递归分割方法进一步拆分。
final_chunks.extend(self.recursive_split_text_keep_separator_(chunk,metadata)) # 调用递归分割方法进一步拆分。
continue
final_chunks.extend(self.recursive_split_text(chunk))
else:
@ -159,20 +161,37 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
def recursive_split_text_keep_separator_(self, text: str) -> list[str]: # 定义递归分割方法。
def recursive_split_text_keep_separator_(self, text: str,metadata:Optional[dict] = None) -> list[str]: # 定义递归分割方法。
"""Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是递归地分割文本并返回块。
char_split = False
full_last_text = False
if metadata is not None:
if "char_split" in metadata:
# 分片未达阈值是否按照char分片继续合并
char_split = metadata["char_split"]
if "char_split" in metadata:
# 是否补全最后一个未达到阈值的分片
full_last_text = metadata["full_last_text"]
final_chunks = [] # 初始化最终的块列表。
current_part_list = []
self.append_next_split_text(current_part_list=current_part_list,
current_length_list=[],
text=text,
final_chunks = final_chunks,
separators = self._separators)
separators = self._separators,
char_split=char_split,
)
if len(current_part_list): # 如果还有剩余的当前块。
final_chunks.append("".join(current_part_list)) # 将其加入最终块列表。
# 是否补全最后一个未达到阈值的分片
if full_last_text:
# 补全
self.set_full_last_text_chunks(final_chunks=final_chunks)
return final_chunks # 返回最终的块列表。
@classmethod
@ -201,23 +220,34 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
current_length_list:list[int],
text: str,
final_chunks: list[str],
separators : list[str]): # 定义递归分割方法。
separators : list[str],
char_split : bool,
): # 定义递归分割方法。
if text:
# 需要判断是否可以再拼接
splits, new_separators_ = self.get_splits_(text, separators)
s_lens = self._length_function(splits) # 计算每个分割部分的长度。
for s, s_len in zip(splits, s_lens): # 遍历每个分割部分及其长度。
split_len = len(splits)
for idx,s in enumerate(splits): # 遍历每个分割部分及其长度。
s_len = s_lens[idx]
current_length = sum(current_length_list)
if "制定综合主进度" in s:
# import pdb; pdb.post_mortem()
print(s)
if current_length + s_len <= self._chunk_size: # 如果当前块可以容纳更多内容。
current_part_list.append(s) # 将当前部分加入当前块。
current_length_list.append(s_len)
else:
if len(new_separators_) == 0:
# 判断是否启用字符拆分
if char_split:
# 按照char拆分和拼接直到长度达到阈值
s,s_len = self.char_splits(
current_part_list=current_part_list,
current_length_list=current_length_list,
text=s,
s_len=s_len
)
# 将片段加入到列表中
final_chunks.append("".join(current_part_list))
# 计算出重叠部分的内容
@ -237,7 +267,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
current_length_list=current_length_list,
text=s,
final_chunks=final_chunks,
separators=new_separators_)
separators=new_separators_,
char_split=char_split)
def get_overlap_part(self,current_part_list:list[str],
current_length_list:list[int]) -> (int,str): # 定义递归分割方法。
@ -265,3 +296,49 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
overlap_part_list[0:0] = current_part_list_reversed[index]
# overlap_part_list.append(current_part_list_reversed[index])
return overlap_part_length_, "".join(overlap_part_list)
# 按照char 继续拼接,直到长度达到阈值
def char_splits(self,
current_part_list:list[str],
current_length_list:list[int],
text: str,
s_len: int) -> (str,int): # 定义递归分割方法。
char_splits = list(text)
char_s_lens = self._length_function(char_splits) # 计算每个分割部分的长度。
for char_idx, char_s in enumerate(char_splits): # 遍历每个分割部分及其长度。
char_s_len = char_s_lens[char_idx]
char_current_length = sum(current_length_list)
if char_current_length + char_s_len <= self._chunk_size: # 如果当前块可以容纳更多内容。
current_part_list.append(char_s) # 将当前部分加入当前块。
current_length_list.append(char_s_len)
else:
last_s = char_splits[char_idx:]
text = "".join(last_s)
last_s_lens = self._length_function([text])
s_len = last_s_lens[0]
break
return text,s_len
# 按照char 继续拼接,直到长度达到阈值
def set_full_last_text_chunks(self,
final_chunks: list[str]): # 定义递归分割方法。
if final_chunks:
# 取最后一个片段
final_chunk = final_chunks[-1]
# 计算最后一个分片的长度
final_chunk_lens = self._length_function([final_chunk])
# 是否达到阈值,如果未达到,计算空格的长度,使用空格补全
if final_chunk_lens[0] < self._chunk_size:
# 计算空格的长度
space_len = self._length_function(["-"])[0]
# 未达阈值,补充空格
sum_len = self._chunk_size - final_chunk_lens[0]
# 整除
num = sum_len // space_len
# 重新合并空格
space_s = [final_chunk]
for i in range(num):
space_s.append("-")
final_chunks[-1] = "".join(space_s)

@ -69,7 +69,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
self._add_start_index = add_start_index
@abstractmethod
def split_text(self, text: str) -> list[str]:
def split_text(self, text: str, metadata:Optional[dict] = None) -> list[str]:
"""Split text into multiple components."""
def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]:
@ -78,7 +78,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
documents = []
for i, text in enumerate(texts):
index = -1
for chunk in self.split_text(text):
for chunk in self.split_text(text,_metadatas[i]):
metadata = copy.deepcopy(_metadatas[i])
if self._add_start_index:
index = text.find(chunk, index + 1)

Loading…
Cancel
Save