fix: Fix failing unit tests after ORM query refactoring

- Fix double dot issue in _build_from_tool_file extension handling
- Update test mocks to use .where() instead of .filter() to match refactored ORM queries
- Update test mocks to use db.session.scalar instead of db.session.query for single result queries
- Fix mock_tool_file fixture to properly patch db.session.scalar

All 367 tests in factories and services modules now pass.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
pull/22801/head
-LAN- 7 months ago
parent de0bd4ccc7
commit f7270d9549
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -273,7 +273,7 @@ def _build_from_tool_file(
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype)
detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
specified_type = mapping.get("type")

@ -54,12 +54,7 @@ def mock_tool_file():
mock.mimetype = "application/pdf"
mock.original_url = "http://example.com/tool.pdf"
mock.size = 2048
with (
patch("factories.file_factory.db.session.query") as mock_query,
patch("factories.file_factory.db.session.scalar") as mock_query1,
):
mock_query.return_value.where.return_value.first.return_value = mock
mock_query1.return_value.where.return_value.first.return_value = mock
with patch("factories.file_factory.db.session.scalar", return_value=mock):
yield mock
@ -157,8 +152,7 @@ def test_build_from_remote_url(mock_http_head):
def test_tool_file_not_found():
"""Test ToolFile not found in database."""
with patch("factories.file_factory.db.session.query") as mock_query:
mock_query.return_value.where.return_value.first.return_value = None
with patch("factories.file_factory.db.session.scalar", return_value=None):
mapping = tool_file_mapping()
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)

@ -52,9 +52,9 @@ class TestApiKeyAuthService:
ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
# Verify filter conditions include disabled.is_(False)
filter_call = mock_session.query.return_value.filter.call_args[0]
assert len(filter_call) == 2 # tenant_id and disabled filter conditions
# Verify where conditions include disabled.is_(False)
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 2 # tenant_id and disabled filter conditions
@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@ -162,9 +162,9 @@ class TestApiKeyAuthService:
ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
# Verify filter conditions are correct
filter_call = mock_session.query.return_value.filter.call_args[0]
assert len(filter_call) == 4 # tenant_id, category, provider, disabled
# Verify where conditions are correct
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 4 # tenant_id, category, provider, disabled
@patch("services.auth.api_key_auth_service.db.session")
def test_get_auth_credentials_json_parsing(self, mock_session):
@ -212,9 +212,9 @@ class TestApiKeyAuthService:
ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
# Verify filter conditions include tenant_id and binding_id
filter_call = mock_session.query.return_value.filter.call_args[0]
assert len(filter_call) == 2
# Verify where conditions include tenant_id and binding_id
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 2
def test_validate_api_key_auth_args_success(self):
"""Test API key auth args validation - success scenario"""

@ -63,10 +63,10 @@ class TestAuthIntegration:
tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)
mock_session.query.return_value.filter.return_value.all.return_value = [tenant1_binding]
mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding]
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)
mock_session.query.return_value.filter.return_value.all.return_value = [tenant2_binding]
mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding]
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)
assert len(result1) == 1
@ -77,7 +77,7 @@ class TestAuthIntegration:
@patch("services.auth.api_key_auth_service.db.session")
def test_cross_tenant_access_prevention(self, mock_session):
"""Test prevention of cross-tenant credential access"""
mock_session.query.return_value.filter.return_value.first.return_value = None
mock_session.query.return_value.where.return_value.first.return_value = None
result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL)

@ -708,9 +708,9 @@ class TestTenantService:
with patch("services.account_service.db") as mock_db:
# Mock the join query that returns the tenant_account_join
mock_query = MagicMock()
mock_filter = MagicMock()
mock_filter.first.return_value = mock_tenant_join
mock_query.filter.return_value = mock_filter
mock_where = MagicMock()
mock_where.first.return_value = mock_tenant_join
mock_query.where.return_value = mock_where
mock_query.join.return_value = mock_query
mock_db.session.query.return_value = mock_query
@ -1381,7 +1381,7 @@ class TestRegisterService:
# Mock database queries - complex query mocking
mock_query1 = MagicMock()
mock_query1.filter.return_value.first.return_value = mock_tenant
mock_query1.where.return_value.first.return_value = mock_tenant
mock_query2 = MagicMock()
mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")

Loading…
Cancel
Save