From 1cd77a9a17e4bf8b3e06462bcbe4610d26afe3f1 Mon Sep 17 00:00:00 2001 From: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Mon, 30 Jun 2025 09:56:36 +0800 Subject: [PATCH] feat: refactor: update segment, abstract common methods --- api/services/dataset_service.py | 481 ++++++++++++++++++++++---------- 1 file changed, 337 insertions(+), 144 deletions(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index e42b5ace75..548ac2fbd7 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2127,174 +2127,367 @@ class SegmentService: @classmethod def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): + """ + Update a document segment with new content, keywords, and metadata. + + This method handles both simple updates (content unchanged) and complex updates (content changed) + with proper indexing, vector updates, and error handling. + + Args: + args: Update arguments containing new content, keywords, and settings + segment: The document segment to update + document: The parent document containing the segment + dataset: The dataset containing the document + + Returns: + DocumentSegment: The updated segment object + + Raises: + ValueError: If segment is currently being indexed or disabled segment cannot be updated + """ + # Check if segment is currently being indexed to prevent concurrent operations indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") + + # Handle segment enable/disable state changes if args.enabled is not None: action = args.enabled if segment.enabled != action: if not action: - segment.enabled = action - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - segment.disabled_by = current_user.id - db.session.add(segment) - db.session.commit() - # Set cache to prevent indexing the same segment multiple times - redis_client.setex(indexing_cache_key, 600, 1) - disable_segment_from_index_task.delay(segment.id) + # Disable segment and trigger index removal + cls._disable_segment(segment, indexing_cache_key) return segment + + # Validate that disabled segments cannot be updated unless being enabled if not segment.enabled: if args.enabled is not None: if not args.enabled: raise ValueError("Can't update disabled segment") else: raise ValueError("Can't update disabled segment") + try: - word_count_change = segment.word_count + # Track word count changes for document update + original_word_count = segment.word_count content = args.content or segment.content + + # Handle simple update case: content unchanged, only metadata updates if segment.content == content: - segment.word_count = len(content) - if document.doc_form == "qa_model": - segment.answer = args.answer - segment.word_count += len(args.answer) if args.answer else 0 - word_count_change = segment.word_count - word_count_change - keyword_changed = False - if args.keywords: - if Counter(segment.keywords) != Counter(args.keywords): - segment.keywords = args.keywords - keyword_changed = True - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - db.session.add(segment) - db.session.commit() - # update document word count - if word_count_change != 0: - document.word_count = max(0, document.word_count + word_count_change) - db.session.add(document) - # update segment index task - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: - # regenerate child chunks - # get embedding model instance - if dataset.indexing_technique == "high_quality": - # check embedding model setting - model_manager = ModelManager() - - if dataset.embedding_model_provider: - embedding_model_instance = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) - else: - embedding_model_instance = model_manager.get_default_model_instance( - tenant_id=dataset.tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - ) - else: - raise ValueError("The knowledge base index technique is not high quality!") - # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == document.dataset_process_rule_id) - .first() - ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): - if args.enabled or keyword_changed: - # update segment vector index - VectorService.update_segment_vector(args.keywords, segment, dataset) + cls._handle_simple_segment_update(args, segment, document, dataset, original_word_count) else: - segment_hash = helper.generate_text_hash(content) - tokens = 0 - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) - - # calc embedding use tokens - if document.doc_form == "qa_model": - segment.answer = args.answer - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] - else: - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] - segment.content = content - segment.index_node_hash = segment_hash - segment.word_count = len(content) - segment.tokens = tokens - segment.status = "completed" - segment.indexing_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - segment.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - segment.updated_by = current_user.id - segment.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - if document.doc_form == "qa_model": - segment.answer = args.answer - segment.word_count += len(args.answer) if args.answer else 0 - word_count_change = segment.word_count - word_count_change - # update document word count - if word_count_change != 0: - document.word_count = max(0, document.word_count + word_count_change) - db.session.add(document) - db.session.add(segment) - db.session.commit() - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: - # get embedding model instance - if dataset.indexing_technique == "high_quality": - # check embedding model setting - model_manager = ModelManager() - - if dataset.embedding_model_provider: - embedding_model_instance = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) - else: - embedding_model_instance = model_manager.get_default_model_instance( - tenant_id=dataset.tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - ) - else: - raise ValueError("The knowledge base index technique is not high quality!") - # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == document.dataset_process_rule_id) - .first() - ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): - # update segment vector index - VectorService.update_segment_vector(args.keywords, segment, dataset) + # Handle complex update case: content changed, requires re-indexing + cls._handle_complex_segment_update(args, segment, document, dataset, content, original_word_count) except Exception as e: - logging.exception("update segment index failed") - segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - segment.status = "error" - segment.error = str(e) - db.session.commit() + # Handle update failures by marking segment as error state + cls._handle_segment_update_error(segment, e) + + # Return fresh segment object from database new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() return new_segment + @classmethod + def _disable_segment(cls, segment: DocumentSegment, indexing_cache_key: str): + """ + Disable a segment and trigger index removal. + + Args: + segment: The segment to disable + indexing_cache_key: Redis cache key for indexing lock + """ + segment.enabled = False + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.disabled_by = current_user.id + db.session.add(segment) + db.session.commit() + + # Set cache to prevent indexing the same segment multiple times + redis_client.setex(indexing_cache_key, 600, 1) + disable_segment_from_index_task.delay(segment.id) + + @classmethod + def _handle_simple_segment_update( + cls, + args: SegmentUpdateArgs, + segment: DocumentSegment, + document: Document, + dataset: Dataset, + original_word_count: int, + ): + """ + Handle segment update when content remains unchanged. + Only updates metadata like keywords, answer, and word count. + + Args: + args: Update arguments + segment: The segment to update + document: Parent document + dataset: Parent dataset + original_word_count: Original word count before update + """ + # Update word count for main content + segment.word_count = len(segment.content) + + # Handle QA model specific updates + if document.doc_form == "qa_model": + segment.answer = args.answer + segment.word_count += len(args.answer) if args.answer else 0 + + # Calculate word count change for document update + word_count_change = segment.word_count - original_word_count + + # Update keywords if provided and changed + keyword_changed = False + if args.keywords: + if Counter(segment.keywords) != Counter(args.keywords): + segment.keywords = args.keywords + keyword_changed = True + + # Reset segment state to enabled + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + + # Persist changes to database + db.session.add(segment) + db.session.commit() + + # Update document word count if changed + if word_count_change != 0: + document.word_count = max(0, document.word_count + word_count_change) + db.session.add(document) + + # Handle vector index updates based on document type + cls._handle_vector_index_updates(args, segment, document, dataset, keyword_changed, False) + + @classmethod + def _handle_complex_segment_update( + cls, + args: SegmentUpdateArgs, + segment: DocumentSegment, + document: Document, + dataset: Dataset, + content: str, + original_word_count: int, + ): + """ + Handle segment update when content has changed. + Requires re-indexing and vector updates. + + Args: + args: Update arguments + segment: The segment to update + document: Parent document + dataset: Parent dataset + content: New content for the segment + original_word_count: Original word count before update + """ + # Generate new content hash for change detection + segment_hash = helper.generate_text_hash(content) + + # Calculate tokens for high quality indexing + tokens = cls._calculate_segment_tokens(args, content, document, dataset) + + # Update segment with new content and metadata + cls._update_segment_metadata(document, segment, content, segment_hash, tokens, args) + + # Calculate word count change for document update + word_count_change = segment.word_count - original_word_count + + # Update document word count if changed + if word_count_change != 0: + document.word_count = max(0, document.word_count + word_count_change) + db.session.add(document) + + # Persist changes to database + db.session.add(segment) + db.session.commit() + + # Handle vector index updates based on document type + cls._handle_vector_index_updates(args, segment, document, dataset, True, True) + + @classmethod + def _calculate_segment_tokens( + cls, args: SegmentUpdateArgs, content: str, document: Document, dataset: Dataset + ) -> int: + """ + Calculate token count for segment content based on indexing technique. + + Args: + args: Update arguments + content: Segment content + document: Parent document + dataset: Parent dataset + + Returns: + int: Token count for the content + """ + tokens = 0 + if dataset.indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + + # Calculate embedding tokens for QA model or regular content + if document.doc_form == "qa_model": + answer = args.answer or "" + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + answer])[0] + else: + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] + + return tokens + + @classmethod + def _update_segment_metadata( + cls, + document: Document, + segment: DocumentSegment, + content: str, + segment_hash: str, + tokens: int, + args: SegmentUpdateArgs, + ): + """ + Update segment metadata with new content and timestamps. + + Args: + segment: The segment to update + content: New content + segment_hash: Content hash for change detection + tokens: Token count for the content + args: Update arguments + """ + segment.content = content + segment.index_node_hash = segment_hash + segment.word_count = len(content) + segment.tokens = tokens + segment.status = "completed" + + # Update timestamps + current_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.indexing_at = current_time + segment.completed_at = current_time + segment.updated_by = current_user.id + segment.updated_at = current_time + + # Reset segment state to enabled + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + + # Handle QA model specific metadata + if document.doc_form == "qa_model": + segment.answer = args.answer + segment.word_count += len(args.answer) if args.answer else 0 + + @classmethod + def _handle_vector_index_updates( + cls, + args: SegmentUpdateArgs, + segment: DocumentSegment, + document: Document, + dataset: Dataset, + keyword_changed: bool, + content_changed: bool, + ): + """ + Handle vector index updates based on document type and update conditions. + + Args: + args: Update arguments + segment: The segment to update + document: Parent document + dataset: Parent dataset + keyword_changed: Whether keywords were changed in this update + content_changed: Whether content was changed in this update + """ + if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + # Regenerate child chunks for parent-child indexing + cls._regenerate_child_chunks(segment, document, dataset) + elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + # Update segment vector for paragraph/QA indexing + if content_changed or args.enabled or keyword_changed: + VectorService.update_segment_vector(args.keywords, segment, dataset) + + @classmethod + def _regenerate_child_chunks(cls, segment: DocumentSegment, document: Document, dataset: Dataset): + """ + Regenerate child chunks for parent-child indexing. + + Args: + segment: The segment to regenerate chunks for + document: Parent document + dataset: Parent dataset + + Raises: + ValueError: If indexing technique is not high quality or processing rule not found + """ + if dataset.indexing_technique != "high_quality": + raise ValueError("The knowledge base index technique is not high quality!") + + # Get embedding model instance + embedding_model_instance = cls._get_embedding_model_instance(dataset) + + # Get processing rule for chunk generation + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + + # Generate new child chunks + VectorService.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, True) + + @classmethod + def _get_embedding_model_instance(cls, dataset: Dataset): + """ + Get embedding model instance for the dataset. + + Args: + dataset: The dataset to get model for + + Returns: + Model instance for text embedding + """ + model_manager = ModelManager() + + if dataset.embedding_model_provider: + return model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + return model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + + @classmethod + def _handle_segment_update_error(cls, segment: DocumentSegment, error: Exception): + """ + Handle errors during segment update by marking segment as error state. + + Args: + segment: The segment that encountered an error + error: The exception that occurred + """ + logging.exception("update segment index failed") + segment.enabled = False + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.status = "error" + segment.error = str(error) + db.session.commit() + @classmethod def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): indexing_cache_key = "segment_{}_delete_indexing".format(segment.id)