|
|
|
@ -1,11 +1,14 @@
|
|
|
|
import base64
|
|
|
|
import base64
|
|
|
|
import hashlib
|
|
|
|
import hashlib
|
|
|
|
|
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
import secrets
|
|
|
|
import urllib.parse
|
|
|
|
import urllib.parse
|
|
|
|
from typing import Optional
|
|
|
|
from typing import Optional
|
|
|
|
from urllib.parse import urljoin
|
|
|
|
from urllib.parse import urljoin
|
|
|
|
|
|
|
|
|
|
|
|
import requests
|
|
|
|
import requests
|
|
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
from core.mcp.auth.auth_provider import OAuthClientProvider
|
|
|
|
from core.mcp.auth.auth_provider import OAuthClientProvider
|
|
|
|
from core.mcp.types import (
|
|
|
|
from core.mcp.types import (
|
|
|
|
@ -15,8 +18,21 @@ from core.mcp.types import (
|
|
|
|
OAuthMetadata,
|
|
|
|
OAuthMetadata,
|
|
|
|
OAuthTokens,
|
|
|
|
OAuthTokens,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
from extensions.ext_redis import redis_client
|
|
|
|
|
|
|
|
|
|
|
|
LATEST_PROTOCOL_VERSION = "1.0"
|
|
|
|
LATEST_PROTOCOL_VERSION = "1.0"
|
|
|
|
|
|
|
|
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
|
|
|
|
|
|
|
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OAuthCallbackState(BaseModel):
|
|
|
|
|
|
|
|
provider_id: str
|
|
|
|
|
|
|
|
tenant_id: str
|
|
|
|
|
|
|
|
server_url: str
|
|
|
|
|
|
|
|
metadata: OAuthMetadata | None = None
|
|
|
|
|
|
|
|
client_information: OAuthClientInformation
|
|
|
|
|
|
|
|
code_verifier: str
|
|
|
|
|
|
|
|
redirect_uri: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_pkce_challenge() -> tuple[str, str]:
|
|
|
|
def generate_pkce_challenge() -> tuple[str, str]:
|
|
|
|
@ -31,6 +47,62 @@ def generate_pkce_challenge() -> tuple[str, str]:
|
|
|
|
return code_verifier, code_challenge
|
|
|
|
return code_verifier, code_challenge
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
|
|
|
|
|
|
|
|
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
|
|
|
|
|
|
|
|
# Generate a secure random state key
|
|
|
|
|
|
|
|
state_key = secrets.token_urlsafe(32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Store the state data in Redis with expiration
|
|
|
|
|
|
|
|
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
|
|
|
|
|
|
|
redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return state_key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
|
|
|
|
|
|
|
"""Retrieve and decode OAuth state data from Redis using the state key."""
|
|
|
|
|
|
|
|
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Get state data from Redis
|
|
|
|
|
|
|
|
state_data = redis_client.get(redis_key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not state_data:
|
|
|
|
|
|
|
|
raise ValueError("State parameter has expired or does not exist")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
# Parse and validate the state data
|
|
|
|
|
|
|
|
if isinstance(state_data, bytes):
|
|
|
|
|
|
|
|
state_data = state_data.decode("utf-8")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
oauth_state = OAuthCallbackState.model_validate_json(state_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return oauth_state
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
raise ValueError(f"Invalid state parameter: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
|
|
|
|
|
|
|
|
"""Handle the callback from the OAuth provider."""
|
|
|
|
|
|
|
|
# Retrieve state data from Redis
|
|
|
|
|
|
|
|
full_state_data = _retrieve_redis_state(state_key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Clean up the state data from Redis after successful retrieval
|
|
|
|
|
|
|
|
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
|
|
|
|
|
|
|
redis_client.delete(redis_key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens = exchange_authorization(
|
|
|
|
|
|
|
|
full_state_data.server_url,
|
|
|
|
|
|
|
|
full_state_data.metadata,
|
|
|
|
|
|
|
|
full_state_data.client_information,
|
|
|
|
|
|
|
|
authorization_code,
|
|
|
|
|
|
|
|
full_state_data.code_verifier,
|
|
|
|
|
|
|
|
full_state_data.redirect_uri,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id)
|
|
|
|
|
|
|
|
provider.save_tokens(tokens)
|
|
|
|
|
|
|
|
return full_state_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
|
|
|
|
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
|
|
|
|
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
|
|
|
|
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
|
|
|
|
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
|
|
|
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
|
|
|
@ -60,8 +132,9 @@ def start_authorization(
|
|
|
|
client_information: OAuthClientInformation,
|
|
|
|
client_information: OAuthClientInformation,
|
|
|
|
redirect_url: str,
|
|
|
|
redirect_url: str,
|
|
|
|
provider_id: str,
|
|
|
|
provider_id: str,
|
|
|
|
|
|
|
|
tenant_id: str,
|
|
|
|
) -> tuple[str, str]:
|
|
|
|
) -> tuple[str, str]:
|
|
|
|
"""Begins the authorization flow."""
|
|
|
|
"""Begins the authorization flow with secure Redis state storage."""
|
|
|
|
response_type = "code"
|
|
|
|
response_type = "code"
|
|
|
|
code_challenge_method = "S256"
|
|
|
|
code_challenge_method = "S256"
|
|
|
|
|
|
|
|
|
|
|
|
@ -81,13 +154,27 @@ def start_authorization(
|
|
|
|
|
|
|
|
|
|
|
|
code_verifier, code_challenge = generate_pkce_challenge()
|
|
|
|
code_verifier, code_challenge = generate_pkce_challenge()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Prepare state data with all necessary information
|
|
|
|
|
|
|
|
state_data = OAuthCallbackState(
|
|
|
|
|
|
|
|
provider_id=provider_id,
|
|
|
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
|
|
|
server_url=server_url,
|
|
|
|
|
|
|
|
metadata=metadata,
|
|
|
|
|
|
|
|
client_information=client_information,
|
|
|
|
|
|
|
|
code_verifier=code_verifier,
|
|
|
|
|
|
|
|
redirect_uri=redirect_url,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Store state data in Redis and generate secure state key
|
|
|
|
|
|
|
|
state_key = _create_secure_redis_state(state_data)
|
|
|
|
|
|
|
|
|
|
|
|
params = {
|
|
|
|
params = {
|
|
|
|
"response_type": response_type,
|
|
|
|
"response_type": response_type,
|
|
|
|
"client_id": client_information.client_id,
|
|
|
|
"client_id": client_information.client_id,
|
|
|
|
"code_challenge": code_challenge,
|
|
|
|
"code_challenge": code_challenge,
|
|
|
|
"code_challenge_method": code_challenge_method,
|
|
|
|
"code_challenge_method": code_challenge_method,
|
|
|
|
"redirect_uri": redirect_url,
|
|
|
|
"redirect_uri": redirect_url,
|
|
|
|
"state": provider_id,
|
|
|
|
"state": state_key,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
|
|
|
|
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
|
|
|
|
@ -187,8 +274,9 @@ def auth(
|
|
|
|
provider: OAuthClientProvider,
|
|
|
|
provider: OAuthClientProvider,
|
|
|
|
server_url: str,
|
|
|
|
server_url: str,
|
|
|
|
authorization_code: Optional[str] = None,
|
|
|
|
authorization_code: Optional[str] = None,
|
|
|
|
|
|
|
|
state_param: Optional[str] = None,
|
|
|
|
) -> dict[str, str]:
|
|
|
|
) -> dict[str, str]:
|
|
|
|
"""Orchestrates the full auth flow with a server."""
|
|
|
|
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
|
|
|
metadata = discover_oauth_metadata(server_url)
|
|
|
|
metadata = discover_oauth_metadata(server_url)
|
|
|
|
|
|
|
|
|
|
|
|
# Handle client registration if needed
|
|
|
|
# Handle client registration if needed
|
|
|
|
@ -205,14 +293,29 @@ def auth(
|
|
|
|
|
|
|
|
|
|
|
|
# Exchange authorization code for tokens
|
|
|
|
# Exchange authorization code for tokens
|
|
|
|
if authorization_code is not None:
|
|
|
|
if authorization_code is not None:
|
|
|
|
code_verifier = provider.code_verifier()
|
|
|
|
if not state_param:
|
|
|
|
|
|
|
|
raise ValueError("State parameter is required when exchanging authorization code")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
# Retrieve state data from Redis using state key
|
|
|
|
|
|
|
|
full_state_data = _retrieve_redis_state(state_param)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code_verifier = full_state_data.code_verifier
|
|
|
|
|
|
|
|
redirect_uri = full_state_data.redirect_uri
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not code_verifier or not redirect_uri:
|
|
|
|
|
|
|
|
raise ValueError("Missing code_verifier or redirect_uri in state data")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
|
|
|
|
|
|
raise ValueError(f"Invalid state parameter: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
tokens = exchange_authorization(
|
|
|
|
tokens = exchange_authorization(
|
|
|
|
server_url,
|
|
|
|
server_url,
|
|
|
|
metadata,
|
|
|
|
metadata,
|
|
|
|
client_information,
|
|
|
|
client_information,
|
|
|
|
authorization_code,
|
|
|
|
authorization_code,
|
|
|
|
code_verifier,
|
|
|
|
code_verifier,
|
|
|
|
provider.redirect_url,
|
|
|
|
redirect_uri,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
provider.save_tokens(tokens)
|
|
|
|
provider.save_tokens(tokens)
|
|
|
|
return {"result": "success"}
|
|
|
|
return {"result": "success"}
|
|
|
|
@ -235,6 +338,7 @@ def auth(
|
|
|
|
client_information,
|
|
|
|
client_information,
|
|
|
|
provider.redirect_url,
|
|
|
|
provider.redirect_url,
|
|
|
|
provider.provider_id,
|
|
|
|
provider.provider_id,
|
|
|
|
|
|
|
|
provider.tenant_id,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
provider.save_code_verifier(code_verifier)
|
|
|
|
provider.save_code_verifier(code_verifier)
|
|
|
|
|