From f7270d954919e9f4bbc71546e10475648df910da Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 24 Jul 2025 00:32:42 +0800 Subject: [PATCH] fix: Fix failing unit tests after ORM query refactoring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix double dot issue in _build_from_tool_file extension handling - Update test mocks to use .where() instead of .filter() to match refactored ORM queries - Update test mocks to use db.session.scalar instead of db.session.query for single result queries - Fix mock_tool_file fixture to properly patch db.session.scalar All 367 tests in factories and services modules now pass. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- api/factories/file_factory.py | 2 +- .../factories/test_build_from_mapping.py | 10 ++-------- .../services/auth/test_api_key_auth_service.py | 18 +++++++++--------- .../services/auth/test_auth_integration.py | 6 +++--- .../services/test_account_service.py | 8 ++++---- 5 files changed, 19 insertions(+), 25 deletions(-) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index f4b2f2c490..512a9cb608 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -273,7 +273,7 @@ def _build_from_tool_file( extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype) + detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype) specified_type = mapping.get("type") diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index 1952be0938..d42c4412f5 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -54,12 +54,7 @@ def mock_tool_file(): mock.mimetype = "application/pdf" mock.original_url = "http://example.com/tool.pdf" mock.size = 2048 - with ( - patch("factories.file_factory.db.session.query") as mock_query, - patch("factories.file_factory.db.session.scalar") as mock_query1, - ): - mock_query.return_value.where.return_value.first.return_value = mock - mock_query1.return_value.where.return_value.first.return_value = mock + with patch("factories.file_factory.db.session.scalar", return_value=mock): yield mock @@ -157,8 +152,7 @@ def test_build_from_remote_url(mock_http_head): def test_tool_file_not_found(): """Test ToolFile not found in database.""" - with patch("factories.file_factory.db.session.query") as mock_query: - mock_query.return_value.where.return_value.first.return_value = None + with patch("factories.file_factory.db.session.scalar", return_value=None): mapping = tool_file_mapping() with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"): build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py index 5dd67b4c7a..dc42a04cf3 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py @@ -52,9 +52,9 @@ class TestApiKeyAuthService: ApiKeyAuthService.get_provider_auth_list(self.tenant_id) - # Verify filter conditions include disabled.is_(False) - filter_call = mock_session.query.return_value.filter.call_args[0] - assert len(filter_call) == 2 # tenant_id and disabled filter conditions + # Verify where conditions include disabled.is_(False) + where_call = mock_session.query.return_value.where.call_args[0] + assert len(where_call) == 2 # tenant_id and disabled filter conditions @patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @@ -162,9 +162,9 @@ class TestApiKeyAuthService: ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) - # Verify filter conditions are correct - filter_call = mock_session.query.return_value.filter.call_args[0] - assert len(filter_call) == 4 # tenant_id, category, provider, disabled + # Verify where conditions are correct + where_call = mock_session.query.return_value.where.call_args[0] + assert len(where_call) == 4 # tenant_id, category, provider, disabled @patch("services.auth.api_key_auth_service.db.session") def test_get_auth_credentials_json_parsing(self, mock_session): @@ -212,9 +212,9 @@ class TestApiKeyAuthService: ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id) - # Verify filter conditions include tenant_id and binding_id - filter_call = mock_session.query.return_value.filter.call_args[0] - assert len(filter_call) == 2 + # Verify where conditions include tenant_id and binding_id + where_call = mock_session.query.return_value.where.call_args[0] + assert len(where_call) == 2 def test_validate_api_key_auth_args_success(self): """Test API key auth args validation - success scenario""" diff --git a/api/tests/unit_tests/services/auth/test_auth_integration.py b/api/tests/unit_tests/services/auth/test_auth_integration.py index 31a617345d..4ce5525942 100644 --- a/api/tests/unit_tests/services/auth/test_auth_integration.py +++ b/api/tests/unit_tests/services/auth/test_auth_integration.py @@ -63,10 +63,10 @@ class TestAuthIntegration: tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) - mock_session.query.return_value.filter.return_value.all.return_value = [tenant1_binding] + mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding] result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) - mock_session.query.return_value.filter.return_value.all.return_value = [tenant2_binding] + mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding] result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) assert len(result1) == 1 @@ -77,7 +77,7 @@ class TestAuthIntegration: @patch("services.auth.api_key_auth_service.db.session") def test_cross_tenant_access_prevention(self, mock_session): """Test prevention of cross-tenant credential access""" - mock_session.query.return_value.filter.return_value.first.return_value = None + mock_session.query.return_value.where.return_value.first.return_value = None result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL) diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 71123e8bdb..442839e44e 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -708,9 +708,9 @@ class TestTenantService: with patch("services.account_service.db") as mock_db: # Mock the join query that returns the tenant_account_join mock_query = MagicMock() - mock_filter = MagicMock() - mock_filter.first.return_value = mock_tenant_join - mock_query.filter.return_value = mock_filter + mock_where = MagicMock() + mock_where.first.return_value = mock_tenant_join + mock_query.where.return_value = mock_where mock_query.join.return_value = mock_query mock_db.session.query.return_value = mock_query @@ -1381,7 +1381,7 @@ class TestRegisterService: # Mock database queries - complex query mocking mock_query1 = MagicMock() - mock_query1.filter.return_value.first.return_value = mock_tenant + mock_query1.where.return_value.first.return_value = mock_tenant mock_query2 = MagicMock() mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")