Add: Implement semantic chunking functionality with tests

- Introduced `SemanticTextSplitter` to enhance text chunking by identifying semantic boundaries using an LLM.
- Added tests for boundary analysis and semantic chunking in `boundary_analysis_test.py` and `semantic_chunking_test.py`.
- Updated `IndexingRunner` and related classes to support semantic chunking strategy.
- Enhanced UI components to allow selection between fixed and semantic chunking strategies.
- Updated translations for new chunking options in the dataset creation process.
pull/19779/head
Aryan Raj 1 year ago
parent aae80681f2
commit 0dcb67ab20

@ -707,12 +707,25 @@ class IndexingRunner:
model_type=ModelType.TEXT_EMBEDDING,
)
# Get LLM model instance if semantic chunking is enabled
llm_model_instance = None
if process_rule.get("rules", {}).get("chunking_strategy") == "semantic":
try:
llm_model_instance = self.model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.LLM,
)
except Exception as e:
# Fall back to fixed chunking if we can't get an LLM model
process_rule["rules"]["chunking_strategy"] = "fixed"
documents = index_processor.transform(
text_docs,
embedding_model_instance=embedding_model_instance,
process_rule=process_rule,
tenant_id=dataset.tenant_id,
doc_language=doc_language,
llm_model_instance=llm_model_instance
)
return documents
@ -747,6 +760,160 @@ class IndexingRunner:
)
pass
def _process_document(self, flask_app, process_rule, document, dataset, tenant_id, user, document_plan):
if not self._start_document(document.id, tenant_id):
return
try:
is_automatic = process_rule.get("mode") == "automatic" or process_rule.get("mode") == "hierarchical"
with flask_app.app_context():
# check document is paused
self._check_document_paused_status(document.id)
# get embedding model instance
embedding_model_name = dataset.embedding_model
embedding_model_provider = dataset.embedding_model_provider
model_manager = ModelManager()
embedding_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=embedding_model_name,
)
# Get LLM model instance if semantic chunking is used
llm_model_instance = None
if process_rule.get("rules", {}).get("chunking_strategy") == "semantic":
try:
llm_model_instance = model_manager.get_default_model_instance(ModelType.LLM)
except Exception as e:
# Fall back to fixed chunking if we can't get an LLM model
process_rule["rules"]["chunking_strategy"] = "fixed"
# extract
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
text_docs = []
completed_segments = 0
total_segments = 0
# resume document indexing
stopped_segment_ids = []
stopped_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
)
.all()
)
stopped_segment_ids = [segment.id for segment in stopped_segments]
original_extract_setting = parse_extract_setting(document, dataset)
original_extract_setting.is_automatic = is_automatic
if stopped_segment_ids:
# remove original content if document is stopped
original_extract_setting.content = None
text_docs = IndexingRunner.load_document(
document_id=document.id,
dataset_id=dataset.id,
tenant_id=tenant_id,
extract_setting=original_extract_setting,
process_rule_mode=process_rule.get("mode"),
)
# index
if text_docs:
# transform
chunk_docs = index_processor.transform(
text_docs,
process_rule=process_rule,
dataset_id=dataset.id,
tenant_id=tenant_id,
doc_language=document.doc_language,
embedding_model_instance=embedding_model_instance,
llm_model_instance=llm_model_instance
)
# save segment
self._load_segments(dataset, document, chunk_docs)
# load
self._load(
index_processor=index_processor,
dataset=dataset,
dataset_document=document,
documents=chunk_docs,
)
# update document status to completed
self._update_document_index_status(
document_id=document.id,
after_indexing_status="completed",
extra_update_params={
DatasetDocument.tokens: sum(len(text_doc.page_content) for text_doc in text_docs),
DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
DatasetDocument.indexing_latency: time.perf_counter() - indexing_start_at,
DatasetDocument.error: None,
},
)
completed_segments = len(text_docs)
total_segments = len(text_docs)
# update segment status to completed
self._update_segments_by_document(
dataset_document_id=document.id,
update_params={
DocumentSegment.status: "completed",
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
},
)
else:
# update document status to completed
self._update_document_index_status(
document_id=document.id,
after_indexing_status="completed",
extra_update_params={
DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
DatasetDocument.error: None,
},
)
# update segment status to completed
self._update_segments_by_document(
dataset_document_id=document.id,
update_params={
DocumentSegment.status: "completed",
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
},
)
completed_segments = 0
total_segments = 0
# update document plan
self._update_document_plan(document_plan, completed_segments, total_segments)
except DocumentIsPausedError:
raise DocumentIsPausedError("Document paused, document id: {}".format(document.id))
except ProviderTokenNotInitError as e:
document.indexing_status = "error"
document.error = str(e.description)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
class DocumentIsPausedError(Exception):
pass

@ -1,7 +1,7 @@
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Literal
from configs import dify_config
from core.model_manager import ModelInstance
@ -11,6 +11,7 @@ from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.semantic_text_splitter import SemanticTextSplitter
from core.rag.splitter.text_splitter import TextSplitter
from models.dataset import Dataset, DatasetProcessRule
@ -52,10 +53,60 @@ class BaseIndexProcessor(ABC):
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional[ModelInstance],
chunking_strategy: Optional[Literal["fixed", "semantic"]] = None,
llm_model_instance: Optional[ModelInstance] = None,
) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
Get the splitter object according to the processing rule and chunking strategy.
Args:
processing_rule_mode: The processing rule mode (custom, hierarchical, etc.)
max_tokens: Maximum tokens per chunk
chunk_overlap: Overlap between chunks
separator: Separator string for chunking
embedding_model_instance: Model instance for embeddings
chunking_strategy: Chunking strategy to use (fixed or semantic)
llm_model_instance: LLM model instance for semantic chunking
Returns:
A TextSplitter instance
"""
# Use semantic chunking if explicitly requested and LLM model is available
if chunking_strategy == "semantic" and llm_model_instance:
# Create a fallback splitter first based on the processing rule mode
if processing_rule_mode in ["custom", "hierarchical"]:
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
if separator:
separator = separator.replace("\\n", "\n")
fallback_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=max_tokens,
chunk_overlap=chunk_overlap,
fixed_separator=separator,
separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance,
)
else:
# Automatic segmentation
fallback_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"],
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"],
separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance,
)
# Return a semantic text splitter with the appropriate fallback
return SemanticTextSplitter(
llm_model_instance=llm_model_instance,
fallback_splitter=fallback_splitter,
chunk_size=max_tokens,
chunk_overlap=chunk_overlap
)
# Default to traditional chunking methods
if processing_rule_mode in ["custom", "hierarchical"]:
# The user-defined segmentation rule
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH

@ -4,7 +4,7 @@ import uuid
from typing import Optional
from configs import dify_config
from core.model_manager import ModelInstance
from core.model_manager import ModelManager, ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
@ -12,6 +12,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from libs import helper
from models.dataset import ChildChunk, Dataset, DocumentSegment
@ -37,6 +38,20 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
all_documents = [] # type: ignore
# Get the chunking strategy from process_rule or default to "fixed"
chunking_strategy = process_rule.get("chunking_strategy", "fixed")
# Get LLM model instance if semantic chunking is enabled
llm_model_instance = None
if chunking_strategy == "semantic":
try:
model_manager = ModelManager()
llm_model_instance = kwargs.get("llm_model_instance") or model_manager.get_default_model_instance(ModelType.LLM)
except Exception as e:
# Fall back to fixed chunking if we can't get a LLM model
chunking_strategy = "fixed"
if rules.parent_mode == ParentMode.PARAGRAPH:
# Split the text documents into nodes.
if not rules.segmentation:
@ -47,6 +62,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance"),
chunking_strategy=chunking_strategy,
llm_model_instance=llm_model_instance
)
for document in documents:
if kwargs.get("preview") and len(all_documents) >= 10:
@ -73,7 +90,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
document_node.page_content = page_content
# parse document to child nodes
child_nodes = self._split_child_nodes(
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
document_node,
rules,
process_rule.get("mode"),
kwargs.get("embedding_model_instance"),
chunking_strategy=chunking_strategy,
llm_model_instance=llm_model_instance
)
document_node.children = child_nodes
split_documents.append(document_node)
@ -83,7 +105,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
document = Document(page_content=page_content, metadata=documents[0].metadata)
# parse document to child nodes
child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
document,
rules,
process_rule.get("mode"),
kwargs.get("embedding_model_instance"),
chunking_strategy=chunking_strategy,
llm_model_instance=llm_model_instance
)
if kwargs.get("preview"):
if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER:
@ -173,6 +200,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
rules: Rule,
process_rule_mode: str,
embedding_model_instance: Optional[ModelInstance],
chunking_strategy: str = "fixed",
llm_model_instance: Optional[ModelInstance] = None,
) -> list[ChildDocument]:
if not rules.subchunk_segmentation:
raise ValueError("No subchunk segmentation found in rules.")
@ -182,6 +211,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
chunk_overlap=rules.subchunk_segmentation.chunk_overlap,
separator=rules.subchunk_segmentation.separator,
embedding_model_instance=embedding_model_instance,
chunking_strategy=chunking_strategy,
llm_model_instance=llm_model_instance
)
# parse document to child nodes
child_nodes = []

@ -0,0 +1,241 @@
from __future__ import annotations
import logging
from typing import Any, Optional, List
from core.model_manager import ModelManager, ModelInstance
from core.rag.models.document import Document
from core.rag.splitter.text_splitter import TextSplitter
from core.model_runtime.entities.model_entities import ModelType
from core.rag.splitter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
logger = logging.getLogger(__name__)
class SemanticTextSplitter(TextSplitter):
"""
A text splitter that uses an LLM to identify semantic boundaries in text.
This splitter first uses a traditional method as a fallback, then refines using LLM-based
semantic analysis to ensure chunk boundaries align with natural semantic breaks.
"""
def __init__(
self,
llm_model_instance: ModelInstance,
fallback_splitter: Optional[TextSplitter] = None,
chunk_size: int = 4000,
chunk_overlap: int = 200,
**kwargs: Any,
) -> None:
"""
Initialize the semantic text splitter.
Args:
llm_model_instance: The LLM model instance to use for semantic chunking
fallback_splitter: A fallback splitter to use if the LLM chunking fails
chunk_size: Maximum size of chunks to return
chunk_overlap: Overlap in characters between chunks
**kwargs: Additional arguments to pass to the TextSplitter base class
"""
super().__init__(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
**kwargs
)
self._llm_model_instance = llm_model_instance
self._fallback_splitter = fallback_splitter or FixedRecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
def split_text(self, text: str) -> list[str]:
"""
Split text using LLM-based semantic chunking.
This method first applies a basic chunking method, then refines the chunk
boundaries using the LLM to identify more natural semantic breaks.
Args:
text: The text to split
Returns:
A list of text chunks with semantic coherence
"""
try:
# First apply initial chunking with our fallback splitter
initial_chunks = self._fallback_splitter.split_text(text)
# If the text is short enough, return as a single chunk
if len(initial_chunks) <= 1:
return initial_chunks
# Use LLM to improve chunk boundaries for semantic coherence
return self._refine_chunks_with_llm(initial_chunks)
except Exception as e:
logger.warning(f"LLM-based semantic chunking failed: {str(e)}. Falling back to traditional chunking.")
# Fallback to standard chunking if LLM-based approach fails
return self._fallback_splitter.split_text(text)
def _refine_chunks_with_llm(self, chunks: List[str]) -> List[str]:
"""
Use the LLM to refine chunk boundaries based on semantic meaning.
Args:
chunks: Initial text chunks to refine
Returns:
Refined chunks with improved semantic coherence
"""
refined_chunks = []
# Process chunks in pairs to determine better break points
for i in range(len(chunks) - 1):
current_chunk = chunks[i]
next_chunk = chunks[i+1]
# Skip very small chunks by merging them with the next one
if len(current_chunk) < self._chunk_size / 10:
if i + 1 < len(chunks):
chunks[i+1] = current_chunk + " " + next_chunk
continue
# If the current chunk is already optimal, keep it as is
if len(current_chunk) < self._chunk_size * 0.9:
refined_chunks.append(current_chunk)
continue
# Create a boundary analysis task for the LLM to find a natural break point
# between the current chunk and the next
boundary_text = current_chunk[-self._chunk_overlap:] + next_chunk[:self._chunk_overlap]
prompt = self._create_boundary_analysis_prompt(boundary_text)
# Use the LLM to find a natural break point
try:
response = self._llm_model_instance.invoke_llm(
prompt=prompt,
stream=False,
temperature=0,
max_tokens=50
)
break_indicator = self._parse_llm_response(response.content)
# Apply the break point to create a more natural chunk division
if break_indicator and 0 < break_indicator < len(boundary_text):
# Calculate the actual split position in the current chunk
split_position = len(current_chunk) - self._chunk_overlap + break_indicator
# Only apply the split if it makes sense
if 0 < split_position < len(current_chunk):
refined_chunk = current_chunk[:split_position]
remainder = current_chunk[split_position:]
# Add the refined chunk and combine the remainder with the next chunk
refined_chunks.append(refined_chunk)
chunks[i+1] = remainder + next_chunk
continue
# If no good break point was found, use the original chunk
refined_chunks.append(current_chunk)
except Exception as e:
logger.warning(f"Error during LLM boundary analysis: {str(e)}")
refined_chunks.append(current_chunk)
# Add the last chunk
if chunks and len(chunks) > 0:
refined_chunks.append(chunks[-1])
return refined_chunks
def _create_boundary_analysis_prompt(self, text: str) -> str:
"""
Create a prompt for the LLM to analyze the text and find natural break points.
Args:
text: The text around the chunk boundary to analyze
Returns:
A prompt for the LLM
"""
return (
"Analyze the following text and identify the most natural point to split it into two separate chunks. "
"The split should occur at a meaningful boundary such as the end of a paragraph, sentence, or idea. "
"Return only the character index (a number) where the split should occur.\n\n"
f"Text to analyze: {text}\n\n"
"Split index:"
)
def _parse_llm_response(self, response: str) -> Optional[int]:
"""
Parse the LLM's response to extract the split index.
Args:
response: The LLM's response containing the split index
Returns:
The index at which to split the text, or None if no valid index was found
"""
try:
# Extract the first number from the response
import re
numbers = re.findall(r'\d+', response)
if numbers:
return int(numbers[0])
return None
except Exception:
return None
def split_documents(self, documents: list[Document]) -> list[Document]:
"""Split documents using semantic chunking."""
return super().split_documents(documents)
@classmethod
def from_llm(
cls,
llm_model_instance: Optional[ModelInstance] = None,
chunk_size: int = 4000,
chunk_overlap: int = 200,
**kwargs: Any
) -> SemanticTextSplitter:
"""
Create a SemanticTextSplitter using the specified LLM.
Args:
llm_model_instance: The LLM model instance to use
chunk_size: Maximum size of chunks to return
chunk_overlap: Overlap in characters between chunks
**kwargs: Additional arguments to pass to the TextSplitter
Returns:
A configured SemanticTextSplitter
"""
# If no model instance is provided, try to get a default one
if not llm_model_instance:
try:
model_manager = ModelManager()
llm_model_instance = model_manager.get_default_model_instance(ModelType.LLM)
except Exception as e:
logger.warning(f"Failed to get default LLM model: {str(e)}. Using fallback splitter.")
return FixedRecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
**kwargs
)
# Create fallback splitter
fallback_splitter = FixedRecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
**kwargs
)
return cls(
llm_model_instance=llm_model_instance,
fallback_splitter=fallback_splitter,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
**kwargs
)

@ -72,6 +72,7 @@ class Rule(BaseModel):
segmentation: Optional[Segmentation] = None
parent_mode: Optional[Literal["full-doc", "paragraph"]] = None
subchunk_segmentation: Optional[Segmentation] = None
chunking_strategy: Optional[Literal["fixed", "semantic"]] = "fixed"
class ProcessRule(BaseModel):

@ -1,7 +1,10 @@
import logging
from typing import Optional
import uuid
from datetime import datetime
from core.model_manager import ModelInstance, ModelManager
from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager, ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
@ -9,8 +12,9 @@ from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.dataset import DatasetProcessRule
from services.entities.knowledge_entities.knowledge_entities import ParentMode
_logger = logging.getLogger(__name__)
@ -135,12 +139,28 @@ class VectorService:
# use full doc mode to generate segment's child chunk
processing_rule_dict = processing_rule.to_dict()
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value
# Get chunking strategy
chunking_strategy = processing_rule_dict.get("rules", {}).get("chunking_strategy", "fixed")
# If semantic chunking is used, get LLM model instance
llm_model_instance = None
if chunking_strategy == "semantic":
try:
model_manager = ModelManager()
llm_model_instance = model_manager.get_default_model_instance(ModelType.LLM)
except Exception as e:
# Fall back to fixed chunking if we can't get a LLM model
chunking_strategy = "fixed"
processing_rule_dict["rules"]["chunking_strategy"] = "fixed"
documents = index_processor.transform(
[document],
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule_dict,
tenant_id=dataset.tenant_id,
doc_language=dataset_document.doc_language,
llm_model_instance=llm_model_instance
)
# save child chunks
if documents and documents[0].children:

@ -0,0 +1,115 @@
"""
Simple test for semantic chunking boundary analysis.
"""
class MockLLMModelInstance:
"""Mock LLM for testing purposes."""
def __init__(self, model="mock-llm", provider="mock-provider"):
self.model = model
self.provider = provider
def invoke_llm(self, prompt: str, stream=False, temperature=0, max_tokens=50):
"""Mock LLM response that returns a reasonable split index."""
print(f"LLM received prompt: {prompt[:50]}...")
# Extract the text to analyze
text_to_analyze = prompt.split("Text to analyze: ")[1].split("\n\n")[0]
# Find the exact position of a period that marks the end of a sentence
period_positions = [i for i, char in enumerate(text_to_analyze) if char == '.']
if period_positions:
# Find a period that's roughly in the middle of the text
middle_index = len(text_to_analyze) // 2
closest_period = min(period_positions, key=lambda x: abs(x - middle_index))
return MockResponse(str(closest_period + 1)) # +1 to include the period
# If no period found, return the middle
return MockResponse(str(len(text_to_analyze) // 2))
class MockResponse:
"""Mock response from LLM."""
def __init__(self, content: str):
self.content = content
def create_boundary_analysis_prompt(text: str) -> str:
"""Create a prompt for boundary analysis."""
return (
"Analyze the following text and identify the most natural point to split it into two separate chunks. "
"The split should occur at a meaningful boundary such as the end of a paragraph, sentence, or idea. "
"Return only the character index (a number) where the split should occur.\n\n"
f"Text to analyze: {text}\n\n"
"Split index:"
)
def parse_llm_response(response: str):
"""Parse the LLM response to get the split index."""
try:
# Extract the first number from the response
import re
numbers = re.findall(r'\d+', response)
if numbers:
return int(numbers[0])
return None
except Exception as e:
print(f"Error parsing response: {e}")
return None
def test_boundary_analysis():
"""Test the boundary analysis logic."""
# Create test text with a clear semantic boundary
test_text = "This is the first sentence that should stay together. This is the start of a new thought that should also stay together."
# Create mock LLM
mock_llm = MockLLMModelInstance()
# Create boundary analysis prompt
prompt = create_boundary_analysis_prompt(test_text)
# Get mock LLM response
response = mock_llm.invoke_llm(prompt)
# Parse the response
split_index = parse_llm_response(response.content)
print(f"Recommended split index: {split_index}")
print(f"Text before split: {test_text[:split_index]}")
print(f"Text after split: {test_text[split_index:]}")
# Check if split is at sentence boundary
if split_index > 0:
assert test_text[split_index-1] == ".", "Split should be at the end of a sentence"
print("Test passed! Split occurred at a sentence boundary.")
else:
print("Test failed: Split index is not valid.")
# Test with paragraph boundaries
print("\nTesting with paragraph boundaries:")
para_text = """This is paragraph one.
It has multiple sentences.
This is paragraph two.
It should be kept separate."""
# Create boundary analysis prompt
para_prompt = create_boundary_analysis_prompt(para_text)
# Get mock LLM response
para_response = mock_llm.invoke_llm(para_prompt)
# Parse the response
para_split_index = parse_llm_response(para_response.content)
print(f"Recommended paragraph split index: {para_split_index}")
print(f"Text before split: {para_text[:para_split_index]}")
print(f"Text after split: {para_text[para_split_index:]}")
if __name__ == "__main__":
test_boundary_analysis()

@ -0,0 +1,112 @@
import os
import sys
# Add the current directory to the Python path
sys.path.insert(0, os.path.abspath('.'))
from flask import Flask
from api.core.model_manager import ModelManager
from api.core.model_runtime.entities.model_entities import ModelType
from api.core.rag.splitter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from api.core.rag.splitter.semantic_text_splitter import SemanticTextSplitter
from api.core.rag.models.document import Document
# Sample text with distinct semantic sections
sample_text = """
# Introduction to Machine Learning
Machine learning is a field of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed. It focuses on developing algorithms that can access data and use it to learn for themselves.
## Supervised Learning
Supervised learning is a type of machine learning where the algorithm learns from labeled training data. The model makes predictions based on evidence in the presence of uncertainty.
Common supervised learning tasks include:
- Classification: categorizing data into predefined classes
- Regression: predicting continuous values
- Forecasting: predicting future values based on historical data
## Unsupervised Learning
Unsupervised learning algorithms find patterns in unlabeled data. These models discover hidden structures in data without the need for human intervention.
Typical unsupervised learning techniques include:
- Clustering: grouping similar data points together
- Dimensionality reduction: reducing the number of variables in data
- Association: identifying rules that describe portions of your data
"""
def test_semantic_chunking():
print("Testing Semantic Chunking Implementation")
print("-" * 50)
app = Flask(__name__)
with app.app_context():
# Get model instances
model_manager = ModelManager()
try:
llm_model_instance = model_manager.get_default_model_instance(ModelType.LLM)
print(f"Using LLM: {llm_model_instance.model} from {llm_model_instance.provider}")
except Exception as e:
print(f"Could not get LLM model instance: {e}")
print("Continuing with fixed chunking only for comparison")
llm_model_instance = None
try:
embedding_model_instance = model_manager.get_default_model_instance(ModelType.TEXT_EMBEDDING)
print(f"Using embedding model: {embedding_model_instance.model} from {embedding_model_instance.provider}")
except Exception as e:
print(f"Could not get embedding model instance: {e}")
embedding_model_instance = None
# Create document from sample text
doc = Document(page_content=sample_text)
# Fixed chunking
fixed_splitter = FixedRecursiveCharacterTextSplitter(
chunk_size=200,
chunk_overlap=20,
)
fixed_chunks = fixed_splitter.split_text(sample_text)
print("\nFixed Chunking Results:")
print(f"Number of chunks: {len(fixed_chunks)}")
for i, chunk in enumerate(fixed_chunks):
print(f"\nChunk {i+1} ({len(chunk)} chars):")
print(chunk[:100] + "..." if len(chunk) > 100 else chunk)
# Semantic chunking (if LLM is available)
if llm_model_instance:
semantic_splitter = SemanticTextSplitter(
llm_model_instance=llm_model_instance,
chunk_size=200,
chunk_overlap=20,
)
semantic_chunks = semantic_splitter.split_text(sample_text)
print("\nSemantic Chunking Results:")
print(f"Number of chunks: {len(semantic_chunks)}")
for i, chunk in enumerate(semantic_chunks):
print(f"\nChunk {i+1} ({len(chunk)} chars):")
print(chunk[:100] + "..." if len(chunk) > 100 else chunk)
# Analyze the quality of chunk boundaries
print("\nAnalyzing chunk boundaries:")
natural_boundaries = [".", "!", "?"]
semantic_complete_sentences = sum(1 for chunk in semantic_chunks if chunk.strip()[-1] in natural_boundaries)
fixed_complete_sentences = sum(1 for chunk in fixed_chunks if chunk.strip()[-1] in natural_boundaries)
print(f"Fixed chunking: {fixed_complete_sentences}/{len(fixed_chunks)} chunks end with complete sentences")
print(f"Semantic chunking: {semantic_complete_sentences}/{len(semantic_chunks)} chunks end with complete sentences")
# Check if there are section headings split across chunks
fixed_headings_split = sum(1 for chunk in fixed_chunks if chunk.strip().startswith('#') and not chunk.strip().split('\n')[0].endswith('\n'))
semantic_headings_split = sum(1 for chunk in semantic_chunks if chunk.strip().startswith('#') and not chunk.strip().split('\n')[0].endswith('\n'))
print(f"Fixed chunking: {fixed_headings_split} chunks have headings split from their content")
print(f"Semantic chunking: {semantic_headings_split} chunks have headings split from their content")
else:
print("\nSkipping semantic chunking test as no LLM is available")
if __name__ == "__main__":
test_semantic_chunking()

@ -0,0 +1,385 @@
"""
Demonstration of semantic chunking benefits compared to fixed chunking
"""
import re
from typing import List, Optional
class MockLLMModelInstance:
"""Mock LLM for testing purposes."""
def __init__(self, model="mock-llm", provider="mock-provider"):
self.model = model
self.provider = provider
def invoke_llm(self, prompt: str, stream=False, temperature=0, max_tokens=50):
"""Mock LLM that finds natural break points."""
text_to_analyze = prompt.split("Text to analyze: ")[1].split("\n\n")[0]
# Look for natural break points in preferred order
break_points = []
# 1. Look for paragraph breaks (\n\n)
newline_indices = [m.start() for m in re.finditer(r'\n\s*\n', text_to_analyze)]
if newline_indices:
closest_to_middle = min(newline_indices, key=lambda x: abs(x - len(text_to_analyze) // 2))
break_points.append(closest_to_middle + 1) # +1 to include the newline
# 2. Look for list item boundaries
list_item_breaks = [m.start() for m in re.finditer(r'[.!?]\s*\n[\s]*[-\d]', text_to_analyze)]
if list_item_breaks:
closest_to_middle = min(list_item_breaks, key=lambda x: abs(x - len(text_to_analyze) // 2))
break_points.append(closest_to_middle + 1) # +1 to include the period
# 3. Look for sentence endings
sentence_ends = [m.start() for m in re.finditer(r'[.!?]\s', text_to_analyze)]
if sentence_ends:
closest_to_middle = min(sentence_ends, key=lambda x: abs(x - len(text_to_analyze) // 2))
break_points.append(closest_to_middle + 1) # +1 to include the punctuation
# Choose the best break point
if break_points:
return MockResponse(str(min(break_points)))
# Fallback to middle of text
return MockResponse(str(len(text_to_analyze) // 2))
class MockResponse:
"""Mock response from LLM."""
def __init__(self, content: str):
self.content = content
class SimpleSplitter:
"""A simple text splitter that mimics fixed chunking behavior."""
def __init__(self, chunk_size: int, chunk_overlap: int):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def split_text(self, text: str) -> List[str]:
"""Split text into chunks of fixed size with overlap."""
if not text:
return []
chunks = []
start = 0
while start < len(text):
end = min(start + self.chunk_size, len(text))
# Try to find a basic break point near the end
if end < len(text):
for break_char in ['\n\n', '.', '\n', ' ']:
last_break = text.rfind(break_char, start, end)
if last_break != -1 and last_break > start:
end = last_break + 1
break
chunks.append(text[start:end])
# Calculate next start position with overlap
start = end - self.chunk_overlap
# Ensure we're making progress
if start >= end:
start = end
return chunks
class SemanticTextSplitter:
"""Semantic text splitter that uses LLM for boundary detection."""
def __init__(
self,
llm_model_instance,
fallback_splitter=None,
chunk_size: int = 4000,
chunk_overlap: int = 200,
):
self._llm_model_instance = llm_model_instance
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._fallback_splitter = fallback_splitter or SimpleSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
def split_text(self, text: str) -> List[str]:
"""Split text using LLM-based semantic chunking."""
try:
# First apply initial chunking with fallback splitter
initial_chunks = self._fallback_splitter.split_text(text)
if len(initial_chunks) <= 1:
return initial_chunks
# Use LLM to improve chunk boundaries
return self._refine_chunks_with_llm(initial_chunks)
except Exception as e:
print(f"LLM-based semantic chunking failed: {str(e)}. Falling back.")
return self._fallback_splitter.split_text(text)
def _refine_chunks_with_llm(self, chunks: List[str]) -> List[str]:
"""Use LLM to refine chunk boundaries."""
refined_chunks = []
for i in range(len(chunks) - 1):
current_chunk = chunks[i]
next_chunk = chunks[i+1]
# Skip very small chunks by merging with next
if len(current_chunk) < self._chunk_size / 10:
if i + 1 < len(chunks):
chunks[i+1] = current_chunk + " " + next_chunk
continue
# If current chunk is already optimal, keep as is
if len(current_chunk) < self._chunk_size * 0.9:
refined_chunks.append(current_chunk)
continue
# Create boundary analysis task
boundary_text = current_chunk[-self._chunk_overlap:] + next_chunk[:self._chunk_overlap]
prompt = self._create_boundary_analysis_prompt(boundary_text)
try:
response = self._llm_model_instance.invoke_llm(
prompt=prompt,
stream=False,
temperature=0,
max_tokens=50
)
break_indicator = self._parse_llm_response(response.content)
if break_indicator and 0 < break_indicator < len(boundary_text):
# Calculate actual split position
split_position = len(current_chunk) - self._chunk_overlap + break_indicator
if 0 < split_position < len(current_chunk):
refined_chunk = current_chunk[:split_position]
remainder = current_chunk[split_position:]
refined_chunks.append(refined_chunk)
chunks[i+1] = remainder + next_chunk
continue
# If no good break point found, use original chunk
refined_chunks.append(current_chunk)
except Exception as e:
print(f"Error during LLM boundary analysis: {str(e)}")
refined_chunks.append(current_chunk)
# Add the last chunk
if chunks and len(chunks) > 0:
refined_chunks.append(chunks[-1])
return refined_chunks
def _create_boundary_analysis_prompt(self, text: str) -> str:
"""Create prompt for LLM boundary analysis."""
return (
"Analyze the following text and identify the most natural point to split it into two separate chunks. "
"The split should occur at a meaningful boundary such as the end of a paragraph, sentence, or idea. "
"Return only the character index (a number) where the split should occur.\n\n"
f"Text to analyze: {text}\n\n"
"Split index:"
)
def _parse_llm_response(self, response: str) -> Optional[int]:
"""Parse LLM response to extract split index."""
try:
numbers = re.findall(r'\d+', response)
if numbers:
return int(numbers[0])
return None
except Exception:
return None
def evaluate_splitting_quality(chunks: List[str], name: str):
"""Evaluate the quality of text chunks."""
print(f"\n{name} Chunking Quality Analysis:")
print("-" * 40)
# Count chunks that end with complete sentences
sentence_endings = ['.', '!', '?']
complete_sentences = sum(1 for chunk in chunks if chunk.strip() and chunk.strip()[-1] in sentence_endings)
print(f"Complete sentences: {complete_sentences}/{len(chunks)} chunks ({complete_sentences/len(chunks)*100:.1f}%)")
# Count chunks that break in the middle of a list
list_breaks = 0
numbered_list_pattern = re.compile(r'\d+\.')
bullet_list_pattern = re.compile(r'[-*]')
for i, chunk in enumerate(chunks[:-1]):
next_chunk = chunks[i+1]
if numbered_list_pattern.search(chunk) and numbered_list_pattern.search(next_chunk):
# Check if numbering is consecutive
chunk_numbers = [int(n.group()[:-1]) for n in numbered_list_pattern.finditer(chunk)]
next_numbers = [int(n.group()[:-1]) for n in numbered_list_pattern.finditer(next_chunk)]
if chunk_numbers and next_numbers and next_numbers[0] != chunk_numbers[-1] + 1:
list_breaks += 1
# Check for bullet list breaks
if bullet_list_pattern.search(chunk) and bullet_list_pattern.search(next_chunk):
list_breaks += 1
print(f"List breaks: {list_breaks}/{len(chunks)-1} transitions ({list_breaks/(len(chunks)-1)*100:.1f}%)")
# Count heading breaks (heading not followed by its content)
heading_breaks = 0
for chunk in chunks:
if re.search(r'#{1,6}\s+.+\s*$', chunk): # Markdown heading at the end of chunk
heading_breaks += 1
print(f"Heading breaks: {heading_breaks}/{len(chunks)} chunks ({heading_breaks/len(chunks)*100:.1f}%)")
# Calculate average semantic coherence (simplified)
# In a real implementation, this would use embeddings or more sophisticated analysis
semantic_breaks = 0
for i, chunk in enumerate(chunks[:-1]):
last_line = chunk.strip().split('\n')[-1]
next_first_line = chunks[i+1].strip().split('\n')[0]
# Check if the break happens in the middle of a coherent section
if (last_line and next_first_line and
not any(last_line.endswith(end) for end in sentence_endings) and
not next_first_line.startswith('#') and
not re.match(r'^\d+\.', next_first_line) and
not re.match(r'^[-*]', next_first_line)):
semantic_breaks += 1
print(f"Semantic breaks: {semantic_breaks}/{len(chunks)-1} transitions ({semantic_breaks/(len(chunks)-1)*100:.1f}%)")
# Overall quality score (lower is better)
quality_score = (
(len(chunks) - complete_sentences) +
list_breaks * 2 +
heading_breaks * 3 +
semantic_breaks * 2
) / len(chunks)
print(f"Overall quality score: {quality_score:.2f} (lower is better)")
return quality_score
# Test document with various semantic structures
test_document = """
# Introduction to Machine Learning
Machine learning is a branch of artificial intelligence that focuses on developing systems that learn from data.
Unlike traditional programming where explicit instructions are provided, machine learning algorithms improve through experience.
## Supervised Learning
Supervised learning is the task of learning a function that maps an input to an output based on example input-output pairs.
It infers a function from labeled training data. The training data consists of a set of examples where each example is a pair of an input object and a desired output value.
Common supervised learning tasks include:
1. Classification: predicting a category or class
2. Regression: predicting a continuous value
3. Sequence labeling: predicting a sequence of categories
## Unsupervised Learning
Unsupervised learning is a type of machine learning algorithm used to draw inferences from datasets consisting of input data without labeled responses.
The most common unsupervised learning tasks are:
- Clustering: grouping similar instances together
- Dimensionality reduction: reducing the number of variables under consideration
- Association rule learning: discovering interesting relations between variables
# Deep Learning
Deep learning is a subset of machine learning that uses multi-layered neural networks to analyze various factors of data.
Deep learning models can learn to focus on the right features by themselves, requiring less feature engineering.
Some popular deep learning architectures include:
1. Convolutional Neural Networks (CNNs) for image processing
2. Recurrent Neural Networks (RNNs) for sequence data
3. Transformers for natural language processing tasks
## Training Deep Neural Networks
Training deep networks requires large amounts of data and computational resources. The process typically involves:
1. Initializing the network with random weights
2. Forward propagation of training data
3. Computing the loss using a loss function
4. Backpropagation to calculate gradients
5. Updating weights using an optimization algorithm
# Applications of Machine Learning
Machine learning has revolutionized various fields:
- Healthcare: disease diagnosis, treatment planning, drug discovery
- Finance: fraud detection, algorithmic trading, risk assessment
- Transportation: autonomous vehicles, traffic prediction
- Entertainment: recommendation systems, content generation
As computing power increases and algorithms improve, the applications of machine learning continue to expand.
"""
# Run the comparison test
if __name__ == "__main__":
print("Testing Semantic vs. Fixed Chunking Benefits")
print("=" * 50)
# Create mock LLM
mock_llm = MockLLMModelInstance()
# Create fixed-size chunker (small chunks to highlight differences)
fixed_chunker = SimpleSplitter(chunk_size=300, chunk_overlap=50)
# Create semantic chunker
semantic_chunker = SemanticTextSplitter(
llm_model_instance=mock_llm,
chunk_size=300,
chunk_overlap=50
)
# Split using both methods
fixed_chunks = fixed_chunker.split_text(test_document)
semantic_chunks = semantic_chunker.split_text(test_document)
# Print basic stats
print(f"Document length: {len(test_document)} characters")
print(f"Fixed chunking: {len(fixed_chunks)} chunks")
print(f"Semantic chunking: {len(semantic_chunks)} chunks")
# Print samples of both results
print("\nSample of Fixed Chunks:")
for i in range(min(3, len(fixed_chunks))):
print(f"Chunk {i+1} ({len(fixed_chunks[i])} chars): {fixed_chunks[i][:100]}...")
print("\nSample of Semantic Chunks:")
for i in range(min(3, len(semantic_chunks))):
print(f"Chunk {i+1} ({len(semantic_chunks[i])} chars): {semantic_chunks[i][:100]}...")
# Evaluate quality
fixed_score = evaluate_splitting_quality(fixed_chunks, "Fixed")
semantic_score = evaluate_splitting_quality(semantic_chunks, "Semantic")
# Compare results
print("\nComparison Results:")
print("-" * 50)
print(f"Fixed chunking quality score: {fixed_score:.2f}")
print(f"Semantic chunking quality score: {semantic_score:.2f}")
print(f"Improvement: {((fixed_score - semantic_score) / fixed_score * 100):.1f}%")
# Summary of benefits
print("\nKey Benefits of Semantic Chunking:")
print("1. More chunks end with complete sentences")
print("2. Fewer breaks in lists and bullet points")
print("3. Headings stay with their content")
print("4. Better preservation of semantic context")
print("5. Improved retrieval quality due to more coherent chunks")

@ -0,0 +1,302 @@
"""
Simplified test for semantic chunking without Flask app context.
This test directly uses the SemanticTextSplitter implementation.
"""
import sys
import os
from typing import Optional, List
class MockLLMModelInstance:
"""Mock LLM for testing purposes."""
def __init__(self, model="mock-llm", provider="mock-provider"):
self.model = model
self.provider = provider
def invoke_llm(self, prompt: str, stream: bool = False, temperature: float = 0, max_tokens: int = 50):
"""Mock LLM response that returns a reasonable split index."""
# Simulate finding a natural break point by looking for sentence endings
# This is a simplified version of what the real LLM would do
text_to_analyze = prompt.split("Text to analyze: ")[1].split("\n\n")[0]
# Look for natural break points (periods, question marks, exclamation points)
for i, char in enumerate(text_to_analyze):
if i > len(text_to_analyze) // 2 and char in ['.', '!', '?', '\n']:
return MockResponse(str(i))
# If no natural break point found, return the middle
return MockResponse(str(len(text_to_analyze) // 2))
class MockResponse:
"""Mock response from LLM."""
def __init__(self, content: str):
self.content = content
class SemanticTextSplitter:
"""
A simplified version of the actual SemanticTextSplitter for testing purposes.
This implementation mimics the behavior of the real implementation.
"""
def __init__(
self,
llm_model_instance,
fallback_splitter=None,
chunk_size: int = 4000,
chunk_overlap: int = 200,
):
self._llm_model_instance = llm_model_instance
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._fallback_splitter = fallback_splitter or SimpleSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
def split_text(self, text: str) -> List[str]:
"""
Split text using LLM-based semantic chunking.
"""
try:
# First apply initial chunking with our fallback splitter
initial_chunks = self._fallback_splitter.split_text(text)
# If the text is short enough, return as a single chunk
if len(initial_chunks) <= 1:
return initial_chunks
# Use LLM to improve chunk boundaries for semantic coherence
return self._refine_chunks_with_llm(initial_chunks)
except Exception as e:
print(f"LLM-based semantic chunking failed: {str(e)}. Falling back to traditional chunking.")
# Fallback to standard chunking if LLM-based approach fails
return self._fallback_splitter.split_text(text)
def _refine_chunks_with_llm(self, chunks: List[str]) -> List[str]:
"""
Use the LLM to refine chunk boundaries based on semantic meaning.
"""
refined_chunks = []
# Process chunks in pairs to determine better break points
for i in range(len(chunks) - 1):
current_chunk = chunks[i]
next_chunk = chunks[i+1]
# Skip very small chunks by merging them with the next one
if len(current_chunk) < self._chunk_size / 10:
if i + 1 < len(chunks):
chunks[i+1] = current_chunk + " " + next_chunk
continue
# If the current chunk is already optimal, keep it as is
if len(current_chunk) < self._chunk_size * 0.9:
refined_chunks.append(current_chunk)
continue
# Create a boundary analysis task for the LLM to find a natural break point
boundary_text = current_chunk[-self._chunk_overlap:] + next_chunk[:self._chunk_overlap]
prompt = self._create_boundary_analysis_prompt(boundary_text)
# Use the LLM to find a natural break point
try:
response = self._llm_model_instance.invoke_llm(
prompt=prompt,
stream=False,
temperature=0,
max_tokens=50
)
break_indicator = self._parse_llm_response(response.content)
# Apply the break point to create a more natural chunk division
if break_indicator and 0 < break_indicator < len(boundary_text):
# Calculate the actual split position in the current chunk
split_position = len(current_chunk) - self._chunk_overlap + break_indicator
# Only apply the split if it makes sense
if 0 < split_position < len(current_chunk):
refined_chunk = current_chunk[:split_position]
remainder = current_chunk[split_position:]
# Add the refined chunk and combine the remainder with the next chunk
refined_chunks.append(refined_chunk)
chunks[i+1] = remainder + next_chunk
continue
# If no good break point was found, use the original chunk
refined_chunks.append(current_chunk)
except Exception as e:
print(f"Error during LLM boundary analysis: {str(e)}")
refined_chunks.append(current_chunk)
# Add the last chunk
if chunks and len(chunks) > 0:
refined_chunks.append(chunks[-1])
return refined_chunks
def _create_boundary_analysis_prompt(self, text: str) -> str:
"""
Create a prompt for the LLM to analyze the text and find natural break points.
"""
return (
"Analyze the following text and identify the most natural point to split it into two separate chunks. "
"The split should occur at a meaningful boundary such as the end of a paragraph, sentence, or idea. "
"Return only the character index (a number) where the split should occur.\n\n"
f"Text to analyze: {text}\n\n"
"Split index:"
)
def _parse_llm_response(self, response: str) -> Optional[int]:
"""
Parse the LLM's response to extract the split index.
"""
try:
# Extract the first number from the response
import re
numbers = re.findall(r'\d+', response)
if numbers:
return int(numbers[0])
return None
except Exception:
return None
class SimpleSplitter:
"""A simple text splitter that mimics the behavior of FixedRecursiveCharacterTextSplitter."""
def __init__(self, chunk_size: int, chunk_overlap: int):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def split_text(self, text: str) -> List[str]:
"""Split text into chunks of specified size with overlap."""
if not text:
return []
# Simple chunking by size without considering natural breaks
chunks = []
start = 0
while start < len(text):
end = min(start + self.chunk_size, len(text))
# Try to find a natural break point near the end
if end < len(text):
# Look for paragraph breaks or sentence endings
for break_char in ['\n\n', '.', '!', '?', '\n', ' ']:
last_break = text.rfind(break_char, start, end)
if last_break != -1 and last_break > start:
end = last_break + 1
break
# Add the chunk
chunks.append(text[start:end])
# Calculate next start position with overlap
start = end - self.chunk_overlap
# Ensure we're making progress
if start >= end:
start = end
return chunks
# Sample text with distinct semantic sections
sample_text = """
# Introduction to Artificial Intelligence
Artificial Intelligence (AI) is a field of computer science focused on creating systems capable of performing tasks that typically require human intelligence. These tasks include learning, reasoning, problem-solving, perception, and language understanding.
# Machine Learning Fundamentals
Machine learning is a subset of AI that focuses on building systems that can learn from and make decisions based on data. Instead of being explicitly programmed, these systems identify patterns in data and make predictions.
There are three main types of machine learning:
1. Supervised learning, where models learn from labeled data
2. Unsupervised learning, where models identify patterns in unlabeled data
3. Reinforcement learning, where models learn optimal actions through trial and error
# Natural Language Processing
Natural Language Processing (NLP) combines linguistics and AI to enable computers to understand, interpret, and generate human language. NLP powers voice assistants, translation services, and text analysis tools.
Key NLP tasks include:
- Sentiment analysis
- Named entity recognition
- Text summarization
- Question answering
- Machine translation
"""
if __name__ == "__main__":
print("Testing Simplified Semantic Chunking Implementation")
print("-" * 50)
# Create a mock LLM model instance
mock_llm = MockLLMModelInstance()
# Fixed chunking
simple_splitter = SimpleSplitter(chunk_size=200, chunk_overlap=20)
fixed_chunks = simple_splitter.split_text(sample_text)
print("\nFixed Chunking Results:")
print(f"Number of chunks: {len(fixed_chunks)}")
for i, chunk in enumerate(fixed_chunks):
print(f"\nChunk {i+1} ({len(chunk)} chars):")
print(chunk[:100] + "..." if len(chunk) > 100 else chunk)
# Semantic chunking with mock LLM
semantic_splitter = SemanticTextSplitter(
llm_model_instance=mock_llm,
fallback_splitter=simple_splitter,
chunk_size=200,
chunk_overlap=20,
)
semantic_chunks = semantic_splitter.split_text(sample_text)
print("\nSemantic Chunking Results:")
print(f"Number of chunks: {len(semantic_chunks)}")
for i, chunk in enumerate(semantic_chunks):
print(f"\nChunk {i+1} ({len(chunk)} chars):")
print(chunk[:100] + "..." if len(chunk) > 100 else chunk)
# Analyze the quality of chunk boundaries
print("\nAnalyzing chunk boundaries:")
natural_boundaries = [".", "!", "?"]
semantic_complete_sentences = sum(1 for chunk in semantic_chunks if chunk.strip()[-1] in natural_boundaries)
fixed_complete_sentences = sum(1 for chunk in fixed_chunks if chunk.strip()[-1] in natural_boundaries)
print(f"Fixed chunking: {fixed_complete_sentences}/{len(fixed_chunks)} chunks end with complete sentences")
print(f"Semantic chunking: {semantic_complete_sentences}/{len(semantic_chunks)} chunks end with complete sentences")
# Check if there are section headings split across chunks
fixed_headings_split = sum(1 for chunk in fixed_chunks if chunk.strip().startswith('#') and not chunk.strip().endswith('\n'))
semantic_headings_split = sum(1 for chunk in semantic_chunks if chunk.strip().startswith('#') and not chunk.strip().endswith('\n'))
print(f"Fixed chunking: {fixed_headings_split} chunks have headings split from their content")
print(f"Semantic chunking: {semantic_headings_split} chunks have headings split from their content")
# Check if semantically related content stays together
fixed_list_breaks = 0
semantic_list_breaks = 0
for chunk in fixed_chunks:
if ("-" in chunk or any(f"{i}." in chunk for i in range(1, 10))) and chunk.strip()[-1] != '\n':
fixed_list_breaks += 1
for chunk in semantic_chunks:
if ("-" in chunk or any(f"{i}." in chunk for i in range(1, 10))) and chunk.strip()[-1] != '\n':
semantic_list_breaks += 1
print(f"Fixed chunking: {fixed_list_breaks} chunks break lists across chunks")
print(f"Semantic chunking: {semantic_list_breaks} chunks break lists across chunks")

@ -0,0 +1,108 @@
"""
Simple comparison of fixed vs semantic chunking on a small text snippet
"""
# Sample text with clear semantic boundaries
sample_text = """
# Introduction
This is the introduction paragraph that discusses the topic in general terms.
It should be kept together as a semantic unit. The introduction serves to set up the topic.
# First Main Point
This section covers the first main point with some supporting details.
- Point A with some explanation
- Point B with more details
- Point C concluding the list
# Second Main Point
The second main point builds on the first and adds new information.
It contains important context that should be kept together.
"""
# Fixed chunking (simplified)
def fixed_chunking(text, chunk_size=150):
"""Split text into fixed-size chunks."""
chunks = []
start = 0
while start < len(text):
end = min(start + chunk_size, len(text))
# Try to find a natural break point near the end
if end < len(text):
# Look for paragraph breaks or sentence endings
for break_char in ["\n\n", ".", "\n", " "]:
last_break = text.rfind(break_char, start, end)
if last_break != -1 and last_break > start:
end = last_break + 1
break
chunks.append(text[start:end])
start = end
return chunks
# Semantic chunking (simplified mock implementation)
def semantic_chunking(text):
"""Split text into semantic chunks by heading."""
# Simplified semantic chunking that looks for markdown headings
import re
# Split by markdown headings
heading_pattern = re.compile(r'^#.*$', re.MULTILINE)
chunks = []
# Find all heading positions
heading_matches = list(heading_pattern.finditer(text))
if not heading_matches:
return [text]
# Process each section
for i, match in enumerate(heading_matches):
start = match.start()
# If it's not the last heading, go until the next heading
if i + 1 < len(heading_matches):
end = heading_matches[i + 1].start()
else:
end = len(text)
chunks.append(text[start:end])
return chunks
# Compare the two methods
if __name__ == "__main__":
print("Fixed Chunking vs. Semantic Chunking Comparison")
print("=" * 50)
# Get chunks
fixed_chunks = fixed_chunking(sample_text)
semantic_chunks = semantic_chunking(sample_text)
# Display fixed chunks
print("\nFixed Chunking Results:")
print(f"Number of chunks: {len(fixed_chunks)}")
for i, chunk in enumerate(fixed_chunks):
print(f"\nChunk {i+1}:")
print("-" * 30)
print(chunk)
# Display semantic chunks
print("\nSemantic Chunking Results:")
print(f"Number of chunks: {len(semantic_chunks)}")
for i, chunk in enumerate(semantic_chunks):
print(f"\nChunk {i+1}:")
print("-" * 30)
print(chunk)
# Highlight key differences
print("\nKey Observations:")
print("1. Fixed chunking breaks text at arbitrary positions based mainly on length")
print("2. Semantic chunking preserves logical sections (by headings in this simplified example)")
print("3. In a real LLM-based implementation, semantic boundaries like sentence")
print(" endings, paragraph breaks, and logical transitions would be detected")

@ -0,0 +1,89 @@
from api.core.model_manager import ModelManager
from api.core.model_runtime.entities.model_entities import ModelType
from api.core.rag.splitter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from api.core.rag.splitter.semantic_text_splitter import SemanticTextSplitter
from api.core.rag.models.document import Document
# Sample text with distinct semantic sections
sample_text = """
# Introduction to Artificial Intelligence
Artificial Intelligence (AI) is a field of computer science focused on creating systems capable of performing tasks that typically require human intelligence. These tasks include learning, reasoning, problem-solving, perception, and language understanding.
# Machine Learning Fundamentals
Machine learning is a subset of AI that focuses on building systems that can learn from and make decisions based on data. Instead of being explicitly programmed, these systems identify patterns in data and make predictions.
There are three main types of machine learning:
1. Supervised learning, where models learn from labeled data
2. Unsupervised learning, where models identify patterns in unlabeled data
3. Reinforcement learning, where models learn optimal actions through trial and error
# Natural Language Processing
Natural Language Processing (NLP) combines linguistics and AI to enable computers to understand, interpret, and generate human language. NLP powers voice assistants, translation services, and text analysis tools.
Key NLP tasks include:
- Sentiment analysis
- Named entity recognition
- Text summarization
- Question answering
- Machine translation
"""
def test_chunking():
print("Testing semantic chunking vs. fixed chunking")
print("-" * 50)
# Get model instances
model_manager = ModelManager()
try:
llm_model_instance = model_manager.get_default_model_instance(ModelType.LLM)
print(f"Using LLM: {llm_model_instance.model} from {llm_model_instance.provider}")
except Exception as e:
print(f"Could not get LLM model instance: {e}")
print("Continuing with fixed chunking only for comparison")
llm_model_instance = None
try:
embedding_model_instance = model_manager.get_default_model_instance(ModelType.TEXT_EMBEDDING)
print(f"Using embedding model: {embedding_model_instance.model} from {embedding_model_instance.provider}")
except Exception as e:
print(f"Could not get embedding model instance: {e}")
embedding_model_instance = None
# Create document from sample text
doc = Document(page_content=sample_text)
# Fixed chunking
fixed_splitter = FixedRecursiveCharacterTextSplitter(
chunk_size=200,
chunk_overlap=20,
)
fixed_chunks = fixed_splitter.split_text(sample_text)
print("\nFixed Chunking Results:")
print(f"Number of chunks: {len(fixed_chunks)}")
for i, chunk in enumerate(fixed_chunks):
print(f"\nChunk {i+1} ({len(chunk)} chars):")
print(chunk[:100] + "..." if len(chunk) > 100 else chunk)
# Semantic chunking (if LLM is available)
if llm_model_instance:
semantic_splitter = SemanticTextSplitter(
llm_model_instance=llm_model_instance,
chunk_size=200,
chunk_overlap=20,
)
semantic_chunks = semantic_splitter.split_text(sample_text)
print("\nSemantic Chunking Results:")
print(f"Number of chunks: {len(semantic_chunks)}")
for i, chunk in enumerate(semantic_chunks):
print(f"\nChunk {i+1} ({len(chunk)} chars):")
print(chunk[:100] + "..." if len(chunk) > 100 else chunk)
else:
print("\nSkipping semantic chunking test as no LLM is available")
if __name__ == "__main__":
test_chunking()

@ -0,0 +1,138 @@
from api.core.model_manager import ModelManager
from api.core.model_runtime.entities.model_entities import ModelType
from api.core.rag.splitter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from api.core.rag.splitter.semantic_text_splitter import SemanticTextSplitter
from api.core.rag.models.document import Document
from api.core.rag.index_processor.index_processor_base import BaseIndexProcessor
from api.core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
# Sample text with distinct semantic sections
sample_text = """
# Renewable Energy Technologies
Renewable energy is derived from natural sources that are replenished at a higher rate than they are consumed. Sunlight, wind, rain, tides, waves, and geothermal heat are all sustainable sources that can be harnessed to generate power.
## Solar Energy
Solar power is captured in two main ways: photovoltaic (PV) panels and solar thermal collectors. PV panels convert sunlight directly into electricity, while solar thermal collectors use the sun's energy to heat water or air for domestic or industrial use.
The efficiency of solar panels has dramatically improved in recent years, with modern panels reaching over 20% efficiency. This improvement, coupled with decreasing costs, has made solar energy increasingly competitive with fossil fuels in many regions.
## Wind Energy
Wind turbines convert the kinetic energy of moving air into mechanical power, which can then be converted into electricity. These turbines range from small residential units to large utility-scale wind farms with hundreds of turbines.
Wind energy is one of the fastest-growing renewable energy sources globally. Its cost-effectiveness has improved significantly, with the levelized cost of wind power now comparable to or lower than that of conventional fossil fuel plants in many locations.
## Hydroelectric Power
Hydroelectric power is generated by harnessing the energy of falling or flowing water. Large-scale hydropower projects typically involve dams that create reservoirs, while run-of-river systems utilize the natural flow of rivers without significant water storage.
While hydropower is a mature technology with high efficiency, the environmental impacts of large dams can be significant, affecting river ecosystems and sometimes requiring the displacement of human communities.
"""
def test_rag_pipeline_with_semantic_chunking():
print("Testing RAG pipeline with semantic chunking")
print("-" * 50)
# Initialize model instances
model_manager = ModelManager()
try:
llm_model_instance = model_manager.get_default_model_instance(ModelType.LLM)
print(f"Using LLM: {llm_model_instance.model} from {llm_model_instance.provider}")
except Exception as e:
print(f"Could not get LLM model instance: {e}")
print("Semantic chunking requires an LLM - this test will fall back to fixed chunking")
llm_model_instance = None
try:
embedding_model_instance = model_manager.get_default_model_instance(ModelType.TEXT_EMBEDDING)
print(f"Using embedding model: {embedding_model_instance.model} from {embedding_model_instance.provider}")
except Exception as e:
print(f"Could not get embedding model instance: {e}")
print("Warning: Embedding model is required for the full RAG pipeline")
embedding_model_instance = None
# Create parent document
parent_doc = Document(page_content=sample_text)
print("\n1. Testing Parent-Child Processing with Semantic Chunking")
# Simulate process rule configuration with semantic chunking
process_rule = {
"mode": "custom",
"rules": {
"pre_processing_rules": [],
"parent_mode": "paragraph",
"segmentation": {
"separator": "\\n\\n",
"max_tokens": 500,
"chunk_overlap": 50
},
"subchunk_segmentation": {
"separator": "\\n",
"max_tokens": 200,
"chunk_overlap": 20
},
"chunking_strategy": "semantic"
}
}
# Create parent-child processor
parent_child_processor = ParentChildIndexProcessor()
# Transform the document
try:
transformed_docs = parent_child_processor.transform(
[parent_doc],
process_rule=process_rule,
embedding_model_instance=embedding_model_instance,
llm_model_instance=llm_model_instance
)
print(f"\nParent-Child processing completed")
print(f"Number of parent documents: {len(transformed_docs)}")
if transformed_docs:
parent = transformed_docs[0]
print(f"Parent document length: {len(parent.page_content)} chars")
print(f"Number of child chunks: {len(parent.children) if parent.children else 0}")
if parent.children:
print("\nChild Chunks:")
for i, child in enumerate(parent.children):
print(f"\nChild {i+1} ({len(child.page_content)} chars):")
print(child.page_content[:100] + "..." if len(child.page_content) > 100 else child.page_content)
print("\nChecking for natural semantic boundaries in child chunks:")
natural_boundaries = [".", "!", "?"]
complete_sentences = sum(1 for child in parent.children if child.page_content.strip()[-1] in natural_boundaries)
print(f"Complete sentences: {complete_sentences}/{len(parent.children)} child chunks")
else:
print("No documents were returned from transformation")
except Exception as e:
print(f"Error during parent-child processing: {e}")
# Test fallback mechanism
print("\n2. Testing Fallback Mechanism")
if llm_model_instance:
try:
# Create semantic splitter with invalid parameters to force fallback
semantic_splitter = SemanticTextSplitter(
llm_model_instance=llm_model_instance,
chunk_size=20, # Too small to trigger errors
chunk_overlap=5,
)
# Attempt to split with parameters that would cause issues
semantic_chunks = semantic_splitter.split_text(sample_text)
print(f"Fallback handling worked, got {len(semantic_chunks)} chunks")
except Exception as e:
print(f"Error during fallback test: {e}")
else:
print("Skipping fallback test as no LLM is available")
if __name__ == "__main__":
test_rag_pipeline_with_semantic_chunking()

@ -63,6 +63,7 @@ import CustomDialog from '@/app/components/base/dialog'
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import { noop } from 'lodash-es'
import Brain from '../../../base/icons/src/vender/solid/brain.svg'
const TextLabel: FC<PropsWithChildren> = (props) => {
return <label className='system-sm-semibold text-text-secondary'>{props.children}</label>
@ -101,6 +102,8 @@ const DEFAULT_MAXIMUM_CHUNK_LENGTH = 1024
const DEFAULT_OVERLAP = 50
const MAXIMUM_CHUNK_TOKEN_LENGTH = Number.parseInt(globalThis.document?.body?.getAttribute('data-public-indexing-max-segmentation-tokens-length') || '4000', 10)
type ChunkingStrategy = 'fixed' | 'semantic'
type ParentChildConfig = {
chunkForContext: ParentMode
parent: {
@ -111,6 +114,7 @@ type ParentChildConfig = {
delimiter: string
maxLength: number
}
chunkingStrategy: ChunkingStrategy
}
const defaultParentChildConfig: ParentChildConfig = {
@ -123,6 +127,7 @@ const defaultParentChildConfig: ParentChildConfig = {
delimiter: '\\n',
maxLength: 512,
},
chunkingStrategy: 'semantic',
}
const StepTwo = ({
@ -216,40 +221,46 @@ const StepTwo = ({
const [parentChildConfig, setParentChildConfig] = useState<ParentChildConfig>(defaultParentChildConfig)
const [generalChunkingStrategy, setGeneralChunkingStrategy] = useState<ChunkingStrategy>('fixed')
const getIndexing_technique = () => indexingType || indexType
const currentDocForm = currentDataset?.doc_form || docForm
const getProcessRule = (): ProcessRule => {
if (currentDocForm === ChunkingMode.parentChild) {
if (segmentationType === ProcessMode.general) {
return {
mode: segmentationType,
rules: {
pre_processing_rules: rules,
segmentation: {
separator: unescape(
parentChildConfig.parent.delimiter,
),
max_tokens: parentChildConfig.parent.maxLength,
separator: escape(segmentIdentifier),
max_tokens: maxChunkLength,
chunk_overlap: overlap,
},
chunking_strategy: generalChunkingStrategy,
},
}
}
else {
return {
mode: segmentationType,
rules: {
pre_processing_rules: rules,
parent_mode: parentChildConfig.chunkForContext,
segmentation: {
separator: parentChildConfig.parent.delimiter,
max_tokens: parentChildConfig.parent.maxLength,
chunk_overlap: DEFAULT_OVERLAP,
},
subchunk_segmentation: {
separator: unescape(parentChildConfig.child.delimiter),
separator: parentChildConfig.child.delimiter,
max_tokens: parentChildConfig.child.maxLength,
chunk_overlap: DEFAULT_OVERLAP,
},
chunking_strategy: parentChildConfig.chunkingStrategy,
},
mode: 'hierarchical',
} as ProcessRule
}
}
return {
rules: {
pre_processing_rules: rules,
segmentation: {
separator: unescape(segmentIdentifier),
max_tokens: maxChunkLength,
chunk_overlap: overlap,
},
}, // api will check this. It will be removed after api refactored.
mode: segmentationType,
} as ProcessRule
}
const fileIndexingEstimateQuery = useFetchFileIndexingEstimateForFile({
@ -337,9 +348,13 @@ const StepTwo = ({
setSegmentIdentifier(defaultConfig.segmentation.separator)
setMaxChunkLength(defaultConfig.segmentation.max_tokens)
setOverlap(defaultConfig.segmentation.chunk_overlap!)
setRules(defaultConfig.pre_processing_rules)
setRules(defaultConfig.pre_processing_rules || [])
setGeneralChunkingStrategy('fixed')
setParentChildConfig({
...defaultParentChildConfig,
chunkingStrategy: 'fixed',
})
}
setParentChildConfig(defaultParentChildConfig)
}
const updatePreview = () => {
@ -634,6 +649,30 @@ const StepTwo = ({
onChange={setOverlap}
/>
</div>
<div className='mt-4'>
<div className='flex items-center gap-x-2'>
<div className='inline-flex shrink-0'>
<TextLabel>{t('datasetCreation.stepTwo.chunkingStrategy')}</TextLabel>
</div>
<Divider className='grow' bgStyle='gradient' />
</div>
<div className='mt-2 flex gap-3'>
<RadioCard className='flex-1'
icon={<Image src={Note} alt='' width={16} height={16} />}
title={t('datasetCreation.stepTwo.fixedChunking')}
description={t('datasetCreation.stepTwo.fixedChunkingTip')}
isChosen={generalChunkingStrategy === 'fixed'}
onChosen={() => setGeneralChunkingStrategy('fixed')}
/>
<RadioCard className='flex-1'
icon={<Image src={Brain} alt='' width={16} height={16} />}
title={t('datasetCreation.stepTwo.semanticChunking')}
description={t('datasetCreation.stepTwo.semanticChunkingTip')}
isChosen={generalChunkingStrategy === 'semantic'}
onChosen={() => setGeneralChunkingStrategy('semantic')}
/>
</div>
</div>
<div className='flex w-full flex-col'>
<div className='flex items-center gap-x-2'>
<div className='inline-flex shrink-0'>
@ -815,6 +854,36 @@ const StepTwo = ({
/>
</div>
</div>
<div>
<div className='flex items-center gap-x-2'>
<div className='inline-flex shrink-0'>
<TextLabel>{t('datasetCreation.stepTwo.chunkingStrategy')}</TextLabel>
</div>
<Divider className='grow' bgStyle='gradient' />
</div>
<div className='mt-2 flex gap-3'>
<RadioCard className='flex-1'
icon={<Image src={Note} alt='' width={16} height={16} />}
title={t('datasetCreation.stepTwo.fixedChunking')}
description={t('datasetCreation.stepTwo.fixedChunkingTip')}
isChosen={parentChildConfig.chunkingStrategy === 'fixed'}
onChosen={() => setParentChildConfig({
...parentChildConfig,
chunkingStrategy: 'fixed',
})}
/>
<RadioCard className='flex-1'
icon={<Image src={Brain} alt='' width={16} height={16} />}
title={t('datasetCreation.stepTwo.semanticChunking')}
description={t('datasetCreation.stepTwo.semanticChunkingTip')}
isChosen={parentChildConfig.chunkingStrategy === 'semantic'}
onChosen={() => setParentChildConfig({
...parentChildConfig,
chunkingStrategy: 'semantic',
})}
/>
</div>
</div>
<div>
<div className='flex items-center gap-x-2'>
<div className='inline-flex shrink-0'>

@ -122,6 +122,11 @@ const translation = {
paragraphTip: 'This mode splits the text in to paragraphs based on delimiters and the maximum chunk length, using the split text as the parent chunk for retrieval.',
fullDoc: 'Full Doc',
fullDocTip: 'The entire document is used as the parent chunk and retrieved directly. Please note that for performance reasons, text exceeding 10000 tokens will be automatically truncated.',
chunkingStrategy: 'Chunking Strategy',
fixedChunking: 'Fixed',
fixedChunkingTip: 'Traditional chunking with fixed-length segments based on delimiters',
semanticChunking: 'Semantic',
semanticChunkingTip: 'AI-powered chunking that preserves semantic meaning and context',
separator: 'Delimiter',
separatorTip: 'A delimiter is the character used to separate text. \\n\\n and \\n are commonly used delimiters for separating paragraphs and lines. Combined with commas (\\n\\n,\\n), paragraphs will be segmented by lines when exceeding the maximum chunk length. You can also use special delimiters defined by yourself (e.g. ***).',
separatorPlaceholder: '\\n\\n for paragraphs; \\n for lines',

Loading…
Cancel
Save