diff --git a/api/libs/passport.py b/api/libs/passport.py index 8df4f529bc..fe8fc33b5f 100644 --- a/api/libs/passport.py +++ b/api/libs/passport.py @@ -14,9 +14,11 @@ class PassportService: def verify(self, token): try: return jwt.decode(token, self.sk, algorithms=["HS256"]) + except jwt.exceptions.ExpiredSignatureError: + raise Unauthorized("Token has expired.") except jwt.exceptions.InvalidSignatureError: raise Unauthorized("Invalid token signature.") except jwt.exceptions.DecodeError: raise Unauthorized("Invalid token.") - except jwt.exceptions.ExpiredSignatureError: - raise Unauthorized("Token has expired.") + except jwt.exceptions.PyJWTError: # Catch-all for other JWT errors + raise Unauthorized("Invalid token.") diff --git a/api/tests/unit_tests/libs/test_passport.py b/api/tests/unit_tests/libs/test_passport.py index f125611612..f33484c18d 100644 --- a/api/tests/unit_tests/libs/test_passport.py +++ b/api/tests/unit_tests/libs/test_passport.py @@ -1,4 +1,3 @@ -import time from datetime import UTC, datetime, timedelta from unittest.mock import patch @@ -105,9 +104,10 @@ class TestPassportService: wrong_alg_token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS512") # Should fail because service expects HS256 - # JWT library raises InvalidAlgorithmError which is not caught by PassportService - with pytest.raises(jwt.exceptions.InvalidAlgorithmError): + # InvalidAlgorithmError is now caught by PyJWTError handler + with pytest.raises(Unauthorized) as exc_info: passport_service.verify(wrong_alg_token) + assert str(exc_info.value) == "401 Unauthorized: Invalid token." # Exception handling tests def test_should_handle_invalid_tokens(self, passport_service): @@ -194,3 +194,12 @@ class TestPassportService: decoded = passport_service.verify(token) assert decoded == payload + def test_should_catch_generic_pyjwt_errors(self, passport_service): + """Test that generic PyJWTError exceptions are caught and converted to Unauthorized""" + # Mock jwt.decode to raise a generic PyJWTError + with patch("libs.passport.jwt.decode") as mock_decode: + mock_decode.side_effect = jwt.exceptions.PyJWTError("Generic JWT error") + + with pytest.raises(Unauthorized) as exc_info: + passport_service.verify("some-token") + assert str(exc_info.value) == "401 Unauthorized: Invalid token."