diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 91b1efb3d7..49ca98624a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1623,85 +1623,177 @@ class DocumentService: Raises: DocumentIndexingError: If document is being indexed or not in correct state + ValueError: If action is invalid """ if not document_ids: return + # Early validation of action parameter + valid_actions = ["enable", "disable", "archive", "un_archive"] + if action not in valid_actions: + raise ValueError(f"Invalid action: {action}. Must be one of {valid_actions}") + + documents_to_update = [] + + # First pass: validate all documents and prepare updates for document_id in document_ids: document = DocumentService.get_document(dataset.id, document_id) - if not document: continue + # Check if document is being indexed indexing_cache_key = f"document_{document.id}_indexing" cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise DocumentIndexingError(f"Document:{document.name} is being indexed, please try again later") - if action == "enable": - if document.enabled: - continue - document.enabled = True - document.disabled_at = None - document.disabled_by = None - document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.commit() + # Prepare update based on action + update_info = DocumentService._prepare_document_status_update(document, action, user) + if update_info: + documents_to_update.append(update_info) - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) + # Second pass: apply all updates in a single transaction + if documents_to_update: + try: + for update_info in documents_to_update: + document = update_info["document"] + updates = update_info["updates"] - add_document_to_index_task.delay(document_id) + # Apply updates to the document + for field, value in updates.items(): + setattr(document, field, value) - elif action == "disable": - if not document.completed_at or document.indexing_status != "completed": - raise DocumentIndexingError(f"Document: {document.name} is not completed.") - if not document.enabled: - continue + db.session.add(document) - document.enabled = False - document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - document.disabled_by = user.id - document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + # Batch commit all changes db.session.commit() + except Exception as e: + # Rollback on any error + db.session.rollback() + raise e + # Execute async tasks and set Redis cache after successful commit + # propagation_error is used to capture any errors for submitting async task execution + propagation_error = None + for update_info in documents_to_update: + try: + # Execute async tasks after successful commit + if update_info["async_task"]: + task_info = update_info["async_task"] + task_func = task_info["function"] + task_args = task_info["args"] + task_func.delay(*task_args) + except Exception as e: + # Log the error but do not rollback the transaction + logging.exception(f"Error executing async task for document {update_info['document'].id}") + # don't raise the error immediately, but capture it for later + propagation_error = e + try: + # Set Redis cache if needed after successful commit + if update_info["set_cache"]: + document = update_info["document"] + indexing_cache_key = f"document_{document.id}_indexing" + redis_client.setex(indexing_cache_key, 600, 1) + except Exception as e: + # Log the error but do not rollback the transaction + logging.exception(f"Error setting cache for document {update_info['document'].id}") + # Raise any propagation error after all updates + if propagation_error: + raise propagation_error - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) + @staticmethod + def _prepare_document_status_update(document, action: str, user): + """ + Prepare document status update information. - remove_document_from_index_task.delay(document_id) + Args: + document: Document object to update + action: Action to perform + user: Current user - elif action == "archive": - if document.archived: - continue + Returns: + dict: Update information or None if no update needed + """ + now = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - document.archived = True - document.archived_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - document.archived_by = user.id - document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.commit() + if action == "enable": + return DocumentService._prepare_enable_update(document, now) + elif action == "disable": + return DocumentService._prepare_disable_update(document, user, now) + elif action == "archive": + return DocumentService._prepare_archive_update(document, user, now) + elif action == "un_archive": + return DocumentService._prepare_unarchive_update(document, now) - if document.enabled: - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) + return None - remove_document_from_index_task.delay(document_id) + @staticmethod + def _prepare_enable_update(document, now): + """Prepare updates for enabling a document.""" + if document.enabled: + return None - elif action == "un_archive": - if not document.archived: - continue - document.archived = False - document.archived_at = None - document.archived_by = None - document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.commit() + return { + "document": document, + "updates": {"enabled": True, "disabled_at": None, "disabled_by": None, "updated_at": now}, + "async_task": {"function": add_document_to_index_task, "args": [document.id]}, + "set_cache": True, + } - # Only re-index if the document is currently enabled - if document.enabled: - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - add_document_to_index_task.delay(document_id) + @staticmethod + def _prepare_disable_update(document, user, now): + """Prepare updates for disabling a document.""" + if not document.completed_at or document.indexing_status != "completed": + raise DocumentIndexingError(f"Document: {document.name} is not completed.") - else: - raise ValueError(f"Invalid action: {action}") + if not document.enabled: + return None + + return { + "document": document, + "updates": {"enabled": False, "disabled_at": now, "disabled_by": user.id, "updated_at": now}, + "async_task": {"function": remove_document_from_index_task, "args": [document.id]}, + "set_cache": True, + } + + @staticmethod + def _prepare_archive_update(document, user, now): + """Prepare updates for archiving a document.""" + if document.archived: + return None + + update_info = { + "document": document, + "updates": {"archived": True, "archived_at": now, "archived_by": user.id, "updated_at": now}, + "async_task": None, + "set_cache": False, + } + + # Only set async task and cache if document is currently enabled + if document.enabled: + update_info["async_task"] = {"function": remove_document_from_index_task, "args": [document.id]} + update_info["set_cache"] = True + + return update_info + + @staticmethod + def _prepare_unarchive_update(document, now): + """Prepare updates for unarchiving a document.""" + if not document.archived: + return None + + update_info = { + "document": document, + "updates": {"archived": False, "archived_at": None, "archived_by": None, "updated_at": now}, + "async_task": None, + "set_cache": False, + } + + # Only re-index if the document is currently enabled + if document.enabled: + update_info["async_task"] = {"function": add_document_to_index_task, "args": [document.id]} + update_info["set_cache"] = True + + return update_info class SegmentService: diff --git a/api/tests/unit_tests/services/test_dataset_service.py b/api/tests/unit_tests/services/test_dataset_service.py index 294529580d..65bd616293 100644 --- a/api/tests/unit_tests/services/test_dataset_service.py +++ b/api/tests/unit_tests/services/test_dataset_service.py @@ -110,8 +110,10 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase): expected_task_calls = [call("doc-1"), call("doc-2")] mock_add_task.delay.assert_has_calls(expected_task_calls) - # Verify database commits (one per document) - assert mock_db.commit.call_count == 2 + # Verify database add counts (one add for one document) + assert mock_db.add.call_count == 2 + # Verify database commits (one commit for the batch operation) + assert mock_db.commit.call_count == 1 @patch("extensions.ext_database.db.session") @patch("services.dataset_service.remove_document_from_index_task") @@ -190,8 +192,10 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase): expected_task_calls = [call("doc-1"), call("doc-2")] mock_remove_task.delay.assert_has_calls(expected_task_calls) - # Verify database commits (one per document) - assert mock_db.commit.call_count == 2 + # Verify database add counts (one add for one document) + assert mock_db.add.call_count == 2 + # Verify database commits (totally 1 for any batch operation) + assert mock_db.commit.call_count == 1 @patch("extensions.ext_database.db.session") @patch("services.dataset_service.remove_document_from_index_task") @@ -254,6 +258,8 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase): # Verify async task was triggered to remove from index (because enabled) mock_remove_task.delay.assert_called_once_with("doc-1") + # Verify database add + mock_db.add.assert_called_once() # Verify database commit mock_db.commit.assert_called_once() @@ -318,6 +324,8 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase): # Verify async task was triggered to add back to index (because enabled) mock_add_task.delay.assert_called_once_with("doc-3") + # Verify database add + mock_db.add.assert_called_once() # Verify database commit mock_db.commit.assert_called_once() @@ -651,7 +659,9 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase): # Verify only the disabled document was processed # (enabled and archived documents should be skipped for enable action) - # Only one commit should occur (for the disabled document that was enabled) + # Only one add should occur (for the disabled document that was enabled) + mock_db.add.assert_called_once() + # Only one commit should occur mock_db.commit.assert_called_once() # Only one Redis setex should occur (for the document that was enabled) @@ -719,6 +729,8 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase): # Verify no index removal task was triggered (document is disabled) mock_remove_task.delay.assert_not_called() + # Verify database add still occurred + mock_db.add.assert_called_once() # Verify database commit still occurred mock_db.commit.assert_called_once() @@ -944,6 +956,8 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase): # Verify no index addition task was triggered (document is disabled) mock_add_task.delay.assert_not_called() + # Verify database add still occurred + mock_db.add.assert_called_once() # Verify database commit still occurred mock_db.commit.assert_called_once() @@ -1003,6 +1017,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase): assert "Celery task error" in str(exc_info.value) # Verify database operations completed successfully + mock_db.add.assert_called_once() mock_db.commit.assert_called_once() # Verify Redis cache was set successfully @@ -1079,8 +1094,10 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase): assert mock_doc.disabled_by is None assert mock_doc.updated_at == current_time.replace(tzinfo=None) - # Verify database commits occurred for each document - assert mock_db.commit.call_count == 100 + # Verify database commits, one add for one document + assert mock_db.add.call_count == 100 + # Verify database commits, one commit for the batch operation + assert mock_db.commit.call_count == 1 # Verify Redis cache operations occurred for each document assert redis_mock.setex.call_count == 100 @@ -1208,7 +1225,8 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase): assert doc5.enabled == True # No change # Verify database commits occurred for processed documents - # Only doc1 should be committed (doc2, doc3, doc4, doc5 were skipped, doc6 doesn't exist) + # Only doc1 should be added (doc2, doc3, doc4, doc5 were skipped, doc6 doesn't exist) + assert mock_db.add.call_count == 1 assert mock_db.commit.call_count == 1 # Verify Redis cache operations occurred for processed documents