refactor: optimize exception handling order in PassportService

- Reorder exception handling to catch ExpiredSignatureError first for better error clarity
- Add PyJWTError as catch-all for unhandled JWT exceptions
- Update tests to reflect new exception handling behavior
- Add test case for generic PyJWTError handling

This improvement was discovered while writing comprehensive unit tests
for PassportService. The change follows Python best practices by catching
more specific exceptions before general ones and ensures all JWT-related
errors are properly handled.
pull/22268/head
Jason Young 10 months ago
parent d7d99feb4f
commit 05659d536e

@ -14,9 +14,11 @@ class PassportService:
def verify(self, token):
try:
return jwt.decode(token, self.sk, algorithms=["HS256"])
except jwt.exceptions.ExpiredSignatureError:
raise Unauthorized("Token has expired.")
except jwt.exceptions.InvalidSignatureError:
raise Unauthorized("Invalid token signature.")
except jwt.exceptions.DecodeError:
raise Unauthorized("Invalid token.")
except jwt.exceptions.ExpiredSignatureError:
raise Unauthorized("Token has expired.")
except jwt.exceptions.PyJWTError: # Catch-all for other JWT errors
raise Unauthorized("Invalid token.")

@ -1,4 +1,3 @@
import time
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
@ -105,9 +104,10 @@ class TestPassportService:
wrong_alg_token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS512")
# Should fail because service expects HS256
# JWT library raises InvalidAlgorithmError which is not caught by PassportService
with pytest.raises(jwt.exceptions.InvalidAlgorithmError):
# InvalidAlgorithmError is now caught by PyJWTError handler
with pytest.raises(Unauthorized) as exc_info:
passport_service.verify(wrong_alg_token)
assert str(exc_info.value) == "401 Unauthorized: Invalid token."
# Exception handling tests
def test_should_handle_invalid_tokens(self, passport_service):
@ -194,3 +194,12 @@ class TestPassportService:
decoded = passport_service.verify(token)
assert decoded == payload
def test_should_catch_generic_pyjwt_errors(self, passport_service):
"""Test that generic PyJWTError exceptions are caught and converted to Unauthorized"""
# Mock jwt.decode to raise a generic PyJWTError
with patch("libs.passport.jwt.decode") as mock_decode:
mock_decode.side_effect = jwt.exceptions.PyJWTError("Generic JWT error")
with pytest.raises(Unauthorized) as exc_info:
passport_service.verify("some-token")
assert str(exc_info.value) == "401 Unauthorized: Invalid token."

Loading…
Cancel
Save