diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 09cdd66e04..5732d023eb 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2128,174 +2128,367 @@ class SegmentService: @classmethod def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): + """ + Update a document segment with new content, keywords, and metadata. + + This method handles both simple updates (content unchanged) and complex updates (content changed) + with proper indexing, vector updates, and error handling. + + Args: + args: Update arguments containing new content, keywords, and settings + segment: The document segment to update + document: The parent document containing the segment + dataset: The dataset containing the document + + Returns: + DocumentSegment: The updated segment object + + Raises: + ValueError: If segment is currently being indexed or disabled segment cannot be updated + """ + # Check if segment is currently being indexed to prevent concurrent operations indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") + + # Handle segment enable/disable state changes if args.enabled is not None: action = args.enabled if segment.enabled != action: if not action: - segment.enabled = action - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - segment.disabled_by = current_user.id - db.session.add(segment) - db.session.commit() - # Set cache to prevent indexing the same segment multiple times - redis_client.setex(indexing_cache_key, 600, 1) - disable_segment_from_index_task.delay(segment.id) + # Disable segment and trigger index removal + cls._disable_segment(segment, indexing_cache_key) return segment + + # Validate that disabled segments cannot be updated unless being enabled if not segment.enabled: if args.enabled is not None: if not args.enabled: raise ValueError("Can't update disabled segment") else: raise ValueError("Can't update disabled segment") + try: - word_count_change = segment.word_count + # Track word count changes for document update + original_word_count = segment.word_count content = args.content or segment.content + + # Handle simple update case: content unchanged, only metadata updates if segment.content == content: - segment.word_count = len(content) - if document.doc_form == "qa_model": - segment.answer = args.answer - segment.word_count += len(args.answer) if args.answer else 0 - word_count_change = segment.word_count - word_count_change - keyword_changed = False - if args.keywords: - if Counter(segment.keywords) != Counter(args.keywords): - segment.keywords = args.keywords - keyword_changed = True - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - db.session.add(segment) - db.session.commit() - # update document word count - if word_count_change != 0: - document.word_count = max(0, document.word_count + word_count_change) - db.session.add(document) - # update segment index task - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: - # regenerate child chunks - # get embedding model instance - if dataset.indexing_technique == "high_quality": - # check embedding model setting - model_manager = ModelManager() - - if dataset.embedding_model_provider: - embedding_model_instance = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) - else: - embedding_model_instance = model_manager.get_default_model_instance( - tenant_id=dataset.tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - ) - else: - raise ValueError("The knowledge base index technique is not high quality!") - # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == document.dataset_process_rule_id) - .first() - ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): - if args.enabled or keyword_changed: - # update segment vector index - VectorService.update_segment_vector(args.keywords, segment, dataset) + cls._handle_simple_segment_update(args, segment, document, dataset, original_word_count) else: - segment_hash = helper.generate_text_hash(content) - tokens = 0 - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) - - # calc embedding use tokens - if document.doc_form == "qa_model": - segment.answer = args.answer - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] - else: - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] - segment.content = content - segment.index_node_hash = segment_hash - segment.word_count = len(content) - segment.tokens = tokens - segment.status = "completed" - segment.indexing_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - segment.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - segment.updated_by = current_user.id - segment.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - if document.doc_form == "qa_model": - segment.answer = args.answer - segment.word_count += len(args.answer) if args.answer else 0 - word_count_change = segment.word_count - word_count_change - # update document word count - if word_count_change != 0: - document.word_count = max(0, document.word_count + word_count_change) - db.session.add(document) - db.session.add(segment) - db.session.commit() - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: - # get embedding model instance - if dataset.indexing_technique == "high_quality": - # check embedding model setting - model_manager = ModelManager() - - if dataset.embedding_model_provider: - embedding_model_instance = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) - else: - embedding_model_instance = model_manager.get_default_model_instance( - tenant_id=dataset.tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - ) - else: - raise ValueError("The knowledge base index technique is not high quality!") - # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == document.dataset_process_rule_id) - .first() - ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): - # update segment vector index - VectorService.update_segment_vector(args.keywords, segment, dataset) + # Handle complex update case: content changed, requires re-indexing + cls._handle_complex_segment_update(args, segment, document, dataset, content, original_word_count) except Exception as e: - logging.exception("update segment index failed") - segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - segment.status = "error" - segment.error = str(e) - db.session.commit() + # Handle update failures by marking segment as error state + cls._handle_segment_update_error(segment, e) + + # Return fresh segment object from database new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() return new_segment + @classmethod + def _disable_segment(cls, segment: DocumentSegment, indexing_cache_key: str): + """ + Disable a segment and trigger index removal. + + Args: + segment: The segment to disable + indexing_cache_key: Redis cache key for indexing lock + """ + segment.enabled = False + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.disabled_by = current_user.id + db.session.add(segment) + db.session.commit() + + # Set cache to prevent indexing the same segment multiple times + redis_client.setex(indexing_cache_key, 600, 1) + disable_segment_from_index_task.delay(segment.id) + + @classmethod + def _handle_simple_segment_update( + cls, + args: SegmentUpdateArgs, + segment: DocumentSegment, + document: Document, + dataset: Dataset, + original_word_count: int, + ): + """ + Handle segment update when content remains unchanged. + Only updates metadata like keywords, answer, and word count. + + Args: + args: Update arguments + segment: The segment to update + document: Parent document + dataset: Parent dataset + original_word_count: Original word count before update + """ + # Update word count for main content + segment.word_count = len(segment.content) + + # Handle QA model specific updates + if document.doc_form == "qa_model": + segment.answer = args.answer + segment.word_count += len(args.answer) if args.answer else 0 + + # Calculate word count change for document update + word_count_change = segment.word_count - original_word_count + + # Update keywords if provided and changed + keyword_changed = False + if args.keywords: + if Counter(segment.keywords) != Counter(args.keywords): + segment.keywords = args.keywords + keyword_changed = True + + # Reset segment state to enabled + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + + # Persist changes to database + db.session.add(segment) + db.session.commit() + + # Update document word count if changed + if word_count_change != 0: + document.word_count = max(0, document.word_count + word_count_change) + db.session.add(document) + + # Handle vector index updates based on document type + cls._handle_vector_index_updates(args, segment, document, dataset, keyword_changed, False) + + @classmethod + def _handle_complex_segment_update( + cls, + args: SegmentUpdateArgs, + segment: DocumentSegment, + document: Document, + dataset: Dataset, + content: str, + original_word_count: int, + ): + """ + Handle segment update when content has changed. + Requires re-indexing and vector updates. + + Args: + args: Update arguments + segment: The segment to update + document: Parent document + dataset: Parent dataset + content: New content for the segment + original_word_count: Original word count before update + """ + # Generate new content hash for change detection + segment_hash = helper.generate_text_hash(content) + + # Calculate tokens for high quality indexing + tokens = cls._calculate_segment_tokens(args, content, document, dataset) + + # Update segment with new content and metadata + cls._update_segment_metadata(document, segment, content, segment_hash, tokens, args) + + # Calculate word count change for document update + word_count_change = segment.word_count - original_word_count + + # Update document word count if changed + if word_count_change != 0: + document.word_count = max(0, document.word_count + word_count_change) + db.session.add(document) + + # Persist changes to database + db.session.add(segment) + db.session.commit() + + # Handle vector index updates based on document type + cls._handle_vector_index_updates(args, segment, document, dataset, True, True) + + @classmethod + def _calculate_segment_tokens( + cls, args: SegmentUpdateArgs, content: str, document: Document, dataset: Dataset + ) -> int: + """ + Calculate token count for segment content based on indexing technique. + + Args: + args: Update arguments + content: Segment content + document: Parent document + dataset: Parent dataset + + Returns: + int: Token count for the content + """ + tokens = 0 + if dataset.indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + + # Calculate embedding tokens for QA model or regular content + if document.doc_form == "qa_model": + answer = args.answer or "" + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + answer])[0] + else: + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] + + return tokens + + @classmethod + def _update_segment_metadata( + cls, + document: Document, + segment: DocumentSegment, + content: str, + segment_hash: str, + tokens: int, + args: SegmentUpdateArgs, + ): + """ + Update segment metadata with new content and timestamps. + + Args: + segment: The segment to update + content: New content + segment_hash: Content hash for change detection + tokens: Token count for the content + args: Update arguments + """ + segment.content = content + segment.index_node_hash = segment_hash + segment.word_count = len(content) + segment.tokens = tokens + segment.status = "completed" + + # Update timestamps + current_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.indexing_at = current_time + segment.completed_at = current_time + segment.updated_by = current_user.id + segment.updated_at = current_time + + # Reset segment state to enabled + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + + # Handle QA model specific metadata + if document.doc_form == "qa_model": + segment.answer = args.answer + segment.word_count += len(args.answer) if args.answer else 0 + + @classmethod + def _handle_vector_index_updates( + cls, + args: SegmentUpdateArgs, + segment: DocumentSegment, + document: Document, + dataset: Dataset, + keyword_changed: bool, + content_changed: bool, + ): + """ + Handle vector index updates based on document type and update conditions. + + Args: + args: Update arguments + segment: The segment to update + document: Parent document + dataset: Parent dataset + keyword_changed: Whether keywords were changed in this update + content_changed: Whether content was changed in this update + """ + if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + # Regenerate child chunks for parent-child indexing + cls._regenerate_child_chunks(segment, document, dataset) + elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + # Update segment vector for paragraph/QA indexing + if content_changed or args.enabled or keyword_changed: + VectorService.update_segment_vector(args.keywords, segment, dataset) + + @classmethod + def _regenerate_child_chunks(cls, segment: DocumentSegment, document: Document, dataset: Dataset): + """ + Regenerate child chunks for parent-child indexing. + + Args: + segment: The segment to regenerate chunks for + document: Parent document + dataset: Parent dataset + + Raises: + ValueError: If indexing technique is not high quality or processing rule not found + """ + if dataset.indexing_technique != "high_quality": + raise ValueError("The knowledge base index technique is not high quality!") + + # Get embedding model instance + embedding_model_instance = cls._get_embedding_model_instance(dataset) + + # Get processing rule for chunk generation + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + + # Generate new child chunks + VectorService.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, True) + + @classmethod + def _get_embedding_model_instance(cls, dataset: Dataset): + """ + Get embedding model instance for the dataset. + + Args: + dataset: The dataset to get model for + + Returns: + Model instance for text embedding + """ + model_manager = ModelManager() + + if dataset.embedding_model_provider: + return model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + return model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + + @classmethod + def _handle_segment_update_error(cls, segment: DocumentSegment, error: Exception): + """ + Handle errors during segment update by marking segment as error state. + + Args: + segment: The segment that encountered an error + error: The exception that occurred + """ + logging.exception("update segment index failed") + segment.enabled = False + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.status = "error" + segment.error = str(error) + db.session.commit() + @classmethod def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): indexing_cache_key = "segment_{}_delete_indexing".format(segment.id) diff --git a/api/tests/unit_tests/services/test_segment_service_update_segment.py b/api/tests/unit_tests/services/test_segment_service_update_segment.py new file mode 100644 index 0000000000..7758cf78d3 --- /dev/null +++ b/api/tests/unit_tests/services/test_segment_service_update_segment.py @@ -0,0 +1,872 @@ +import datetime +from typing import Any, Optional + +# Mock redis_client before importing segment_service +from unittest.mock import Mock, patch + +import pytest + +from core.rag.index_processor.constant.index_type import IndexType +from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment +from services.dataset_service import SegmentService +from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs +from tests.unit_tests.conftest import redis_mock + + +class SegmentUpdateTestDataFactory: + """Factory class for creating test data and mock objects for segment update tests.""" + + @staticmethod + def create_segment_mock( + segment_id: str = "segment-123", + content: str = "old_content", + answer: Optional[str] = None, + keywords: Optional[list[str]] = None, + enabled: bool = True, + status: str = "completed", + word_count: int = 10, + tokens: int = 15, + position: int = 1, + **kwargs, + ) -> Mock: + """Create a mock segment with specified attributes.""" + segment = Mock(spec=DocumentSegment) + segment.id = segment_id + segment.content = content + segment.answer = answer + segment.keywords = keywords or [] + segment.enabled = enabled + segment.status = status + segment.word_count = word_count + segment.tokens = tokens + segment.position = position + segment.index_node_id = f"node-{segment_id}" + segment.index_node_hash = f"hash-{segment_id}" + segment.tenant_id = "tenant-123" + segment.dataset_id = "dataset-123" + segment.document_id = "document-123" + segment.created_by = "user-789" + segment.created_at = datetime.datetime(2023, 1, 1, 12, 0, 0) + segment.updated_by = None + segment.updated_at = datetime.datetime(2023, 1, 1, 12, 0, 0) + segment.indexing_at = datetime.datetime(2023, 1, 1, 12, 0, 0) + segment.completed_at = datetime.datetime(2023, 1, 1, 12, 0, 0) + segment.disabled_at = None + segment.disabled_by = None + segment.error = None + for key, value in kwargs.items(): + setattr(segment, key, value) + return segment + + @staticmethod + def create_document_mock( + document_id: str = "document-123", + doc_form: str = IndexType.PARAGRAPH_INDEX, + word_count: int = 100, + dataset_process_rule_id: str = "rule-123", + **kwargs, + ) -> Mock: + """Create a mock document with specified attributes.""" + document = Mock(spec=Document) + document.id = document_id + document.doc_form = doc_form + document.word_count = word_count + document.dataset_process_rule_id = dataset_process_rule_id + document.tenant_id = "tenant-123" + document.dataset_id = "dataset-123" + document.created_by = "user-789" + document.created_at = datetime.datetime(2023, 1, 1, 12, 0, 0) + for key, value in kwargs.items(): + setattr(document, key, value) + return document + + @staticmethod + def create_dataset_mock( + dataset_id: str = "dataset-123", + indexing_technique: str = "high_quality", + embedding_model_provider: str = "openai", + embedding_model: str = "text-embedding-ada-002", + tenant_id: str = "tenant-123", + **kwargs, + ) -> Mock: + """Create a mock dataset with specified attributes.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.indexing_technique = indexing_technique + dataset.embedding_model_provider = embedding_model_provider + dataset.embedding_model = embedding_model + dataset.tenant_id = tenant_id + for key, value in kwargs.items(): + setattr(dataset, key, value) + return dataset + + @staticmethod + def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock: + """Create a mock embedding model.""" + embedding_model = Mock() + embedding_model.model = model + embedding_model.provider = provider + embedding_model.get_text_embedding_num_tokens.return_value = [20] + return embedding_model + + @staticmethod + def create_processing_rule_mock(rule_id: str = "rule-123") -> Mock: + """Create a mock processing rule.""" + processing_rule = Mock(spec=DatasetProcessRule) + processing_rule.id = rule_id + processing_rule.to_dict.return_value = {"rules": {"parent_mode": "full_doc"}} + return processing_rule + + @staticmethod + def create_current_user_mock(user_id: str = "user-789", tenant_id: str = "tenant-123") -> Mock: + """Create a mock current user.""" + current_user = Mock() + current_user.id = user_id + current_user.current_tenant_id = tenant_id + return current_user + + +class TestSegmentServiceUpdateSegment: + """ + Comprehensive unit tests for SegmentService.update_segment method. + + This test suite covers all supported scenarios including: + - Segment enable/disable operations + - Content updates with same and different content + - QA model updates with answer field + - Keyword updates + - Different document forms (paragraph, QA, parent-child) + - High quality vs economy indexing techniques + - Child chunk regeneration + - Error handling and edge cases + - Redis cache management + - Vector index updates + """ + + @pytest.fixture + def mock_segment_service_dependencies(self): + """Common mock setup for segment service dependencies.""" + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("extensions.ext_database.db.session") as mock_db, + patch("services.dataset_service.datetime") as mock_datetime, + patch("services.dataset_service.current_user") as mock_current_user, + patch("services.dataset_service.helper") as mock_helper, + ): + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + mock_current_user.id = "user-789" + mock_current_user.current_tenant_id = "tenant-123" + mock_helper.generate_text_hash.return_value = "new_hash_123" + + yield { + "redis_client": mock_redis, + "db_session": mock_db, + "datetime": mock_datetime, + "current_user": mock_current_user, + "helper": mock_helper, + "current_time": current_time, + } + + @pytest.fixture + def mock_model_manager_dependencies(self): + """Mock setup for model manager tests.""" + with patch("services.dataset_service.ModelManager") as mock_model_manager: + yield mock_model_manager + + @pytest.fixture + def mock_vector_service_dependencies(self): + """Mock setup for vector service tests.""" + with ( + patch("services.dataset_service.VectorService") as mock_vector_service, + patch("services.dataset_service.disable_segment_from_index_task") as mock_disable_task, + ): + yield { + "vector_service": mock_vector_service, + "disable_task": mock_disable_task, + } + + @pytest.fixture + def mock_processing_rule_dependencies(self): + """Mock setup for processing rule tests.""" + with patch("services.dataset_service.DatasetProcessRule") as mock_processing_rule: + yield mock_processing_rule + + def _assert_redis_cache_operations(self, mock_redis, segment_id: str, should_set_cache: bool = False): + """Helper method to verify Redis cache operations.""" + mock_redis.get.assert_called_once_with(f"segment_{segment_id}_indexing") + if should_set_cache: + mock_redis.setex.assert_called_once_with(f"segment_{segment_id}_indexing", 600, 1) + + def _assert_database_operations(self, mock_db, expected_objects: list[Any]): + """Helper method to verify database operations. + + Args: + mock_db: Mock database session + expected_objects: List of objects that should have been added to the session. + If None, only verifies that add() was called at least once. + """ + if expected_objects is None: + # Just verify that add was called at least once + assert mock_db.add.call_count >= 1 + else: + # Get all the objects that were passed to add() + added_objects = [call.args[0] for call in mock_db.add.call_args_list] + + # Verify each expected object was added + for expected_obj in expected_objects: + assert expected_obj in added_objects, f"Expected object {expected_obj} was not added to session" + + def _assert_segment_attributes_updated(self, segment, expected_updates: dict[str, Any]): + """Helper method to verify segment attribute updates.""" + for key, value in expected_updates.items(): + assert getattr(segment, key) == value + + # ==================== Enable/Disable Segment Tests ==================== + + def test_disable_segment_success(self, mock_segment_service_dependencies, mock_vector_service_dependencies): + """Test successful disable of an enabled segment.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(enabled=True) + document = SegmentUpdateTestDataFactory.create_document_mock() + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs(enabled=False) + + result = SegmentService.update_segment(args, segment, document, dataset) + + # Verify Redis cache check + self._assert_redis_cache_operations( + mock_segment_service_dependencies["redis_client"], segment.id, should_set_cache=True + ) + + # Verify segment was disabled + self._assert_segment_attributes_updated( + segment, + { + "enabled": False, + "disabled_at": mock_segment_service_dependencies["current_time"].replace(tzinfo=None), + "disabled_by": mock_segment_service_dependencies["current_user"].id, + }, + ) + + # Verify database operations + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment]) + + # Verify disable task was triggered + mock_vector_service_dependencies["disable_task"].delay.assert_called_once_with(segment.id) + + # Verify return value + assert result == segment + + def test_disable_segment_already_disabled(self, mock_segment_service_dependencies): + """Test disable operation on already disabled segment.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(enabled=False) + document = SegmentUpdateTestDataFactory.create_document_mock() + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs(enabled=False) + + with pytest.raises(ValueError) as context: + SegmentService.update_segment(args, segment, document, dataset) + + assert "Can't update disabled segment" in str(context.value) + + def test_enable_segment_no_change(self, mock_segment_service_dependencies, mock_vector_service_dependencies): + """Test enable operation on already enabled segment (no change).""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(enabled=True) + document = SegmentUpdateTestDataFactory.create_document_mock() + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs(enabled=True) + + result = SegmentService.update_segment(args, segment, document, dataset) + + # Verify segment remains enabled (no change) + assert segment.enabled is True + assert segment.disabled_at is None + assert segment.disabled_by is None + + def test_update_disabled_segment_without_enable(self, mock_segment_service_dependencies): + """Test updating disabled segment without enabling it.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(enabled=False) + document = SegmentUpdateTestDataFactory.create_document_mock() + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs(content="new_content") + + with pytest.raises(ValueError) as context: + SegmentService.update_segment(args, segment, document, dataset) + + assert "Can't update disabled segment" in str(context.value) + + def test_segment_indexing_in_progress_error(self, mock_segment_service_dependencies): + """Test error when segment is currently being indexed.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock() + document = SegmentUpdateTestDataFactory.create_document_mock() + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as indexing in progress + mock_segment_service_dependencies["redis_client"].get.return_value = "1" + + args = SegmentUpdateArgs(content="new_content") + + with pytest.raises(ValueError) as context: + SegmentService.update_segment(args, segment, document, dataset) + + assert "Segment is indexing, please try again later" in str(context.value) + + # ==================== Content Update Tests (Same Content) ==================== + + def test_update_segment_same_content_success( + self, mock_segment_service_dependencies, mock_vector_service_dependencies + ): + """Test updating segment with same content (only keywords/answer change).""" + segment = SegmentUpdateTestDataFactory.create_segment_mock( + content="test_content", keywords=["old_keyword"], answer="old_answer" + ) + document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.QA_INDEX) + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs( + content="test_content", # Same content + keywords=["new_keyword"], + answer="new_answer", + ) + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify segment attributes were updated + self._assert_segment_attributes_updated( + segment, + { + "keywords": ["new_keyword"], + "answer": "new_answer", + "enabled": True, + "disabled_at": None, + "disabled_by": None, + "word_count": len("test_content") + len("new_answer"), + }, + ) + + # Verify database operations + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment]) + + def test_update_segment_same_content_keywords_unchanged( + self, mock_segment_service_dependencies, mock_vector_service_dependencies + ): + """Test updating segment with same content and unchanged keywords.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock( + content="test_content", keywords=["keyword1", "keyword2"] + ) + document = SegmentUpdateTestDataFactory.create_document_mock() + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs( + content="test_content", # Same content + keywords=["keyword1", "keyword2"], # Same keywords + ) + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify keywords remain unchanged + assert segment.keywords == ["keyword1", "keyword2"] + + def test_update_segment_same_content_no_keywords_provided( + self, mock_segment_service_dependencies, mock_vector_service_dependencies + ): + """Test updating segment with same content and no keywords provided.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock( + content="test_content", keywords=["existing_keyword"] + ) + document = SegmentUpdateTestDataFactory.create_document_mock() + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs(content="test_content") # No keywords provided + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify keywords remain unchanged + assert segment.keywords == ["existing_keyword"] + + # ==================== Content Update Tests (Different Content) ==================== + + def test_update_segment_different_content_paragraph_index( + self, mock_segment_service_dependencies, mock_model_manager_dependencies, mock_vector_service_dependencies + ): + """Test updating segment with different content for paragraph index.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(content="old_content", word_count=10, tokens=15) + document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.PARAGRAPH_INDEX) + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + # Mock embedding model + embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock() + mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model + + args = SegmentUpdateArgs(content="new_content", keywords=["new_keyword"]) + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify segment attributes were updated + self._assert_segment_attributes_updated( + segment, + { + "content": "new_content", + "index_node_hash": "new_hash_123", + "word_count": len("new_content"), + "tokens": 20, # From mock embedding model + "status": "completed", + "keywords": [], # keywords are not updated for content update + "enabled": True, + "disabled_at": None, + "disabled_by": None, + }, + ) + + # Verify database operations + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment]) + + # Verify vector service was called + mock_vector_service_dependencies["vector_service"].update_segment_vector.assert_called_once_with( + ["new_keyword"], segment, dataset + ) + + def test_update_segment_different_content_qa_index( + self, mock_segment_service_dependencies, mock_model_manager_dependencies, mock_vector_service_dependencies + ): + """Test updating segment with different content for QA index.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock( + content="old_content", answer="old_answer", word_count=10, tokens=15 + ) + document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.QA_INDEX) + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + # Mock embedding model + embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock() + mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model + + args = SegmentUpdateArgs(content="new_content", answer="new_answer", keywords=["new_keyword"]) + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify segment attributes were updated + expected_word_count = len("new_content") + len("new_answer") + self._assert_segment_attributes_updated( + segment, + { + "content": "new_content", + "answer": "new_answer", + "index_node_hash": "new_hash_123", + "word_count": expected_word_count, + "tokens": 20, # From mock embedding model + "status": "completed", + "keywords": [], # keywords are not updated for content update + "enabled": True, + "disabled_at": None, + "disabled_by": None, + }, + ) + + # Verify embedding model was called with combined content + embedding_model.get_text_embedding_num_tokens.assert_called_once_with(texts=["new_contentnew_answer"]) + + # Verify vector service was called + mock_vector_service_dependencies["vector_service"].update_segment_vector.assert_called_once_with( + ["new_keyword"], segment, dataset + ) + + # Verify database operations - segment should be added + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment]) + + def test_update_segment_different_content_economy_indexing( + self, mock_segment_service_dependencies, mock_vector_service_dependencies + ): + """Test updating segment with different content for economy indexing technique.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(content="old_content", word_count=10, tokens=15) + document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.PARAGRAPH_INDEX) + dataset = SegmentUpdateTestDataFactory.create_dataset_mock(indexing_technique="economy") + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs(content="new_content", keywords=["new_keyword"]) + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify segment attributes were updated (no tokens calculation for economy) + self._assert_segment_attributes_updated( + segment, + { + "content": "new_content", + "index_node_hash": "new_hash_123", + "word_count": len("new_content"), + "tokens": 0, # No tokens calculation for economy + "status": "completed", + "keywords": [], # keywords are not updated for content update + "enabled": True, + "disabled_at": None, + "disabled_by": None, + }, + ) + + # Verify vector service was called + mock_vector_service_dependencies["vector_service"].update_segment_vector.assert_called_once_with( + ["new_keyword"], segment, dataset + ) + + # Verify database operations - segment should be added + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment]) + + # ==================== Parent-Child Index Tests ==================== + + def test_update_segment_parent_child_index_regenerate_child_chunks( + self, + mock_segment_service_dependencies, + mock_model_manager_dependencies, + mock_vector_service_dependencies, + mock_processing_rule_dependencies, + ): + """Test updating segment with parent-child index and child chunk regeneration.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(content="new_content") + document = SegmentUpdateTestDataFactory.create_document_mock( + doc_form=IndexType.PARENT_CHILD_INDEX, dataset_process_rule_id="rule-123" + ) + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + # Mock embedding model + embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock() + mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model + + # Mock processing rule query + processing_rule = SegmentUpdateTestDataFactory.create_processing_rule_mock() + mock_segment_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = processing_rule + + args = SegmentUpdateArgs(content="new_content", regenerate_child_chunks=True, keywords=["new_keyword"]) + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify child chunks were regenerated + mock_vector_service_dependencies["vector_service"].generate_child_chunks.assert_called_once_with( + segment, document, dataset, embedding_model, processing_rule, True + ) + + # Verify database operations - segment should be added + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment]) + + def test_update_segment_parent_child_index_no_regenerate( + self, mock_segment_service_dependencies, mock_model_manager_dependencies, mock_vector_service_dependencies + ): + """Test updating segment with parent-child index without child chunk regeneration.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(content="new_content") + document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.PARENT_CHILD_INDEX) + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + # Mock embedding model + embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock() + mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model + + args = SegmentUpdateArgs(content="new_content", regenerate_child_chunks=False) + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify child chunks were not regenerated + mock_vector_service_dependencies["vector_service"].generate_child_chunks.assert_not_called() + + # Verify database operations - segment should be added + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment]) + + def test_update_segment_parent_child_index_economy_technique_error( + self, mock_segment_service_dependencies, mock_model_manager_dependencies + ): + """Test error when trying to regenerate child chunks with economy indexing technique.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(content="new_content") + document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.PARENT_CHILD_INDEX) + dataset = SegmentUpdateTestDataFactory.create_dataset_mock(indexing_technique="economy") + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs(content="new_content", regenerate_child_chunks=True) + + SegmentService.update_segment(args, segment, document, dataset) + + assert "error" in segment.status + assert "The knowledge base index technique is not high quality!" in str(segment.error) + + def test_update_segment_parent_child_index_no_processing_rule_error( + self, mock_segment_service_dependencies, mock_model_manager_dependencies + ): + """Test error when processing rule is not found for parent-child index.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(content="new_content") + document = SegmentUpdateTestDataFactory.create_document_mock( + doc_form=IndexType.PARENT_CHILD_INDEX, dataset_process_rule_id="rule-123" + ) + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + # Mock embedding model + embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock() + mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model + + # Mock processing rule query returning None + mock_segment_service_dependencies["db_session"].query.return_value.filter.return_value.first.return_value = None + + args = SegmentUpdateArgs(content="new_content", regenerate_child_chunks=True) + + SegmentService.update_segment(args, segment, document, dataset) + + assert "error" in segment.status + assert "No processing rule found." in str(segment.error) + + # ==================== Document Word Count Update Tests ==================== + + def test_update_segment_word_count_increase( + self, mock_segment_service_dependencies, mock_model_manager_dependencies + ): + """Test that document word count is updated when segment word count increases.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(content="old_content", word_count=10) + document = SegmentUpdateTestDataFactory.create_document_mock(word_count=100) + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs(content="new_content_much_longer_than_old") + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify document word count was updated + expected_word_count_change = len("new_content_much_longer_than_old") - 10 + expected_document_word_count = 100 + expected_word_count_change + assert document.word_count == expected_document_word_count + + # Verify database operations - both segment and document should be added + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment, document]) + + def test_update_segment_word_count_decrease( + self, mock_segment_service_dependencies, mock_model_manager_dependencies + ): + """Test that document word count is updated when segment word count decreases.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(content="very_long_old_content", word_count=25) + document = SegmentUpdateTestDataFactory.create_document_mock(word_count=100) + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs(content="short") + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify document word count was updated + expected_word_count_change = len("short") - 25 + expected_document_word_count = 100 + expected_word_count_change + assert document.word_count == expected_document_word_count + + # Verify database operations - both segment and document should be added + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment, document]) + + def test_update_segment_word_count_no_change(self, mock_segment_service_dependencies): + """Test that document word count is not updated when segment word count doesn't change.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(content="same_length", word_count=11) + document = SegmentUpdateTestDataFactory.create_document_mock(word_count=100) + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs(content="same_length") # Same length content + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify document word count was not changed + assert document.word_count == 100 + + # Verify database operations - only segment should be added (no word count change for document) + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment]) + + def test_update_segment_word_count_qa_model_with_answer( + self, mock_segment_service_dependencies, mock_model_manager_dependencies + ): + """Test word count update for QA model with answer field.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock( + content="question", answer="old_answer", word_count=7 + ) + document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.QA_INDEX, word_count=100) + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + # Mock embedding model + embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock() + mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model + + args = SegmentUpdateArgs(content="question", answer="new_longer_answer") + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify segment word count includes answer + expected_segment_word_count = len("question") + len("new_longer_answer") + assert segment.word_count == expected_segment_word_count + + # Verify document word count was updated + expected_word_count_change = expected_segment_word_count - 7 + expected_document_word_count = 100 + expected_word_count_change + assert document.word_count == expected_document_word_count + + # Verify database operations - both segment and document should be added + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment, document]) + + # ==================== Error Handling Tests ==================== + + def test_update_segment_vector_service_error( + self, mock_segment_service_dependencies, mock_model_manager_dependencies, mock_vector_service_dependencies + ): + """Test error handling when vector service fails.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(content="old_content") + document = SegmentUpdateTestDataFactory.create_document_mock() + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + # Mock embedding model + embedding_model = SegmentUpdateTestDataFactory.create_embedding_model_mock() + mock_model_manager_dependencies.return_value.get_model_instance.return_value = embedding_model + + # Mock vector service to raise error + mock_vector_service_dependencies["vector_service"].update_segment_vector.side_effect = Exception( + "Vector service error" + ) + + args = SegmentUpdateArgs(content="new_content", keywords=["keyword"]) + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify segment was marked as error + self._assert_segment_attributes_updated( + segment, + { + "enabled": False, + "status": "error", + "error": "Vector service error", + "disabled_at": mock_segment_service_dependencies["current_time"].replace(tzinfo=None), + }, + ) + + # Verify database operations - segment should be added with error state + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment]) + + # ==================== Edge Cases and Integration Tests ==================== + + def test_update_segment_with_none_content_uses_existing(self, mock_segment_service_dependencies): + """Test that None content uses existing segment content.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(content="existing_content") + document = SegmentUpdateTestDataFactory.create_document_mock() + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs(content=None, keywords=["new_keyword"]) + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify content remains unchanged + assert segment.content == "existing_content" + + # Verify keywords were updated + assert segment.keywords == ["new_keyword"] + + # Verify database operations - segment should be added + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment]) + + def test_update_segment_with_none_answer_qa_model(self, mock_segment_service_dependencies): + """Test updating QA model segment with None answer.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock( + content="question", answer="old_answer", word_count=15 + ) + document = SegmentUpdateTestDataFactory.create_document_mock(doc_form=IndexType.QA_INDEX) + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + args = SegmentUpdateArgs(content="question", answer=None) + + SegmentService.update_segment(args, segment, document, dataset) + + # Verify answer was updated to None + assert segment.answer is None + + # Verify word count only includes content (no answer) + assert segment.word_count == len("question") + + # Verify database operations - segment should be added + self._assert_database_operations(mock_segment_service_dependencies["db_session"], [segment]) + + def test_update_segment_final_query_returns_updated_segment(self, mock_segment_service_dependencies): + """Test that the final database query returns the updated segment.""" + segment = SegmentUpdateTestDataFactory.create_segment_mock(content="old_content") + document = SegmentUpdateTestDataFactory.create_document_mock() + dataset = SegmentUpdateTestDataFactory.create_dataset_mock() + + # Mock Redis cache as not indexing + mock_segment_service_dependencies["redis_client"].get.return_value = None + + # Mock final query to return updated segment + updated_segment = SegmentUpdateTestDataFactory.create_segment_mock( + segment_id="segment-123", content="new_content" + ) + mock_segment_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = updated_segment + + args = SegmentUpdateArgs(content="new_content") + + mock_segment_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = updated_segment + + result = SegmentService.update_segment(args, segment, document, dataset) + + # Verify result is the updated segment + assert result == updated_segment