|
|
|
|
@ -39,6 +39,12 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
|
|
|
|
else:
|
|
|
|
|
return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
|
|
|
|
|
|
|
|
|
|
def _character_encoder(texts: list[str]) -> list[int]:
|
|
|
|
|
if not texts:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
return [len(text) for text in texts]
|
|
|
|
|
|
|
|
|
|
if issubclass(cls, TokenTextSplitter):
|
|
|
|
|
extra_kwargs = {
|
|
|
|
|
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
|
|
|
|
|
@ -47,7 +53,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
|
|
|
|
}
|
|
|
|
|
kwargs = {**kwargs, **extra_kwargs}
|
|
|
|
|
|
|
|
|
|
return cls(length_function=_token_encoder, **kwargs)
|
|
|
|
|
return cls(length_function=_character_encoder, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
|
|
|
|
|
@ -103,7 +109,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
|
|
|
|
_good_splits_lengths = [] # cache the lengths of the splits
|
|
|
|
|
_separator = "" if self._keep_separator else separator
|
|
|
|
|
s_lens = self._length_function(splits)
|
|
|
|
|
if _separator != "":
|
|
|
|
|
if separator != "":
|
|
|
|
|
for s, s_len in zip(splits, s_lens):
|
|
|
|
|
if s_len < self._chunk_size:
|
|
|
|
|
_good_splits.append(s)
|
|
|
|
|
|