From 689cdd30f94fdcfb329a29f92ce72d7d1051cd62 Mon Sep 17 00:00:00 2001 From: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Wed, 25 Jun 2025 19:37:57 +0800 Subject: [PATCH] feat: refactor: refine unit tests for save_document_with_dataset_id --- ...t_service_save_document_with_dataset_id.py | 1016 +++++++++-------- 1 file changed, 510 insertions(+), 506 deletions(-) diff --git a/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py b/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py index 77112dfaa7..8a71add0f2 100644 --- a/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py +++ b/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py @@ -1,4 +1,4 @@ -import unittest +from typing import Optional from unittest.mock import Mock, patch import pytest @@ -9,31 +9,21 @@ from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig -class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): - """ - Comprehensive unit tests for DocumentService.save_document_with_dataset_id. - - This test suite covers all major code branches including: - - Billing and quota validation - - Different data source types (upload_file, notion_import, website_crawl) - - Duplicate document handling - - Process rule validation and error cases - - Exception handling and edge cases - - Database session operations (add and flush) - """ +class DocumentSaveTestDataFactory: + """Factory class for creating test data and mock objects for document save tests.""" - def setUp(self): - """Set up common test fixtures and mock objects.""" - self.dataset_id = "ds1" - self.tenant_id = "tenant1" - self.user_id = "user1" - self.batch_id = "batch1" - - def _create_mock_dataset(self, data_source_type=None, indexing_technique=None): - """Create a mock Dataset object with common attributes.""" + @staticmethod + def create_dataset_mock( + dataset_id: str = "ds1", + tenant_id: str = "tenant1", + data_source_type: Optional[str] = None, + indexing_technique: Optional[str] = None, + **kwargs, + ) -> Mock: + """Create a mock Dataset object with specified attributes.""" dataset = Mock(spec=Dataset) - dataset.id = self.dataset_id - dataset.tenant_id = self.tenant_id + dataset.id = dataset_id + dataset.tenant_id = tenant_id dataset.data_source_type = data_source_type dataset.indexing_technique = indexing_technique dataset.retrieval_model = None @@ -41,16 +31,22 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): dataset.embedding_model_provider = None dataset.collection_binding_id = None dataset.latest_process_rule = None + for key, value in kwargs.items(): + setattr(dataset, key, value) return dataset - def _create_mock_account(self): + @staticmethod + def create_account_mock(user_id: str = "user1", name: str = "Test User") -> Mock: """Create a mock Account object.""" account = Mock(spec=Account) - account.id = self.user_id - account.name = "Test User" + account.id = user_id + account.name = name return account - def _create_mock_features(self, billing_enabled=True, plan="pro", quota_limit=100, quota_size=0): + @staticmethod + def create_features_mock( + billing_enabled: bool = True, plan: str = "pro", quota_limit: int = 100, quota_size: int = 0 + ) -> Mock: """Create a mock features object for billing tests.""" features = Mock() features.billing.enabled = billing_enabled @@ -60,16 +56,17 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): features.documents_upload_quota.size = quota_size return features - def _create_mock_knowledge_config( - self, - data_source_type, - original_document_id=None, - file_ids=None, - notion_pages=None, - website_urls=None, - indexing_technique="high_quality", - duplicate=False, - ): + @staticmethod + def create_knowledge_config_mock( + data_source_type: str, + original_document_id: Optional[str] = None, + file_ids: Optional[list[str]] = None, + notion_pages: Optional[list[Mock]] = None, + website_urls: Optional[list[str]] = None, + indexing_technique: str = "high_quality", + duplicate: bool = False, + **kwargs, + ) -> Mock: """Create a mock KnowledgeConfig object with specified data source configuration.""" knowledge_config = Mock(spec=KnowledgeConfig) knowledge_config.original_document_id = original_document_id @@ -104,55 +101,151 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): website_info.only_main_content = True knowledge_config.data_source.info_list.website_info_list = website_info + for key, value in kwargs.items(): + setattr(knowledge_config, key, value) return knowledge_config - def _setup_common_mocks(self, mock_current_user, mock_features, mock_redis=None, mock_db=None): - """Set up common mock objects used across multiple tests.""" - mock_current_user.current_tenant_id = self.tenant_id - - if mock_redis: - mock_lock = Mock() - mock_redis.lock.return_value.__enter__ = Mock(return_value=None) - mock_redis.lock.return_value.__exit__ = Mock(return_value=None) - - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.time") - @patch("services.dataset_service.secrets.randbelow", return_value=123456) - @patch("services.dataset_service.DocumentService.build_document") - @patch("services.dataset_service.document_indexing_task.delay") - @patch("services.dataset_service.duplicate_document_indexing_task.delay") - @patch("services.dataset_service.current_user") - @patch("services.dataset_service.ModelManager") - @patch("services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding") - @patch("services.dataset_service.DocumentService.get_documents_position", return_value=0) + @staticmethod + def create_upload_file_mock(file_id: str, name: str) -> Mock: + """Create a mock upload file.""" + upload_file = Mock() + upload_file.id = file_id + upload_file.name = name + return upload_file + + @staticmethod + def create_document_mock(document_id: str) -> Mock: + """Create a mock Document.""" + document = Mock(spec=Document, id=document_id) + return document + + @staticmethod + def create_notion_page_mock(page_id: str, page_name: str, page_type: str = "page") -> Mock: + """Create a mock Notion page.""" + page = Mock() + page.page_id = page_id + page.page_name = page_name + page.page_icon = None + page.type = page_type + return page + + @staticmethod + def create_notion_info_mock(workspace_id: str, pages: list[Mock]) -> Mock: + """Create a mock Notion info.""" + notion_info = Mock() + notion_info.workspace_id = workspace_id + notion_info.pages = pages + return notion_info + + +class TestDocumentServiceSaveDocumentWithDatasetId: + """ + Comprehensive unit tests for DocumentService.save_document_with_dataset_id. + + This test suite covers all major code branches including: + - Billing and quota validation + - Different data source types (upload_file, notion_import, website_crawl) + - Duplicate document handling + - Process rule validation and error cases + - Exception handling and edge cases + - Database session operations (add and flush) + """ + + @pytest.fixture(autouse=True) + def setup_method(self): + """Set up common test fixtures and mock objects.""" + self.dataset_id = "ds1" + self.tenant_id = "tenant1" + self.user_id = "user1" + self.batch_id = "batch1" + + @pytest.fixture + def mock_document_service_dependencies(self): + """Common mock setup for document service dependencies.""" + with ( + patch("services.dataset_service.FeatureService.get_features") as mock_features, + patch("services.dataset_service.db.session") as mock_db, + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.time") as mock_time, + patch("services.dataset_service.secrets.randbelow", return_value=123456) as mock_rand, + patch("services.dataset_service.current_user") as mock_current_user, + ): + mock_current_user.current_tenant_id = self.tenant_id + mock_time.strftime.return_value = "20231201120000" + + yield { + "features": mock_features, + "db_session": mock_db, + "redis_client": mock_redis, + "time": mock_time, + "randbelow": mock_rand, + "current_user": mock_current_user, + } + + @pytest.fixture + def mock_async_task_dependencies(self): + """Mock setup for async task dependencies.""" + with ( + patch("services.dataset_service.DocumentService.build_document") as mock_build_doc, + patch("services.dataset_service.document_indexing_task.delay") as mock_doc_task, + patch("services.dataset_service.duplicate_document_indexing_task.delay") as mock_dup_task, + patch("services.dataset_service.clean_notion_document_task.delay") as mock_clean_task, + ): + yield { + "build_document": mock_build_doc, + "document_indexing_task": mock_doc_task, + "duplicate_document_indexing_task": mock_dup_task, + "clean_notion_document_task": mock_clean_task, + } + + @pytest.fixture + def mock_model_dependencies(self): + """Mock setup for model dependencies.""" + with ( + patch("services.dataset_service.ModelManager") as mock_model_manager, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" + ) as mock_collection_binding, + patch( + "services.dataset_service.DocumentService.get_documents_position", return_value=0 + ) as mock_get_position, + ): + yield { + "model_manager": mock_model_manager, + "collection_binding": mock_collection_binding, + "get_position": mock_get_position, + } + + def _setup_redis_lock(self, mock_redis): + """Helper method to set up Redis lock.""" + mock_lock = Mock() + mock_redis.lock.return_value.__enter__ = Mock(return_value=None) + mock_redis.lock.return_value.__exit__ = Mock(return_value=None) + + def _assert_document_created(self, mock_db, expected_documents: list[Mock]): + """Helper method to verify documents were created and added to session.""" + for doc in expected_documents: + mock_db.add.assert_any_call(doc) + + def _assert_async_task_called(self, mock_task, expected_calls: int = 1): + """Helper method to verify async task was called.""" + assert mock_task.call_count == expected_calls + + # ==================== Upload File Tests ==================== + def test_upload_file_success( - self, - mock_get_position, - mock_collection_binding, - mock_model_manager, - mock_current_user, - mock_dup_task, - mock_doc_task, - mock_build_doc, - mock_rand, - mock_time, - mock_redis, - mock_db, - mock_features, + self, mock_document_service_dependencies, mock_async_task_dependencies, mock_model_dependencies ): """Test successful upload_file document creation with multiple files.""" # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features() - knowledge_config = self._create_mock_knowledge_config( + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock() + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="upload_file", file_ids=["file1", "file2"] ) - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Mock ModelManager mock_model_manager_instance = Mock() @@ -160,272 +253,50 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): mock_embedding_model.model = "embed-model" mock_embedding_model.provider = "openai" mock_model_manager_instance.get_default_model_instance.return_value = mock_embedding_model - mock_model_manager.return_value = mock_model_manager_instance + mock_model_dependencies["model_manager"].return_value = mock_model_manager_instance # Mock collection binding mock_collection_binding_instance = Mock() mock_collection_binding_instance.id = "binding-123" - mock_collection_binding.return_value = mock_collection_binding_instance + mock_model_dependencies["collection_binding"].return_value = mock_collection_binding_instance # Mock build_document - mock_doc1 = Mock(spec=Document, id="doc1") - mock_doc2 = Mock(spec=Document, id="doc2") - mock_build_doc.side_effect = [mock_doc1, mock_doc2] + mock_doc1 = DocumentSaveTestDataFactory.create_document_mock("doc1") + mock_doc2 = DocumentSaveTestDataFactory.create_document_mock("doc2") + mock_async_task_dependencies["build_document"].side_effect = [mock_doc1, mock_doc2] # Mock upload files - upload_file1 = Mock() - upload_file1.id = "file1" - upload_file1.name = "file1.pdf" - upload_file2 = Mock() - upload_file2.id = "file2" - upload_file2.name = "file2.pdf" - mock_db.query.return_value.filter.return_value.first.side_effect = [upload_file1, upload_file2] - - # Mock time - mock_time.strftime.return_value = "20231201120000" + upload_file1 = DocumentSaveTestDataFactory.create_upload_file_mock("file1", "file1.pdf") + upload_file2 = DocumentSaveTestDataFactory.create_upload_file_mock("file2", "file2.pdf") + mock_document_service_dependencies["db_session"].query.return_value.filter.return_value.first.side_effect = [ + upload_file1, + upload_file2, + ] # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) # Assert assert len(docs) == 2 - mock_doc_task.assert_called_once() - mock_dup_task.assert_not_called() + self._assert_async_task_called(mock_async_task_dependencies["document_indexing_task"]) + mock_async_task_dependencies["duplicate_document_indexing_task"].assert_not_called() # Verify the documents were added to session - mock_db.add.assert_any_call(mock_doc1) - mock_db.add.assert_any_call(mock_doc2) - - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_billing_batch_limit_exceeded(self, mock_current_user, mock_features): - """Test that batch upload limit exceeded raises appropriate error.""" - # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=True, plan="sandbox") - knowledge_config = self._create_mock_knowledge_config( - data_source_type="upload_file", file_ids=["file1", "file2"] - ) - - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features - - # Act & Assert - with pytest.raises(ValueError, match="Your current plan does not support batch upload"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_billing_quota_limit_exceeded(self, mock_current_user, mock_features): - """Test that document upload quota exceeded raises appropriate error.""" - # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=True, plan="pro", quota_limit=1, quota_size=1) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="upload_file", file_ids=["file1", "file2"] - ) - - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features - - # Act & Assert - with pytest.raises(ValueError, match="You have reached the limit of your subscription"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_invalid_indexing_technique(self, mock_current_user, mock_features): - """Test that invalid indexing technique raises appropriate error.""" - # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="upload_file", indexing_technique="invalid" - ) - - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features - - # Act & Assert - with pytest.raises(ValueError, match="Indexing technique is invalid"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_no_process_rule_found(self, mock_current_user, mock_features): - """Test that missing process rule raises appropriate error.""" - # Arrange - dataset = self._create_mock_dataset(data_source_type="upload_file", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="upload_file", indexing_technique="high_quality" - ) - knowledge_config.process_rule.rules = None - - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features - - # Act & Assert - with pytest.raises(ValueError, match="No process rule found"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_invalid_process_rule_mode(self, mock_current_user, mock_features, mock_db): - """Test that invalid process rule mode returns None without creating document.""" - # Arrange - dataset = self._create_mock_dataset(data_source_type="upload_file", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="upload_file", indexing_technique="high_quality" - ) - knowledge_config.process_rule.mode = "invalid" - - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features - - # Act - with patch("logging.warning") as mock_log: - result = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - # Assert - assert result is None - mock_log.assert_called() - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_notion_import_no_info(self, mock_current_user, mock_features, mock_redis, mock_db): - """Test that notion_import with missing notion_info_list raises appropriate error.""" - # Arrange - dataset = self._create_mock_dataset(data_source_type="notion_import", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="notion_import", indexing_technique="high_quality" - ) - knowledge_config.data_source.info_list.notion_info_list = None + self._assert_document_created(mock_document_service_dependencies["db_session"], [mock_doc1, mock_doc2]) - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features - - # Act & Assert - with pytest.raises(ValueError, match="No notion info list found"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_website_crawl_no_info(self, mock_current_user, mock_features, mock_redis, mock_db): - """Test that website_crawl with missing website_info raises appropriate error.""" - # Arrange - dataset = self._create_mock_dataset(data_source_type="website_crawl", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="website_crawl", indexing_technique="high_quality" - ) - knowledge_config.data_source.info_list.website_info_list = None - - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features - - # Act & Assert - with pytest.raises(ValueError, match="No website info list found"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.DocumentService.update_document_with_dataset_id") - def test_update_document_branch(self, mock_update_doc, mock_db): - """Test the update document flow when original_document_id is provided.""" - # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - knowledge_config = self._create_mock_knowledge_config( - data_source_type="upload_file", original_document_id="docid" - ) - mock_update_doc.return_value = Mock(batch=self.batch_id) - - # Mock current_user - mock_current_user = Mock() - mock_current_user.current_tenant_id = self.tenant_id - - # Act - with patch("services.dataset_service.current_user", mock_current_user): - docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - # Assert - assert len(docs) == 1 - assert batch == self.batch_id - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_upload_file_file_not_found(self, mock_current_user, mock_features, mock_redis, mock_db): - """Test that missing upload file raises FileNotExistsError.""" - # Arrange - from services.dataset_service import FileNotExistsError - - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config(data_source_type="upload_file", file_ids=["file1"]) - - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features - mock_db.query.return_value.filter.return_value.first.return_value = None - - # Act & Assert - with pytest.raises(FileNotExistsError): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.time") - @patch("services.dataset_service.secrets.randbelow", return_value=123456) - @patch("services.dataset_service.DocumentService.build_document") - @patch("services.dataset_service.document_indexing_task.delay") - @patch("services.dataset_service.duplicate_document_indexing_task.delay") - @patch("services.dataset_service.current_user") - @patch("services.dataset_service.ModelManager") - @patch("services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding") - @patch("services.dataset_service.DocumentService.get_documents_position", return_value=0) def test_upload_file_duplicate( - self, - mock_get_position, - mock_collection_binding, - mock_model_manager, - mock_current_user, - mock_dup_task, - mock_doc_task, - mock_build_doc, - mock_rand, - mock_time, - mock_redis, - mock_db, - mock_features, + self, mock_document_service_dependencies, mock_async_task_dependencies, mock_model_dependencies ): """Test upload_file with duplicate=True when document already exists.""" # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features() - knowledge_config = self._create_mock_knowledge_config( + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock() + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="upload_file", file_ids=["file1"], duplicate=True ) - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Mock ModelManager mock_model_manager_instance = Mock() @@ -433,195 +304,140 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): mock_embedding_model.model = "embed-model" mock_embedding_model.provider = "openai" mock_model_manager_instance.get_default_model_instance.return_value = mock_embedding_model - mock_model_manager.return_value = mock_model_manager_instance + mock_model_dependencies["model_manager"].return_value = mock_model_manager_instance # Mock collection binding mock_collection_binding_instance = Mock() mock_collection_binding_instance.id = "binding-123" - mock_collection_binding.return_value = mock_collection_binding_instance + mock_model_dependencies["collection_binding"].return_value = mock_collection_binding_instance # Mock upload file and existing document - upload_file = Mock() - upload_file.id = "file1" - upload_file.name = "file1.pdf" - existing_doc = Mock(id="docid", name="file1.pdf") - mock_db.query.return_value.filter.return_value.first.return_value = upload_file - mock_db.query.return_value.filter_by.return_value.first.return_value = existing_doc - - # Mock time - mock_time.strftime.return_value = "20231201120000" + upload_file = DocumentSaveTestDataFactory.create_upload_file_mock("file1", "file1.pdf") + existing_doc = DocumentSaveTestDataFactory.create_document_mock("docid") + existing_doc.name = "file1.pdf" + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = upload_file + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter_by.return_value.first.return_value = existing_doc # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) # Assert assert len(docs) == 1 - mock_dup_task.assert_called_once() + self._assert_async_task_called(mock_async_task_dependencies["duplicate_document_indexing_task"]) # Verify the existing document was added to session with updated properties - mock_db.add.assert_any_call(existing_doc) + mock_document_service_dependencies["db_session"].add.assert_any_call(existing_doc) # Verify the document properties were updated before being added - # These assertions verify that the duplicate document was properly updated assert existing_doc.batch == "20231201120000223456" assert existing_doc.indexing_status == "waiting" assert existing_doc.created_from == "web" assert existing_doc.doc_form == "pdf" assert existing_doc.doc_language == "en" - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_notion_import_data_source_binding_not_found(self, mock_current_user, mock_features, mock_redis, mock_db): - """Test that missing data source binding for notion_import raises appropriate error.""" + def test_upload_file_file_not_found(self, mock_document_service_dependencies): + """Test that missing upload file raises FileNotExistsError.""" # Arrange - dataset = self._create_mock_dataset(data_source_type="notion_import", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) + from services.dataset_service import FileNotExistsError - notion_info = Mock() - notion_info.workspace_id = "ws1" - notion_info.pages = [] - knowledge_config = self._create_mock_knowledge_config( - data_source_type="notion_import", indexing_technique="high_quality" + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", file_ids=["file1"] ) - knowledge_config.data_source.info_list.notion_info_list = [notion_info] - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features - mock_db.query.return_value.filter.return_value.first.return_value = None + mock_document_service_dependencies["features"].return_value = features + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = None # Act & Assert - with pytest.raises(ValueError, match="Data source binding not found."): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - @patch("services.dataset_service.document_indexing_task.delay") - def test_website_crawl_url_too_long( - self, mock_document_indexing_task, mock_current_user, mock_features, mock_redis, mock_db - ): - """Test that long URLs are properly truncated in website_crawl document names.""" - # Arrange - dataset = self._create_mock_dataset(data_source_type="website_crawl", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="website_crawl", website_urls=["http://" + "a" * 300] - ) - - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features - mock_db.query.return_value.filter.return_value.first.return_value = True - - # Act - with patch("services.dataset_service.DocumentService.build_document") as mock_build_doc: + with pytest.raises(FileNotExistsError): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - # Assert - args, kwargs = mock_build_doc.call_args - assert args[9].startswith("http://") - assert len(args[9]) < 256 + # ==================== Notion Import Tests ==================== - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - @patch("services.dataset_service.DocumentService.build_document") - @patch("services.dataset_service.document_indexing_task.delay") - @patch("services.dataset_service.clean_notion_document_task.delay") - def test_notion_import_success( - self, mock_clean_task, mock_doc_task, mock_build_doc, mock_current_user, mock_features, mock_redis, mock_db - ): + def test_notion_import_success(self, mock_document_service_dependencies, mock_async_task_dependencies): """Test successful notion_import document creation for new pages.""" # Arrange - dataset = self._create_mock_dataset(data_source_type="notion_import", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="notion_import", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) - notion_info = Mock() - notion_info.workspace_id = "ws1" - page = Mock() - page.page_id = "page1" - page.page_name = "Test Page" - page.page_icon = None - page.type = "page" - notion_info.pages = [page] + page = DocumentSaveTestDataFactory.create_notion_page_mock("page1", "Test Page") + notion_info = DocumentSaveTestDataFactory.create_notion_info_mock("ws1", [page]) - knowledge_config = self._create_mock_knowledge_config( + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="notion_import", indexing_technique="high_quality" ) knowledge_config.data_source.info_list.notion_info_list = [notion_info] - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Mock existing documents query (empty) - mock_db.query.return_value.filter_by.return_value.all.return_value = [] + mock_document_service_dependencies["db_session"].query.return_value.filter_by.return_value.all.return_value = [] # Mock data source binding binding = Mock() binding.id = "binding1" - mock_db.query.return_value.filter.return_value.first.return_value = binding + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = binding # Mock build_document - mock_doc = Mock(spec=Document, id="doc1") - mock_build_doc.return_value = mock_doc + mock_doc = DocumentSaveTestDataFactory.create_document_mock("doc1") + mock_async_task_dependencies["build_document"].return_value = mock_doc # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) # Assert assert len(docs) == 1 - mock_doc_task.assert_called_once() + self._assert_async_task_called(mock_async_task_dependencies["document_indexing_task"]) # Verify the document was added to the database session - mock_db.add.assert_called_with(mock_doc) - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - @patch("services.dataset_service.DocumentService.build_document") - @patch("services.dataset_service.clean_notion_document_task.delay") - @patch("services.dataset_service.document_indexing_task.delay") - def test_notion_import_page_exists( - self, mock_doc_task, mock_clean_task, mock_build_doc, mock_current_user, mock_features, mock_redis, mock_db - ): + mock_document_service_dependencies["db_session"].add.assert_called_with(mock_doc) + + def test_notion_import_page_exists(self, mock_document_service_dependencies, mock_async_task_dependencies): """Test notion_import when page already exists - should skip creation.""" # Arrange - dataset = self._create_mock_dataset(data_source_type="notion_import", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="notion_import", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) - notion_info = Mock() - notion_info.workspace_id = "ws1" - page = Mock() - page.page_id = "page1" - page.page_name = "Test Page" - notion_info.pages = [page] + page = DocumentSaveTestDataFactory.create_notion_page_mock("page1", "Test Page") + notion_info = DocumentSaveTestDataFactory.create_notion_info_mock("ws1", [page]) - knowledge_config = self._create_mock_knowledge_config( + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="notion_import", indexing_technique="high_quality" ) knowledge_config.data_source.info_list.notion_info_list = [notion_info] - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Mock existing document with same page_id existing_doc = Mock() existing_doc.data_source_info = '{"notion_page_id": "page1"}' existing_doc.id = "doc1" - mock_db.query.return_value.filter_by.return_value.all.return_value = [existing_doc] + mock_document_service_dependencies["db_session"].query.return_value.filter_by.return_value.all.return_value = [ + existing_doc + ] # Mock data source binding binding = Mock() binding.id = "binding1" - mock_db.query.return_value.filter.return_value.first.return_value = binding + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = binding # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) @@ -629,130 +445,318 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): # Assert assert len(docs) == 0 # No document should be created since it already exists - for call in mock_db.add.call_args_list: + for call in mock_document_service_dependencies["db_session"].add.call_args_list: args, kwargs = call assert not any(isinstance(arg, Document) for arg in args), "Method was called with a Document!" - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - @patch("services.dataset_service.DocumentService.build_document") - @patch("services.dataset_service.document_indexing_task.delay") - def test_website_crawl_success( - self, mock_doc_task, mock_build_doc, mock_current_user, mock_features, mock_redis, mock_db - ): + def test_notion_import_no_info(self, mock_document_service_dependencies): + """Test that notion_import with missing notion_info_list raises appropriate error.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="notion_import", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="notion_import", indexing_technique="high_quality" + ) + knowledge_config.data_source.info_list.notion_info_list = None + + mock_document_service_dependencies["features"].return_value = features + + # Act & Assert + with pytest.raises(ValueError, match="No notion info list found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + def test_notion_import_data_source_binding_not_found(self, mock_document_service_dependencies): + """Test that missing data source binding for notion_import raises appropriate error.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="notion_import", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + + notion_info = DocumentSaveTestDataFactory.create_notion_info_mock("ws1", []) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="notion_import", indexing_technique="high_quality" + ) + knowledge_config.data_source.info_list.notion_info_list = [notion_info] + + mock_document_service_dependencies["features"].return_value = features + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = None + + # Act & Assert + with pytest.raises(ValueError, match="Data source binding not found."): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # ==================== Website Crawl Tests ==================== + + def test_website_crawl_success(self, mock_document_service_dependencies, mock_async_task_dependencies): """Test successful website_crawl document creation for multiple URLs.""" # Arrange - dataset = self._create_mock_dataset(data_source_type="website_crawl", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="website_crawl", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="website_crawl", website_urls=["http://example1.com", "http://example2.com"] ) - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Mock build_document - mock_doc1 = Mock(spec=Document, id="doc1") - mock_doc2 = Mock(spec=Document, id="doc2") - mock_build_doc.side_effect = [mock_doc1, mock_doc2] + mock_doc1 = DocumentSaveTestDataFactory.create_document_mock("doc1") + mock_doc2 = DocumentSaveTestDataFactory.create_document_mock("doc2") + mock_async_task_dependencies["build_document"].side_effect = [mock_doc1, mock_doc2] # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) # Assert assert len(docs) == 2 - assert mock_build_doc.call_count == 2 - mock_doc_task.assert_called_once() + assert mock_async_task_dependencies["build_document"].call_count == 2 + self._assert_async_task_called(mock_async_task_dependencies["document_indexing_task"]) # Verify database session operations - mock_db.add.assert_any_call(mock_doc1) - mock_db.add.assert_any_call(mock_doc2) - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_unknown_data_source_type(self, mock_current_user, mock_features, mock_redis, mock_db): - """Test that unknown data source type is handled gracefully.""" + self._assert_document_created(mock_document_service_dependencies["db_session"], [mock_doc1, mock_doc2]) + + def test_website_crawl_no_info(self, mock_document_service_dependencies): + """Test that website_crawl with missing website_info raises appropriate error.""" # Arrange - dataset = self._create_mock_dataset(data_source_type="unknown_type", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="unknown_type", indexing_technique="high_quality" + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="website_crawl", indexing_technique="high_quality" ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="website_crawl", indexing_technique="high_quality" + ) + knowledge_config.data_source.info_list.website_info_list = None + + mock_document_service_dependencies["features"].return_value = features - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features + # Act & Assert + with pytest.raises(ValueError, match="No website info list found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + def test_website_crawl_url_too_long(self, mock_document_service_dependencies, mock_async_task_dependencies): + """Test that long URLs are properly truncated in website_crawl document names.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="website_crawl", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="website_crawl", website_urls=["http://" + "a" * 300] + ) + + mock_document_service_dependencies["features"].return_value = features + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = True # Act - result = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + with patch("services.dataset_service.DocumentService.build_document") as mock_build_doc: + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - # Assert - assert result is None or len(result[0]) == 0 + # Assert + args, kwargs = mock_build_doc.call_args + assert args[9].startswith("http://") + assert len(args[9]) < 256 + + # ==================== Billing and Quota Tests ==================== + + def test_billing_batch_limit_exceeded(self, mock_document_service_dependencies): + """Test that batch upload limit exceeded raises appropriate error.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=True, plan="sandbox") + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", file_ids=["file1", "file2"] + ) + + mock_document_service_dependencies["features"].return_value = features - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_upload_file_batch_limit_exceeded(self, mock_current_user, mock_features): + # Act & Assert + with pytest.raises(ValueError, match="Your current plan does not support batch upload"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + def test_billing_quota_limit_exceeded(self, mock_document_service_dependencies): + """Test that document upload quota exceeded raises appropriate error.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock( + billing_enabled=True, plan="pro", quota_limit=1, quota_size=1 + ) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", file_ids=["file1", "file2"] + ) + + mock_document_service_dependencies["features"].return_value = features + + # Act & Assert + with pytest.raises(ValueError, match="You have reached the limit of your subscription"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + def test_upload_file_batch_limit_exceeded(self, mock_document_service_dependencies): """Test that upload_file batch limit exceeded raises appropriate error.""" # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features() - knowledge_config = self._create_mock_knowledge_config( + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock() + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="upload_file", file_ids=["file" + str(i) for i in range(100)] ) - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Act & Assert with patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 50): with pytest.raises(ValueError, match="You have reached the batch upload limit"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_notion_import_batch_limit_exceeded(self, mock_current_user, mock_features): + def test_notion_import_batch_limit_exceeded(self, mock_document_service_dependencies): """Test that notion_import batch limit exceeded raises appropriate error.""" # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features() + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock() - notion_info = Mock() - notion_info.pages = [Mock() for _ in range(100)] # 100 pages + notion_info = DocumentSaveTestDataFactory.create_notion_info_mock("ws1", [Mock() for _ in range(100)]) - knowledge_config = self._create_mock_knowledge_config(data_source_type="notion_import") + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock(data_source_type="notion_import") knowledge_config.data_source.info_list.notion_info_list = [notion_info] - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Act & Assert with patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 50): with pytest.raises(ValueError, match="You have reached the batch upload limit"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_website_crawl_batch_limit_exceeded(self, mock_current_user, mock_features): + def test_website_crawl_batch_limit_exceeded(self, mock_document_service_dependencies): """Test that website_crawl batch limit exceeded raises appropriate error.""" # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features() - knowledge_config = self._create_mock_knowledge_config( + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock() + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="website_crawl", website_urls=["http://example" + str(i) + ".com" for i in range(100)] ) - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Act & Assert with patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 50): with pytest.raises(ValueError, match="You have reached the batch upload limit"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # ==================== Process Rule Tests ==================== + + def test_invalid_indexing_technique(self, mock_document_service_dependencies): + """Test that invalid indexing technique raises appropriate error.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", indexing_technique="invalid" + ) + + mock_document_service_dependencies["features"].return_value = features + + # Act & Assert + with pytest.raises(ValueError, match="Indexing technique is invalid"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + def test_no_process_rule_found(self, mock_document_service_dependencies): + """Test that missing process rule raises appropriate error.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="upload_file", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", indexing_technique="high_quality" + ) + knowledge_config.process_rule.rules = None + + mock_document_service_dependencies["features"].return_value = features + + # Act & Assert + with pytest.raises(ValueError, match="No process rule found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + def test_invalid_process_rule_mode(self, mock_document_service_dependencies): + """Test that invalid process rule mode returns None without creating document.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="upload_file", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", indexing_technique="high_quality" + ) + knowledge_config.process_rule.mode = "invalid" + + mock_document_service_dependencies["features"].return_value = features + + # Act + with patch("logging.warning") as mock_log: + result = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert + assert result is None + mock_log.assert_called() + + # ==================== Update Document Tests ==================== + + def test_update_document_branch(self, mock_document_service_dependencies): + """Test the update document flow when original_document_id is provided.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", original_document_id="docid" + ) + + with patch("services.dataset_service.DocumentService.update_document_with_dataset_id") as mock_update_doc: + mock_update_doc.return_value = Mock(batch=self.batch_id) + + # Act + docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert + assert len(docs) == 1 + assert batch == self.batch_id + + # ==================== Edge Case Tests ==================== + + def test_unknown_data_source_type(self, mock_document_service_dependencies): + """Test that unknown data source type is handled gracefully.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="unknown_type", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="unknown_type", indexing_technique="high_quality" + ) + + mock_document_service_dependencies["features"].return_value = features + + # Act + result = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert + assert result is None or len(result[0]) == 0