From 4e1b17332ca52eb3ff11e745d08baab8e299ad4f Mon Sep 17 00:00:00 2001 From: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Tue, 24 Jun 2025 17:11:20 +0800 Subject: [PATCH 1/7] feat: refactor: save_document_with_dataset_id, fix error shadowing --- api/services/dataset_service.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 49ca98624a..bd2a38f271 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1075,7 +1075,7 @@ class DocumentService: raise ValueError("No notion info list found.") exist_page_ids = [] exist_document = {} - documents = ( + documents_from_db = ( db.session.query(Document) .filter_by( dataset_id=dataset.id, @@ -1085,8 +1085,8 @@ class DocumentService: ) .all() ) - if documents: - for document in documents: + if documents_from_db: + for document in documents_from_db: data_source_info = json.loads(document.data_source_info) exist_page_ids.append(data_source_info["notion_page_id"]) exist_document[data_source_info["notion_page_id"]] = document.id From a955616a21f33387b968906502f94f43de1e72b3 Mon Sep 17 00:00:00 2001 From: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Tue, 24 Jun 2025 17:12:04 +0800 Subject: [PATCH 2/7] feat: refactor: add unit test for original save_document_with_dataset_id --- ...t_service_save_document_with_dataset_id.py | 824 ++++++++++++++++++ 1 file changed, 824 insertions(+) create mode 100644 api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py diff --git a/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py b/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py new file mode 100644 index 0000000000..4bfaf42c2d --- /dev/null +++ b/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py @@ -0,0 +1,824 @@ +from unittest.mock import Mock, patch + +import pytest + +from models.account import Account +from models.dataset import Dataset, Document +from services.dataset_service import DocumentService +from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig + + +class TestDocumentServiceSaveDocumentWithDatasetId: + """ + Full branch unit tests for DocumentService.save_document_with_dataset_id. + This suite covers all main branches, including: + - Billing and quota checks + - Data source types: upload_file, notion_import, website_crawl + - Duplicate document handling + - Process rule creation and error cases + - Exception and edge cases + """ + + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.db.session") + @patch("services.dataset_service.redis_client") + @patch("services.dataset_service.time") + @patch("services.dataset_service.secrets.randbelow", return_value=123456) + @patch("services.dataset_service.DocumentService.build_document") + @patch("services.dataset_service.document_indexing_task.delay") + @patch("services.dataset_service.duplicate_document_indexing_task.delay") + @patch("services.dataset_service.current_user") + @patch("services.dataset_service.ModelManager") + @patch("services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding") + @patch("services.dataset_service.DocumentService.get_documents_position", return_value=0) + def test_upload_file_success( + self, + mock_get_position, + mock_collection_binding, + mock_model_manager, + mock_current_user, + mock_dup_task, + mock_doc_task, + mock_build_doc, + mock_rand, + mock_time, + mock_redis, + mock_db, + mock_features, + ): + """ + Test successful upload_file document creation, including duplicate and non-duplicate cases. + """ + # Setup mocks and input + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + dataset.data_source_type = None + dataset.indexing_technique = None + dataset.retrieval_model = None + dataset.embedding_model = None + dataset.embedding_model_provider = None + dataset.collection_binding_id = None + dataset.latest_process_rule = None + + account = Mock(spec=Account) + account.id = "user1" + account.name = "User One" + + # Mock current_user + mock_current_user.current_tenant_id = "tenant1" + + # Mock features + features = Mock() + features.billing.enabled = True + features.billing.subscription.plan = "pro" + features.documents_upload_quota.limit = 100 + features.documents_upload_quota.size = 0 + mock_features.return_value = features + + # Mock knowledge_config for upload_file with proper nested structure + knowledge_config = Mock(spec=KnowledgeConfig) + knowledge_config.original_document_id = None + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "upload_file" + knowledge_config.data_source.info_list.file_info_list = Mock() + knowledge_config.data_source.info_list.file_info_list.file_ids = ["file1", "file2"] + knowledge_config.indexing_technique = "high_quality" + knowledge_config.embedding_model = "embed-model" + knowledge_config.embedding_model_provider = "openai" + knowledge_config.retrieval_model = None + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "custom" + knowledge_config.process_rule.rules = Mock() + knowledge_config.doc_form = "pdf" + knowledge_config.doc_language = "en" + knowledge_config.duplicate = False + + # Mock ModelManager + mock_model_manager_instance = Mock() + mock_embedding_model = Mock() + mock_embedding_model.model = "embed-model" + mock_embedding_model.provider = "openai" + mock_model_manager_instance.get_default_model_instance.return_value = mock_embedding_model + mock_model_manager.return_value = mock_model_manager_instance + + # Mock collection binding + mock_collection_binding_instance = Mock() + mock_collection_binding_instance.id = "binding-123" + mock_collection_binding.return_value = mock_collection_binding_instance + + # Mock build_document + mock_doc1 = Mock(spec=Document, id="doc1") + mock_doc2 = Mock(spec=Document, id="doc2") + mock_build_doc.side_effect = [mock_doc1, mock_doc2] + + # Mock db.session.query(UploadFile) + upload_file1 = Mock() + upload_file1.id = "file1" + upload_file1.name = "file1.pdf" + upload_file2 = Mock() + upload_file2.id = "file2" + upload_file2.name = "file2.pdf" + mock_db.query.return_value.filter.return_value.first.side_effect = [upload_file1, upload_file2] + + # Mock redis lock + mock_lock = Mock() + mock_redis.lock.return_value.__enter__ = Mock(return_value=None) + mock_redis.lock.return_value.__exit__ = Mock(return_value=None) + + # Mock time.strftime + mock_time.strftime.return_value = "20231201120000" + + # Run + docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert + assert len(docs) == 2 + mock_doc_task.assert_called_once() + mock_dup_task.assert_not_called() + + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + def test_billing_batch_limit(self, mock_current_user, mock_features): + """ + Test batch upload limit exceeded raises ValueError. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = True + features.billing.subscription.plan = "sandbox" + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "upload_file" + knowledge_config.data_source.info_list.file_info_list = Mock() + knowledge_config.data_source.info_list.file_info_list.file_ids = ["file1", "file2"] + with pytest.raises(ValueError, match="Your current plan does not support batch upload"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + def test_billing_quota_limit(self, mock_current_user, mock_features): + """ + Test document upload quota exceeded raises ValueError. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = True + features.billing.subscription.plan = "pro" + features.documents_upload_quota.limit = 1 + features.documents_upload_quota.size = 1 + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "upload_file" + knowledge_config.data_source.info_list.file_info_list = Mock() + knowledge_config.data_source.info_list.file_info_list.file_ids = ["file1", "file2"] + with pytest.raises(ValueError, match="You have reached the limit of your subscription"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + def test_invalid_indexing_technique(self, mock_current_user, mock_features): + """ + Test invalid indexing technique raises ValueError. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + dataset.data_source_type = None + dataset.indexing_technique = None + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = False + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "upload_file" + knowledge_config.indexing_technique = "invalid" + with pytest.raises(ValueError, match="Indexing technique is invalid"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + def test_no_process_rule_found(self, mock_current_user, mock_features): + """ + Test no process rule found raises ValueError. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + dataset.latest_process_rule = None + dataset.data_source_type = "upload_file" + dataset.indexing_technique = "high_quality" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = False + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "upload_file" + knowledge_config.indexing_technique = "high_quality" + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "custom" + knowledge_config.process_rule.rules = None + with pytest.raises(ValueError, match="No process rule found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + @patch("services.dataset_service.db.session") + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + def test_invalid_process_rule_mode(self, mock_current_user, mock_features, mock_db): + """ + Test invalid process rule mode returns None (no document created). + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + dataset.latest_process_rule = None + dataset.data_source_type = "upload_file" + dataset.indexing_technique = "high_quality" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = False + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "upload_file" + knowledge_config.indexing_technique = "high_quality" + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "invalid" + with patch("logging.warning") as mock_log: + result = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + assert result is None + mock_log.assert_called() + + @patch("services.dataset_service.db.session") + @patch("services.dataset_service.redis_client") + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + def test_notion_import_no_info(self, mock_current_user, mock_features, mock_redis, mock_db): + """ + Test notion_import with no notion_info_list raises ValueError. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + dataset.data_source_type = "notion_import" + dataset.indexing_technique = "high_quality" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = False + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "automatic" + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "notion_import" + knowledge_config.data_source.info_list.notion_info_list = None + with pytest.raises(ValueError, match="No notion info list found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + @patch("services.dataset_service.db.session") + @patch("services.dataset_service.redis_client") + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + def test_website_crawl_no_info(self, mock_current_user, mock_features, mock_redis, mock_db): + """ + Test website_crawl with no website_info raises ValueError. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + dataset.data_source_type = "website_crawl" + dataset.indexing_technique = "high_quality" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = False + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "automatic" + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "website_crawl" + knowledge_config.data_source.info_list.website_info_list = None + with pytest.raises(ValueError, match="No website info list found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + @patch("services.dataset_service.DocumentService.update_document_with_dataset_id") + def test_update_document_branch(self, mock_update_doc): + """ + Test the branch where original_document_id is provided (update flow). + """ + dataset = Mock(spec=Dataset) + account = Mock(spec=Account) + knowledge_config = Mock() + knowledge_config.original_document_id = "docid" + mock_update_doc.return_value = Mock(batch="batch1") + # Mock current_user + mock_current_user = Mock() + mock_current_user.current_tenant_id = "tenant-123" + # Patch current_user to return the mock + with patch("services.dataset_service.current_user", mock_current_user): + docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + assert len(docs) == 1 + assert batch == "batch1" + + @patch("services.dataset_service.db.session") + @patch("services.dataset_service.redis_client") + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + def test_upload_file_file_not_found(self, mock_current_user, mock_features, mock_redis, mock_db): + """ + Test upload_file: should raise FileNotExistsError if file not found in db. + """ + from services.dataset_service import FileNotExistsError + + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = False + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "automatic" + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "upload_file" + knowledge_config.data_source.info_list.file_info_list = Mock() + knowledge_config.data_source.info_list.file_info_list.file_ids = ["file1"] + mock_db.query.return_value.filter.return_value.first.return_value = None + with pytest.raises(FileNotExistsError): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.db.session") + @patch("services.dataset_service.redis_client") + @patch("services.dataset_service.time") + @patch("services.dataset_service.secrets.randbelow", return_value=123456) + @patch("services.dataset_service.DocumentService.build_document") + @patch("services.dataset_service.document_indexing_task.delay") + @patch("services.dataset_service.duplicate_document_indexing_task.delay") + @patch("services.dataset_service.current_user") + @patch("services.dataset_service.ModelManager") + @patch("services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding") + @patch("services.dataset_service.DocumentService.get_documents_position", return_value=0) + def test_upload_file_duplicate( + self, + mock_get_position, + mock_collection_binding, + mock_model_manager, + mock_current_user, + mock_dup_task, + mock_doc_task, + mock_build_doc, + mock_rand, + mock_time, + mock_redis, + mock_db, + mock_features, + ): + """ + Test upload_file: duplicate=True and document already exists, should update and append to documents. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + dataset.data_source_type = None + dataset.indexing_technique = None + dataset.retrieval_model = None + dataset.embedding_model = None + dataset.embedding_model_provider = None + dataset.collection_binding_id = None + dataset.latest_process_rule = None + account = Mock(spec=Account) + account.id = "user1" + account.name = "User One" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = True + features.billing.subscription.plan = "pro" + features.documents_upload_quota.limit = 100 + features.documents_upload_quota.size = 0 + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "upload_file" + knowledge_config.data_source.info_list.file_info_list = Mock() + knowledge_config.data_source.info_list.file_info_list.file_ids = ["file1"] + knowledge_config.indexing_technique = "high_quality" + knowledge_config.embedding_model = "embed-model" + knowledge_config.embedding_model_provider = "openai" + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "custom" + knowledge_config.process_rule.rules = Mock() + knowledge_config.doc_form = "pdf" + knowledge_config.doc_language = "en" + knowledge_config.duplicate = True + mock_model_manager_instance = Mock() + mock_embedding_model = Mock() + mock_embedding_model.model = "embed-model" + mock_embedding_model.provider = "openai" + mock_model_manager_instance.get_default_model_instance.return_value = mock_embedding_model + mock_model_manager.return_value = mock_model_manager_instance + mock_collection_binding_instance = Mock() + mock_collection_binding_instance.id = "binding-123" + mock_collection_binding.return_value = mock_collection_binding_instance + upload_file = Mock() + upload_file.id = "file1" + upload_file.name = "file1.pdf" + mock_db.query.return_value.filter.return_value.first.side_effect = [ + upload_file, + Mock(id="docid", name="file1.pdf"), + ] # file, then document + mock_redis.lock.return_value.__enter__ = Mock(return_value=None) + mock_redis.lock.return_value.__exit__ = Mock(return_value=None) + mock_time.strftime.return_value = "20231201120000" + docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + assert len(docs) == 1 + mock_dup_task.assert_called_once() + + @patch("services.dataset_service.db.session") + @patch("services.dataset_service.redis_client") + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + def test_notion_import_data_source_binding_not_found(self, mock_current_user, mock_features, mock_redis, mock_db): + """ + Test notion_import: should raise ValueError if data source binding not found. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + dataset.data_source_type = "notion_import" + dataset.indexing_technique = "high_quality" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = False + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "automatic" + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "notion_import" + notion_info = Mock() + notion_info.workspace_id = "ws1" + notion_info.pages = [] + knowledge_config.data_source.info_list.notion_info_list = [notion_info] + mock_db.query.return_value.filter.return_value.first.return_value = None + with pytest.raises(ValueError, match="Data source binding not found."): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + @patch("services.dataset_service.db.session") + @patch("services.dataset_service.redis_client") + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + @patch("services.dataset_service.document_indexing_task.delay") + def test_website_crawl_url_too_long( + self, mock_document_indexing_task, mock_current_user, mock_features, mock_redis, mock_db + ): + """ + Test website_crawl: url longer than 255 chars should be truncated in document name. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + dataset.data_source_type = "website_crawl" + dataset.indexing_technique = "high_quality" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = False + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "automatic" + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "website_crawl" + website_info = Mock() + website_info.urls = ["http://" + "a" * 300] + website_info.provider = "test" + website_info.job_id = "job1" + website_info.only_main_content = True + knowledge_config.data_source.info_list.website_info_list = website_info + mock_db.query.return_value.filter.return_value.first.return_value = True + # Patch build_document to check name truncation + with patch("services.dataset_service.DocumentService.build_document") as mock_build_doc: + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + args, kwargs = mock_build_doc.call_args + assert args[9].startswith("http://") + assert len(args[9]) < 256 + + @patch("services.dataset_service.db.session") + @patch("services.dataset_service.redis_client") + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + @patch("services.dataset_service.DocumentService.build_document") + @patch("services.dataset_service.document_indexing_task.delay") + @patch("services.dataset_service.clean_notion_document_task.delay") + def test_notion_import_success( + self, mock_clean_task, mock_doc_task, mock_build_doc, mock_current_user, mock_features, mock_redis, mock_db + ): + """ + Test notion_import: successful document creation for new pages. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + dataset.data_source_type = "notion_import" + dataset.indexing_technique = "high_quality" + account = Mock(spec=Account) + account.id = "user1" + account.name = "User One" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = False + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "automatic" + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "notion_import" + notion_info = Mock() + notion_info.workspace_id = "ws1" + page = Mock() + page.page_id = "page1" + page.page_name = "Test Page" + page.page_icon = None + page.type = "page" + notion_info.pages = [page] + knowledge_config.data_source.info_list.notion_info_list = [notion_info] + # Mock existing documents query (empty) + mock_db.query.return_value.filter_by.return_value.all.return_value = [] + # Mock data source binding + binding = Mock() + binding.id = "binding1" + mock_db.query.return_value.filter.return_value.first.return_value = binding + # Mock build_document + mock_doc = Mock(spec=Document, id="doc1") + mock_build_doc.return_value = mock_doc + docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + assert len(docs) == 1 + mock_doc_task.assert_called_once() + + @patch("services.dataset_service.db.session") + @patch("services.dataset_service.redis_client") + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + @patch("services.dataset_service.DocumentService.build_document") + @patch("services.dataset_service.clean_notion_document_task.delay") + @patch("services.dataset_service.document_indexing_task.delay") + def test_notion_import_page_exists( + self, mock_doc_task, mock_clean_task, mock_build_doc, mock_current_user, mock_features, mock_redis, mock_db + ): + """ + Test notion_import: page already exists, should skip creation and clean old documents. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + dataset.data_source_type = "notion_import" + dataset.indexing_technique = "high_quality" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = False + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "automatic" + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "notion_import" + notion_info = Mock() + notion_info.workspace_id = "ws1" + page = Mock() + page.page_id = "page1" + page.page_name = "Test Page" + notion_info.pages = [page] + knowledge_config.data_source.info_list.notion_info_list = [notion_info] + # Mock existing document with same page_id + existing_doc = Mock() + existing_doc.data_source_info = '{"notion_page_id": "page1"}' + existing_doc.id = "doc1" + mock_db.query.return_value.filter_by.return_value.all.return_value = [existing_doc] + # Mock data source binding + binding = Mock() + binding.id = "binding1" + mock_db.query.return_value.filter.return_value.first.return_value = binding + docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + assert len(docs) == 0 + mock_clean_task.assert_not_called() + + @patch("services.dataset_service.db.session") + @patch("services.dataset_service.redis_client") + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + @patch("services.dataset_service.DocumentService.build_document") + @patch("services.dataset_service.document_indexing_task.delay") + def test_website_crawl_success( + self, mock_doc_task, mock_build_doc, mock_current_user, mock_features, mock_redis, mock_db + ): + """ + Test website_crawl: successful document creation for multiple URLs. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + dataset.data_source_type = "website_crawl" + dataset.indexing_technique = "high_quality" + account = Mock(spec=Account) + account.id = "user1" + account.name = "User One" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = False + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "automatic" + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "website_crawl" + website_info = Mock() + website_info.urls = ["http://example1.com", "http://example2.com"] + website_info.provider = "test" + website_info.job_id = "job1" + website_info.only_main_content = True + knowledge_config.data_source.info_list.website_info_list = website_info + # Mock build_document + mock_doc1 = Mock(spec=Document, id="doc1") + mock_doc2 = Mock(spec=Document, id="doc2") + mock_build_doc.side_effect = [mock_doc1, mock_doc2] + docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + assert len(docs) == 2 + assert mock_build_doc.call_count == 2 + mock_doc_task.assert_called_once() + + @patch("services.dataset_service.db.session") + @patch("services.dataset_service.redis_client") + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + def test_unknown_data_source_type(self, mock_current_user, mock_features, mock_redis, mock_db): + """ + Test unknown data_source_type: should not raise error but return None when no matching branch. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + dataset.data_source_type = "unknown_type" + dataset.indexing_technique = "high_quality" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = False + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "automatic" + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "unknown_type" + # This should not raise an error but return None due to no matching data source type + result = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + # The method should handle unknown data source types gracefully + assert result is None or len(result[0]) == 0 + + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + def test_upload_file_batch_limit_exceeded(self, mock_current_user, mock_features): + """ + Test upload_file: batch upload limit exceeded raises ValueError. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = True + features.billing.subscription.plan = "pro" + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "upload_file" + knowledge_config.data_source.info_list.file_info_list = Mock() + # Create a list with more than BATCH_UPLOAD_LIMIT files + knowledge_config.data_source.info_list.file_info_list.file_ids = ["file" + str(i) for i in range(100)] + with patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 50): + with pytest.raises(ValueError, match="You have reached the batch upload limit"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + def test_notion_import_batch_limit_exceeded(self, mock_current_user, mock_features): + """ + Test notion_import: batch upload limit exceeded raises ValueError. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = True + features.billing.subscription.plan = "pro" + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "notion_import" + notion_info = Mock() + notion_info.pages = [Mock() for _ in range(100)] # 100 pages + knowledge_config.data_source.info_list.notion_info_list = [notion_info] + with patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 50): + with pytest.raises(ValueError, match="You have reached the batch upload limit"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + @patch("services.dataset_service.FeatureService.get_features") + @patch("services.dataset_service.current_user") + def test_website_crawl_batch_limit_exceeded(self, mock_current_user, mock_features): + """ + Test website_crawl: batch upload limit exceeded raises ValueError. + """ + dataset = Mock(spec=Dataset) + dataset.id = "ds1" + dataset.tenant_id = "tenant1" + account = Mock(spec=Account) + account.id = "user1" + mock_current_user.current_tenant_id = "tenant1" + features = Mock() + features.billing.enabled = True + features.billing.subscription.plan = "pro" + mock_features.return_value = features + knowledge_config = Mock() + knowledge_config.original_document_id = None + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = "website_crawl" + website_info = Mock() + website_info.urls = ["http://example" + str(i) + ".com" for i in range(100)] # 100 URLs + knowledge_config.data_source.info_list.website_info_list = website_info + with patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 50): + with pytest.raises(ValueError, match="You have reached the batch upload limit"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) From 92bc10de96fa9ba3b052ec034c08fa467c76a039 Mon Sep 17 00:00:00 2001 From: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Wed, 25 Jun 2025 09:37:26 +0800 Subject: [PATCH 3/7] feat: refactor: improve the readability of the unit test for original save_document_with_dataset_id --- ...t_service_save_document_with_dataset_id.py | 875 ++++++++---------- 1 file changed, 391 insertions(+), 484 deletions(-) diff --git a/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py b/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py index 4bfaf42c2d..c59b423e74 100644 --- a/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py +++ b/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py @@ -1,3 +1,4 @@ +import unittest from unittest.mock import Mock, patch import pytest @@ -8,17 +9,111 @@ from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig -class TestDocumentServiceSaveDocumentWithDatasetId: +class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): """ - Full branch unit tests for DocumentService.save_document_with_dataset_id. - This suite covers all main branches, including: - - Billing and quota checks - - Data source types: upload_file, notion_import, website_crawl + Comprehensive unit tests for DocumentService.save_document_with_dataset_id. + + This test suite covers all major code branches including: + - Billing and quota validation + - Different data source types (upload_file, notion_import, website_crawl) - Duplicate document handling - - Process rule creation and error cases - - Exception and edge cases + - Process rule validation and error cases + - Exception handling and edge cases """ + def setUp(self): + """Set up common test fixtures and mock objects.""" + self.dataset_id = "ds1" + self.tenant_id = "tenant1" + self.user_id = "user1" + self.batch_id = "batch1" + + def _create_mock_dataset(self, data_source_type=None, indexing_technique=None): + """Create a mock Dataset object with common attributes.""" + dataset = Mock(spec=Dataset) + dataset.id = self.dataset_id + dataset.tenant_id = self.tenant_id + dataset.data_source_type = data_source_type + dataset.indexing_technique = indexing_technique + dataset.retrieval_model = None + dataset.embedding_model = None + dataset.embedding_model_provider = None + dataset.collection_binding_id = None + dataset.latest_process_rule = None + return dataset + + def _create_mock_account(self): + """Create a mock Account object.""" + account = Mock(spec=Account) + account.id = self.user_id + account.name = "Test User" + return account + + def _create_mock_features(self, billing_enabled=True, plan="pro", quota_limit=100, quota_size=0): + """Create a mock features object for billing tests.""" + features = Mock() + features.billing.enabled = billing_enabled + if billing_enabled: + features.billing.subscription.plan = plan + features.documents_upload_quota.limit = quota_limit + features.documents_upload_quota.size = quota_size + return features + + def _create_mock_knowledge_config( + self, + data_source_type, + original_document_id=None, + file_ids=None, + notion_pages=None, + website_urls=None, + indexing_technique="high_quality", + duplicate=False, + ): + """Create a mock KnowledgeConfig object with specified data source configuration.""" + knowledge_config = Mock(spec=KnowledgeConfig) + knowledge_config.original_document_id = original_document_id + knowledge_config.indexing_technique = indexing_technique + knowledge_config.embedding_model = "embed-model" + knowledge_config.embedding_model_provider = "openai" + knowledge_config.retrieval_model = None + knowledge_config.doc_form = "pdf" + knowledge_config.doc_language = "en" + knowledge_config.duplicate = duplicate + + # Set up process rule + knowledge_config.process_rule = Mock() + knowledge_config.process_rule.mode = "custom" if data_source_type == "upload_file" else "automatic" + knowledge_config.process_rule.rules = Mock() + + # Set up data source info + knowledge_config.data_source = Mock() + knowledge_config.data_source.info_list = Mock() + knowledge_config.data_source.info_list.data_source_type = data_source_type + + if data_source_type == "upload_file" and file_ids: + knowledge_config.data_source.info_list.file_info_list = Mock() + knowledge_config.data_source.info_list.file_info_list.file_ids = file_ids + elif data_source_type == "notion_import" and notion_pages: + knowledge_config.data_source.info_list.notion_info_list = notion_pages + elif data_source_type == "website_crawl" and website_urls: + website_info = Mock() + website_info.urls = website_urls + website_info.provider = "test" + website_info.job_id = "job1" + website_info.only_main_content = True + knowledge_config.data_source.info_list.website_info_list = website_info + + return knowledge_config + + def _setup_common_mocks(self, mock_current_user, mock_features, mock_redis=None, mock_db=None): + """Set up common mock objects used across multiple tests.""" + mock_current_user.current_tenant_id = self.tenant_id + + if mock_redis: + mock_lock = Mock() + mock_redis.lock.return_value.__enter__ = Mock(return_value=None) + mock_redis.lock.return_value.__exit__ = Mock(return_value=None) + @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.db.session") @patch("services.dataset_service.redis_client") @@ -46,55 +141,18 @@ class TestDocumentServiceSaveDocumentWithDatasetId: mock_db, mock_features, ): - """ - Test successful upload_file document creation, including duplicate and non-duplicate cases. - """ - # Setup mocks and input - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - dataset.data_source_type = None - dataset.indexing_technique = None - dataset.retrieval_model = None - dataset.embedding_model = None - dataset.embedding_model_provider = None - dataset.collection_binding_id = None - dataset.latest_process_rule = None - - account = Mock(spec=Account) - account.id = "user1" - account.name = "User One" - - # Mock current_user - mock_current_user.current_tenant_id = "tenant1" - - # Mock features - features = Mock() - features.billing.enabled = True - features.billing.subscription.plan = "pro" - features.documents_upload_quota.limit = 100 - features.documents_upload_quota.size = 0 + """Test successful upload_file document creation with multiple files.""" + # Arrange + dataset = self._create_mock_dataset() + account = self._create_mock_account() + features = self._create_mock_features() + knowledge_config = self._create_mock_knowledge_config( + data_source_type="upload_file", file_ids=["file1", "file2"] + ) + + self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) mock_features.return_value = features - # Mock knowledge_config for upload_file with proper nested structure - knowledge_config = Mock(spec=KnowledgeConfig) - knowledge_config.original_document_id = None - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "upload_file" - knowledge_config.data_source.info_list.file_info_list = Mock() - knowledge_config.data_source.info_list.file_info_list.file_ids = ["file1", "file2"] - knowledge_config.indexing_technique = "high_quality" - knowledge_config.embedding_model = "embed-model" - knowledge_config.embedding_model_provider = "openai" - knowledge_config.retrieval_model = None - knowledge_config.process_rule = Mock() - knowledge_config.process_rule.mode = "custom" - knowledge_config.process_rule.rules = Mock() - knowledge_config.doc_form = "pdf" - knowledge_config.doc_language = "en" - knowledge_config.duplicate = False - # Mock ModelManager mock_model_manager_instance = Mock() mock_embedding_model = Mock() @@ -113,7 +171,7 @@ class TestDocumentServiceSaveDocumentWithDatasetId: mock_doc2 = Mock(spec=Document, id="doc2") mock_build_doc.side_effect = [mock_doc1, mock_doc2] - # Mock db.session.query(UploadFile) + # Mock upload files upload_file1 = Mock() upload_file1.id = "file1" upload_file1.name = "file1.pdf" @@ -122,15 +180,10 @@ class TestDocumentServiceSaveDocumentWithDatasetId: upload_file2.name = "file2.pdf" mock_db.query.return_value.filter.return_value.first.side_effect = [upload_file1, upload_file2] - # Mock redis lock - mock_lock = Mock() - mock_redis.lock.return_value.__enter__ = Mock(return_value=None) - mock_redis.lock.return_value.__exit__ = Mock(return_value=None) - - # Mock time.strftime + # Mock time mock_time.strftime.return_value = "20231201120000" - # Run + # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) # Assert @@ -140,111 +193,78 @@ class TestDocumentServiceSaveDocumentWithDatasetId: @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") - def test_billing_batch_limit(self, mock_current_user, mock_features): - """ - Test batch upload limit exceeded raises ValueError. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = True - features.billing.subscription.plan = "sandbox" + def test_billing_batch_limit_exceeded(self, mock_current_user, mock_features): + """Test that batch upload limit exceeded raises appropriate error.""" + # Arrange + dataset = self._create_mock_dataset() + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=True, plan="sandbox") + knowledge_config = self._create_mock_knowledge_config( + data_source_type="upload_file", file_ids=["file1", "file2"] + ) + + self._setup_common_mocks(mock_current_user, mock_features) mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "upload_file" - knowledge_config.data_source.info_list.file_info_list = Mock() - knowledge_config.data_source.info_list.file_info_list.file_ids = ["file1", "file2"] + + # Act & Assert with pytest.raises(ValueError, match="Your current plan does not support batch upload"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") - def test_billing_quota_limit(self, mock_current_user, mock_features): - """ - Test document upload quota exceeded raises ValueError. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = True - features.billing.subscription.plan = "pro" - features.documents_upload_quota.limit = 1 - features.documents_upload_quota.size = 1 + def test_billing_quota_limit_exceeded(self, mock_current_user, mock_features): + """Test that document upload quota exceeded raises appropriate error.""" + # Arrange + dataset = self._create_mock_dataset() + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=True, plan="pro", quota_limit=1, quota_size=1) + knowledge_config = self._create_mock_knowledge_config( + data_source_type="upload_file", file_ids=["file1", "file2"] + ) + + self._setup_common_mocks(mock_current_user, mock_features) mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "upload_file" - knowledge_config.data_source.info_list.file_info_list = Mock() - knowledge_config.data_source.info_list.file_info_list.file_ids = ["file1", "file2"] + + # Act & Assert with pytest.raises(ValueError, match="You have reached the limit of your subscription"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") def test_invalid_indexing_technique(self, mock_current_user, mock_features): - """ - Test invalid indexing technique raises ValueError. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - dataset.data_source_type = None - dataset.indexing_technique = None - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = False + """Test that invalid indexing technique raises appropriate error.""" + # Arrange + dataset = self._create_mock_dataset() + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=False) + knowledge_config = self._create_mock_knowledge_config( + data_source_type="upload_file", indexing_technique="invalid" + ) + + self._setup_common_mocks(mock_current_user, mock_features) mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "upload_file" - knowledge_config.indexing_technique = "invalid" + + # Act & Assert with pytest.raises(ValueError, match="Indexing technique is invalid"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") def test_no_process_rule_found(self, mock_current_user, mock_features): - """ - Test no process rule found raises ValueError. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - dataset.latest_process_rule = None - dataset.data_source_type = "upload_file" - dataset.indexing_technique = "high_quality" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = False - mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "upload_file" - knowledge_config.indexing_technique = "high_quality" - knowledge_config.process_rule = Mock() - knowledge_config.process_rule.mode = "custom" + """Test that missing process rule raises appropriate error.""" + # Arrange + dataset = self._create_mock_dataset(data_source_type="upload_file", indexing_technique="high_quality") + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=False) + knowledge_config = self._create_mock_knowledge_config( + data_source_type="upload_file", indexing_technique="high_quality" + ) knowledge_config.process_rule.rules = None + + self._setup_common_mocks(mock_current_user, mock_features) + mock_features.return_value = features + + # Act & Assert with pytest.raises(ValueError, match="No process rule found"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) @@ -252,31 +272,24 @@ class TestDocumentServiceSaveDocumentWithDatasetId: @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") def test_invalid_process_rule_mode(self, mock_current_user, mock_features, mock_db): - """ - Test invalid process rule mode returns None (no document created). - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - dataset.latest_process_rule = None - dataset.data_source_type = "upload_file" - dataset.indexing_technique = "high_quality" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = False - mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "upload_file" - knowledge_config.indexing_technique = "high_quality" - knowledge_config.process_rule = Mock() + """Test that invalid process rule mode returns None without creating document.""" + # Arrange + dataset = self._create_mock_dataset(data_source_type="upload_file", indexing_technique="high_quality") + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=False) + knowledge_config = self._create_mock_knowledge_config( + data_source_type="upload_file", indexing_technique="high_quality" + ) knowledge_config.process_rule.mode = "invalid" + + self._setup_common_mocks(mock_current_user, mock_features) + mock_features.return_value = features + + # Act with patch("logging.warning") as mock_log: result = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert assert result is None mock_log.assert_called() @@ -285,28 +298,20 @@ class TestDocumentServiceSaveDocumentWithDatasetId: @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") def test_notion_import_no_info(self, mock_current_user, mock_features, mock_redis, mock_db): - """ - Test notion_import with no notion_info_list raises ValueError. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - dataset.data_source_type = "notion_import" - dataset.indexing_technique = "high_quality" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = False - mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.process_rule = Mock() - knowledge_config.process_rule.mode = "automatic" - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "notion_import" + """Test that notion_import with missing notion_info_list raises appropriate error.""" + # Arrange + dataset = self._create_mock_dataset(data_source_type="notion_import", indexing_technique="high_quality") + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=False) + knowledge_config = self._create_mock_knowledge_config( + data_source_type="notion_import", indexing_technique="high_quality" + ) knowledge_config.data_source.info_list.notion_info_list = None + + self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) + mock_features.return_value = features + + # Act & Assert with pytest.raises(ValueError, match="No notion info list found"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) @@ -315,79 +320,66 @@ class TestDocumentServiceSaveDocumentWithDatasetId: @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") def test_website_crawl_no_info(self, mock_current_user, mock_features, mock_redis, mock_db): - """ - Test website_crawl with no website_info raises ValueError. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - dataset.data_source_type = "website_crawl" - dataset.indexing_technique = "high_quality" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = False - mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.process_rule = Mock() - knowledge_config.process_rule.mode = "automatic" - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "website_crawl" + """Test that website_crawl with missing website_info raises appropriate error.""" + # Arrange + dataset = self._create_mock_dataset(data_source_type="website_crawl", indexing_technique="high_quality") + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=False) + knowledge_config = self._create_mock_knowledge_config( + data_source_type="website_crawl", indexing_technique="high_quality" + ) knowledge_config.data_source.info_list.website_info_list = None + + self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) + mock_features.return_value = features + + # Act & Assert with pytest.raises(ValueError, match="No website info list found"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + @patch("services.dataset_service.db.session") @patch("services.dataset_service.DocumentService.update_document_with_dataset_id") - def test_update_document_branch(self, mock_update_doc): - """ - Test the branch where original_document_id is provided (update flow). - """ - dataset = Mock(spec=Dataset) - account = Mock(spec=Account) - knowledge_config = Mock() - knowledge_config.original_document_id = "docid" - mock_update_doc.return_value = Mock(batch="batch1") + def test_update_document_branch(self, mock_update_doc, mock_db): + """Test the update document flow when original_document_id is provided.""" + # Arrange + dataset = self._create_mock_dataset() + account = self._create_mock_account() + knowledge_config = self._create_mock_knowledge_config( + data_source_type="upload_file", original_document_id="docid" + ) + mock_update_doc.return_value = Mock(batch=self.batch_id) + # Mock current_user mock_current_user = Mock() - mock_current_user.current_tenant_id = "tenant-123" - # Patch current_user to return the mock + mock_current_user.current_tenant_id = self.tenant_id + + # Act with patch("services.dataset_service.current_user", mock_current_user): docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert assert len(docs) == 1 - assert batch == "batch1" + assert batch == self.batch_id @patch("services.dataset_service.db.session") @patch("services.dataset_service.redis_client") @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") def test_upload_file_file_not_found(self, mock_current_user, mock_features, mock_redis, mock_db): - """ - Test upload_file: should raise FileNotExistsError if file not found in db. - """ + """Test that missing upload file raises FileNotExistsError.""" + # Arrange from services.dataset_service import FileNotExistsError - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = False + dataset = self._create_mock_dataset() + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=False) + knowledge_config = self._create_mock_knowledge_config(data_source_type="upload_file", file_ids=["file1"]) + + self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.process_rule = Mock() - knowledge_config.process_rule.mode = "automatic" - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "upload_file" - knowledge_config.data_source.info_list.file_info_list = Mock() - knowledge_config.data_source.info_list.file_info_list.file_ids = ["file1"] mock_db.query.return_value.filter.return_value.first.return_value = None + + # Act & Assert with pytest.raises(FileNotExistsError): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) @@ -418,65 +410,45 @@ class TestDocumentServiceSaveDocumentWithDatasetId: mock_db, mock_features, ): - """ - Test upload_file: duplicate=True and document already exists, should update and append to documents. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - dataset.data_source_type = None - dataset.indexing_technique = None - dataset.retrieval_model = None - dataset.embedding_model = None - dataset.embedding_model_provider = None - dataset.collection_binding_id = None - dataset.latest_process_rule = None - account = Mock(spec=Account) - account.id = "user1" - account.name = "User One" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = True - features.billing.subscription.plan = "pro" - features.documents_upload_quota.limit = 100 - features.documents_upload_quota.size = 0 + """Test upload_file with duplicate=True when document already exists.""" + # Arrange + dataset = self._create_mock_dataset() + account = self._create_mock_account() + features = self._create_mock_features() + knowledge_config = self._create_mock_knowledge_config( + data_source_type="upload_file", file_ids=["file1"], duplicate=True + ) + + self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "upload_file" - knowledge_config.data_source.info_list.file_info_list = Mock() - knowledge_config.data_source.info_list.file_info_list.file_ids = ["file1"] - knowledge_config.indexing_technique = "high_quality" - knowledge_config.embedding_model = "embed-model" - knowledge_config.embedding_model_provider = "openai" - knowledge_config.process_rule = Mock() - knowledge_config.process_rule.mode = "custom" - knowledge_config.process_rule.rules = Mock() - knowledge_config.doc_form = "pdf" - knowledge_config.doc_language = "en" - knowledge_config.duplicate = True + + # Mock ModelManager mock_model_manager_instance = Mock() mock_embedding_model = Mock() mock_embedding_model.model = "embed-model" mock_embedding_model.provider = "openai" mock_model_manager_instance.get_default_model_instance.return_value = mock_embedding_model mock_model_manager.return_value = mock_model_manager_instance + + # Mock collection binding mock_collection_binding_instance = Mock() mock_collection_binding_instance.id = "binding-123" mock_collection_binding.return_value = mock_collection_binding_instance + + # Mock upload file and existing document upload_file = Mock() upload_file.id = "file1" upload_file.name = "file1.pdf" - mock_db.query.return_value.filter.return_value.first.side_effect = [ - upload_file, - Mock(id="docid", name="file1.pdf"), - ] # file, then document - mock_redis.lock.return_value.__enter__ = Mock(return_value=None) - mock_redis.lock.return_value.__exit__ = Mock(return_value=None) + existing_doc = Mock(id="docid", name="file1.pdf") + mock_db.query.return_value.filter.return_value.first.side_effect = [upload_file, existing_doc] + + # Mock time mock_time.strftime.return_value = "20231201120000" + + # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert assert len(docs) == 1 mock_dup_task.assert_called_once() @@ -485,32 +457,25 @@ class TestDocumentServiceSaveDocumentWithDatasetId: @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") def test_notion_import_data_source_binding_not_found(self, mock_current_user, mock_features, mock_redis, mock_db): - """ - Test notion_import: should raise ValueError if data source binding not found. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - dataset.data_source_type = "notion_import" - dataset.indexing_technique = "high_quality" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = False - mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.process_rule = Mock() - knowledge_config.process_rule.mode = "automatic" - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "notion_import" + """Test that missing data source binding for notion_import raises appropriate error.""" + # Arrange + dataset = self._create_mock_dataset(data_source_type="notion_import", indexing_technique="high_quality") + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=False) + notion_info = Mock() notion_info.workspace_id = "ws1" notion_info.pages = [] + knowledge_config = self._create_mock_knowledge_config( + data_source_type="notion_import", indexing_technique="high_quality" + ) knowledge_config.data_source.info_list.notion_info_list = [notion_info] + + self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) + mock_features.return_value = features mock_db.query.return_value.filter.return_value.first.return_value = None + + # Act & Assert with pytest.raises(ValueError, match="Data source binding not found."): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) @@ -522,37 +487,24 @@ class TestDocumentServiceSaveDocumentWithDatasetId: def test_website_crawl_url_too_long( self, mock_document_indexing_task, mock_current_user, mock_features, mock_redis, mock_db ): - """ - Test website_crawl: url longer than 255 chars should be truncated in document name. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - dataset.data_source_type = "website_crawl" - dataset.indexing_technique = "high_quality" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = False + """Test that long URLs are properly truncated in website_crawl document names.""" + # Arrange + dataset = self._create_mock_dataset(data_source_type="website_crawl", indexing_technique="high_quality") + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=False) + knowledge_config = self._create_mock_knowledge_config( + data_source_type="website_crawl", website_urls=["http://" + "a" * 300] + ) + + self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.process_rule = Mock() - knowledge_config.process_rule.mode = "automatic" - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "website_crawl" - website_info = Mock() - website_info.urls = ["http://" + "a" * 300] - website_info.provider = "test" - website_info.job_id = "job1" - website_info.only_main_content = True - knowledge_config.data_source.info_list.website_info_list = website_info mock_db.query.return_value.filter.return_value.first.return_value = True - # Patch build_document to check name truncation + + # Act with patch("services.dataset_service.DocumentService.build_document") as mock_build_doc: DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert args, kwargs = mock_build_doc.call_args assert args[9].startswith("http://") assert len(args[9]) < 256 @@ -567,28 +519,12 @@ class TestDocumentServiceSaveDocumentWithDatasetId: def test_notion_import_success( self, mock_clean_task, mock_doc_task, mock_build_doc, mock_current_user, mock_features, mock_redis, mock_db ): - """ - Test notion_import: successful document creation for new pages. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - dataset.data_source_type = "notion_import" - dataset.indexing_technique = "high_quality" - account = Mock(spec=Account) - account.id = "user1" - account.name = "User One" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = False - mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.process_rule = Mock() - knowledge_config.process_rule.mode = "automatic" - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "notion_import" + """Test successful notion_import document creation for new pages.""" + # Arrange + dataset = self._create_mock_dataset(data_source_type="notion_import", indexing_technique="high_quality") + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=False) + notion_info = Mock() notion_info.workspace_id = "ws1" page = Mock() @@ -597,17 +533,31 @@ class TestDocumentServiceSaveDocumentWithDatasetId: page.page_icon = None page.type = "page" notion_info.pages = [page] + + knowledge_config = self._create_mock_knowledge_config( + data_source_type="notion_import", indexing_technique="high_quality" + ) knowledge_config.data_source.info_list.notion_info_list = [notion_info] + + self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) + mock_features.return_value = features + # Mock existing documents query (empty) mock_db.query.return_value.filter_by.return_value.all.return_value = [] + # Mock data source binding binding = Mock() binding.id = "binding1" mock_db.query.return_value.filter.return_value.first.return_value = binding + # Mock build_document mock_doc = Mock(spec=Document, id="doc1") mock_build_doc.return_value = mock_doc + + # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert assert len(docs) == 1 mock_doc_task.assert_called_once() @@ -621,44 +571,42 @@ class TestDocumentServiceSaveDocumentWithDatasetId: def test_notion_import_page_exists( self, mock_doc_task, mock_clean_task, mock_build_doc, mock_current_user, mock_features, mock_redis, mock_db ): - """ - Test notion_import: page already exists, should skip creation and clean old documents. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - dataset.data_source_type = "notion_import" - dataset.indexing_technique = "high_quality" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = False - mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.process_rule = Mock() - knowledge_config.process_rule.mode = "automatic" - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "notion_import" + """Test notion_import when page already exists - should skip creation.""" + # Arrange + dataset = self._create_mock_dataset(data_source_type="notion_import", indexing_technique="high_quality") + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=False) + notion_info = Mock() notion_info.workspace_id = "ws1" page = Mock() page.page_id = "page1" page.page_name = "Test Page" notion_info.pages = [page] + + knowledge_config = self._create_mock_knowledge_config( + data_source_type="notion_import", indexing_technique="high_quality" + ) knowledge_config.data_source.info_list.notion_info_list = [notion_info] + + self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) + mock_features.return_value = features + # Mock existing document with same page_id existing_doc = Mock() existing_doc.data_source_info = '{"notion_page_id": "page1"}' existing_doc.id = "doc1" mock_db.query.return_value.filter_by.return_value.all.return_value = [existing_doc] + # Mock data source binding binding = Mock() binding.id = "binding1" mock_db.query.return_value.filter.return_value.first.return_value = binding + + # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert assert len(docs) == 0 mock_clean_task.assert_not_called() @@ -671,39 +619,27 @@ class TestDocumentServiceSaveDocumentWithDatasetId: def test_website_crawl_success( self, mock_doc_task, mock_build_doc, mock_current_user, mock_features, mock_redis, mock_db ): - """ - Test website_crawl: successful document creation for multiple URLs. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - dataset.data_source_type = "website_crawl" - dataset.indexing_technique = "high_quality" - account = Mock(spec=Account) - account.id = "user1" - account.name = "User One" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = False + """Test successful website_crawl document creation for multiple URLs.""" + # Arrange + dataset = self._create_mock_dataset(data_source_type="website_crawl", indexing_technique="high_quality") + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=False) + knowledge_config = self._create_mock_knowledge_config( + data_source_type="website_crawl", website_urls=["http://example1.com", "http://example2.com"] + ) + + self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.process_rule = Mock() - knowledge_config.process_rule.mode = "automatic" - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "website_crawl" - website_info = Mock() - website_info.urls = ["http://example1.com", "http://example2.com"] - website_info.provider = "test" - website_info.job_id = "job1" - website_info.only_main_content = True - knowledge_config.data_source.info_list.website_info_list = website_info + # Mock build_document mock_doc1 = Mock(spec=Document, id="doc1") mock_doc2 = Mock(spec=Document, id="doc2") mock_build_doc.side_effect = [mock_doc1, mock_doc2] + + # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert assert len(docs) == 2 assert mock_build_doc.call_count == 2 mock_doc_task.assert_called_once() @@ -713,56 +649,40 @@ class TestDocumentServiceSaveDocumentWithDatasetId: @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") def test_unknown_data_source_type(self, mock_current_user, mock_features, mock_redis, mock_db): - """ - Test unknown data_source_type: should not raise error but return None when no matching branch. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - dataset.data_source_type = "unknown_type" - dataset.indexing_technique = "high_quality" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = False + """Test that unknown data source type is handled gracefully.""" + # Arrange + dataset = self._create_mock_dataset(data_source_type="unknown_type", indexing_technique="high_quality") + account = self._create_mock_account() + features = self._create_mock_features(billing_enabled=False) + knowledge_config = self._create_mock_knowledge_config( + data_source_type="unknown_type", indexing_technique="high_quality" + ) + + self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.process_rule = Mock() - knowledge_config.process_rule.mode = "automatic" - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "unknown_type" - # This should not raise an error but return None due to no matching data source type + + # Act result = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - # The method should handle unknown data source types gracefully + + # Assert assert result is None or len(result[0]) == 0 @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") def test_upload_file_batch_limit_exceeded(self, mock_current_user, mock_features): - """ - Test upload_file: batch upload limit exceeded raises ValueError. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = True - features.billing.subscription.plan = "pro" + """Test that upload_file batch limit exceeded raises appropriate error.""" + # Arrange + dataset = self._create_mock_dataset() + account = self._create_mock_account() + features = self._create_mock_features() + knowledge_config = self._create_mock_knowledge_config( + data_source_type="upload_file", file_ids=["file" + str(i) for i in range(100)] + ) + + self._setup_common_mocks(mock_current_user, mock_features) mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "upload_file" - knowledge_config.data_source.info_list.file_info_list = Mock() - # Create a list with more than BATCH_UPLOAD_LIMIT files - knowledge_config.data_source.info_list.file_info_list.file_ids = ["file" + str(i) for i in range(100)] + + # Act & Assert with patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 50): with pytest.raises(ValueError, match="You have reached the batch upload limit"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) @@ -770,27 +690,22 @@ class TestDocumentServiceSaveDocumentWithDatasetId: @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") def test_notion_import_batch_limit_exceeded(self, mock_current_user, mock_features): - """ - Test notion_import: batch upload limit exceeded raises ValueError. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = True - features.billing.subscription.plan = "pro" - mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "notion_import" + """Test that notion_import batch limit exceeded raises appropriate error.""" + # Arrange + dataset = self._create_mock_dataset() + account = self._create_mock_account() + features = self._create_mock_features() + notion_info = Mock() notion_info.pages = [Mock() for _ in range(100)] # 100 pages + + knowledge_config = self._create_mock_knowledge_config(data_source_type="notion_import") knowledge_config.data_source.info_list.notion_info_list = [notion_info] + + self._setup_common_mocks(mock_current_user, mock_features) + mock_features.return_value = features + + # Act & Assert with patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 50): with pytest.raises(ValueError, match="You have reached the batch upload limit"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) @@ -798,27 +713,19 @@ class TestDocumentServiceSaveDocumentWithDatasetId: @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") def test_website_crawl_batch_limit_exceeded(self, mock_current_user, mock_features): - """ - Test website_crawl: batch upload limit exceeded raises ValueError. - """ - dataset = Mock(spec=Dataset) - dataset.id = "ds1" - dataset.tenant_id = "tenant1" - account = Mock(spec=Account) - account.id = "user1" - mock_current_user.current_tenant_id = "tenant1" - features = Mock() - features.billing.enabled = True - features.billing.subscription.plan = "pro" + """Test that website_crawl batch limit exceeded raises appropriate error.""" + # Arrange + dataset = self._create_mock_dataset() + account = self._create_mock_account() + features = self._create_mock_features() + knowledge_config = self._create_mock_knowledge_config( + data_source_type="website_crawl", website_urls=["http://example" + str(i) + ".com" for i in range(100)] + ) + + self._setup_common_mocks(mock_current_user, mock_features) mock_features.return_value = features - knowledge_config = Mock() - knowledge_config.original_document_id = None - knowledge_config.data_source = Mock() - knowledge_config.data_source.info_list = Mock() - knowledge_config.data_source.info_list.data_source_type = "website_crawl" - website_info = Mock() - website_info.urls = ["http://example" + str(i) + ".com" for i in range(100)] # 100 URLs - knowledge_config.data_source.info_list.website_info_list = website_info + + # Act & Assert with patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 50): with pytest.raises(ValueError, match="You have reached the batch upload limit"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) From c3c776d41cb01175a0dc7b358186676cda53a5d6 Mon Sep 17 00:00:00 2001 From: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Wed, 25 Jun 2025 12:08:45 +0800 Subject: [PATCH 4/7] feat: refactor: add db add assert to unit test for original save_document_with_dataset_id --- ...t_service_save_document_with_dataset_id.py | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py b/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py index c59b423e74..1832e8699f 100644 --- a/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py +++ b/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py @@ -3,6 +3,7 @@ from unittest.mock import Mock, patch import pytest +from models import DatasetProcessRule from models.account import Account from models.dataset import Dataset, Document from services.dataset_service import DocumentService @@ -19,6 +20,7 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): - Duplicate document handling - Process rule validation and error cases - Exception handling and edge cases + - Database session operations (add and flush) """ def setUp(self): @@ -191,6 +193,10 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): mock_doc_task.assert_called_once() mock_dup_task.assert_not_called() + # Verify the documents were added to session + mock_db.add.assert_any_call(mock_doc1) + mock_db.add.assert_any_call(mock_doc2) + @patch("services.dataset_service.FeatureService.get_features") @patch("services.dataset_service.current_user") def test_billing_batch_limit_exceeded(self, mock_current_user, mock_features): @@ -440,7 +446,8 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): upload_file.id = "file1" upload_file.name = "file1.pdf" existing_doc = Mock(id="docid", name="file1.pdf") - mock_db.query.return_value.filter.return_value.first.side_effect = [upload_file, existing_doc] + mock_db.query.return_value.filter.return_value.first.return_value = upload_file + mock_db.query.return_value.filter_by.return_value.first.return_value = existing_doc # Mock time mock_time.strftime.return_value = "20231201120000" @@ -452,6 +459,17 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): assert len(docs) == 1 mock_dup_task.assert_called_once() + # Verify the existing document was added to session with updated properties + mock_db.add.assert_any_call(existing_doc) + + # Verify the document properties were updated before being added + # These assertions verify that the duplicate document was properly updated + assert existing_doc.batch == "20231201120000223456" + assert existing_doc.indexing_status == "waiting" + assert existing_doc.created_from == "web" + assert existing_doc.doc_form == "pdf" + assert existing_doc.doc_language == "en" + @patch("services.dataset_service.db.session") @patch("services.dataset_service.redis_client") @patch("services.dataset_service.FeatureService.get_features") @@ -561,6 +579,9 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): assert len(docs) == 1 mock_doc_task.assert_called_once() + # Verify the document was added to the database session + mock_db.add.assert_called_with(mock_doc) + @patch("services.dataset_service.db.session") @patch("services.dataset_service.redis_client") @patch("services.dataset_service.FeatureService.get_features") @@ -608,7 +629,10 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): # Assert assert len(docs) == 0 - mock_clean_task.assert_not_called() + # No document should be created since it already exists + for call in mock_db.add.call_args_list: + args, kwargs = call + assert not any(isinstance(arg, Document) for arg in args), "Method was called with a Document!" @patch("services.dataset_service.db.session") @patch("services.dataset_service.redis_client") @@ -644,6 +668,10 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): assert mock_build_doc.call_count == 2 mock_doc_task.assert_called_once() + # Verify database session operations + mock_db.add.assert_any_call(mock_doc1) + mock_db.add.assert_any_call(mock_doc2) + @patch("services.dataset_service.db.session") @patch("services.dataset_service.redis_client") @patch("services.dataset_service.FeatureService.get_features") From 07e9ee68d1daf9883e460e5462d865902d50887d Mon Sep 17 00:00:00 2001 From: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Wed, 25 Jun 2025 12:10:06 +0800 Subject: [PATCH 5/7] feat: refactor: add db add assert to unit test for original save_document_with_dataset_id --- .../test_document_service_save_document_with_dataset_id.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py b/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py index 1832e8699f..77112dfaa7 100644 --- a/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py +++ b/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py @@ -3,7 +3,6 @@ from unittest.mock import Mock, patch import pytest -from models import DatasetProcessRule from models.account import Account from models.dataset import Dataset, Document from services.dataset_service import DocumentService From 2036d951fc57a5b71b39576b197b8dc1955e6bde Mon Sep 17 00:00:00 2001 From: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Wed, 25 Jun 2025 16:51:33 +0800 Subject: [PATCH 6/7] feat: refactor: save_document_with_dataset_id for readability --- api/services/dataset_service.py | 912 +++++++++++++++++++++++--------- 1 file changed, 648 insertions(+), 264 deletions(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index b36d4bdabc..5359c3711f 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1053,290 +1053,674 @@ class DocumentService: dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", ): - # check document limit + """ + Save documents to a dataset with comprehensive validation and processing. + + This method handles document creation and updates for various data sources including + file uploads, Notion imports, and website crawling. It performs billing checks, + dataset configuration, and triggers indexing tasks. + + Args: + dataset: The target dataset for document storage + knowledge_config: Configuration containing document metadata and processing rules + account: User account performing the operation + dataset_process_rule: Optional pre-existing process rule for the dataset + created_from: Source identifier for document creation (default: "web") + + Returns: + tuple: (list of created/updated documents, batch identifier) + + Raises: + ValueError: When billing limits are exceeded or configuration is invalid + FileNotExistsError: When referenced files are not found + """ + # Validate billing and upload limits for new documents features = FeatureService.get_features(current_user.current_tenant_id) - if features.billing.enabled: - if not knowledge_config.original_document_id: - count = 0 - if knowledge_config.data_source: - if knowledge_config.data_source.info_list.data_source_type == "upload_file": - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore - count = len(upload_file_list) - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": - notion_info_list = knowledge_config.data_source.info_list.notion_info_list - for notion_info in notion_info_list: # type: ignore - count = count + len(notion_info.pages) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": - website_info = knowledge_config.data_source.info_list.website_info_list - count = len(website_info.urls) # type: ignore - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - - if features.billing.subscription.plan == "sandbox" and count > 1: - raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - - DocumentService.check_documents_upload_quota(count, features) - - # if dataset is empty, update dataset data_source_type + if features.billing.enabled and not knowledge_config.original_document_id: + document_count = DocumentService._calculate_document_count(knowledge_config) + if document_count > 0: + DocumentService._validate_upload_limits(document_count, features) + + # Initialize dataset configuration if not already set + DocumentService._initialize_dataset_configuration(dataset, knowledge_config) + + documents = [] + + # Handle document update scenario + if knowledge_config.original_document_id: + document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) + documents.append(document) + batch = document.batch + else: + # Handle new document creation scenario + batch = DocumentService._generate_batch_identifier() + + # Create or retrieve process rule for document processing + dataset_process_rule = DocumentService._prepare_process_rule( + dataset, knowledge_config, account, dataset_process_rule + ) + if not dataset_process_rule: + # keep the original process rule if no valid rule found, but it seems not completed + return + + # Process documents with distributed lock to prevent race conditions + documents, batch = DocumentService._process_new_documents( + dataset, knowledge_config, account, dataset_process_rule, created_from, batch + ) + + return documents, batch + + @staticmethod + def _calculate_document_count(knowledge_config: KnowledgeConfig) -> int: + """ + Calculate the total number of documents to be processed based on data source type. + + Args: + knowledge_config: Configuration containing data source information + + Returns: + int: Total count of documents to be processed + """ + count = 0 + if knowledge_config.data_source: + data_source_info = knowledge_config.data_source.info_list + + if data_source_info.data_source_type == "upload_file": + upload_file_list = data_source_info.file_info_list.file_ids # type: ignore + count = len(upload_file_list) + elif data_source_info.data_source_type == "notion_import": + notion_info_list = data_source_info.notion_info_list + for notion_info in notion_info_list: # type: ignore + count += len(notion_info.pages) + elif data_source_info.data_source_type == "website_crawl": + website_info = data_source_info.website_info_list + count = len(website_info.urls) # type: ignore + + return count + + @staticmethod + def _validate_upload_limits(document_count: int, features: FeatureModel) -> None: + """ + Validate that the document upload operation complies with billing and subscription limits. + + Args: + document_count: Number of documents to be uploaded + features: Feature model containing billing and subscription information + + Raises: + ValueError: When upload limits are exceeded + """ + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + + # Check sandbox plan restrictions + if features.billing.subscription.plan == "sandbox" and document_count > 1: + raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") + + # Check batch upload limit + if document_count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + + # Check document quota + DocumentService.check_documents_upload_quota(document_count, features) + + @staticmethod + def _initialize_dataset_configuration(dataset: Dataset, knowledge_config: KnowledgeConfig) -> None: + """ + Initialize dataset configuration settings if not already configured. + + Args: + dataset: Dataset to be configured + knowledge_config: Configuration containing dataset settings + + Raises: + ValueError: When indexing technique is invalid + """ + # Set data source type if not already configured if not dataset.data_source_type: dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore + # Configure indexing technique and related settings if not dataset.indexing_technique: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: raise ValueError("Indexing technique is invalid") dataset.indexing_technique = knowledge_config.indexing_technique + + # Configure high-quality indexing settings if knowledge_config.indexing_technique == "high_quality": - model_manager = ModelManager() - if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: - dataset_embedding_model = knowledge_config.embedding_model - dataset_embedding_model_provider = knowledge_config.embedding_model_provider - else: - embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) - dataset_embedding_model = embedding_model.model - dataset_embedding_model_provider = embedding_model.provider - dataset.embedding_model = dataset_embedding_model - dataset.embedding_model_provider = dataset_embedding_model_provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - dataset_embedding_model_provider, dataset_embedding_model + DocumentService._configure_high_quality_indexing(dataset, knowledge_config) + + @staticmethod + def _configure_high_quality_indexing(dataset: Dataset, knowledge_config: KnowledgeConfig) -> None: + """ + Configure embedding model and retrieval settings for high-quality indexing. + + Args: + dataset: Dataset to be configured + knowledge_config: Configuration containing model settings + """ + model_manager = ModelManager() + + # Set embedding model configuration + if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: + dataset_embedding_model = knowledge_config.embedding_model + dataset_embedding_model_provider = knowledge_config.embedding_model_provider + else: + embedding_model = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING + ) + dataset_embedding_model = embedding_model.model + dataset_embedding_model_provider = embedding_model.provider + + dataset.embedding_model = dataset_embedding_model + dataset.embedding_model_provider = dataset_embedding_model_provider + + # Configure collection binding + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + dataset_embedding_model_provider, dataset_embedding_model + ) + dataset.collection_binding_id = dataset_collection_binding.id + + # Configure retrieval model if not set + if not dataset.retrieval_model: + default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, + } + + dataset.retrieval_model = ( + knowledge_config.retrieval_model.model_dump() + if knowledge_config.retrieval_model + else default_retrieval_model + ) # type: ignore + + @staticmethod + def _generate_batch_identifier() -> str: + """ + Generate a unique batch identifier for document grouping. + + Returns: + str: Unique batch identifier combining timestamp and random number + """ + return time.strftime("%Y%m%d%H%M%S") + str(100000 + secrets.randbelow(exclusive_upper_bound=900000)) + + @staticmethod + def _prepare_process_rule( + dataset: Dataset, + knowledge_config: KnowledgeConfig, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule], + ) -> Optional[DatasetProcessRule]: + """ + Prepare or create dataset process rule for document processing. + + Args: + dataset: Target dataset + knowledge_config: Configuration containing process rules + account: User account + dataset_process_rule: Optional existing process rule + + Returns: + DatasetProcessRule: Prepared process rule for document processing + + Raises: + ValueError: When no valid process rule can be found + """ + if dataset_process_rule: + return dataset_process_rule + + process_rule = knowledge_config.process_rule + if not process_rule: + return None + + if process_rule.mode in ("custom", "hierarchical"): + if process_rule.rules: + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + created_by=account.id, ) - dataset.collection_binding_id = dataset_collection_binding.id - if not dataset.retrieval_model: - default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, - } + else: + dataset_process_rule = dataset.latest_process_rule + if not dataset_process_rule: + raise ValueError("No process rule found.") + elif process_rule.mode == "automatic": + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id, + ) + else: + logging.warning(f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule") + return None - dataset.retrieval_model = ( - knowledge_config.retrieval_model.model_dump() - if knowledge_config.retrieval_model - else default_retrieval_model - ) # type: ignore + db.session.add(dataset_process_rule) + db.session.commit() + return dataset_process_rule - documents = [] - if knowledge_config.original_document_id: - document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) + @staticmethod + def _process_new_documents( + dataset: Dataset, + knowledge_config: KnowledgeConfig, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule], + created_from: str, + batch: str, + ) -> tuple[list, str]: + """ + Process new documents with distributed locking to prevent race conditions. + + Args: + dataset: Target dataset + knowledge_config: Document configuration + account: User account + dataset_process_rule: Process rule for document processing + created_from: Source identifier + batch: Batch identifier + + Returns: + tuple: (list of created documents, batch identifier) + """ + lock_name = f"add_document_lock_dataset_id_{dataset.id}" + + with redis_client.lock(lock_name, timeout=600): + position = DocumentService.get_documents_position(dataset.id) + document_ids: list[str] = [] + duplicate_document_ids: list[str] = [] + documents: list[Document] = [] + + data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore + + # Process documents based on data source type + if data_source_type == "upload_file": + documents, document_ids, duplicate_document_ids = DocumentService._process_upload_file_documents( + dataset, knowledge_config, account, dataset_process_rule, created_from, batch, position + ) + elif data_source_type == "notion_import": + documents, document_ids = DocumentService._process_notion_documents( + dataset, knowledge_config, account, dataset_process_rule, created_from, batch, position + ) + elif data_source_type == "website_crawl": + documents, document_ids = DocumentService._process_website_documents( + dataset, knowledge_config, account, dataset_process_rule, created_from, batch, position + ) + + db.session.commit() + + # Trigger asynchronous indexing tasks + if document_ids: + document_indexing_task.delay(dataset.id, document_ids) + if duplicate_document_ids: + duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + + return documents, batch + + @staticmethod + def _process_upload_file_documents( + dataset: Dataset, + knowledge_config: KnowledgeConfig, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule], + created_from: str, + batch: str, + position: int, + ) -> tuple[list[Document], list[str], list[str]]: + """ + Process uploaded file documents with duplicate checking and validation. + + Args: + dataset: Target dataset + knowledge_config: Document configuration + account: User account + dataset_process_rule: Process rule + created_from: Source identifier + batch: Batch identifier + position: Starting position for document ordering + + Returns: + tuple: (list of documents, list of document IDs, list of duplicate document IDs) + + Raises: + FileNotExistsError: When referenced files are not found + """ + documents: list[Document] = [] + document_ids: list[str] = [] + duplicate_document_ids: list[str] = [] + current_position = position + + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + + for file_id in upload_file_list: + # Validate file existence + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) + + if not file: + raise FileNotExistsError() + + file_name = file.name + data_source_info = {"upload_file_id": file_id} + + # Handle duplicate document processing + if knowledge_config.duplicate: + document = DocumentService._find_duplicate_document(dataset, file_name) + if document: + document = DocumentService._update_duplicate_document( + document, dataset_process_rule, knowledge_config, created_from, data_source_info, batch + ) + documents.append(document) + duplicate_document_ids.append(document.id) + continue + + # Create new document + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + current_position, + account, + file_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) documents.append(document) - batch = document.batch - else: - batch = time.strftime("%Y%m%d%H%M%S") + str(100000 + secrets.randbelow(exclusive_upper_bound=900000)) - # save process rule - if not dataset_process_rule: - process_rule = knowledge_config.process_rule - if process_rule: - if process_rule.mode in ("custom", "hierarchical"): - if process_rule.rules: - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule.mode, - rules=process_rule.rules.model_dump_json() if process_rule.rules else None, - created_by=account.id, - ) - else: - dataset_process_rule = dataset.latest_process_rule - if not dataset_process_rule: - raise ValueError("No process rule found.") - elif process_rule.mode == "automatic": - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule.mode, - rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id, - ) - else: - logging.warning( - f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" - ) - return - db.session.add(dataset_process_rule) - db.session.commit() - lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) - with redis_client.lock(lock_name, timeout=600): - position = DocumentService.get_documents_position(dataset.id) - document_ids = [] - duplicate_document_ids = [] - if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore - for file_id in upload_file_list: - file = ( - db.session.query(UploadFile) - .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) - .first() - ) + current_position += 1 - # raise error if file not found - if not file: - raise FileNotExistsError() + return documents, document_ids, duplicate_document_ids - file_name = file.name - data_source_info = { - "upload_file_id": file_id, - } - # check duplicate - if knowledge_config.duplicate: - document = ( - db.session.query(Document) - .filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="upload_file", - enabled=True, - name=file_name, - ) - .first() - ) - if document: - document.dataset_process_rule_id = dataset_process_rule.id # type: ignore - document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - document.created_from = created_from - document.doc_form = knowledge_config.doc_form - document.doc_language = knowledge_config.doc_language - document.data_source_info = json.dumps(data_source_info) - document.batch = batch - document.indexing_status = "waiting" - db.session.add(document) - documents.append(document) - duplicate_document_ids.append(document.id) - continue - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - file_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore - notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore - if not notion_info_list: - raise ValueError("No notion info list found.") - exist_page_ids = [] - exist_document = {} - documents_from_db = ( - db.session.query(Document) - .filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="notion_import", - enabled=True, - ) - .all() + @staticmethod + def _find_duplicate_document(dataset: Dataset, file_name: str) -> Optional[Document]: + """ + Find existing document with the same name in the dataset. + + Args: + dataset: Target dataset + file_name: Name of the file to check for duplicates + + Returns: + Document: Existing document if found, None otherwise + """ + return ( + db.session.query(Document) + .filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="upload_file", + enabled=True, + name=file_name, + ) + .first() + ) + + @staticmethod + def _update_duplicate_document( + document: Document, + dataset_process_rule: Optional[DatasetProcessRule], + knowledge_config: KnowledgeConfig, + created_from: str, + data_source_info: dict, + batch: str, + ) -> Document: + """ + Update existing document with new configuration and mark for re-indexing. + + Args: + document: Document to be updated + dataset_process_rule: Process rule to apply + knowledge_config: New configuration + created_from: Source identifier + data_source_info: Data source information + batch: Batch identifier + + Returns: + Document: Updated document + """ + document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.created_from = created_from + document.doc_form = knowledge_config.doc_form + document.doc_language = knowledge_config.doc_language + document.data_source_info = json.dumps(data_source_info) + document.batch = batch + document.indexing_status = "waiting" + db.session.add(document) + return document + + @staticmethod + def _process_notion_documents( + dataset: Dataset, + knowledge_config: KnowledgeConfig, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule], + created_from: str, + batch: str, + position: int, + ) -> tuple[list[Document], list[str]]: + """ + Process Notion import documents with workspace validation and page deduplication. + + Args: + dataset: Target dataset + knowledge_config: Document configuration + account: User account + dataset_process_rule: Process rule + created_from: Source identifier + batch: Batch identifier + position: Starting position for document ordering + + Returns: + tuple: (list of documents, list of document IDs) + + Raises: + ValueError: When no notion info list is found or data source binding is missing + """ + documents: list[Document] = [] + document_ids: list[str] = [] + current_position = position + + notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + if not notion_info_list: + raise ValueError("No notion info list found.") + + # Get existing Notion documents for deduplication + exist_page_ids, exist_document = DocumentService._get_existing_notion_documents(dataset) + + for notion_info in notion_info_list: + workspace_id = notion_info.workspace_id + + # Validate data source binding + data_source_binding = DocumentService._validate_notion_binding(workspace_id) + + for page in notion_info.pages: + if page.page_id not in exist_page_ids: + # Create new Notion document + data_source_info = { + "notion_workspace_id": workspace_id, + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + "type": page.type, + } + + # Truncate page name to prevent DB field length errors + truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" + + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + current_position, + account, + truncated_page_name, + batch, ) - if documents_from_db: - for document in documents_from_db: - data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info["notion_page_id"]) - exist_document[data_source_info["notion_page_id"]] = document.id - for notion_info in notion_info_list: - workspace_id = notion_info.workspace_id - data_source_binding = ( - db.session.query(DataSourceOauthBinding) - .filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', - ) - ) - .first() - ) - if not data_source_binding: - raise ValueError("Data source binding not found.") - for page in notion_info.pages: - if page.page_id not in exist_page_ids: - data_source_info = { - "notion_workspace_id": workspace_id, - "notion_page_id": page.page_id, - "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, - "type": page.type, - } - # Truncate page name to 255 characters to prevent DB field length errors - truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - truncated_page_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - else: - exist_document.pop(page.page_id) - # delete not selected documents - if len(exist_document) > 0: - clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore - website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore - if not website_info: - raise ValueError("No website info list found.") - urls = website_info.urls - for url in urls: - data_source_info = { - "url": url, - "provider": website_info.provider, - "job_id": website_info.job_id, - "only_main_content": website_info.only_main_content, - "mode": "crawl", - } - if len(url) > 255: - document_name = url[:200] + "..." - else: - document_name = url - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, # type: ignore - knowledge_config.data_source.info_list.data_source_type, # type: ignore - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - document_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - db.session.commit() + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + current_position += 1 + else: + exist_document.pop(page.page_id) - # trigger async task - if document_ids: - document_indexing_task.delay(dataset.id, document_ids) - if duplicate_document_ids: - duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + # Clean up unselected documents + if len(exist_document) > 0: + clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - return documents, batch + return documents, document_ids + + @staticmethod + def _get_existing_notion_documents(dataset: Dataset) -> tuple[list[str], dict[str, str]]: + """ + Retrieve existing Notion documents for deduplication purposes. + + Args: + dataset: Target dataset + + Returns: + tuple: (list of existing page IDs, dict mapping page IDs to document IDs) + """ + exist_page_ids: list[str] = [] + exist_document: dict[str, str] = {} + + documents = ( + db.session.query(Document) + .filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="notion_import", + enabled=True, + ) + .all() + ) + + if documents: + for document in documents: + data_source_info = json.loads(document.data_source_info) + exist_page_ids.append(data_source_info["notion_page_id"]) + exist_document[data_source_info["notion_page_id"]] = document.id + + return exist_page_ids, exist_document + + @staticmethod + def _validate_notion_binding(workspace_id: str) -> DataSourceOauthBinding: + """ + Validate Notion data source binding for the specified workspace. + + Args: + workspace_id: Notion workspace identifier + + Returns: + DataSourceOauthBinding: Valid binding for the workspace + + Raises: + ValueError: When data source binding is not found + """ + data_source_binding = ( + db.session.query(DataSourceOauthBinding) + .filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + ) + ) + .first() + ) + + if not data_source_binding: + raise ValueError("Data source binding not found.") + + return data_source_binding + + @staticmethod + def _process_website_documents( + dataset: Dataset, + knowledge_config: KnowledgeConfig, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule], + created_from: str, + batch: str, + position: int, + ) -> tuple[list[Document], list[str]]: + """ + Process website crawl documents with URL validation and naming. + + Args: + dataset: Target dataset + knowledge_config: Document configuration + account: User account + dataset_process_rule: Process rule + created_from: Source identifier + batch: Batch identifier + position: Starting position for document ordering + + Returns: + tuple: (list of documents, list of document IDs) + + Raises: + ValueError: When no website info list is found + """ + documents: list[Document] = [] + document_ids: list[str] = [] + current_position = position + + website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + if not website_info: + raise ValueError("No website info list found.") + + urls = website_info.urls + + for url in urls: + data_source_info = { + "url": url, + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, + "mode": "crawl", + } + + # Truncate URL for document naming if too long + document_name = url[:200] + "..." if len(url) > 255 else url + + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + current_position, + account, + document_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + current_position += 1 + + return documents, document_ids @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): From 689cdd30f94fdcfb329a29f92ce72d7d1051cd62 Mon Sep 17 00:00:00 2001 From: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Wed, 25 Jun 2025 19:37:57 +0800 Subject: [PATCH 7/7] feat: refactor: refine unit tests for save_document_with_dataset_id --- ...t_service_save_document_with_dataset_id.py | 1016 +++++++++-------- 1 file changed, 510 insertions(+), 506 deletions(-) diff --git a/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py b/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py index 77112dfaa7..8a71add0f2 100644 --- a/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py +++ b/api/tests/unit_tests/services/test_document_service_save_document_with_dataset_id.py @@ -1,4 +1,4 @@ -import unittest +from typing import Optional from unittest.mock import Mock, patch import pytest @@ -9,31 +9,21 @@ from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig -class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): - """ - Comprehensive unit tests for DocumentService.save_document_with_dataset_id. - - This test suite covers all major code branches including: - - Billing and quota validation - - Different data source types (upload_file, notion_import, website_crawl) - - Duplicate document handling - - Process rule validation and error cases - - Exception handling and edge cases - - Database session operations (add and flush) - """ +class DocumentSaveTestDataFactory: + """Factory class for creating test data and mock objects for document save tests.""" - def setUp(self): - """Set up common test fixtures and mock objects.""" - self.dataset_id = "ds1" - self.tenant_id = "tenant1" - self.user_id = "user1" - self.batch_id = "batch1" - - def _create_mock_dataset(self, data_source_type=None, indexing_technique=None): - """Create a mock Dataset object with common attributes.""" + @staticmethod + def create_dataset_mock( + dataset_id: str = "ds1", + tenant_id: str = "tenant1", + data_source_type: Optional[str] = None, + indexing_technique: Optional[str] = None, + **kwargs, + ) -> Mock: + """Create a mock Dataset object with specified attributes.""" dataset = Mock(spec=Dataset) - dataset.id = self.dataset_id - dataset.tenant_id = self.tenant_id + dataset.id = dataset_id + dataset.tenant_id = tenant_id dataset.data_source_type = data_source_type dataset.indexing_technique = indexing_technique dataset.retrieval_model = None @@ -41,16 +31,22 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): dataset.embedding_model_provider = None dataset.collection_binding_id = None dataset.latest_process_rule = None + for key, value in kwargs.items(): + setattr(dataset, key, value) return dataset - def _create_mock_account(self): + @staticmethod + def create_account_mock(user_id: str = "user1", name: str = "Test User") -> Mock: """Create a mock Account object.""" account = Mock(spec=Account) - account.id = self.user_id - account.name = "Test User" + account.id = user_id + account.name = name return account - def _create_mock_features(self, billing_enabled=True, plan="pro", quota_limit=100, quota_size=0): + @staticmethod + def create_features_mock( + billing_enabled: bool = True, plan: str = "pro", quota_limit: int = 100, quota_size: int = 0 + ) -> Mock: """Create a mock features object for billing tests.""" features = Mock() features.billing.enabled = billing_enabled @@ -60,16 +56,17 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): features.documents_upload_quota.size = quota_size return features - def _create_mock_knowledge_config( - self, - data_source_type, - original_document_id=None, - file_ids=None, - notion_pages=None, - website_urls=None, - indexing_technique="high_quality", - duplicate=False, - ): + @staticmethod + def create_knowledge_config_mock( + data_source_type: str, + original_document_id: Optional[str] = None, + file_ids: Optional[list[str]] = None, + notion_pages: Optional[list[Mock]] = None, + website_urls: Optional[list[str]] = None, + indexing_technique: str = "high_quality", + duplicate: bool = False, + **kwargs, + ) -> Mock: """Create a mock KnowledgeConfig object with specified data source configuration.""" knowledge_config = Mock(spec=KnowledgeConfig) knowledge_config.original_document_id = original_document_id @@ -104,55 +101,151 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): website_info.only_main_content = True knowledge_config.data_source.info_list.website_info_list = website_info + for key, value in kwargs.items(): + setattr(knowledge_config, key, value) return knowledge_config - def _setup_common_mocks(self, mock_current_user, mock_features, mock_redis=None, mock_db=None): - """Set up common mock objects used across multiple tests.""" - mock_current_user.current_tenant_id = self.tenant_id - - if mock_redis: - mock_lock = Mock() - mock_redis.lock.return_value.__enter__ = Mock(return_value=None) - mock_redis.lock.return_value.__exit__ = Mock(return_value=None) - - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.time") - @patch("services.dataset_service.secrets.randbelow", return_value=123456) - @patch("services.dataset_service.DocumentService.build_document") - @patch("services.dataset_service.document_indexing_task.delay") - @patch("services.dataset_service.duplicate_document_indexing_task.delay") - @patch("services.dataset_service.current_user") - @patch("services.dataset_service.ModelManager") - @patch("services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding") - @patch("services.dataset_service.DocumentService.get_documents_position", return_value=0) + @staticmethod + def create_upload_file_mock(file_id: str, name: str) -> Mock: + """Create a mock upload file.""" + upload_file = Mock() + upload_file.id = file_id + upload_file.name = name + return upload_file + + @staticmethod + def create_document_mock(document_id: str) -> Mock: + """Create a mock Document.""" + document = Mock(spec=Document, id=document_id) + return document + + @staticmethod + def create_notion_page_mock(page_id: str, page_name: str, page_type: str = "page") -> Mock: + """Create a mock Notion page.""" + page = Mock() + page.page_id = page_id + page.page_name = page_name + page.page_icon = None + page.type = page_type + return page + + @staticmethod + def create_notion_info_mock(workspace_id: str, pages: list[Mock]) -> Mock: + """Create a mock Notion info.""" + notion_info = Mock() + notion_info.workspace_id = workspace_id + notion_info.pages = pages + return notion_info + + +class TestDocumentServiceSaveDocumentWithDatasetId: + """ + Comprehensive unit tests for DocumentService.save_document_with_dataset_id. + + This test suite covers all major code branches including: + - Billing and quota validation + - Different data source types (upload_file, notion_import, website_crawl) + - Duplicate document handling + - Process rule validation and error cases + - Exception handling and edge cases + - Database session operations (add and flush) + """ + + @pytest.fixture(autouse=True) + def setup_method(self): + """Set up common test fixtures and mock objects.""" + self.dataset_id = "ds1" + self.tenant_id = "tenant1" + self.user_id = "user1" + self.batch_id = "batch1" + + @pytest.fixture + def mock_document_service_dependencies(self): + """Common mock setup for document service dependencies.""" + with ( + patch("services.dataset_service.FeatureService.get_features") as mock_features, + patch("services.dataset_service.db.session") as mock_db, + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.time") as mock_time, + patch("services.dataset_service.secrets.randbelow", return_value=123456) as mock_rand, + patch("services.dataset_service.current_user") as mock_current_user, + ): + mock_current_user.current_tenant_id = self.tenant_id + mock_time.strftime.return_value = "20231201120000" + + yield { + "features": mock_features, + "db_session": mock_db, + "redis_client": mock_redis, + "time": mock_time, + "randbelow": mock_rand, + "current_user": mock_current_user, + } + + @pytest.fixture + def mock_async_task_dependencies(self): + """Mock setup for async task dependencies.""" + with ( + patch("services.dataset_service.DocumentService.build_document") as mock_build_doc, + patch("services.dataset_service.document_indexing_task.delay") as mock_doc_task, + patch("services.dataset_service.duplicate_document_indexing_task.delay") as mock_dup_task, + patch("services.dataset_service.clean_notion_document_task.delay") as mock_clean_task, + ): + yield { + "build_document": mock_build_doc, + "document_indexing_task": mock_doc_task, + "duplicate_document_indexing_task": mock_dup_task, + "clean_notion_document_task": mock_clean_task, + } + + @pytest.fixture + def mock_model_dependencies(self): + """Mock setup for model dependencies.""" + with ( + patch("services.dataset_service.ModelManager") as mock_model_manager, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" + ) as mock_collection_binding, + patch( + "services.dataset_service.DocumentService.get_documents_position", return_value=0 + ) as mock_get_position, + ): + yield { + "model_manager": mock_model_manager, + "collection_binding": mock_collection_binding, + "get_position": mock_get_position, + } + + def _setup_redis_lock(self, mock_redis): + """Helper method to set up Redis lock.""" + mock_lock = Mock() + mock_redis.lock.return_value.__enter__ = Mock(return_value=None) + mock_redis.lock.return_value.__exit__ = Mock(return_value=None) + + def _assert_document_created(self, mock_db, expected_documents: list[Mock]): + """Helper method to verify documents were created and added to session.""" + for doc in expected_documents: + mock_db.add.assert_any_call(doc) + + def _assert_async_task_called(self, mock_task, expected_calls: int = 1): + """Helper method to verify async task was called.""" + assert mock_task.call_count == expected_calls + + # ==================== Upload File Tests ==================== + def test_upload_file_success( - self, - mock_get_position, - mock_collection_binding, - mock_model_manager, - mock_current_user, - mock_dup_task, - mock_doc_task, - mock_build_doc, - mock_rand, - mock_time, - mock_redis, - mock_db, - mock_features, + self, mock_document_service_dependencies, mock_async_task_dependencies, mock_model_dependencies ): """Test successful upload_file document creation with multiple files.""" # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features() - knowledge_config = self._create_mock_knowledge_config( + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock() + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="upload_file", file_ids=["file1", "file2"] ) - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Mock ModelManager mock_model_manager_instance = Mock() @@ -160,272 +253,50 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): mock_embedding_model.model = "embed-model" mock_embedding_model.provider = "openai" mock_model_manager_instance.get_default_model_instance.return_value = mock_embedding_model - mock_model_manager.return_value = mock_model_manager_instance + mock_model_dependencies["model_manager"].return_value = mock_model_manager_instance # Mock collection binding mock_collection_binding_instance = Mock() mock_collection_binding_instance.id = "binding-123" - mock_collection_binding.return_value = mock_collection_binding_instance + mock_model_dependencies["collection_binding"].return_value = mock_collection_binding_instance # Mock build_document - mock_doc1 = Mock(spec=Document, id="doc1") - mock_doc2 = Mock(spec=Document, id="doc2") - mock_build_doc.side_effect = [mock_doc1, mock_doc2] + mock_doc1 = DocumentSaveTestDataFactory.create_document_mock("doc1") + mock_doc2 = DocumentSaveTestDataFactory.create_document_mock("doc2") + mock_async_task_dependencies["build_document"].side_effect = [mock_doc1, mock_doc2] # Mock upload files - upload_file1 = Mock() - upload_file1.id = "file1" - upload_file1.name = "file1.pdf" - upload_file2 = Mock() - upload_file2.id = "file2" - upload_file2.name = "file2.pdf" - mock_db.query.return_value.filter.return_value.first.side_effect = [upload_file1, upload_file2] - - # Mock time - mock_time.strftime.return_value = "20231201120000" + upload_file1 = DocumentSaveTestDataFactory.create_upload_file_mock("file1", "file1.pdf") + upload_file2 = DocumentSaveTestDataFactory.create_upload_file_mock("file2", "file2.pdf") + mock_document_service_dependencies["db_session"].query.return_value.filter.return_value.first.side_effect = [ + upload_file1, + upload_file2, + ] # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) # Assert assert len(docs) == 2 - mock_doc_task.assert_called_once() - mock_dup_task.assert_not_called() + self._assert_async_task_called(mock_async_task_dependencies["document_indexing_task"]) + mock_async_task_dependencies["duplicate_document_indexing_task"].assert_not_called() # Verify the documents were added to session - mock_db.add.assert_any_call(mock_doc1) - mock_db.add.assert_any_call(mock_doc2) - - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_billing_batch_limit_exceeded(self, mock_current_user, mock_features): - """Test that batch upload limit exceeded raises appropriate error.""" - # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=True, plan="sandbox") - knowledge_config = self._create_mock_knowledge_config( - data_source_type="upload_file", file_ids=["file1", "file2"] - ) - - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features - - # Act & Assert - with pytest.raises(ValueError, match="Your current plan does not support batch upload"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_billing_quota_limit_exceeded(self, mock_current_user, mock_features): - """Test that document upload quota exceeded raises appropriate error.""" - # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=True, plan="pro", quota_limit=1, quota_size=1) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="upload_file", file_ids=["file1", "file2"] - ) - - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features - - # Act & Assert - with pytest.raises(ValueError, match="You have reached the limit of your subscription"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_invalid_indexing_technique(self, mock_current_user, mock_features): - """Test that invalid indexing technique raises appropriate error.""" - # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="upload_file", indexing_technique="invalid" - ) - - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features - - # Act & Assert - with pytest.raises(ValueError, match="Indexing technique is invalid"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_no_process_rule_found(self, mock_current_user, mock_features): - """Test that missing process rule raises appropriate error.""" - # Arrange - dataset = self._create_mock_dataset(data_source_type="upload_file", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="upload_file", indexing_technique="high_quality" - ) - knowledge_config.process_rule.rules = None - - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features - - # Act & Assert - with pytest.raises(ValueError, match="No process rule found"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_invalid_process_rule_mode(self, mock_current_user, mock_features, mock_db): - """Test that invalid process rule mode returns None without creating document.""" - # Arrange - dataset = self._create_mock_dataset(data_source_type="upload_file", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="upload_file", indexing_technique="high_quality" - ) - knowledge_config.process_rule.mode = "invalid" - - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features - - # Act - with patch("logging.warning") as mock_log: - result = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - # Assert - assert result is None - mock_log.assert_called() - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_notion_import_no_info(self, mock_current_user, mock_features, mock_redis, mock_db): - """Test that notion_import with missing notion_info_list raises appropriate error.""" - # Arrange - dataset = self._create_mock_dataset(data_source_type="notion_import", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="notion_import", indexing_technique="high_quality" - ) - knowledge_config.data_source.info_list.notion_info_list = None + self._assert_document_created(mock_document_service_dependencies["db_session"], [mock_doc1, mock_doc2]) - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features - - # Act & Assert - with pytest.raises(ValueError, match="No notion info list found"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_website_crawl_no_info(self, mock_current_user, mock_features, mock_redis, mock_db): - """Test that website_crawl with missing website_info raises appropriate error.""" - # Arrange - dataset = self._create_mock_dataset(data_source_type="website_crawl", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="website_crawl", indexing_technique="high_quality" - ) - knowledge_config.data_source.info_list.website_info_list = None - - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features - - # Act & Assert - with pytest.raises(ValueError, match="No website info list found"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.DocumentService.update_document_with_dataset_id") - def test_update_document_branch(self, mock_update_doc, mock_db): - """Test the update document flow when original_document_id is provided.""" - # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - knowledge_config = self._create_mock_knowledge_config( - data_source_type="upload_file", original_document_id="docid" - ) - mock_update_doc.return_value = Mock(batch=self.batch_id) - - # Mock current_user - mock_current_user = Mock() - mock_current_user.current_tenant_id = self.tenant_id - - # Act - with patch("services.dataset_service.current_user", mock_current_user): - docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - # Assert - assert len(docs) == 1 - assert batch == self.batch_id - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_upload_file_file_not_found(self, mock_current_user, mock_features, mock_redis, mock_db): - """Test that missing upload file raises FileNotExistsError.""" - # Arrange - from services.dataset_service import FileNotExistsError - - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config(data_source_type="upload_file", file_ids=["file1"]) - - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features - mock_db.query.return_value.filter.return_value.first.return_value = None - - # Act & Assert - with pytest.raises(FileNotExistsError): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.time") - @patch("services.dataset_service.secrets.randbelow", return_value=123456) - @patch("services.dataset_service.DocumentService.build_document") - @patch("services.dataset_service.document_indexing_task.delay") - @patch("services.dataset_service.duplicate_document_indexing_task.delay") - @patch("services.dataset_service.current_user") - @patch("services.dataset_service.ModelManager") - @patch("services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding") - @patch("services.dataset_service.DocumentService.get_documents_position", return_value=0) def test_upload_file_duplicate( - self, - mock_get_position, - mock_collection_binding, - mock_model_manager, - mock_current_user, - mock_dup_task, - mock_doc_task, - mock_build_doc, - mock_rand, - mock_time, - mock_redis, - mock_db, - mock_features, + self, mock_document_service_dependencies, mock_async_task_dependencies, mock_model_dependencies ): """Test upload_file with duplicate=True when document already exists.""" # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features() - knowledge_config = self._create_mock_knowledge_config( + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock() + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="upload_file", file_ids=["file1"], duplicate=True ) - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Mock ModelManager mock_model_manager_instance = Mock() @@ -433,195 +304,140 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): mock_embedding_model.model = "embed-model" mock_embedding_model.provider = "openai" mock_model_manager_instance.get_default_model_instance.return_value = mock_embedding_model - mock_model_manager.return_value = mock_model_manager_instance + mock_model_dependencies["model_manager"].return_value = mock_model_manager_instance # Mock collection binding mock_collection_binding_instance = Mock() mock_collection_binding_instance.id = "binding-123" - mock_collection_binding.return_value = mock_collection_binding_instance + mock_model_dependencies["collection_binding"].return_value = mock_collection_binding_instance # Mock upload file and existing document - upload_file = Mock() - upload_file.id = "file1" - upload_file.name = "file1.pdf" - existing_doc = Mock(id="docid", name="file1.pdf") - mock_db.query.return_value.filter.return_value.first.return_value = upload_file - mock_db.query.return_value.filter_by.return_value.first.return_value = existing_doc - - # Mock time - mock_time.strftime.return_value = "20231201120000" + upload_file = DocumentSaveTestDataFactory.create_upload_file_mock("file1", "file1.pdf") + existing_doc = DocumentSaveTestDataFactory.create_document_mock("docid") + existing_doc.name = "file1.pdf" + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = upload_file + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter_by.return_value.first.return_value = existing_doc # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) # Assert assert len(docs) == 1 - mock_dup_task.assert_called_once() + self._assert_async_task_called(mock_async_task_dependencies["duplicate_document_indexing_task"]) # Verify the existing document was added to session with updated properties - mock_db.add.assert_any_call(existing_doc) + mock_document_service_dependencies["db_session"].add.assert_any_call(existing_doc) # Verify the document properties were updated before being added - # These assertions verify that the duplicate document was properly updated assert existing_doc.batch == "20231201120000223456" assert existing_doc.indexing_status == "waiting" assert existing_doc.created_from == "web" assert existing_doc.doc_form == "pdf" assert existing_doc.doc_language == "en" - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_notion_import_data_source_binding_not_found(self, mock_current_user, mock_features, mock_redis, mock_db): - """Test that missing data source binding for notion_import raises appropriate error.""" + def test_upload_file_file_not_found(self, mock_document_service_dependencies): + """Test that missing upload file raises FileNotExistsError.""" # Arrange - dataset = self._create_mock_dataset(data_source_type="notion_import", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) + from services.dataset_service import FileNotExistsError - notion_info = Mock() - notion_info.workspace_id = "ws1" - notion_info.pages = [] - knowledge_config = self._create_mock_knowledge_config( - data_source_type="notion_import", indexing_technique="high_quality" + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", file_ids=["file1"] ) - knowledge_config.data_source.info_list.notion_info_list = [notion_info] - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features - mock_db.query.return_value.filter.return_value.first.return_value = None + mock_document_service_dependencies["features"].return_value = features + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = None # Act & Assert - with pytest.raises(ValueError, match="Data source binding not found."): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - @patch("services.dataset_service.document_indexing_task.delay") - def test_website_crawl_url_too_long( - self, mock_document_indexing_task, mock_current_user, mock_features, mock_redis, mock_db - ): - """Test that long URLs are properly truncated in website_crawl document names.""" - # Arrange - dataset = self._create_mock_dataset(data_source_type="website_crawl", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="website_crawl", website_urls=["http://" + "a" * 300] - ) - - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features - mock_db.query.return_value.filter.return_value.first.return_value = True - - # Act - with patch("services.dataset_service.DocumentService.build_document") as mock_build_doc: + with pytest.raises(FileNotExistsError): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - # Assert - args, kwargs = mock_build_doc.call_args - assert args[9].startswith("http://") - assert len(args[9]) < 256 + # ==================== Notion Import Tests ==================== - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - @patch("services.dataset_service.DocumentService.build_document") - @patch("services.dataset_service.document_indexing_task.delay") - @patch("services.dataset_service.clean_notion_document_task.delay") - def test_notion_import_success( - self, mock_clean_task, mock_doc_task, mock_build_doc, mock_current_user, mock_features, mock_redis, mock_db - ): + def test_notion_import_success(self, mock_document_service_dependencies, mock_async_task_dependencies): """Test successful notion_import document creation for new pages.""" # Arrange - dataset = self._create_mock_dataset(data_source_type="notion_import", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="notion_import", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) - notion_info = Mock() - notion_info.workspace_id = "ws1" - page = Mock() - page.page_id = "page1" - page.page_name = "Test Page" - page.page_icon = None - page.type = "page" - notion_info.pages = [page] + page = DocumentSaveTestDataFactory.create_notion_page_mock("page1", "Test Page") + notion_info = DocumentSaveTestDataFactory.create_notion_info_mock("ws1", [page]) - knowledge_config = self._create_mock_knowledge_config( + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="notion_import", indexing_technique="high_quality" ) knowledge_config.data_source.info_list.notion_info_list = [notion_info] - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Mock existing documents query (empty) - mock_db.query.return_value.filter_by.return_value.all.return_value = [] + mock_document_service_dependencies["db_session"].query.return_value.filter_by.return_value.all.return_value = [] # Mock data source binding binding = Mock() binding.id = "binding1" - mock_db.query.return_value.filter.return_value.first.return_value = binding + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = binding # Mock build_document - mock_doc = Mock(spec=Document, id="doc1") - mock_build_doc.return_value = mock_doc + mock_doc = DocumentSaveTestDataFactory.create_document_mock("doc1") + mock_async_task_dependencies["build_document"].return_value = mock_doc # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) # Assert assert len(docs) == 1 - mock_doc_task.assert_called_once() + self._assert_async_task_called(mock_async_task_dependencies["document_indexing_task"]) # Verify the document was added to the database session - mock_db.add.assert_called_with(mock_doc) - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - @patch("services.dataset_service.DocumentService.build_document") - @patch("services.dataset_service.clean_notion_document_task.delay") - @patch("services.dataset_service.document_indexing_task.delay") - def test_notion_import_page_exists( - self, mock_doc_task, mock_clean_task, mock_build_doc, mock_current_user, mock_features, mock_redis, mock_db - ): + mock_document_service_dependencies["db_session"].add.assert_called_with(mock_doc) + + def test_notion_import_page_exists(self, mock_document_service_dependencies, mock_async_task_dependencies): """Test notion_import when page already exists - should skip creation.""" # Arrange - dataset = self._create_mock_dataset(data_source_type="notion_import", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="notion_import", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) - notion_info = Mock() - notion_info.workspace_id = "ws1" - page = Mock() - page.page_id = "page1" - page.page_name = "Test Page" - notion_info.pages = [page] + page = DocumentSaveTestDataFactory.create_notion_page_mock("page1", "Test Page") + notion_info = DocumentSaveTestDataFactory.create_notion_info_mock("ws1", [page]) - knowledge_config = self._create_mock_knowledge_config( + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="notion_import", indexing_technique="high_quality" ) knowledge_config.data_source.info_list.notion_info_list = [notion_info] - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Mock existing document with same page_id existing_doc = Mock() existing_doc.data_source_info = '{"notion_page_id": "page1"}' existing_doc.id = "doc1" - mock_db.query.return_value.filter_by.return_value.all.return_value = [existing_doc] + mock_document_service_dependencies["db_session"].query.return_value.filter_by.return_value.all.return_value = [ + existing_doc + ] # Mock data source binding binding = Mock() binding.id = "binding1" - mock_db.query.return_value.filter.return_value.first.return_value = binding + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = binding # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) @@ -629,130 +445,318 @@ class TestDocumentServiceSaveDocumentWithDatasetId(unittest.TestCase): # Assert assert len(docs) == 0 # No document should be created since it already exists - for call in mock_db.add.call_args_list: + for call in mock_document_service_dependencies["db_session"].add.call_args_list: args, kwargs = call assert not any(isinstance(arg, Document) for arg in args), "Method was called with a Document!" - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - @patch("services.dataset_service.DocumentService.build_document") - @patch("services.dataset_service.document_indexing_task.delay") - def test_website_crawl_success( - self, mock_doc_task, mock_build_doc, mock_current_user, mock_features, mock_redis, mock_db - ): + def test_notion_import_no_info(self, mock_document_service_dependencies): + """Test that notion_import with missing notion_info_list raises appropriate error.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="notion_import", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="notion_import", indexing_technique="high_quality" + ) + knowledge_config.data_source.info_list.notion_info_list = None + + mock_document_service_dependencies["features"].return_value = features + + # Act & Assert + with pytest.raises(ValueError, match="No notion info list found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + def test_notion_import_data_source_binding_not_found(self, mock_document_service_dependencies): + """Test that missing data source binding for notion_import raises appropriate error.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="notion_import", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + + notion_info = DocumentSaveTestDataFactory.create_notion_info_mock("ws1", []) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="notion_import", indexing_technique="high_quality" + ) + knowledge_config.data_source.info_list.notion_info_list = [notion_info] + + mock_document_service_dependencies["features"].return_value = features + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = None + + # Act & Assert + with pytest.raises(ValueError, match="Data source binding not found."): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # ==================== Website Crawl Tests ==================== + + def test_website_crawl_success(self, mock_document_service_dependencies, mock_async_task_dependencies): """Test successful website_crawl document creation for multiple URLs.""" # Arrange - dataset = self._create_mock_dataset(data_source_type="website_crawl", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="website_crawl", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="website_crawl", website_urls=["http://example1.com", "http://example2.com"] ) - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Mock build_document - mock_doc1 = Mock(spec=Document, id="doc1") - mock_doc2 = Mock(spec=Document, id="doc2") - mock_build_doc.side_effect = [mock_doc1, mock_doc2] + mock_doc1 = DocumentSaveTestDataFactory.create_document_mock("doc1") + mock_doc2 = DocumentSaveTestDataFactory.create_document_mock("doc2") + mock_async_task_dependencies["build_document"].side_effect = [mock_doc1, mock_doc2] # Act docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) # Assert assert len(docs) == 2 - assert mock_build_doc.call_count == 2 - mock_doc_task.assert_called_once() + assert mock_async_task_dependencies["build_document"].call_count == 2 + self._assert_async_task_called(mock_async_task_dependencies["document_indexing_task"]) # Verify database session operations - mock_db.add.assert_any_call(mock_doc1) - mock_db.add.assert_any_call(mock_doc2) - - @patch("services.dataset_service.db.session") - @patch("services.dataset_service.redis_client") - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_unknown_data_source_type(self, mock_current_user, mock_features, mock_redis, mock_db): - """Test that unknown data source type is handled gracefully.""" + self._assert_document_created(mock_document_service_dependencies["db_session"], [mock_doc1, mock_doc2]) + + def test_website_crawl_no_info(self, mock_document_service_dependencies): + """Test that website_crawl with missing website_info raises appropriate error.""" # Arrange - dataset = self._create_mock_dataset(data_source_type="unknown_type", indexing_technique="high_quality") - account = self._create_mock_account() - features = self._create_mock_features(billing_enabled=False) - knowledge_config = self._create_mock_knowledge_config( - data_source_type="unknown_type", indexing_technique="high_quality" + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="website_crawl", indexing_technique="high_quality" ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="website_crawl", indexing_technique="high_quality" + ) + knowledge_config.data_source.info_list.website_info_list = None + + mock_document_service_dependencies["features"].return_value = features - self._setup_common_mocks(mock_current_user, mock_features, mock_redis, mock_db) - mock_features.return_value = features + # Act & Assert + with pytest.raises(ValueError, match="No website info list found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + def test_website_crawl_url_too_long(self, mock_document_service_dependencies, mock_async_task_dependencies): + """Test that long URLs are properly truncated in website_crawl document names.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="website_crawl", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="website_crawl", website_urls=["http://" + "a" * 300] + ) + + mock_document_service_dependencies["features"].return_value = features + mock_document_service_dependencies[ + "db_session" + ].query.return_value.filter.return_value.first.return_value = True # Act - result = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + with patch("services.dataset_service.DocumentService.build_document") as mock_build_doc: + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - # Assert - assert result is None or len(result[0]) == 0 + # Assert + args, kwargs = mock_build_doc.call_args + assert args[9].startswith("http://") + assert len(args[9]) < 256 + + # ==================== Billing and Quota Tests ==================== + + def test_billing_batch_limit_exceeded(self, mock_document_service_dependencies): + """Test that batch upload limit exceeded raises appropriate error.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=True, plan="sandbox") + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", file_ids=["file1", "file2"] + ) + + mock_document_service_dependencies["features"].return_value = features - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_upload_file_batch_limit_exceeded(self, mock_current_user, mock_features): + # Act & Assert + with pytest.raises(ValueError, match="Your current plan does not support batch upload"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + def test_billing_quota_limit_exceeded(self, mock_document_service_dependencies): + """Test that document upload quota exceeded raises appropriate error.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock( + billing_enabled=True, plan="pro", quota_limit=1, quota_size=1 + ) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", file_ids=["file1", "file2"] + ) + + mock_document_service_dependencies["features"].return_value = features + + # Act & Assert + with pytest.raises(ValueError, match="You have reached the limit of your subscription"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + def test_upload_file_batch_limit_exceeded(self, mock_document_service_dependencies): """Test that upload_file batch limit exceeded raises appropriate error.""" # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features() - knowledge_config = self._create_mock_knowledge_config( + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock() + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="upload_file", file_ids=["file" + str(i) for i in range(100)] ) - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Act & Assert with patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 50): with pytest.raises(ValueError, match="You have reached the batch upload limit"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_notion_import_batch_limit_exceeded(self, mock_current_user, mock_features): + def test_notion_import_batch_limit_exceeded(self, mock_document_service_dependencies): """Test that notion_import batch limit exceeded raises appropriate error.""" # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features() + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock() - notion_info = Mock() - notion_info.pages = [Mock() for _ in range(100)] # 100 pages + notion_info = DocumentSaveTestDataFactory.create_notion_info_mock("ws1", [Mock() for _ in range(100)]) - knowledge_config = self._create_mock_knowledge_config(data_source_type="notion_import") + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock(data_source_type="notion_import") knowledge_config.data_source.info_list.notion_info_list = [notion_info] - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Act & Assert with patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 50): with pytest.raises(ValueError, match="You have reached the batch upload limit"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) - @patch("services.dataset_service.FeatureService.get_features") - @patch("services.dataset_service.current_user") - def test_website_crawl_batch_limit_exceeded(self, mock_current_user, mock_features): + def test_website_crawl_batch_limit_exceeded(self, mock_document_service_dependencies): """Test that website_crawl batch limit exceeded raises appropriate error.""" # Arrange - dataset = self._create_mock_dataset() - account = self._create_mock_account() - features = self._create_mock_features() - knowledge_config = self._create_mock_knowledge_config( + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock() + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( data_source_type="website_crawl", website_urls=["http://example" + str(i) + ".com" for i in range(100)] ) - self._setup_common_mocks(mock_current_user, mock_features) - mock_features.return_value = features + mock_document_service_dependencies["features"].return_value = features # Act & Assert with patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 50): with pytest.raises(ValueError, match="You have reached the batch upload limit"): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # ==================== Process Rule Tests ==================== + + def test_invalid_indexing_technique(self, mock_document_service_dependencies): + """Test that invalid indexing technique raises appropriate error.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", indexing_technique="invalid" + ) + + mock_document_service_dependencies["features"].return_value = features + + # Act & Assert + with pytest.raises(ValueError, match="Indexing technique is invalid"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + def test_no_process_rule_found(self, mock_document_service_dependencies): + """Test that missing process rule raises appropriate error.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="upload_file", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", indexing_technique="high_quality" + ) + knowledge_config.process_rule.rules = None + + mock_document_service_dependencies["features"].return_value = features + + # Act & Assert + with pytest.raises(ValueError, match="No process rule found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + def test_invalid_process_rule_mode(self, mock_document_service_dependencies): + """Test that invalid process rule mode returns None without creating document.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="upload_file", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", indexing_technique="high_quality" + ) + knowledge_config.process_rule.mode = "invalid" + + mock_document_service_dependencies["features"].return_value = features + + # Act + with patch("logging.warning") as mock_log: + result = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert + assert result is None + mock_log.assert_called() + + # ==================== Update Document Tests ==================== + + def test_update_document_branch(self, mock_document_service_dependencies): + """Test the update document flow when original_document_id is provided.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock() + account = DocumentSaveTestDataFactory.create_account_mock() + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="upload_file", original_document_id="docid" + ) + + with patch("services.dataset_service.DocumentService.update_document_with_dataset_id") as mock_update_doc: + mock_update_doc.return_value = Mock(batch=self.batch_id) + + # Act + docs, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert + assert len(docs) == 1 + assert batch == self.batch_id + + # ==================== Edge Case Tests ==================== + + def test_unknown_data_source_type(self, mock_document_service_dependencies): + """Test that unknown data source type is handled gracefully.""" + # Arrange + dataset = DocumentSaveTestDataFactory.create_dataset_mock( + data_source_type="unknown_type", indexing_technique="high_quality" + ) + account = DocumentSaveTestDataFactory.create_account_mock() + features = DocumentSaveTestDataFactory.create_features_mock(billing_enabled=False) + knowledge_config = DocumentSaveTestDataFactory.create_knowledge_config_mock( + data_source_type="unknown_type", indexing_technique="high_quality" + ) + + mock_document_service_dependencies["features"].return_value = features + + # Act + result = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + + # Assert + assert result is None or len(result[0]) == 0