fix: Fix failing unit tests after ORM query refactoring

- Fix double dot issue in _build_from_tool_file extension handling
- Update test mocks to use .where() instead of .filter() to match refactored ORM queries
- Update test mocks to use db.session.scalar instead of db.session.query for single result queries
- Fix mock_tool_file fixture to properly patch db.session.scalar

All 367 tests in factories and services modules now pass.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
pull/22801/head
-LAN- 10 months ago
parent de0bd4ccc7
commit f7270d9549
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -273,7 +273,7 @@ def _build_from_tool_file(
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype) detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
specified_type = mapping.get("type") specified_type = mapping.get("type")

@ -54,12 +54,7 @@ def mock_tool_file():
mock.mimetype = "application/pdf" mock.mimetype = "application/pdf"
mock.original_url = "http://example.com/tool.pdf" mock.original_url = "http://example.com/tool.pdf"
mock.size = 2048 mock.size = 2048
with ( with patch("factories.file_factory.db.session.scalar", return_value=mock):
patch("factories.file_factory.db.session.query") as mock_query,
patch("factories.file_factory.db.session.scalar") as mock_query1,
):
mock_query.return_value.where.return_value.first.return_value = mock
mock_query1.return_value.where.return_value.first.return_value = mock
yield mock yield mock
@ -157,8 +152,7 @@ def test_build_from_remote_url(mock_http_head):
def test_tool_file_not_found(): def test_tool_file_not_found():
"""Test ToolFile not found in database.""" """Test ToolFile not found in database."""
with patch("factories.file_factory.db.session.query") as mock_query: with patch("factories.file_factory.db.session.scalar", return_value=None):
mock_query.return_value.where.return_value.first.return_value = None
mapping = tool_file_mapping() mapping = tool_file_mapping()
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"): with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)

@ -52,9 +52,9 @@ class TestApiKeyAuthService:
ApiKeyAuthService.get_provider_auth_list(self.tenant_id) ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
# Verify filter conditions include disabled.is_(False) # Verify where conditions include disabled.is_(False)
filter_call = mock_session.query.return_value.filter.call_args[0] where_call = mock_session.query.return_value.where.call_args[0]
assert len(filter_call) == 2 # tenant_id and disabled filter conditions assert len(where_call) == 2 # tenant_id and disabled filter conditions
@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@ -162,9 +162,9 @@ class TestApiKeyAuthService:
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
# Verify filter conditions are correct # Verify where conditions are correct
filter_call = mock_session.query.return_value.filter.call_args[0] where_call = mock_session.query.return_value.where.call_args[0]
assert len(filter_call) == 4 # tenant_id, category, provider, disabled assert len(where_call) == 4 # tenant_id, category, provider, disabled
@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_json_parsing(self, mock_session): def test_get_auth_credentials_json_parsing(self, mock_session):
@ -212,9 +212,9 @@ class TestApiKeyAuthService:
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id) ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
# Verify filter conditions include tenant_id and binding_id # Verify where conditions include tenant_id and binding_id
filter_call = mock_session.query.return_value.filter.call_args[0] where_call = mock_session.query.return_value.where.call_args[0]
assert len(filter_call) == 2 assert len(where_call) == 2
def test_validate_api_key_auth_args_success(self): def test_validate_api_key_auth_args_success(self):
"""Test API key auth args validation - success scenario""" """Test API key auth args validation - success scenario"""

@ -63,10 +63,10 @@ class TestAuthIntegration:
tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)
mock_session.query.return_value.filter.return_value.all.return_value = [tenant1_binding] mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding]
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)
mock_session.query.return_value.filter.return_value.all.return_value = [tenant2_binding] mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding]
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)
assert len(result1) == 1 assert len(result1) == 1
@ -77,7 +77,7 @@ class TestAuthIntegration:
@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
def test_cross_tenant_access_prevention(self, mock_session): def test_cross_tenant_access_prevention(self, mock_session):
"""Test prevention of cross-tenant credential access""" """Test prevention of cross-tenant credential access"""
mock_session.query.return_value.filter.return_value.first.return_value = None mock_session.query.return_value.where.return_value.first.return_value = None
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL) result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL)

@ -708,9 +708,9 @@ class TestTenantService:
with patch("services.account_service.db") as mock_db: with patch("services.account_service.db") as mock_db:
# Mock the join query that returns the tenant_account_join # Mock the join query that returns the tenant_account_join
mock_query = MagicMock() mock_query = MagicMock()
mock_filter = MagicMock() mock_where = MagicMock()
mock_filter.first.return_value = mock_tenant_join mock_where.first.return_value = mock_tenant_join
mock_query.filter.return_value = mock_filter mock_query.where.return_value = mock_where
mock_query.join.return_value = mock_query mock_query.join.return_value = mock_query
mock_db.session.query.return_value = mock_query mock_db.session.query.return_value = mock_query
@ -1381,7 +1381,7 @@ class TestRegisterService:
# Mock database queries - complex query mocking # Mock database queries - complex query mocking
mock_query1 = MagicMock() mock_query1 = MagicMock()
mock_query1.filter.return_value.first.return_value = mock_tenant mock_query1.where.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock() mock_query2 = MagicMock()
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal") mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")

Loading…
Cancel
Save