|
|
|
@ -4,6 +4,8 @@ from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Optional
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sqlalchemy.dialects.postgresql import JSONB
|
|
|
|
|
|
|
|
|
|
|
|
from core.model_manager import ModelInstance
|
|
|
|
from core.model_manager import ModelInstance
|
|
|
|
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
|
|
|
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
|
|
|
from core.rag.splitter.text_splitter import (
|
|
|
|
from core.rag.splitter.text_splitter import (
|
|
|
|
@ -63,7 +65,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|
|
|
self._fixed_separator = fixed_separator
|
|
|
|
self._fixed_separator = fixed_separator
|
|
|
|
self._separators = separators or ["\n\n", "\n", " ", ""]
|
|
|
|
self._separators = separators or ["\n\n", "\n", " ", ""]
|
|
|
|
|
|
|
|
|
|
|
|
def split_text(self, text: str) -> list[str]:
|
|
|
|
def split_text(self, text: str, metadata:Optional[dict] = None) -> list[str]:
|
|
|
|
"""Split incoming text and return chunks."""
|
|
|
|
"""Split incoming text and return chunks."""
|
|
|
|
if self._fixed_separator:
|
|
|
|
if self._fixed_separator:
|
|
|
|
chunks = text.split(self._fixed_separator)
|
|
|
|
chunks = text.split(self._fixed_separator)
|
|
|
|
@ -75,7 +77,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|
|
|
for chunk, chunk_length in zip(chunks, chunks_lengths):
|
|
|
|
for chunk, chunk_length in zip(chunks, chunks_lengths):
|
|
|
|
if chunk_length > self._chunk_size:
|
|
|
|
if chunk_length > self._chunk_size:
|
|
|
|
if self._keep_separator :
|
|
|
|
if self._keep_separator :
|
|
|
|
final_chunks.extend(self.recursive_split_text_keep_separator_(chunk)) # 调用递归分割方法进一步拆分。
|
|
|
|
final_chunks.extend(self.recursive_split_text_keep_separator_(chunk,metadata)) # 调用递归分割方法进一步拆分。
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
final_chunks.extend(self.recursive_split_text(chunk))
|
|
|
|
final_chunks.extend(self.recursive_split_text(chunk))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
@ -159,20 +161,37 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def recursive_split_text_keep_separator_(self, text: str) -> list[str]: # 定义递归分割方法。
|
|
|
|
def recursive_split_text_keep_separator_(self, text: str,metadata:Optional[dict] = None) -> list[str]: # 定义递归分割方法。
|
|
|
|
"""Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是递归地分割文本并返回块。
|
|
|
|
"""Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是递归地分割文本并返回块。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
char_split = False
|
|
|
|
|
|
|
|
full_last_text = False
|
|
|
|
|
|
|
|
if metadata is not None:
|
|
|
|
|
|
|
|
if "char_split" in metadata:
|
|
|
|
|
|
|
|
# 分片未达阈值,是否按照char分片继续合并
|
|
|
|
|
|
|
|
char_split = metadata["char_split"]
|
|
|
|
|
|
|
|
if "char_split" in metadata:
|
|
|
|
|
|
|
|
# 是否补全最后一个未达到阈值的分片
|
|
|
|
|
|
|
|
full_last_text = metadata["full_last_text"]
|
|
|
|
|
|
|
|
|
|
|
|
final_chunks = [] # 初始化最终的块列表。
|
|
|
|
final_chunks = [] # 初始化最终的块列表。
|
|
|
|
current_part_list = []
|
|
|
|
current_part_list = []
|
|
|
|
self.append_next_split_text(current_part_list=current_part_list,
|
|
|
|
self.append_next_split_text(current_part_list=current_part_list,
|
|
|
|
current_length_list=[],
|
|
|
|
current_length_list=[],
|
|
|
|
text=text,
|
|
|
|
text=text,
|
|
|
|
final_chunks = final_chunks,
|
|
|
|
final_chunks = final_chunks,
|
|
|
|
separators = self._separators)
|
|
|
|
separators = self._separators,
|
|
|
|
|
|
|
|
char_split=char_split,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if len(current_part_list): # 如果还有剩余的当前块。
|
|
|
|
if len(current_part_list): # 如果还有剩余的当前块。
|
|
|
|
final_chunks.append("".join(current_part_list)) # 将其加入最终块列表。
|
|
|
|
final_chunks.append("".join(current_part_list)) # 将其加入最终块列表。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 是否补全最后一个未达到阈值的分片
|
|
|
|
|
|
|
|
if full_last_text:
|
|
|
|
|
|
|
|
# 补全
|
|
|
|
|
|
|
|
self.set_full_last_text_chunks(final_chunks=final_chunks)
|
|
|
|
|
|
|
|
|
|
|
|
return final_chunks # 返回最终的块列表。
|
|
|
|
return final_chunks # 返回最终的块列表。
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
@ -201,23 +220,34 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|
|
|
current_length_list:list[int],
|
|
|
|
current_length_list:list[int],
|
|
|
|
text: str,
|
|
|
|
text: str,
|
|
|
|
final_chunks: list[str],
|
|
|
|
final_chunks: list[str],
|
|
|
|
separators : list[str]): # 定义递归分割方法。
|
|
|
|
separators : list[str],
|
|
|
|
|
|
|
|
char_split : bool,
|
|
|
|
|
|
|
|
): # 定义递归分割方法。
|
|
|
|
if text:
|
|
|
|
if text:
|
|
|
|
# 需要判断是否可以再拼接
|
|
|
|
# 需要判断是否可以再拼接
|
|
|
|
splits, new_separators_ = self.get_splits_(text, separators)
|
|
|
|
splits, new_separators_ = self.get_splits_(text, separators)
|
|
|
|
s_lens = self._length_function(splits) # 计算每个分割部分的长度。
|
|
|
|
s_lens = self._length_function(splits) # 计算每个分割部分的长度。
|
|
|
|
for s, s_len in zip(splits, s_lens): # 遍历每个分割部分及其长度。
|
|
|
|
split_len = len(splits)
|
|
|
|
|
|
|
|
for idx,s in enumerate(splits): # 遍历每个分割部分及其长度。
|
|
|
|
|
|
|
|
s_len = s_lens[idx]
|
|
|
|
current_length = sum(current_length_list)
|
|
|
|
current_length = sum(current_length_list)
|
|
|
|
if "制定综合主进度" in s:
|
|
|
|
if "制定综合主进度" in s:
|
|
|
|
# import pdb; pdb.post_mortem()
|
|
|
|
# import pdb; pdb.post_mortem()
|
|
|
|
print(s)
|
|
|
|
print(s)
|
|
|
|
|
|
|
|
|
|
|
|
if current_length + s_len <= self._chunk_size: # 如果当前块可以容纳更多内容。
|
|
|
|
if current_length + s_len <= self._chunk_size: # 如果当前块可以容纳更多内容。
|
|
|
|
current_part_list.append(s) # 将当前部分加入当前块。
|
|
|
|
current_part_list.append(s) # 将当前部分加入当前块。
|
|
|
|
current_length_list.append(s_len)
|
|
|
|
current_length_list.append(s_len)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if len(new_separators_) == 0:
|
|
|
|
if len(new_separators_) == 0:
|
|
|
|
|
|
|
|
# 判断是否启用字符拆分
|
|
|
|
|
|
|
|
if char_split:
|
|
|
|
|
|
|
|
# 按照char拆分和拼接,直到长度达到阈值
|
|
|
|
|
|
|
|
s,s_len = self.char_splits(
|
|
|
|
|
|
|
|
current_part_list=current_part_list,
|
|
|
|
|
|
|
|
current_length_list=current_length_list,
|
|
|
|
|
|
|
|
text=s,
|
|
|
|
|
|
|
|
s_len=s_len
|
|
|
|
|
|
|
|
)
|
|
|
|
# 将片段加入到列表中
|
|
|
|
# 将片段加入到列表中
|
|
|
|
final_chunks.append("".join(current_part_list))
|
|
|
|
final_chunks.append("".join(current_part_list))
|
|
|
|
# 计算出重叠部分的内容
|
|
|
|
# 计算出重叠部分的内容
|
|
|
|
@ -237,7 +267,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|
|
|
current_length_list=current_length_list,
|
|
|
|
current_length_list=current_length_list,
|
|
|
|
text=s,
|
|
|
|
text=s,
|
|
|
|
final_chunks=final_chunks,
|
|
|
|
final_chunks=final_chunks,
|
|
|
|
separators=new_separators_)
|
|
|
|
separators=new_separators_,
|
|
|
|
|
|
|
|
char_split=char_split)
|
|
|
|
|
|
|
|
|
|
|
|
def get_overlap_part(self,current_part_list:list[str],
|
|
|
|
def get_overlap_part(self,current_part_list:list[str],
|
|
|
|
current_length_list:list[int]) -> (int,str): # 定义递归分割方法。
|
|
|
|
current_length_list:list[int]) -> (int,str): # 定义递归分割方法。
|
|
|
|
@ -265,3 +296,49 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|
|
|
overlap_part_list[0:0] = current_part_list_reversed[index]
|
|
|
|
overlap_part_list[0:0] = current_part_list_reversed[index]
|
|
|
|
# overlap_part_list.append(current_part_list_reversed[index])
|
|
|
|
# overlap_part_list.append(current_part_list_reversed[index])
|
|
|
|
return overlap_part_length_, "".join(overlap_part_list)
|
|
|
|
return overlap_part_length_, "".join(overlap_part_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 按照char 继续拼接,直到长度达到阈值
|
|
|
|
|
|
|
|
def char_splits(self,
|
|
|
|
|
|
|
|
current_part_list:list[str],
|
|
|
|
|
|
|
|
current_length_list:list[int],
|
|
|
|
|
|
|
|
text: str,
|
|
|
|
|
|
|
|
s_len: int) -> (str,int): # 定义递归分割方法。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
char_splits = list(text)
|
|
|
|
|
|
|
|
char_s_lens = self._length_function(char_splits) # 计算每个分割部分的长度。
|
|
|
|
|
|
|
|
for char_idx, char_s in enumerate(char_splits): # 遍历每个分割部分及其长度。
|
|
|
|
|
|
|
|
char_s_len = char_s_lens[char_idx]
|
|
|
|
|
|
|
|
char_current_length = sum(current_length_list)
|
|
|
|
|
|
|
|
if char_current_length + char_s_len <= self._chunk_size: # 如果当前块可以容纳更多内容。
|
|
|
|
|
|
|
|
current_part_list.append(char_s) # 将当前部分加入当前块。
|
|
|
|
|
|
|
|
current_length_list.append(char_s_len)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
last_s = char_splits[char_idx:]
|
|
|
|
|
|
|
|
text = "".join(last_s)
|
|
|
|
|
|
|
|
last_s_lens = self._length_function([text])
|
|
|
|
|
|
|
|
s_len = last_s_lens[0]
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
return text,s_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 按照char 继续拼接,直到长度达到阈值
|
|
|
|
|
|
|
|
def set_full_last_text_chunks(self,
|
|
|
|
|
|
|
|
final_chunks: list[str]): # 定义递归分割方法。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if final_chunks:
|
|
|
|
|
|
|
|
# 取最后一个片段
|
|
|
|
|
|
|
|
final_chunk = final_chunks[-1]
|
|
|
|
|
|
|
|
# 计算最后一个分片的长度
|
|
|
|
|
|
|
|
final_chunk_lens = self._length_function([final_chunk])
|
|
|
|
|
|
|
|
# 是否达到阈值,如果未达到,计算空格的长度,使用空格补全
|
|
|
|
|
|
|
|
if final_chunk_lens[0] < self._chunk_size:
|
|
|
|
|
|
|
|
# 计算空格的长度
|
|
|
|
|
|
|
|
space_len = self._length_function(["-"])[0]
|
|
|
|
|
|
|
|
# 未达阈值,补充空格
|
|
|
|
|
|
|
|
sum_len = self._chunk_size - final_chunk_lens[0]
|
|
|
|
|
|
|
|
# 整除
|
|
|
|
|
|
|
|
num = sum_len // space_len
|
|
|
|
|
|
|
|
# 重新合并空格
|
|
|
|
|
|
|
|
space_s = [final_chunk]
|
|
|
|
|
|
|
|
for i in range(num):
|
|
|
|
|
|
|
|
space_s.append("-")
|
|
|
|
|
|
|
|
final_chunks[-1] = "".join(space_s)
|
|
|
|
|