feat: refactor: batch_update_document_status with unit tests #21324

pull/21325/head
neatguycoding 11 months ago
parent 034ee3747d
commit d239c39af9

@ -1623,85 +1623,177 @@ class DocumentService:
Raises: Raises:
DocumentIndexingError: If document is being indexed or not in correct state DocumentIndexingError: If document is being indexed or not in correct state
ValueError: If action is invalid
""" """
if not document_ids: if not document_ids:
return return
# Early validation of action parameter
valid_actions = ["enable", "disable", "archive", "un_archive"]
if action not in valid_actions:
raise ValueError(f"Invalid action: {action}. Must be one of {valid_actions}")
documents_to_update = []
# First pass: validate all documents and prepare updates
for document_id in document_ids: for document_id in document_ids:
document = DocumentService.get_document(dataset.id, document_id) document = DocumentService.get_document(dataset.id, document_id)
if not document: if not document:
continue continue
# Check if document is being indexed
indexing_cache_key = f"document_{document.id}_indexing" indexing_cache_key = f"document_{document.id}_indexing"
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None: if cache_result is not None:
raise DocumentIndexingError(f"Document:{document.name} is being indexed, please try again later") raise DocumentIndexingError(f"Document:{document.name} is being indexed, please try again later")
if action == "enable": # Prepare update based on action
if document.enabled: update_info = DocumentService._prepare_document_status_update(document, action, user)
continue if update_info:
document.enabled = True documents_to_update.append(update_info)
document.disabled_at = None
document.disabled_by = None
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times # Second pass: apply all updates in a single transaction
redis_client.setex(indexing_cache_key, 600, 1) if documents_to_update:
try:
for update_info in documents_to_update:
document = update_info["document"]
updates = update_info["updates"]
add_document_to_index_task.delay(document_id) # Apply updates to the document
for field, value in updates.items():
setattr(document, field, value)
elif action == "disable": db.session.add(document)
if not document.completed_at or document.indexing_status != "completed":
raise DocumentIndexingError(f"Document: {document.name} is not completed.")
if not document.enabled:
continue
document.enabled = False # Batch commit all changes
document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.disabled_by = user.id
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
except Exception as e:
# Rollback on any error
db.session.rollback()
raise e
# Execute async tasks and set Redis cache after successful commit
# propagation_error is used to capture any errors for submitting async task execution
propagation_error = None
for update_info in documents_to_update:
try:
# Execute async tasks after successful commit
if update_info["async_task"]:
task_info = update_info["async_task"]
task_func = task_info["function"]
task_args = task_info["args"]
task_func.delay(*task_args)
except Exception as e:
# Log the error but do not rollback the transaction
logging.exception(f"Error executing async task for document {update_info['document'].id}")
# don't raise the error immediately, but capture it for later
propagation_error = e
try:
# Set Redis cache if needed after successful commit
if update_info["set_cache"]:
document = update_info["document"]
indexing_cache_key = f"document_{document.id}_indexing"
redis_client.setex(indexing_cache_key, 600, 1)
except Exception as e:
# Log the error but do not rollback the transaction
logging.exception(f"Error setting cache for document {update_info['document'].id}")
# Raise any propagation error after all updates
if propagation_error:
raise propagation_error
# Set cache to prevent indexing the same document multiple times @staticmethod
redis_client.setex(indexing_cache_key, 600, 1) def _prepare_document_status_update(document, action: str, user):
"""
Prepare document status update information.
remove_document_from_index_task.delay(document_id) Args:
document: Document object to update
action: Action to perform
user: Current user
elif action == "archive": Returns:
if document.archived: dict: Update information or None if no update needed
continue """
now = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.archived = True if action == "enable":
document.archived_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) return DocumentService._prepare_enable_update(document, now)
document.archived_by = user.id elif action == "disable":
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) return DocumentService._prepare_disable_update(document, user, now)
db.session.commit() elif action == "archive":
return DocumentService._prepare_archive_update(document, user, now)
elif action == "un_archive":
return DocumentService._prepare_unarchive_update(document, now)
if document.enabled: return None
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
remove_document_from_index_task.delay(document_id) @staticmethod
def _prepare_enable_update(document, now):
"""Prepare updates for enabling a document."""
if document.enabled:
return None
elif action == "un_archive": return {
if not document.archived: "document": document,
continue "updates": {"enabled": True, "disabled_at": None, "disabled_by": None, "updated_at": now},
document.archived = False "async_task": {"function": add_document_to_index_task, "args": [document.id]},
document.archived_at = None "set_cache": True,
document.archived_by = None }
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
# Only re-index if the document is currently enabled @staticmethod
if document.enabled: def _prepare_disable_update(document, user, now):
# Set cache to prevent indexing the same document multiple times """Prepare updates for disabling a document."""
redis_client.setex(indexing_cache_key, 600, 1) if not document.completed_at or document.indexing_status != "completed":
add_document_to_index_task.delay(document_id) raise DocumentIndexingError(f"Document: {document.name} is not completed.")
else: if not document.enabled:
raise ValueError(f"Invalid action: {action}") return None
return {
"document": document,
"updates": {"enabled": False, "disabled_at": now, "disabled_by": user.id, "updated_at": now},
"async_task": {"function": remove_document_from_index_task, "args": [document.id]},
"set_cache": True,
}
@staticmethod
def _prepare_archive_update(document, user, now):
"""Prepare updates for archiving a document."""
if document.archived:
return None
update_info = {
"document": document,
"updates": {"archived": True, "archived_at": now, "archived_by": user.id, "updated_at": now},
"async_task": None,
"set_cache": False,
}
# Only set async task and cache if document is currently enabled
if document.enabled:
update_info["async_task"] = {"function": remove_document_from_index_task, "args": [document.id]}
update_info["set_cache"] = True
return update_info
@staticmethod
def _prepare_unarchive_update(document, now):
"""Prepare updates for unarchiving a document."""
if not document.archived:
return None
update_info = {
"document": document,
"updates": {"archived": False, "archived_at": None, "archived_by": None, "updated_at": now},
"async_task": None,
"set_cache": False,
}
# Only re-index if the document is currently enabled
if document.enabled:
update_info["async_task"] = {"function": add_document_to_index_task, "args": [document.id]}
update_info["set_cache"] = True
return update_info
class SegmentService: class SegmentService:

@ -110,8 +110,10 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase):
expected_task_calls = [call("doc-1"), call("doc-2")] expected_task_calls = [call("doc-1"), call("doc-2")]
mock_add_task.delay.assert_has_calls(expected_task_calls) mock_add_task.delay.assert_has_calls(expected_task_calls)
# Verify database commits (one per document) # Verify database add counts (one add for one document)
assert mock_db.commit.call_count == 2 assert mock_db.add.call_count == 2
# Verify database commits (one commit for the batch operation)
assert mock_db.commit.call_count == 1
@patch("extensions.ext_database.db.session") @patch("extensions.ext_database.db.session")
@patch("services.dataset_service.remove_document_from_index_task") @patch("services.dataset_service.remove_document_from_index_task")
@ -190,8 +192,10 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase):
expected_task_calls = [call("doc-1"), call("doc-2")] expected_task_calls = [call("doc-1"), call("doc-2")]
mock_remove_task.delay.assert_has_calls(expected_task_calls) mock_remove_task.delay.assert_has_calls(expected_task_calls)
# Verify database commits (one per document) # Verify database add counts (one add for one document)
assert mock_db.commit.call_count == 2 assert mock_db.add.call_count == 2
# Verify database commits (totally 1 for any batch operation)
assert mock_db.commit.call_count == 1
@patch("extensions.ext_database.db.session") @patch("extensions.ext_database.db.session")
@patch("services.dataset_service.remove_document_from_index_task") @patch("services.dataset_service.remove_document_from_index_task")
@ -254,6 +258,8 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase):
# Verify async task was triggered to remove from index (because enabled) # Verify async task was triggered to remove from index (because enabled)
mock_remove_task.delay.assert_called_once_with("doc-1") mock_remove_task.delay.assert_called_once_with("doc-1")
# Verify database add
mock_db.add.assert_called_once()
# Verify database commit # Verify database commit
mock_db.commit.assert_called_once() mock_db.commit.assert_called_once()
@ -318,6 +324,8 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase):
# Verify async task was triggered to add back to index (because enabled) # Verify async task was triggered to add back to index (because enabled)
mock_add_task.delay.assert_called_once_with("doc-3") mock_add_task.delay.assert_called_once_with("doc-3")
# Verify database add
mock_db.add.assert_called_once()
# Verify database commit # Verify database commit
mock_db.commit.assert_called_once() mock_db.commit.assert_called_once()
@ -651,7 +659,9 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase):
# Verify only the disabled document was processed # Verify only the disabled document was processed
# (enabled and archived documents should be skipped for enable action) # (enabled and archived documents should be skipped for enable action)
# Only one commit should occur (for the disabled document that was enabled) # Only one add should occur (for the disabled document that was enabled)
mock_db.add.assert_called_once()
# Only one commit should occur
mock_db.commit.assert_called_once() mock_db.commit.assert_called_once()
# Only one Redis setex should occur (for the document that was enabled) # Only one Redis setex should occur (for the document that was enabled)
@ -719,6 +729,8 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase):
# Verify no index removal task was triggered (document is disabled) # Verify no index removal task was triggered (document is disabled)
mock_remove_task.delay.assert_not_called() mock_remove_task.delay.assert_not_called()
# Verify database add still occurred
mock_db.add.assert_called_once()
# Verify database commit still occurred # Verify database commit still occurred
mock_db.commit.assert_called_once() mock_db.commit.assert_called_once()
@ -944,6 +956,8 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase):
# Verify no index addition task was triggered (document is disabled) # Verify no index addition task was triggered (document is disabled)
mock_add_task.delay.assert_not_called() mock_add_task.delay.assert_not_called()
# Verify database add still occurred
mock_db.add.assert_called_once()
# Verify database commit still occurred # Verify database commit still occurred
mock_db.commit.assert_called_once() mock_db.commit.assert_called_once()
@ -1003,6 +1017,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase):
assert "Celery task error" in str(exc_info.value) assert "Celery task error" in str(exc_info.value)
# Verify database operations completed successfully # Verify database operations completed successfully
mock_db.add.assert_called_once()
mock_db.commit.assert_called_once() mock_db.commit.assert_called_once()
# Verify Redis cache was set successfully # Verify Redis cache was set successfully
@ -1079,8 +1094,10 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase):
assert mock_doc.disabled_by is None assert mock_doc.disabled_by is None
assert mock_doc.updated_at == current_time.replace(tzinfo=None) assert mock_doc.updated_at == current_time.replace(tzinfo=None)
# Verify database commits occurred for each document # Verify database commits, one add for one document
assert mock_db.commit.call_count == 100 assert mock_db.add.call_count == 100
# Verify database commits, one commit for the batch operation
assert mock_db.commit.call_count == 1
# Verify Redis cache operations occurred for each document # Verify Redis cache operations occurred for each document
assert redis_mock.setex.call_count == 100 assert redis_mock.setex.call_count == 100
@ -1208,7 +1225,8 @@ class TestDatasetServiceBatchUpdateDocumentStatus(unittest.TestCase):
assert doc5.enabled == True # No change assert doc5.enabled == True # No change
# Verify database commits occurred for processed documents # Verify database commits occurred for processed documents
# Only doc1 should be committed (doc2, doc3, doc4, doc5 were skipped, doc6 doesn't exist) # Only doc1 should be added (doc2, doc3, doc4, doc5 were skipped, doc6 doesn't exist)
assert mock_db.add.call_count == 1
assert mock_db.commit.call_count == 1 assert mock_db.commit.call_count == 1
# Verify Redis cache operations occurred for processed documents # Verify Redis cache operations occurred for processed documents

Loading…
Cancel
Save