diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 0db77309de..310e3c048f 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -4,6 +4,8 @@ from __future__ import annotations from typing import Any, Optional +from sqlalchemy.dialects.postgresql import JSONB + from core.model_manager import ModelInstance from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from core.rag.splitter.text_splitter import ( @@ -63,7 +65,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) self._fixed_separator = fixed_separator self._separators = separators or ["\n\n", "\n", " ", ""] - def split_text(self, text: str) -> list[str]: + def split_text(self, text: str, metadata:Optional[dict] = None) -> list[str]: """Split incoming text and return chunks.""" if self._fixed_separator: chunks = text.split(self._fixed_separator) @@ -75,7 +77,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) for chunk, chunk_length in zip(chunks, chunks_lengths): if chunk_length > self._chunk_size: if self._keep_separator : - final_chunks.extend(self.recursive_split_text_keep_separator_(chunk)) # 调用递归分割方法进一步拆分。 + final_chunks.extend(self.recursive_split_text_keep_separator_(chunk,metadata)) # 调用递归分割方法进一步拆分。 continue final_chunks.extend(self.recursive_split_text(chunk)) else: @@ -159,20 +161,37 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) - def recursive_split_text_keep_separator_(self, text: str) -> list[str]: # 定义递归分割方法。 + def recursive_split_text_keep_separator_(self, text: str,metadata:Optional[dict] = None) -> list[str]: # 定义递归分割方法。 """Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是递归地分割文本并返回块。 + char_split = False + full_last_text = False + if metadata is not None: + if "char_split" in metadata: + # 分片未达阈值,是否按照char分片继续合并 + char_split = metadata["char_split"] + if "char_split" in metadata: + # 是否补全最后一个未达到阈值的分片 + full_last_text = metadata["full_last_text"] + final_chunks = [] # 初始化最终的块列表。 current_part_list = [] self.append_next_split_text(current_part_list=current_part_list, current_length_list=[], text=text, final_chunks = final_chunks, - separators = self._separators) + separators = self._separators, + char_split=char_split, + ) if len(current_part_list): # 如果还有剩余的当前块。 final_chunks.append("".join(current_part_list)) # 将其加入最终块列表。 + # 是否补全最后一个未达到阈值的分片 + if full_last_text: + # 补全 + self.set_full_last_text_chunks(final_chunks=final_chunks) + return final_chunks # 返回最终的块列表。 @classmethod @@ -201,23 +220,34 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) current_length_list:list[int], text: str, final_chunks: list[str], - separators : list[str]): # 定义递归分割方法。 + separators : list[str], + char_split : bool, + ): # 定义递归分割方法。 if text: # 需要判断是否可以再拼接 splits, new_separators_ = self.get_splits_(text, separators) s_lens = self._length_function(splits) # 计算每个分割部分的长度。 - for s, s_len in zip(splits, s_lens): # 遍历每个分割部分及其长度。 - + split_len = len(splits) + for idx,s in enumerate(splits): # 遍历每个分割部分及其长度。 + s_len = s_lens[idx] current_length = sum(current_length_list) if "制定综合主进度" in s: # import pdb; pdb.post_mortem() print(s) - if current_length + s_len <= self._chunk_size: # 如果当前块可以容纳更多内容。 current_part_list.append(s) # 将当前部分加入当前块。 current_length_list.append(s_len) else: if len(new_separators_) == 0: + # 判断是否启用字符拆分 + if char_split: + # 按照char拆分和拼接,直到长度达到阈值 + s,s_len = self.char_splits( + current_part_list=current_part_list, + current_length_list=current_length_list, + text=s, + s_len=s_len + ) # 将片段加入到列表中 final_chunks.append("".join(current_part_list)) # 计算出重叠部分的内容 @@ -237,7 +267,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) current_length_list=current_length_list, text=s, final_chunks=final_chunks, - separators=new_separators_) + separators=new_separators_, + char_split=char_split) def get_overlap_part(self,current_part_list:list[str], current_length_list:list[int]) -> (int,str): # 定义递归分割方法。 @@ -265,3 +296,49 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) overlap_part_list[0:0] = current_part_list_reversed[index] # overlap_part_list.append(current_part_list_reversed[index]) return overlap_part_length_, "".join(overlap_part_list) + + # 按照char 继续拼接,直到长度达到阈值 + def char_splits(self, + current_part_list:list[str], + current_length_list:list[int], + text: str, + s_len: int) -> (str,int): # 定义递归分割方法。 + + char_splits = list(text) + char_s_lens = self._length_function(char_splits) # 计算每个分割部分的长度。 + for char_idx, char_s in enumerate(char_splits): # 遍历每个分割部分及其长度。 + char_s_len = char_s_lens[char_idx] + char_current_length = sum(current_length_list) + if char_current_length + char_s_len <= self._chunk_size: # 如果当前块可以容纳更多内容。 + current_part_list.append(char_s) # 将当前部分加入当前块。 + current_length_list.append(char_s_len) + else: + last_s = char_splits[char_idx:] + text = "".join(last_s) + last_s_lens = self._length_function([text]) + s_len = last_s_lens[0] + break + return text,s_len + + # 按照char 继续拼接,直到长度达到阈值 + def set_full_last_text_chunks(self, + final_chunks: list[str]): # 定义递归分割方法。 + + if final_chunks: + # 取最后一个片段 + final_chunk = final_chunks[-1] + # 计算最后一个分片的长度 + final_chunk_lens = self._length_function([final_chunk]) + # 是否达到阈值,如果未达到,计算空格的长度,使用空格补全 + if final_chunk_lens[0] < self._chunk_size: + # 计算空格的长度 + space_len = self._length_function(["-"])[0] + # 未达阈值,补充空格 + sum_len = self._chunk_size - final_chunk_lens[0] + # 整除 + num = sum_len // space_len + # 重新合并空格 + space_s = [final_chunk] + for i in range(num): + space_s.append("-") + final_chunks[-1] = "".join(space_s) diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 34b4056cf5..b45518e667 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -69,7 +69,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): self._add_start_index = add_start_index @abstractmethod - def split_text(self, text: str) -> list[str]: + def split_text(self, text: str, metadata:Optional[dict] = None) -> list[str]: """Split text into multiple components.""" def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]: @@ -78,7 +78,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): documents = [] for i, text in enumerate(texts): index = -1 - for chunk in self.split_text(text): + for chunk in self.split_text(text,_metadatas[i]): metadata = copy.deepcopy(_metadatas[i]) if self._add_start_index: index = text.find(chunk, index + 1)