From 0dcb67ab20d43af8836953195d070d8995fd6d5c Mon Sep 17 00:00:00 2001 From: Aryan Raj Date: Thu, 15 May 2025 17:55:56 +0530 Subject: [PATCH] Add: Implement semantic chunking functionality with tests - Introduced `SemanticTextSplitter` to enhance text chunking by identifying semantic boundaries using an LLM. - Added tests for boundary analysis and semantic chunking in `boundary_analysis_test.py` and `semantic_chunking_test.py`. - Updated `IndexingRunner` and related classes to support semantic chunking strategy. - Enhanced UI components to allow selection between fixed and semantic chunking strategies. - Updated translations for new chunking options in the dataset creation process. --- api/core/indexing_runner.py | 167 ++++++++ .../index_processor/index_processor_base.py | 55 ++- .../processor/parent_child_index_processor.py | 37 +- .../rag/splitter/semantic_text_splitter.py | 241 +++++++++++ .../knowledge_entities/knowledge_entities.py | 1 + api/services/vector_service.py | 24 +- boundary_analysis_test.py | 115 ++++++ run_test.py | 112 +++++ semantic_chunking_benefits.py | 385 ++++++++++++++++++ semantic_chunking_test.py | 302 ++++++++++++++ simple_comparison.py | 108 +++++ test_semantic_chunking.py | 89 ++++ test_semantic_rag_pipeline.py | 138 +++++++ .../datasets/create/step-two/index.tsx | 111 ++++- web/i18n/en-US/dataset-creation.ts | 5 + 15 files changed, 1862 insertions(+), 28 deletions(-) create mode 100644 api/core/rag/splitter/semantic_text_splitter.py create mode 100644 boundary_analysis_test.py create mode 100644 run_test.py create mode 100644 semantic_chunking_benefits.py create mode 100644 semantic_chunking_test.py create mode 100644 simple_comparison.py create mode 100644 test_semantic_chunking.py create mode 100644 test_semantic_rag_pipeline.py diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 848d897779..5e2a5cad3b 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -706,6 +706,18 @@ class IndexingRunner: tenant_id=dataset.tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) + + # Get LLM model instance if semantic chunking is enabled + llm_model_instance = None + if process_rule.get("rules", {}).get("chunking_strategy") == "semantic": + try: + llm_model_instance = self.model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.LLM, + ) + except Exception as e: + # Fall back to fixed chunking if we can't get an LLM model + process_rule["rules"]["chunking_strategy"] = "fixed" documents = index_processor.transform( text_docs, @@ -713,6 +725,7 @@ class IndexingRunner: process_rule=process_rule, tenant_id=dataset.tenant_id, doc_language=doc_language, + llm_model_instance=llm_model_instance ) return documents @@ -747,6 +760,160 @@ class IndexingRunner: ) pass + def _process_document(self, flask_app, process_rule, document, dataset, tenant_id, user, document_plan): + if not self._start_document(document.id, tenant_id): + return + + try: + is_automatic = process_rule.get("mode") == "automatic" or process_rule.get("mode") == "hierarchical" + with flask_app.app_context(): + # check document is paused + self._check_document_paused_status(document.id) + + # get embedding model instance + embedding_model_name = dataset.embedding_model + embedding_model_provider = dataset.embedding_model_provider + + model_manager = ModelManager() + + embedding_model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=embedding_model_name, + ) + + # Get LLM model instance if semantic chunking is used + llm_model_instance = None + if process_rule.get("rules", {}).get("chunking_strategy") == "semantic": + try: + llm_model_instance = model_manager.get_default_model_instance(ModelType.LLM) + except Exception as e: + # Fall back to fixed chunking if we can't get an LLM model + process_rule["rules"]["chunking_strategy"] = "fixed" + + # extract + index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() + + text_docs = [] + completed_segments = 0 + total_segments = 0 + + # resume document indexing + stopped_segment_ids = [] + stopped_segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + ) + .all() + ) + stopped_segment_ids = [segment.id for segment in stopped_segments] + + original_extract_setting = parse_extract_setting(document, dataset) + original_extract_setting.is_automatic = is_automatic + if stopped_segment_ids: + # remove original content if document is stopped + original_extract_setting.content = None + + text_docs = IndexingRunner.load_document( + document_id=document.id, + dataset_id=dataset.id, + tenant_id=tenant_id, + extract_setting=original_extract_setting, + process_rule_mode=process_rule.get("mode"), + ) + + # index + if text_docs: + # transform + chunk_docs = index_processor.transform( + text_docs, + process_rule=process_rule, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_language=document.doc_language, + embedding_model_instance=embedding_model_instance, + llm_model_instance=llm_model_instance + ) + + # save segment + self._load_segments(dataset, document, chunk_docs) + + # load + self._load( + index_processor=index_processor, + dataset=dataset, + dataset_document=document, + documents=chunk_docs, + ) + + # update document status to completed + self._update_document_index_status( + document_id=document.id, + after_indexing_status="completed", + extra_update_params={ + DatasetDocument.tokens: sum(len(text_doc.page_content) for text_doc in text_docs), + DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DatasetDocument.indexing_latency: time.perf_counter() - indexing_start_at, + DatasetDocument.error: None, + }, + ) + + completed_segments = len(text_docs) + total_segments = len(text_docs) + + # update segment status to completed + self._update_segments_by_document( + dataset_document_id=document.id, + update_params={ + DocumentSegment.status: "completed", + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + }, + ) + + else: + # update document status to completed + self._update_document_index_status( + document_id=document.id, + after_indexing_status="completed", + extra_update_params={ + DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DatasetDocument.error: None, + }, + ) + + # update segment status to completed + self._update_segments_by_document( + dataset_document_id=document.id, + update_params={ + DocumentSegment.status: "completed", + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + }, + ) + + completed_segments = 0 + total_segments = 0 + + # update document plan + self._update_document_plan(document_plan, completed_segments, total_segments) + + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(document.id)) + except ProviderTokenNotInitError as e: + document.indexing_status = "error" + document.error = str(e.description) + document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + except Exception as e: + logging.exception("consume document failed") + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + db.session.commit() + class DocumentIsPausedError(Exception): pass diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 2bcd1c79bb..cc85ab93e5 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, Literal from configs import dify_config from core.model_manager import ModelInstance @@ -11,6 +11,7 @@ from core.rag.splitter.fixed_text_splitter import ( EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter, ) +from core.rag.splitter.semantic_text_splitter import SemanticTextSplitter from core.rag.splitter.text_splitter import TextSplitter from models.dataset import Dataset, DatasetProcessRule @@ -52,10 +53,60 @@ class BaseIndexProcessor(ABC): chunk_overlap: int, separator: str, embedding_model_instance: Optional[ModelInstance], + chunking_strategy: Optional[Literal["fixed", "semantic"]] = None, + llm_model_instance: Optional[ModelInstance] = None, ) -> TextSplitter: """ - Get the NodeParser object according to the processing rule. + Get the splitter object according to the processing rule and chunking strategy. + + Args: + processing_rule_mode: The processing rule mode (custom, hierarchical, etc.) + max_tokens: Maximum tokens per chunk + chunk_overlap: Overlap between chunks + separator: Separator string for chunking + embedding_model_instance: Model instance for embeddings + chunking_strategy: Chunking strategy to use (fixed or semantic) + llm_model_instance: LLM model instance for semantic chunking + + Returns: + A TextSplitter instance """ + # Use semantic chunking if explicitly requested and LLM model is available + if chunking_strategy == "semantic" and llm_model_instance: + # Create a fallback splitter first based on the processing rule mode + if processing_rule_mode in ["custom", "hierarchical"]: + max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH + if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: + raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") + + if separator: + separator = separator.replace("\\n", "\n") + + fallback_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( + chunk_size=max_tokens, + chunk_overlap=chunk_overlap, + fixed_separator=separator, + separators=["\n\n", "。", ". ", " ", ""], + embedding_model_instance=embedding_model_instance, + ) + else: + # Automatic segmentation + fallback_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( + chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], + chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], + separators=["\n\n", "。", ". ", " ", ""], + embedding_model_instance=embedding_model_instance, + ) + + # Return a semantic text splitter with the appropriate fallback + return SemanticTextSplitter( + llm_model_instance=llm_model_instance, + fallback_splitter=fallback_splitter, + chunk_size=max_tokens, + chunk_overlap=chunk_overlap + ) + + # Default to traditional chunking methods if processing_rule_mode in ["custom", "hierarchical"]: # The user-defined segmentation rule max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 1cde5e1c8f..a09ccd67a0 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -4,7 +4,7 @@ import uuid from typing import Optional from configs import dify_config -from core.model_manager import ModelInstance +from core.model_manager import ModelManager, ModelInstance from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector @@ -12,6 +12,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import ChildDocument, Document +from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from libs import helper from models.dataset import ChildChunk, Dataset, DocumentSegment @@ -37,6 +38,20 @@ class ParentChildIndexProcessor(BaseIndexProcessor): raise ValueError("No rules found in process rule.") rules = Rule(**process_rule.get("rules")) all_documents = [] # type: ignore + + # Get the chunking strategy from process_rule or default to "fixed" + chunking_strategy = process_rule.get("chunking_strategy", "fixed") + + # Get LLM model instance if semantic chunking is enabled + llm_model_instance = None + if chunking_strategy == "semantic": + try: + model_manager = ModelManager() + llm_model_instance = kwargs.get("llm_model_instance") or model_manager.get_default_model_instance(ModelType.LLM) + except Exception as e: + # Fall back to fixed chunking if we can't get a LLM model + chunking_strategy = "fixed" + if rules.parent_mode == ParentMode.PARAGRAPH: # Split the text documents into nodes. if not rules.segmentation: @@ -47,6 +62,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor): chunk_overlap=rules.segmentation.chunk_overlap, separator=rules.segmentation.separator, embedding_model_instance=kwargs.get("embedding_model_instance"), + chunking_strategy=chunking_strategy, + llm_model_instance=llm_model_instance ) for document in documents: if kwargs.get("preview") and len(all_documents) >= 10: @@ -73,7 +90,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor): document_node.page_content = page_content # parse document to child nodes child_nodes = self._split_child_nodes( - document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") + document_node, + rules, + process_rule.get("mode"), + kwargs.get("embedding_model_instance"), + chunking_strategy=chunking_strategy, + llm_model_instance=llm_model_instance ) document_node.children = child_nodes split_documents.append(document_node) @@ -83,7 +105,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor): document = Document(page_content=page_content, metadata=documents[0].metadata) # parse document to child nodes child_nodes = self._split_child_nodes( - document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") + document, + rules, + process_rule.get("mode"), + kwargs.get("embedding_model_instance"), + chunking_strategy=chunking_strategy, + llm_model_instance=llm_model_instance ) if kwargs.get("preview"): if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER: @@ -173,6 +200,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor): rules: Rule, process_rule_mode: str, embedding_model_instance: Optional[ModelInstance], + chunking_strategy: str = "fixed", + llm_model_instance: Optional[ModelInstance] = None, ) -> list[ChildDocument]: if not rules.subchunk_segmentation: raise ValueError("No subchunk segmentation found in rules.") @@ -182,6 +211,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor): chunk_overlap=rules.subchunk_segmentation.chunk_overlap, separator=rules.subchunk_segmentation.separator, embedding_model_instance=embedding_model_instance, + chunking_strategy=chunking_strategy, + llm_model_instance=llm_model_instance ) # parse document to child nodes child_nodes = [] diff --git a/api/core/rag/splitter/semantic_text_splitter.py b/api/core/rag/splitter/semantic_text_splitter.py new file mode 100644 index 0000000000..145287c81e --- /dev/null +++ b/api/core/rag/splitter/semantic_text_splitter.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import logging +from typing import Any, Optional, List + +from core.model_manager import ModelManager, ModelInstance +from core.rag.models.document import Document +from core.rag.splitter.text_splitter import TextSplitter +from core.model_runtime.entities.model_entities import ModelType +from core.rag.splitter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter + +logger = logging.getLogger(__name__) + +class SemanticTextSplitter(TextSplitter): + """ + A text splitter that uses an LLM to identify semantic boundaries in text. + + This splitter first uses a traditional method as a fallback, then refines using LLM-based + semantic analysis to ensure chunk boundaries align with natural semantic breaks. + """ + + def __init__( + self, + llm_model_instance: ModelInstance, + fallback_splitter: Optional[TextSplitter] = None, + chunk_size: int = 4000, + chunk_overlap: int = 200, + **kwargs: Any, + ) -> None: + """ + Initialize the semantic text splitter. + + Args: + llm_model_instance: The LLM model instance to use for semantic chunking + fallback_splitter: A fallback splitter to use if the LLM chunking fails + chunk_size: Maximum size of chunks to return + chunk_overlap: Overlap in characters between chunks + **kwargs: Additional arguments to pass to the TextSplitter base class + """ + super().__init__( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + **kwargs + ) + self._llm_model_instance = llm_model_instance + self._fallback_splitter = fallback_splitter or FixedRecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + def split_text(self, text: str) -> list[str]: + """ + Split text using LLM-based semantic chunking. + + This method first applies a basic chunking method, then refines the chunk + boundaries using the LLM to identify more natural semantic breaks. + + Args: + text: The text to split + + Returns: + A list of text chunks with semantic coherence + """ + try: + # First apply initial chunking with our fallback splitter + initial_chunks = self._fallback_splitter.split_text(text) + + # If the text is short enough, return as a single chunk + if len(initial_chunks) <= 1: + return initial_chunks + + # Use LLM to improve chunk boundaries for semantic coherence + return self._refine_chunks_with_llm(initial_chunks) + + except Exception as e: + logger.warning(f"LLM-based semantic chunking failed: {str(e)}. Falling back to traditional chunking.") + # Fallback to standard chunking if LLM-based approach fails + return self._fallback_splitter.split_text(text) + + def _refine_chunks_with_llm(self, chunks: List[str]) -> List[str]: + """ + Use the LLM to refine chunk boundaries based on semantic meaning. + + Args: + chunks: Initial text chunks to refine + + Returns: + Refined chunks with improved semantic coherence + """ + refined_chunks = [] + + # Process chunks in pairs to determine better break points + for i in range(len(chunks) - 1): + current_chunk = chunks[i] + next_chunk = chunks[i+1] + + # Skip very small chunks by merging them with the next one + if len(current_chunk) < self._chunk_size / 10: + if i + 1 < len(chunks): + chunks[i+1] = current_chunk + " " + next_chunk + continue + + # If the current chunk is already optimal, keep it as is + if len(current_chunk) < self._chunk_size * 0.9: + refined_chunks.append(current_chunk) + continue + + # Create a boundary analysis task for the LLM to find a natural break point + # between the current chunk and the next + boundary_text = current_chunk[-self._chunk_overlap:] + next_chunk[:self._chunk_overlap] + + prompt = self._create_boundary_analysis_prompt(boundary_text) + + # Use the LLM to find a natural break point + try: + response = self._llm_model_instance.invoke_llm( + prompt=prompt, + stream=False, + temperature=0, + max_tokens=50 + ) + + break_indicator = self._parse_llm_response(response.content) + + # Apply the break point to create a more natural chunk division + if break_indicator and 0 < break_indicator < len(boundary_text): + # Calculate the actual split position in the current chunk + split_position = len(current_chunk) - self._chunk_overlap + break_indicator + + # Only apply the split if it makes sense + if 0 < split_position < len(current_chunk): + refined_chunk = current_chunk[:split_position] + remainder = current_chunk[split_position:] + + # Add the refined chunk and combine the remainder with the next chunk + refined_chunks.append(refined_chunk) + chunks[i+1] = remainder + next_chunk + continue + + # If no good break point was found, use the original chunk + refined_chunks.append(current_chunk) + + except Exception as e: + logger.warning(f"Error during LLM boundary analysis: {str(e)}") + refined_chunks.append(current_chunk) + + # Add the last chunk + if chunks and len(chunks) > 0: + refined_chunks.append(chunks[-1]) + + return refined_chunks + + def _create_boundary_analysis_prompt(self, text: str) -> str: + """ + Create a prompt for the LLM to analyze the text and find natural break points. + + Args: + text: The text around the chunk boundary to analyze + + Returns: + A prompt for the LLM + """ + return ( + "Analyze the following text and identify the most natural point to split it into two separate chunks. " + "The split should occur at a meaningful boundary such as the end of a paragraph, sentence, or idea. " + "Return only the character index (a number) where the split should occur.\n\n" + f"Text to analyze: {text}\n\n" + "Split index:" + ) + + def _parse_llm_response(self, response: str) -> Optional[int]: + """ + Parse the LLM's response to extract the split index. + + Args: + response: The LLM's response containing the split index + + Returns: + The index at which to split the text, or None if no valid index was found + """ + try: + # Extract the first number from the response + import re + numbers = re.findall(r'\d+', response) + if numbers: + return int(numbers[0]) + return None + except Exception: + return None + + def split_documents(self, documents: list[Document]) -> list[Document]: + """Split documents using semantic chunking.""" + return super().split_documents(documents) + + @classmethod + def from_llm( + cls, + llm_model_instance: Optional[ModelInstance] = None, + chunk_size: int = 4000, + chunk_overlap: int = 200, + **kwargs: Any + ) -> SemanticTextSplitter: + """ + Create a SemanticTextSplitter using the specified LLM. + + Args: + llm_model_instance: The LLM model instance to use + chunk_size: Maximum size of chunks to return + chunk_overlap: Overlap in characters between chunks + **kwargs: Additional arguments to pass to the TextSplitter + + Returns: + A configured SemanticTextSplitter + """ + # If no model instance is provided, try to get a default one + if not llm_model_instance: + try: + model_manager = ModelManager() + llm_model_instance = model_manager.get_default_model_instance(ModelType.LLM) + except Exception as e: + logger.warning(f"Failed to get default LLM model: {str(e)}. Using fallback splitter.") + return FixedRecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + **kwargs + ) + + # Create fallback splitter + fallback_splitter = FixedRecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + **kwargs + ) + + return cls( + llm_model_instance=llm_model_instance, + fallback_splitter=fallback_splitter, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + **kwargs + ) \ No newline at end of file diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index bb3be61f85..dfbc3fef80 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -72,6 +72,7 @@ class Rule(BaseModel): segmentation: Optional[Segmentation] = None parent_mode: Optional[Literal["full-doc", "paragraph"]] = None subchunk_segmentation: Optional[Segmentation] = None + chunking_strategy: Optional[Literal["fixed", "semantic"]] = "fixed" class ProcessRule(BaseModel): diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 58292c59f4..36e43b52db 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,7 +1,10 @@ import logging from typing import Optional +import uuid +from datetime import datetime -from core.model_manager import ModelInstance, ModelManager +from core.indexing_runner import IndexingRunner +from core.model_manager import ModelManager, ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector @@ -9,8 +12,9 @@ from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment from models.dataset import Document as DatasetDocument +from models.dataset import DatasetProcessRule from services.entities.knowledge_entities.knowledge_entities import ParentMode _logger = logging.getLogger(__name__) @@ -135,12 +139,28 @@ class VectorService: # use full doc mode to generate segment's child chunk processing_rule_dict = processing_rule.to_dict() processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value + + # Get chunking strategy + chunking_strategy = processing_rule_dict.get("rules", {}).get("chunking_strategy", "fixed") + + # If semantic chunking is used, get LLM model instance + llm_model_instance = None + if chunking_strategy == "semantic": + try: + model_manager = ModelManager() + llm_model_instance = model_manager.get_default_model_instance(ModelType.LLM) + except Exception as e: + # Fall back to fixed chunking if we can't get a LLM model + chunking_strategy = "fixed" + processing_rule_dict["rules"]["chunking_strategy"] = "fixed" + documents = index_processor.transform( [document], embedding_model_instance=embedding_model_instance, process_rule=processing_rule_dict, tenant_id=dataset.tenant_id, doc_language=dataset_document.doc_language, + llm_model_instance=llm_model_instance ) # save child chunks if documents and documents[0].children: diff --git a/boundary_analysis_test.py b/boundary_analysis_test.py new file mode 100644 index 0000000000..d9677134ca --- /dev/null +++ b/boundary_analysis_test.py @@ -0,0 +1,115 @@ +""" +Simple test for semantic chunking boundary analysis. +""" + +class MockLLMModelInstance: + """Mock LLM for testing purposes.""" + + def __init__(self, model="mock-llm", provider="mock-provider"): + self.model = model + self.provider = provider + + def invoke_llm(self, prompt: str, stream=False, temperature=0, max_tokens=50): + """Mock LLM response that returns a reasonable split index.""" + print(f"LLM received prompt: {prompt[:50]}...") + + # Extract the text to analyze + text_to_analyze = prompt.split("Text to analyze: ")[1].split("\n\n")[0] + + # Find the exact position of a period that marks the end of a sentence + period_positions = [i for i, char in enumerate(text_to_analyze) if char == '.'] + + if period_positions: + # Find a period that's roughly in the middle of the text + middle_index = len(text_to_analyze) // 2 + closest_period = min(period_positions, key=lambda x: abs(x - middle_index)) + return MockResponse(str(closest_period + 1)) # +1 to include the period + + # If no period found, return the middle + return MockResponse(str(len(text_to_analyze) // 2)) + + +class MockResponse: + """Mock response from LLM.""" + + def __init__(self, content: str): + self.content = content + + +def create_boundary_analysis_prompt(text: str) -> str: + """Create a prompt for boundary analysis.""" + return ( + "Analyze the following text and identify the most natural point to split it into two separate chunks. " + "The split should occur at a meaningful boundary such as the end of a paragraph, sentence, or idea. " + "Return only the character index (a number) where the split should occur.\n\n" + f"Text to analyze: {text}\n\n" + "Split index:" + ) + + +def parse_llm_response(response: str): + """Parse the LLM response to get the split index.""" + try: + # Extract the first number from the response + import re + numbers = re.findall(r'\d+', response) + if numbers: + return int(numbers[0]) + return None + except Exception as e: + print(f"Error parsing response: {e}") + return None + + +def test_boundary_analysis(): + """Test the boundary analysis logic.""" + # Create test text with a clear semantic boundary + test_text = "This is the first sentence that should stay together. This is the start of a new thought that should also stay together." + + # Create mock LLM + mock_llm = MockLLMModelInstance() + + # Create boundary analysis prompt + prompt = create_boundary_analysis_prompt(test_text) + + # Get mock LLM response + response = mock_llm.invoke_llm(prompt) + + # Parse the response + split_index = parse_llm_response(response.content) + + print(f"Recommended split index: {split_index}") + print(f"Text before split: {test_text[:split_index]}") + print(f"Text after split: {test_text[split_index:]}") + + # Check if split is at sentence boundary + if split_index > 0: + assert test_text[split_index-1] == ".", "Split should be at the end of a sentence" + print("Test passed! Split occurred at a sentence boundary.") + else: + print("Test failed: Split index is not valid.") + + # Test with paragraph boundaries + print("\nTesting with paragraph boundaries:") + para_text = """This is paragraph one. + It has multiple sentences. + + This is paragraph two. + It should be kept separate.""" + + # Create boundary analysis prompt + para_prompt = create_boundary_analysis_prompt(para_text) + + # Get mock LLM response + para_response = mock_llm.invoke_llm(para_prompt) + + # Parse the response + para_split_index = parse_llm_response(para_response.content) + + print(f"Recommended paragraph split index: {para_split_index}") + print(f"Text before split: {para_text[:para_split_index]}") + print(f"Text after split: {para_text[para_split_index:]}") + + +if __name__ == "__main__": + test_boundary_analysis() \ No newline at end of file diff --git a/run_test.py b/run_test.py new file mode 100644 index 0000000000..8a05ba4324 --- /dev/null +++ b/run_test.py @@ -0,0 +1,112 @@ +import os +import sys + +# Add the current directory to the Python path +sys.path.insert(0, os.path.abspath('.')) + +from flask import Flask +from api.core.model_manager import ModelManager +from api.core.model_runtime.entities.model_entities import ModelType +from api.core.rag.splitter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter +from api.core.rag.splitter.semantic_text_splitter import SemanticTextSplitter +from api.core.rag.models.document import Document + +# Sample text with distinct semantic sections +sample_text = """ +# Introduction to Machine Learning + +Machine learning is a field of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed. It focuses on developing algorithms that can access data and use it to learn for themselves. + +## Supervised Learning + +Supervised learning is a type of machine learning where the algorithm learns from labeled training data. The model makes predictions based on evidence in the presence of uncertainty. + +Common supervised learning tasks include: +- Classification: categorizing data into predefined classes +- Regression: predicting continuous values +- Forecasting: predicting future values based on historical data + +## Unsupervised Learning + +Unsupervised learning algorithms find patterns in unlabeled data. These models discover hidden structures in data without the need for human intervention. + +Typical unsupervised learning techniques include: +- Clustering: grouping similar data points together +- Dimensionality reduction: reducing the number of variables in data +- Association: identifying rules that describe portions of your data +""" + +def test_semantic_chunking(): + print("Testing Semantic Chunking Implementation") + print("-" * 50) + + app = Flask(__name__) + with app.app_context(): + # Get model instances + model_manager = ModelManager() + try: + llm_model_instance = model_manager.get_default_model_instance(ModelType.LLM) + print(f"Using LLM: {llm_model_instance.model} from {llm_model_instance.provider}") + except Exception as e: + print(f"Could not get LLM model instance: {e}") + print("Continuing with fixed chunking only for comparison") + llm_model_instance = None + + try: + embedding_model_instance = model_manager.get_default_model_instance(ModelType.TEXT_EMBEDDING) + print(f"Using embedding model: {embedding_model_instance.model} from {embedding_model_instance.provider}") + except Exception as e: + print(f"Could not get embedding model instance: {e}") + embedding_model_instance = None + + # Create document from sample text + doc = Document(page_content=sample_text) + + # Fixed chunking + fixed_splitter = FixedRecursiveCharacterTextSplitter( + chunk_size=200, + chunk_overlap=20, + ) + fixed_chunks = fixed_splitter.split_text(sample_text) + + print("\nFixed Chunking Results:") + print(f"Number of chunks: {len(fixed_chunks)}") + for i, chunk in enumerate(fixed_chunks): + print(f"\nChunk {i+1} ({len(chunk)} chars):") + print(chunk[:100] + "..." if len(chunk) > 100 else chunk) + + # Semantic chunking (if LLM is available) + if llm_model_instance: + semantic_splitter = SemanticTextSplitter( + llm_model_instance=llm_model_instance, + chunk_size=200, + chunk_overlap=20, + ) + semantic_chunks = semantic_splitter.split_text(sample_text) + + print("\nSemantic Chunking Results:") + print(f"Number of chunks: {len(semantic_chunks)}") + for i, chunk in enumerate(semantic_chunks): + print(f"\nChunk {i+1} ({len(chunk)} chars):") + print(chunk[:100] + "..." if len(chunk) > 100 else chunk) + + # Analyze the quality of chunk boundaries + print("\nAnalyzing chunk boundaries:") + natural_boundaries = [".", "!", "?"] + semantic_complete_sentences = sum(1 for chunk in semantic_chunks if chunk.strip()[-1] in natural_boundaries) + fixed_complete_sentences = sum(1 for chunk in fixed_chunks if chunk.strip()[-1] in natural_boundaries) + + print(f"Fixed chunking: {fixed_complete_sentences}/{len(fixed_chunks)} chunks end with complete sentences") + print(f"Semantic chunking: {semantic_complete_sentences}/{len(semantic_chunks)} chunks end with complete sentences") + + # Check if there are section headings split across chunks + fixed_headings_split = sum(1 for chunk in fixed_chunks if chunk.strip().startswith('#') and not chunk.strip().split('\n')[0].endswith('\n')) + semantic_headings_split = sum(1 for chunk in semantic_chunks if chunk.strip().startswith('#') and not chunk.strip().split('\n')[0].endswith('\n')) + + print(f"Fixed chunking: {fixed_headings_split} chunks have headings split from their content") + print(f"Semantic chunking: {semantic_headings_split} chunks have headings split from their content") + else: + print("\nSkipping semantic chunking test as no LLM is available") + +if __name__ == "__main__": + test_semantic_chunking() \ No newline at end of file diff --git a/semantic_chunking_benefits.py b/semantic_chunking_benefits.py new file mode 100644 index 0000000000..2f4e1cde14 --- /dev/null +++ b/semantic_chunking_benefits.py @@ -0,0 +1,385 @@ +""" +Demonstration of semantic chunking benefits compared to fixed chunking +""" + +import re +from typing import List, Optional + +class MockLLMModelInstance: + """Mock LLM for testing purposes.""" + + def __init__(self, model="mock-llm", provider="mock-provider"): + self.model = model + self.provider = provider + + def invoke_llm(self, prompt: str, stream=False, temperature=0, max_tokens=50): + """Mock LLM that finds natural break points.""" + text_to_analyze = prompt.split("Text to analyze: ")[1].split("\n\n")[0] + + # Look for natural break points in preferred order + break_points = [] + + # 1. Look for paragraph breaks (\n\n) + newline_indices = [m.start() for m in re.finditer(r'\n\s*\n', text_to_analyze)] + if newline_indices: + closest_to_middle = min(newline_indices, key=lambda x: abs(x - len(text_to_analyze) // 2)) + break_points.append(closest_to_middle + 1) # +1 to include the newline + + # 2. Look for list item boundaries + list_item_breaks = [m.start() for m in re.finditer(r'[.!?]\s*\n[\s]*[-\d]', text_to_analyze)] + if list_item_breaks: + closest_to_middle = min(list_item_breaks, key=lambda x: abs(x - len(text_to_analyze) // 2)) + break_points.append(closest_to_middle + 1) # +1 to include the period + + # 3. Look for sentence endings + sentence_ends = [m.start() for m in re.finditer(r'[.!?]\s', text_to_analyze)] + if sentence_ends: + closest_to_middle = min(sentence_ends, key=lambda x: abs(x - len(text_to_analyze) // 2)) + break_points.append(closest_to_middle + 1) # +1 to include the punctuation + + # Choose the best break point + if break_points: + return MockResponse(str(min(break_points))) + + # Fallback to middle of text + return MockResponse(str(len(text_to_analyze) // 2)) + + +class MockResponse: + """Mock response from LLM.""" + + def __init__(self, content: str): + self.content = content + + +class SimpleSplitter: + """A simple text splitter that mimics fixed chunking behavior.""" + + def __init__(self, chunk_size: int, chunk_overlap: int): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + def split_text(self, text: str) -> List[str]: + """Split text into chunks of fixed size with overlap.""" + if not text: + return [] + + chunks = [] + start = 0 + + while start < len(text): + end = min(start + self.chunk_size, len(text)) + + # Try to find a basic break point near the end + if end < len(text): + for break_char in ['\n\n', '.', '\n', ' ']: + last_break = text.rfind(break_char, start, end) + if last_break != -1 and last_break > start: + end = last_break + 1 + break + + chunks.append(text[start:end]) + + # Calculate next start position with overlap + start = end - self.chunk_overlap + + # Ensure we're making progress + if start >= end: + start = end + + return chunks + + +class SemanticTextSplitter: + """Semantic text splitter that uses LLM for boundary detection.""" + + def __init__( + self, + llm_model_instance, + fallback_splitter=None, + chunk_size: int = 4000, + chunk_overlap: int = 200, + ): + self._llm_model_instance = llm_model_instance + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._fallback_splitter = fallback_splitter or SimpleSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + def split_text(self, text: str) -> List[str]: + """Split text using LLM-based semantic chunking.""" + try: + # First apply initial chunking with fallback splitter + initial_chunks = self._fallback_splitter.split_text(text) + + if len(initial_chunks) <= 1: + return initial_chunks + + # Use LLM to improve chunk boundaries + return self._refine_chunks_with_llm(initial_chunks) + + except Exception as e: + print(f"LLM-based semantic chunking failed: {str(e)}. Falling back.") + return self._fallback_splitter.split_text(text) + + def _refine_chunks_with_llm(self, chunks: List[str]) -> List[str]: + """Use LLM to refine chunk boundaries.""" + refined_chunks = [] + + for i in range(len(chunks) - 1): + current_chunk = chunks[i] + next_chunk = chunks[i+1] + + # Skip very small chunks by merging with next + if len(current_chunk) < self._chunk_size / 10: + if i + 1 < len(chunks): + chunks[i+1] = current_chunk + " " + next_chunk + continue + + # If current chunk is already optimal, keep as is + if len(current_chunk) < self._chunk_size * 0.9: + refined_chunks.append(current_chunk) + continue + + # Create boundary analysis task + boundary_text = current_chunk[-self._chunk_overlap:] + next_chunk[:self._chunk_overlap] + prompt = self._create_boundary_analysis_prompt(boundary_text) + + try: + response = self._llm_model_instance.invoke_llm( + prompt=prompt, + stream=False, + temperature=0, + max_tokens=50 + ) + + break_indicator = self._parse_llm_response(response.content) + + if break_indicator and 0 < break_indicator < len(boundary_text): + # Calculate actual split position + split_position = len(current_chunk) - self._chunk_overlap + break_indicator + + if 0 < split_position < len(current_chunk): + refined_chunk = current_chunk[:split_position] + remainder = current_chunk[split_position:] + + refined_chunks.append(refined_chunk) + chunks[i+1] = remainder + next_chunk + continue + + # If no good break point found, use original chunk + refined_chunks.append(current_chunk) + + except Exception as e: + print(f"Error during LLM boundary analysis: {str(e)}") + refined_chunks.append(current_chunk) + + # Add the last chunk + if chunks and len(chunks) > 0: + refined_chunks.append(chunks[-1]) + + return refined_chunks + + def _create_boundary_analysis_prompt(self, text: str) -> str: + """Create prompt for LLM boundary analysis.""" + return ( + "Analyze the following text and identify the most natural point to split it into two separate chunks. " + "The split should occur at a meaningful boundary such as the end of a paragraph, sentence, or idea. " + "Return only the character index (a number) where the split should occur.\n\n" + f"Text to analyze: {text}\n\n" + "Split index:" + ) + + def _parse_llm_response(self, response: str) -> Optional[int]: + """Parse LLM response to extract split index.""" + try: + numbers = re.findall(r'\d+', response) + if numbers: + return int(numbers[0]) + return None + except Exception: + return None + + +def evaluate_splitting_quality(chunks: List[str], name: str): + """Evaluate the quality of text chunks.""" + print(f"\n{name} Chunking Quality Analysis:") + print("-" * 40) + + # Count chunks that end with complete sentences + sentence_endings = ['.', '!', '?'] + complete_sentences = sum(1 for chunk in chunks if chunk.strip() and chunk.strip()[-1] in sentence_endings) + print(f"Complete sentences: {complete_sentences}/{len(chunks)} chunks ({complete_sentences/len(chunks)*100:.1f}%)") + + # Count chunks that break in the middle of a list + list_breaks = 0 + numbered_list_pattern = re.compile(r'\d+\.') + bullet_list_pattern = re.compile(r'[-*]') + + for i, chunk in enumerate(chunks[:-1]): + next_chunk = chunks[i+1] + if numbered_list_pattern.search(chunk) and numbered_list_pattern.search(next_chunk): + # Check if numbering is consecutive + chunk_numbers = [int(n.group()[:-1]) for n in numbered_list_pattern.finditer(chunk)] + next_numbers = [int(n.group()[:-1]) for n in numbered_list_pattern.finditer(next_chunk)] + if chunk_numbers and next_numbers and next_numbers[0] != chunk_numbers[-1] + 1: + list_breaks += 1 + + # Check for bullet list breaks + if bullet_list_pattern.search(chunk) and bullet_list_pattern.search(next_chunk): + list_breaks += 1 + + print(f"List breaks: {list_breaks}/{len(chunks)-1} transitions ({list_breaks/(len(chunks)-1)*100:.1f}%)") + + # Count heading breaks (heading not followed by its content) + heading_breaks = 0 + for chunk in chunks: + if re.search(r'#{1,6}\s+.+\s*$', chunk): # Markdown heading at the end of chunk + heading_breaks += 1 + + print(f"Heading breaks: {heading_breaks}/{len(chunks)} chunks ({heading_breaks/len(chunks)*100:.1f}%)") + + # Calculate average semantic coherence (simplified) + # In a real implementation, this would use embeddings or more sophisticated analysis + semantic_breaks = 0 + for i, chunk in enumerate(chunks[:-1]): + last_line = chunk.strip().split('\n')[-1] + next_first_line = chunks[i+1].strip().split('\n')[0] + + # Check if the break happens in the middle of a coherent section + if (last_line and next_first_line and + not any(last_line.endswith(end) for end in sentence_endings) and + not next_first_line.startswith('#') and + not re.match(r'^\d+\.', next_first_line) and + not re.match(r'^[-*]', next_first_line)): + semantic_breaks += 1 + + print(f"Semantic breaks: {semantic_breaks}/{len(chunks)-1} transitions ({semantic_breaks/(len(chunks)-1)*100:.1f}%)") + + # Overall quality score (lower is better) + quality_score = ( + (len(chunks) - complete_sentences) + + list_breaks * 2 + + heading_breaks * 3 + + semantic_breaks * 2 + ) / len(chunks) + + print(f"Overall quality score: {quality_score:.2f} (lower is better)") + + return quality_score + + +# Test document with various semantic structures +test_document = """ +# Introduction to Machine Learning + +Machine learning is a branch of artificial intelligence that focuses on developing systems that learn from data. +Unlike traditional programming where explicit instructions are provided, machine learning algorithms improve through experience. + +## Supervised Learning + +Supervised learning is the task of learning a function that maps an input to an output based on example input-output pairs. +It infers a function from labeled training data. The training data consists of a set of examples where each example is a pair of an input object and a desired output value. + +Common supervised learning tasks include: +1. Classification: predicting a category or class +2. Regression: predicting a continuous value +3. Sequence labeling: predicting a sequence of categories + +## Unsupervised Learning + +Unsupervised learning is a type of machine learning algorithm used to draw inferences from datasets consisting of input data without labeled responses. +The most common unsupervised learning tasks are: + +- Clustering: grouping similar instances together +- Dimensionality reduction: reducing the number of variables under consideration +- Association rule learning: discovering interesting relations between variables + +# Deep Learning + +Deep learning is a subset of machine learning that uses multi-layered neural networks to analyze various factors of data. +Deep learning models can learn to focus on the right features by themselves, requiring less feature engineering. + +Some popular deep learning architectures include: +1. Convolutional Neural Networks (CNNs) for image processing +2. Recurrent Neural Networks (RNNs) for sequence data +3. Transformers for natural language processing tasks + +## Training Deep Neural Networks + +Training deep networks requires large amounts of data and computational resources. The process typically involves: + +1. Initializing the network with random weights +2. Forward propagation of training data +3. Computing the loss using a loss function +4. Backpropagation to calculate gradients +5. Updating weights using an optimization algorithm + +# Applications of Machine Learning + +Machine learning has revolutionized various fields: + +- Healthcare: disease diagnosis, treatment planning, drug discovery +- Finance: fraud detection, algorithmic trading, risk assessment +- Transportation: autonomous vehicles, traffic prediction +- Entertainment: recommendation systems, content generation + +As computing power increases and algorithms improve, the applications of machine learning continue to expand. +""" + +# Run the comparison test +if __name__ == "__main__": + print("Testing Semantic vs. Fixed Chunking Benefits") + print("=" * 50) + + # Create mock LLM + mock_llm = MockLLMModelInstance() + + # Create fixed-size chunker (small chunks to highlight differences) + fixed_chunker = SimpleSplitter(chunk_size=300, chunk_overlap=50) + + # Create semantic chunker + semantic_chunker = SemanticTextSplitter( + llm_model_instance=mock_llm, + chunk_size=300, + chunk_overlap=50 + ) + + # Split using both methods + fixed_chunks = fixed_chunker.split_text(test_document) + semantic_chunks = semantic_chunker.split_text(test_document) + + # Print basic stats + print(f"Document length: {len(test_document)} characters") + print(f"Fixed chunking: {len(fixed_chunks)} chunks") + print(f"Semantic chunking: {len(semantic_chunks)} chunks") + + # Print samples of both results + print("\nSample of Fixed Chunks:") + for i in range(min(3, len(fixed_chunks))): + print(f"Chunk {i+1} ({len(fixed_chunks[i])} chars): {fixed_chunks[i][:100]}...") + + print("\nSample of Semantic Chunks:") + for i in range(min(3, len(semantic_chunks))): + print(f"Chunk {i+1} ({len(semantic_chunks[i])} chars): {semantic_chunks[i][:100]}...") + + # Evaluate quality + fixed_score = evaluate_splitting_quality(fixed_chunks, "Fixed") + semantic_score = evaluate_splitting_quality(semantic_chunks, "Semantic") + + # Compare results + print("\nComparison Results:") + print("-" * 50) + print(f"Fixed chunking quality score: {fixed_score:.2f}") + print(f"Semantic chunking quality score: {semantic_score:.2f}") + print(f"Improvement: {((fixed_score - semantic_score) / fixed_score * 100):.1f}%") + + # Summary of benefits + print("\nKey Benefits of Semantic Chunking:") + print("1. More chunks end with complete sentences") + print("2. Fewer breaks in lists and bullet points") + print("3. Headings stay with their content") + print("4. Better preservation of semantic context") + print("5. Improved retrieval quality due to more coherent chunks") \ No newline at end of file diff --git a/semantic_chunking_test.py b/semantic_chunking_test.py new file mode 100644 index 0000000000..596a330e3b --- /dev/null +++ b/semantic_chunking_test.py @@ -0,0 +1,302 @@ +""" +Simplified test for semantic chunking without Flask app context. +This test directly uses the SemanticTextSplitter implementation. +""" + +import sys +import os +from typing import Optional, List + +class MockLLMModelInstance: + """Mock LLM for testing purposes.""" + + def __init__(self, model="mock-llm", provider="mock-provider"): + self.model = model + self.provider = provider + + def invoke_llm(self, prompt: str, stream: bool = False, temperature: float = 0, max_tokens: int = 50): + """Mock LLM response that returns a reasonable split index.""" + # Simulate finding a natural break point by looking for sentence endings + # This is a simplified version of what the real LLM would do + text_to_analyze = prompt.split("Text to analyze: ")[1].split("\n\n")[0] + + # Look for natural break points (periods, question marks, exclamation points) + for i, char in enumerate(text_to_analyze): + if i > len(text_to_analyze) // 2 and char in ['.', '!', '?', '\n']: + return MockResponse(str(i)) + + # If no natural break point found, return the middle + return MockResponse(str(len(text_to_analyze) // 2)) + + +class MockResponse: + """Mock response from LLM.""" + + def __init__(self, content: str): + self.content = content + + +class SemanticTextSplitter: + """ + A simplified version of the actual SemanticTextSplitter for testing purposes. + This implementation mimics the behavior of the real implementation. + """ + + def __init__( + self, + llm_model_instance, + fallback_splitter=None, + chunk_size: int = 4000, + chunk_overlap: int = 200, + ): + self._llm_model_instance = llm_model_instance + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._fallback_splitter = fallback_splitter or SimpleSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + def split_text(self, text: str) -> List[str]: + """ + Split text using LLM-based semantic chunking. + """ + try: + # First apply initial chunking with our fallback splitter + initial_chunks = self._fallback_splitter.split_text(text) + + # If the text is short enough, return as a single chunk + if len(initial_chunks) <= 1: + return initial_chunks + + # Use LLM to improve chunk boundaries for semantic coherence + return self._refine_chunks_with_llm(initial_chunks) + + except Exception as e: + print(f"LLM-based semantic chunking failed: {str(e)}. Falling back to traditional chunking.") + # Fallback to standard chunking if LLM-based approach fails + return self._fallback_splitter.split_text(text) + + def _refine_chunks_with_llm(self, chunks: List[str]) -> List[str]: + """ + Use the LLM to refine chunk boundaries based on semantic meaning. + """ + refined_chunks = [] + + # Process chunks in pairs to determine better break points + for i in range(len(chunks) - 1): + current_chunk = chunks[i] + next_chunk = chunks[i+1] + + # Skip very small chunks by merging them with the next one + if len(current_chunk) < self._chunk_size / 10: + if i + 1 < len(chunks): + chunks[i+1] = current_chunk + " " + next_chunk + continue + + # If the current chunk is already optimal, keep it as is + if len(current_chunk) < self._chunk_size * 0.9: + refined_chunks.append(current_chunk) + continue + + # Create a boundary analysis task for the LLM to find a natural break point + boundary_text = current_chunk[-self._chunk_overlap:] + next_chunk[:self._chunk_overlap] + + prompt = self._create_boundary_analysis_prompt(boundary_text) + + # Use the LLM to find a natural break point + try: + response = self._llm_model_instance.invoke_llm( + prompt=prompt, + stream=False, + temperature=0, + max_tokens=50 + ) + + break_indicator = self._parse_llm_response(response.content) + + # Apply the break point to create a more natural chunk division + if break_indicator and 0 < break_indicator < len(boundary_text): + # Calculate the actual split position in the current chunk + split_position = len(current_chunk) - self._chunk_overlap + break_indicator + + # Only apply the split if it makes sense + if 0 < split_position < len(current_chunk): + refined_chunk = current_chunk[:split_position] + remainder = current_chunk[split_position:] + + # Add the refined chunk and combine the remainder with the next chunk + refined_chunks.append(refined_chunk) + chunks[i+1] = remainder + next_chunk + continue + + # If no good break point was found, use the original chunk + refined_chunks.append(current_chunk) + + except Exception as e: + print(f"Error during LLM boundary analysis: {str(e)}") + refined_chunks.append(current_chunk) + + # Add the last chunk + if chunks and len(chunks) > 0: + refined_chunks.append(chunks[-1]) + + return refined_chunks + + def _create_boundary_analysis_prompt(self, text: str) -> str: + """ + Create a prompt for the LLM to analyze the text and find natural break points. + """ + return ( + "Analyze the following text and identify the most natural point to split it into two separate chunks. " + "The split should occur at a meaningful boundary such as the end of a paragraph, sentence, or idea. " + "Return only the character index (a number) where the split should occur.\n\n" + f"Text to analyze: {text}\n\n" + "Split index:" + ) + + def _parse_llm_response(self, response: str) -> Optional[int]: + """ + Parse the LLM's response to extract the split index. + """ + try: + # Extract the first number from the response + import re + numbers = re.findall(r'\d+', response) + if numbers: + return int(numbers[0]) + return None + except Exception: + return None + + +class SimpleSplitter: + """A simple text splitter that mimics the behavior of FixedRecursiveCharacterTextSplitter.""" + + def __init__(self, chunk_size: int, chunk_overlap: int): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + def split_text(self, text: str) -> List[str]: + """Split text into chunks of specified size with overlap.""" + if not text: + return [] + + # Simple chunking by size without considering natural breaks + chunks = [] + start = 0 + + while start < len(text): + end = min(start + self.chunk_size, len(text)) + + # Try to find a natural break point near the end + if end < len(text): + # Look for paragraph breaks or sentence endings + for break_char in ['\n\n', '.', '!', '?', '\n', ' ']: + last_break = text.rfind(break_char, start, end) + if last_break != -1 and last_break > start: + end = last_break + 1 + break + + # Add the chunk + chunks.append(text[start:end]) + + # Calculate next start position with overlap + start = end - self.chunk_overlap + + # Ensure we're making progress + if start >= end: + start = end + + return chunks + + +# Sample text with distinct semantic sections +sample_text = """ +# Introduction to Artificial Intelligence + +Artificial Intelligence (AI) is a field of computer science focused on creating systems capable of performing tasks that typically require human intelligence. These tasks include learning, reasoning, problem-solving, perception, and language understanding. + +# Machine Learning Fundamentals + +Machine learning is a subset of AI that focuses on building systems that can learn from and make decisions based on data. Instead of being explicitly programmed, these systems identify patterns in data and make predictions. + +There are three main types of machine learning: +1. Supervised learning, where models learn from labeled data +2. Unsupervised learning, where models identify patterns in unlabeled data +3. Reinforcement learning, where models learn optimal actions through trial and error + +# Natural Language Processing + +Natural Language Processing (NLP) combines linguistics and AI to enable computers to understand, interpret, and generate human language. NLP powers voice assistants, translation services, and text analysis tools. + +Key NLP tasks include: +- Sentiment analysis +- Named entity recognition +- Text summarization +- Question answering +- Machine translation +""" + +if __name__ == "__main__": + print("Testing Simplified Semantic Chunking Implementation") + print("-" * 50) + + # Create a mock LLM model instance + mock_llm = MockLLMModelInstance() + + # Fixed chunking + simple_splitter = SimpleSplitter(chunk_size=200, chunk_overlap=20) + fixed_chunks = simple_splitter.split_text(sample_text) + + print("\nFixed Chunking Results:") + print(f"Number of chunks: {len(fixed_chunks)}") + for i, chunk in enumerate(fixed_chunks): + print(f"\nChunk {i+1} ({len(chunk)} chars):") + print(chunk[:100] + "..." if len(chunk) > 100 else chunk) + + # Semantic chunking with mock LLM + semantic_splitter = SemanticTextSplitter( + llm_model_instance=mock_llm, + fallback_splitter=simple_splitter, + chunk_size=200, + chunk_overlap=20, + ) + semantic_chunks = semantic_splitter.split_text(sample_text) + + print("\nSemantic Chunking Results:") + print(f"Number of chunks: {len(semantic_chunks)}") + for i, chunk in enumerate(semantic_chunks): + print(f"\nChunk {i+1} ({len(chunk)} chars):") + print(chunk[:100] + "..." if len(chunk) > 100 else chunk) + + # Analyze the quality of chunk boundaries + print("\nAnalyzing chunk boundaries:") + natural_boundaries = [".", "!", "?"] + semantic_complete_sentences = sum(1 for chunk in semantic_chunks if chunk.strip()[-1] in natural_boundaries) + fixed_complete_sentences = sum(1 for chunk in fixed_chunks if chunk.strip()[-1] in natural_boundaries) + + print(f"Fixed chunking: {fixed_complete_sentences}/{len(fixed_chunks)} chunks end with complete sentences") + print(f"Semantic chunking: {semantic_complete_sentences}/{len(semantic_chunks)} chunks end with complete sentences") + + # Check if there are section headings split across chunks + fixed_headings_split = sum(1 for chunk in fixed_chunks if chunk.strip().startswith('#') and not chunk.strip().endswith('\n')) + semantic_headings_split = sum(1 for chunk in semantic_chunks if chunk.strip().startswith('#') and not chunk.strip().endswith('\n')) + + print(f"Fixed chunking: {fixed_headings_split} chunks have headings split from their content") + print(f"Semantic chunking: {semantic_headings_split} chunks have headings split from their content") + + # Check if semantically related content stays together + fixed_list_breaks = 0 + semantic_list_breaks = 0 + + for chunk in fixed_chunks: + if ("-" in chunk or any(f"{i}." in chunk for i in range(1, 10))) and chunk.strip()[-1] != '\n': + fixed_list_breaks += 1 + + for chunk in semantic_chunks: + if ("-" in chunk or any(f"{i}." in chunk for i in range(1, 10))) and chunk.strip()[-1] != '\n': + semantic_list_breaks += 1 + + print(f"Fixed chunking: {fixed_list_breaks} chunks break lists across chunks") + print(f"Semantic chunking: {semantic_list_breaks} chunks break lists across chunks") \ No newline at end of file diff --git a/simple_comparison.py b/simple_comparison.py new file mode 100644 index 0000000000..7c9a4da0d8 --- /dev/null +++ b/simple_comparison.py @@ -0,0 +1,108 @@ +""" +Simple comparison of fixed vs semantic chunking on a small text snippet +""" + +# Sample text with clear semantic boundaries +sample_text = """ +# Introduction + +This is the introduction paragraph that discusses the topic in general terms. +It should be kept together as a semantic unit. The introduction serves to set up the topic. + +# First Main Point + +This section covers the first main point with some supporting details. +- Point A with some explanation +- Point B with more details +- Point C concluding the list + +# Second Main Point + +The second main point builds on the first and adds new information. +It contains important context that should be kept together. +""" + +# Fixed chunking (simplified) +def fixed_chunking(text, chunk_size=150): + """Split text into fixed-size chunks.""" + chunks = [] + start = 0 + + while start < len(text): + end = min(start + chunk_size, len(text)) + + # Try to find a natural break point near the end + if end < len(text): + # Look for paragraph breaks or sentence endings + for break_char in ["\n\n", ".", "\n", " "]: + last_break = text.rfind(break_char, start, end) + if last_break != -1 and last_break > start: + end = last_break + 1 + break + + chunks.append(text[start:end]) + start = end + + return chunks + +# Semantic chunking (simplified mock implementation) +def semantic_chunking(text): + """Split text into semantic chunks by heading.""" + # Simplified semantic chunking that looks for markdown headings + import re + + # Split by markdown headings + heading_pattern = re.compile(r'^#.*$', re.MULTILINE) + chunks = [] + + # Find all heading positions + heading_matches = list(heading_pattern.finditer(text)) + + if not heading_matches: + return [text] + + # Process each section + for i, match in enumerate(heading_matches): + start = match.start() + + # If it's not the last heading, go until the next heading + if i + 1 < len(heading_matches): + end = heading_matches[i + 1].start() + else: + end = len(text) + + chunks.append(text[start:end]) + + return chunks + +# Compare the two methods +if __name__ == "__main__": + print("Fixed Chunking vs. Semantic Chunking Comparison") + print("=" * 50) + + # Get chunks + fixed_chunks = fixed_chunking(sample_text) + semantic_chunks = semantic_chunking(sample_text) + + # Display fixed chunks + print("\nFixed Chunking Results:") + print(f"Number of chunks: {len(fixed_chunks)}") + for i, chunk in enumerate(fixed_chunks): + print(f"\nChunk {i+1}:") + print("-" * 30) + print(chunk) + + # Display semantic chunks + print("\nSemantic Chunking Results:") + print(f"Number of chunks: {len(semantic_chunks)}") + for i, chunk in enumerate(semantic_chunks): + print(f"\nChunk {i+1}:") + print("-" * 30) + print(chunk) + + # Highlight key differences + print("\nKey Observations:") + print("1. Fixed chunking breaks text at arbitrary positions based mainly on length") + print("2. Semantic chunking preserves logical sections (by headings in this simplified example)") + print("3. In a real LLM-based implementation, semantic boundaries like sentence") + print(" endings, paragraph breaks, and logical transitions would be detected") \ No newline at end of file diff --git a/test_semantic_chunking.py b/test_semantic_chunking.py new file mode 100644 index 0000000000..62bab98952 --- /dev/null +++ b/test_semantic_chunking.py @@ -0,0 +1,89 @@ +from api.core.model_manager import ModelManager +from api.core.model_runtime.entities.model_entities import ModelType +from api.core.rag.splitter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter +from api.core.rag.splitter.semantic_text_splitter import SemanticTextSplitter +from api.core.rag.models.document import Document + +# Sample text with distinct semantic sections +sample_text = """ +# Introduction to Artificial Intelligence + +Artificial Intelligence (AI) is a field of computer science focused on creating systems capable of performing tasks that typically require human intelligence. These tasks include learning, reasoning, problem-solving, perception, and language understanding. + +# Machine Learning Fundamentals + +Machine learning is a subset of AI that focuses on building systems that can learn from and make decisions based on data. Instead of being explicitly programmed, these systems identify patterns in data and make predictions. + +There are three main types of machine learning: +1. Supervised learning, where models learn from labeled data +2. Unsupervised learning, where models identify patterns in unlabeled data +3. Reinforcement learning, where models learn optimal actions through trial and error + +# Natural Language Processing + +Natural Language Processing (NLP) combines linguistics and AI to enable computers to understand, interpret, and generate human language. NLP powers voice assistants, translation services, and text analysis tools. + +Key NLP tasks include: +- Sentiment analysis +- Named entity recognition +- Text summarization +- Question answering +- Machine translation +""" + +def test_chunking(): + print("Testing semantic chunking vs. fixed chunking") + print("-" * 50) + + # Get model instances + model_manager = ModelManager() + try: + llm_model_instance = model_manager.get_default_model_instance(ModelType.LLM) + print(f"Using LLM: {llm_model_instance.model} from {llm_model_instance.provider}") + except Exception as e: + print(f"Could not get LLM model instance: {e}") + print("Continuing with fixed chunking only for comparison") + llm_model_instance = None + + try: + embedding_model_instance = model_manager.get_default_model_instance(ModelType.TEXT_EMBEDDING) + print(f"Using embedding model: {embedding_model_instance.model} from {embedding_model_instance.provider}") + except Exception as e: + print(f"Could not get embedding model instance: {e}") + embedding_model_instance = None + + # Create document from sample text + doc = Document(page_content=sample_text) + + # Fixed chunking + fixed_splitter = FixedRecursiveCharacterTextSplitter( + chunk_size=200, + chunk_overlap=20, + ) + fixed_chunks = fixed_splitter.split_text(sample_text) + + print("\nFixed Chunking Results:") + print(f"Number of chunks: {len(fixed_chunks)}") + for i, chunk in enumerate(fixed_chunks): + print(f"\nChunk {i+1} ({len(chunk)} chars):") + print(chunk[:100] + "..." if len(chunk) > 100 else chunk) + + # Semantic chunking (if LLM is available) + if llm_model_instance: + semantic_splitter = SemanticTextSplitter( + llm_model_instance=llm_model_instance, + chunk_size=200, + chunk_overlap=20, + ) + semantic_chunks = semantic_splitter.split_text(sample_text) + + print("\nSemantic Chunking Results:") + print(f"Number of chunks: {len(semantic_chunks)}") + for i, chunk in enumerate(semantic_chunks): + print(f"\nChunk {i+1} ({len(chunk)} chars):") + print(chunk[:100] + "..." if len(chunk) > 100 else chunk) + else: + print("\nSkipping semantic chunking test as no LLM is available") + +if __name__ == "__main__": + test_chunking() \ No newline at end of file diff --git a/test_semantic_rag_pipeline.py b/test_semantic_rag_pipeline.py new file mode 100644 index 0000000000..9f6365ade3 --- /dev/null +++ b/test_semantic_rag_pipeline.py @@ -0,0 +1,138 @@ +from api.core.model_manager import ModelManager +from api.core.model_runtime.entities.model_entities import ModelType +from api.core.rag.splitter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter +from api.core.rag.splitter.semantic_text_splitter import SemanticTextSplitter +from api.core.rag.models.document import Document +from api.core.rag.index_processor.index_processor_base import BaseIndexProcessor +from api.core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor + +# Sample text with distinct semantic sections +sample_text = """ +# Renewable Energy Technologies + +Renewable energy is derived from natural sources that are replenished at a higher rate than they are consumed. Sunlight, wind, rain, tides, waves, and geothermal heat are all sustainable sources that can be harnessed to generate power. + +## Solar Energy + +Solar power is captured in two main ways: photovoltaic (PV) panels and solar thermal collectors. PV panels convert sunlight directly into electricity, while solar thermal collectors use the sun's energy to heat water or air for domestic or industrial use. + +The efficiency of solar panels has dramatically improved in recent years, with modern panels reaching over 20% efficiency. This improvement, coupled with decreasing costs, has made solar energy increasingly competitive with fossil fuels in many regions. + +## Wind Energy + +Wind turbines convert the kinetic energy of moving air into mechanical power, which can then be converted into electricity. These turbines range from small residential units to large utility-scale wind farms with hundreds of turbines. + +Wind energy is one of the fastest-growing renewable energy sources globally. Its cost-effectiveness has improved significantly, with the levelized cost of wind power now comparable to or lower than that of conventional fossil fuel plants in many locations. + +## Hydroelectric Power + +Hydroelectric power is generated by harnessing the energy of falling or flowing water. Large-scale hydropower projects typically involve dams that create reservoirs, while run-of-river systems utilize the natural flow of rivers without significant water storage. + +While hydropower is a mature technology with high efficiency, the environmental impacts of large dams can be significant, affecting river ecosystems and sometimes requiring the displacement of human communities. +""" + +def test_rag_pipeline_with_semantic_chunking(): + print("Testing RAG pipeline with semantic chunking") + print("-" * 50) + + # Initialize model instances + model_manager = ModelManager() + try: + llm_model_instance = model_manager.get_default_model_instance(ModelType.LLM) + print(f"Using LLM: {llm_model_instance.model} from {llm_model_instance.provider}") + except Exception as e: + print(f"Could not get LLM model instance: {e}") + print("Semantic chunking requires an LLM - this test will fall back to fixed chunking") + llm_model_instance = None + + try: + embedding_model_instance = model_manager.get_default_model_instance(ModelType.TEXT_EMBEDDING) + print(f"Using embedding model: {embedding_model_instance.model} from {embedding_model_instance.provider}") + except Exception as e: + print(f"Could not get embedding model instance: {e}") + print("Warning: Embedding model is required for the full RAG pipeline") + embedding_model_instance = None + + # Create parent document + parent_doc = Document(page_content=sample_text) + + print("\n1. Testing Parent-Child Processing with Semantic Chunking") + + # Simulate process rule configuration with semantic chunking + process_rule = { + "mode": "custom", + "rules": { + "pre_processing_rules": [], + "parent_mode": "paragraph", + "segmentation": { + "separator": "\\n\\n", + "max_tokens": 500, + "chunk_overlap": 50 + }, + "subchunk_segmentation": { + "separator": "\\n", + "max_tokens": 200, + "chunk_overlap": 20 + }, + "chunking_strategy": "semantic" + } + } + + # Create parent-child processor + parent_child_processor = ParentChildIndexProcessor() + + # Transform the document + try: + transformed_docs = parent_child_processor.transform( + [parent_doc], + process_rule=process_rule, + embedding_model_instance=embedding_model_instance, + llm_model_instance=llm_model_instance + ) + + print(f"\nParent-Child processing completed") + print(f"Number of parent documents: {len(transformed_docs)}") + + if transformed_docs: + parent = transformed_docs[0] + print(f"Parent document length: {len(parent.page_content)} chars") + print(f"Number of child chunks: {len(parent.children) if parent.children else 0}") + + if parent.children: + print("\nChild Chunks:") + for i, child in enumerate(parent.children): + print(f"\nChild {i+1} ({len(child.page_content)} chars):") + print(child.page_content[:100] + "..." if len(child.page_content) > 100 else child.page_content) + + print("\nChecking for natural semantic boundaries in child chunks:") + natural_boundaries = [".", "!", "?"] + complete_sentences = sum(1 for child in parent.children if child.page_content.strip()[-1] in natural_boundaries) + print(f"Complete sentences: {complete_sentences}/{len(parent.children)} child chunks") + + else: + print("No documents were returned from transformation") + + except Exception as e: + print(f"Error during parent-child processing: {e}") + + # Test fallback mechanism + print("\n2. Testing Fallback Mechanism") + if llm_model_instance: + try: + # Create semantic splitter with invalid parameters to force fallback + semantic_splitter = SemanticTextSplitter( + llm_model_instance=llm_model_instance, + chunk_size=20, # Too small to trigger errors + chunk_overlap=5, + ) + + # Attempt to split with parameters that would cause issues + semantic_chunks = semantic_splitter.split_text(sample_text) + print(f"Fallback handling worked, got {len(semantic_chunks)} chunks") + except Exception as e: + print(f"Error during fallback test: {e}") + else: + print("Skipping fallback test as no LLM is available") + +if __name__ == "__main__": + test_rag_pipeline_with_semantic_chunking() \ No newline at end of file diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 6b6580ae7e..3b99e393e8 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -63,6 +63,7 @@ import CustomDialog from '@/app/components/base/dialog' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import { noop } from 'lodash-es' +import Brain from '../../../base/icons/src/vender/solid/brain.svg' const TextLabel: FC = (props) => { return @@ -101,6 +102,8 @@ const DEFAULT_MAXIMUM_CHUNK_LENGTH = 1024 const DEFAULT_OVERLAP = 50 const MAXIMUM_CHUNK_TOKEN_LENGTH = Number.parseInt(globalThis.document?.body?.getAttribute('data-public-indexing-max-segmentation-tokens-length') || '4000', 10) +type ChunkingStrategy = 'fixed' | 'semantic' + type ParentChildConfig = { chunkForContext: ParentMode parent: { @@ -111,6 +114,7 @@ type ParentChildConfig = { delimiter: string maxLength: number } + chunkingStrategy: ChunkingStrategy } const defaultParentChildConfig: ParentChildConfig = { @@ -123,6 +127,7 @@ const defaultParentChildConfig: ParentChildConfig = { delimiter: '\\n', maxLength: 512, }, + chunkingStrategy: 'semantic', } const StepTwo = ({ @@ -216,40 +221,46 @@ const StepTwo = ({ const [parentChildConfig, setParentChildConfig] = useState(defaultParentChildConfig) + const [generalChunkingStrategy, setGeneralChunkingStrategy] = useState('fixed') + const getIndexing_technique = () => indexingType || indexType const currentDocForm = currentDataset?.doc_form || docForm const getProcessRule = (): ProcessRule => { - if (currentDocForm === ChunkingMode.parentChild) { + if (segmentationType === ProcessMode.general) { return { + mode: segmentationType, rules: { pre_processing_rules: rules, segmentation: { - separator: unescape( - parentChildConfig.parent.delimiter, - ), - max_tokens: parentChildConfig.parent.maxLength, + separator: escape(segmentIdentifier), + max_tokens: maxChunkLength, + chunk_overlap: overlap, }, + chunking_strategy: generalChunkingStrategy, + }, + } + } + else { + return { + mode: segmentationType, + rules: { + pre_processing_rules: rules, parent_mode: parentChildConfig.chunkForContext, + segmentation: { + separator: parentChildConfig.parent.delimiter, + max_tokens: parentChildConfig.parent.maxLength, + chunk_overlap: DEFAULT_OVERLAP, + }, subchunk_segmentation: { - separator: unescape(parentChildConfig.child.delimiter), + separator: parentChildConfig.child.delimiter, max_tokens: parentChildConfig.child.maxLength, + chunk_overlap: DEFAULT_OVERLAP, }, + chunking_strategy: parentChildConfig.chunkingStrategy, }, - mode: 'hierarchical', - } as ProcessRule + } } - return { - rules: { - pre_processing_rules: rules, - segmentation: { - separator: unescape(segmentIdentifier), - max_tokens: maxChunkLength, - chunk_overlap: overlap, - }, - }, // api will check this. It will be removed after api refactored. - mode: segmentationType, - } as ProcessRule } const fileIndexingEstimateQuery = useFetchFileIndexingEstimateForFile({ @@ -337,9 +348,13 @@ const StepTwo = ({ setSegmentIdentifier(defaultConfig.segmentation.separator) setMaxChunkLength(defaultConfig.segmentation.max_tokens) setOverlap(defaultConfig.segmentation.chunk_overlap!) - setRules(defaultConfig.pre_processing_rules) + setRules(defaultConfig.pre_processing_rules || []) + setGeneralChunkingStrategy('fixed') + setParentChildConfig({ + ...defaultParentChildConfig, + chunkingStrategy: 'fixed', + }) } - setParentChildConfig(defaultParentChildConfig) } const updatePreview = () => { @@ -634,6 +649,30 @@ const StepTwo = ({ onChange={setOverlap} /> +
+
+
+ {t('datasetCreation.stepTwo.chunkingStrategy')} +
+ +
+
+ } + title={t('datasetCreation.stepTwo.fixedChunking')} + description={t('datasetCreation.stepTwo.fixedChunkingTip')} + isChosen={generalChunkingStrategy === 'fixed'} + onChosen={() => setGeneralChunkingStrategy('fixed')} + /> + } + title={t('datasetCreation.stepTwo.semanticChunking')} + description={t('datasetCreation.stepTwo.semanticChunkingTip')} + isChosen={generalChunkingStrategy === 'semantic'} + onChosen={() => setGeneralChunkingStrategy('semantic')} + /> +
+
@@ -815,6 +854,36 @@ const StepTwo = ({ />
+
+
+
+ {t('datasetCreation.stepTwo.chunkingStrategy')} +
+ +
+
+ } + title={t('datasetCreation.stepTwo.fixedChunking')} + description={t('datasetCreation.stepTwo.fixedChunkingTip')} + isChosen={parentChildConfig.chunkingStrategy === 'fixed'} + onChosen={() => setParentChildConfig({ + ...parentChildConfig, + chunkingStrategy: 'fixed', + })} + /> + } + title={t('datasetCreation.stepTwo.semanticChunking')} + description={t('datasetCreation.stepTwo.semanticChunkingTip')} + isChosen={parentChildConfig.chunkingStrategy === 'semantic'} + onChosen={() => setParentChildConfig({ + ...parentChildConfig, + chunkingStrategy: 'semantic', + })} + /> +
+
diff --git a/web/i18n/en-US/dataset-creation.ts b/web/i18n/en-US/dataset-creation.ts index cf2d454f06..6a92d6347c 100644 --- a/web/i18n/en-US/dataset-creation.ts +++ b/web/i18n/en-US/dataset-creation.ts @@ -122,6 +122,11 @@ const translation = { paragraphTip: 'This mode splits the text in to paragraphs based on delimiters and the maximum chunk length, using the split text as the parent chunk for retrieval.', fullDoc: 'Full Doc', fullDocTip: 'The entire document is used as the parent chunk and retrieved directly. Please note that for performance reasons, text exceeding 10000 tokens will be automatically truncated.', + chunkingStrategy: 'Chunking Strategy', + fixedChunking: 'Fixed', + fixedChunkingTip: 'Traditional chunking with fixed-length segments based on delimiters', + semanticChunking: 'Semantic', + semanticChunkingTip: 'AI-powered chunking that preserves semantic meaning and context', separator: 'Delimiter', separatorTip: 'A delimiter is the character used to separate text. \\n\\n and \\n are commonly used delimiters for separating paragraphs and lines. Combined with commas (\\n\\n,\\n), paragraphs will be segmented by lines when exceeding the maximum chunk length. You can also use special delimiters defined by yourself (e.g. ***).', separatorPlaceholder: '\\n\\n for paragraphs; \\n for lines',