chore: change the oauth process

pull/22036/head
Novice 11 months ago
parent a467612b2b
commit b5b5d7493d

@ -1,7 +1,7 @@
import io
import validators
from flask import send_file
from flask import redirect, send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
@ -10,7 +10,7 @@ from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.mcp.auth.auth_flow import auth
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
@ -759,28 +759,16 @@ class ToolMCPUpdateApi(Resource):
return jsonable_encoder(tools)
class ToolMCPTokenApi(Resource):
@setup_required
@login_required
@account_initialization_required
class ToolMCPCallbackApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="args")
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="args")
parser.add_argument("code", type=str, required=True, nullable=False, location="args")
parser.add_argument("state", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
server_url = MCPToolManageService.get_mcp_provider_server_url(
current_user.current_tenant_id, args["provider_id"]
)
provider = MCPToolManageService.get_mcp_provider_by_provider_id(
args["provider_id"], current_user.current_tenant_id
)
if not provider:
raise ValueError("provider not found")
return auth(
OAuthClientProvider(args["provider_id"], current_user.current_tenant_id),
server_url,
authorization_code=args["authorization_code"],
)
state_key = args["state"]
authorization_code = args["code"]
full_state_data = handle_callback(state_key, authorization_code)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/tools?mcp_provider_id={full_state_data.provider_id}")
# tool provider
@ -822,7 +810,7 @@ api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/
api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
api.add_resource(ToolMCPTokenApi, "/workspaces/current/tool-provider/mcp/token")
api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback")
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")

@ -1,11 +1,14 @@
import base64
import hashlib
import json
import os
import secrets
import urllib.parse
from typing import Optional
from urllib.parse import urljoin
import requests
from pydantic import BaseModel
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.types import (
@ -15,8 +18,21 @@ from core.mcp.types import (
OAuthMetadata,
OAuthTokens,
)
from extensions.ext_redis import redis_client
LATEST_PROTOCOL_VERSION = "1.0"
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
class OAuthCallbackState(BaseModel):
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str
def generate_pkce_challenge() -> tuple[str, str]:
@ -31,6 +47,62 @@ def generate_pkce_challenge() -> tuple[str, str]:
return code_verifier, code_challenge
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
# Generate a secure random state key
state_key = secrets.token_urlsafe(32)
# Store the state data in Redis with expiration
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
return state_key
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
"""Retrieve and decode OAuth state data from Redis using the state key."""
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
# Get state data from Redis
state_data = redis_client.get(redis_key)
if not state_data:
raise ValueError("State parameter has expired or does not exist")
try:
# Parse and validate the state data
if isinstance(state_data, bytes):
state_data = state_data.decode("utf-8")
oauth_state = OAuthCallbackState.model_validate_json(state_data)
return oauth_state
except Exception as e:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
"""Handle the callback from the OAuth provider."""
# Retrieve state data from Redis
full_state_data = _retrieve_redis_state(state_key)
# Clean up the state data from Redis after successful retrieval
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
redis_client.delete(redis_key)
tokens = exchange_authorization(
full_state_data.server_url,
full_state_data.metadata,
full_state_data.client_information,
authorization_code,
full_state_data.code_verifier,
full_state_data.redirect_uri,
)
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id)
provider.save_tokens(tokens)
return full_state_data
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
@ -60,8 +132,9 @@ def start_authorization(
client_information: OAuthClientInformation,
redirect_url: str,
provider_id: str,
tenant_id: str,
) -> tuple[str, str]:
"""Begins the authorization flow."""
"""Begins the authorization flow with secure Redis state storage."""
response_type = "code"
code_challenge_method = "S256"
@ -81,13 +154,27 @@ def start_authorization(
code_verifier, code_challenge = generate_pkce_challenge()
# Prepare state data with all necessary information
state_data = OAuthCallbackState(
provider_id=provider_id,
tenant_id=tenant_id,
server_url=server_url,
metadata=metadata,
client_information=client_information,
code_verifier=code_verifier,
redirect_uri=redirect_url,
)
# Store state data in Redis and generate secure state key
state_key = _create_secure_redis_state(state_data)
params = {
"response_type": response_type,
"client_id": client_information.client_id,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"redirect_uri": redirect_url,
"state": provider_id,
"state": state_key,
}
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
@ -187,8 +274,9 @@ def auth(
provider: OAuthClientProvider,
server_url: str,
authorization_code: Optional[str] = None,
state_param: Optional[str] = None,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server."""
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
metadata = discover_oauth_metadata(server_url)
# Handle client registration if needed
@ -205,14 +293,29 @@ def auth(
# Exchange authorization code for tokens
if authorization_code is not None:
code_verifier = provider.code_verifier()
if not state_param:
raise ValueError("State parameter is required when exchanging authorization code")
try:
# Retrieve state data from Redis using state key
full_state_data = _retrieve_redis_state(state_param)
code_verifier = full_state_data.code_verifier
redirect_uri = full_state_data.redirect_uri
if not code_verifier or not redirect_uri:
raise ValueError("Missing code_verifier or redirect_uri in state data")
except (json.JSONDecodeError, ValueError) as e:
raise ValueError(f"Invalid state parameter: {e}")
tokens = exchange_authorization(
server_url,
metadata,
client_information,
authorization_code,
code_verifier,
provider.redirect_url,
redirect_uri,
)
provider.save_tokens(tokens)
return {"result": "success"}
@ -235,6 +338,7 @@ def auth(
client_information,
provider.redirect_url,
provider.provider_id,
provider.tenant_id,
)
provider.save_code_verifier(code_verifier)

@ -23,7 +23,7 @@ class OAuthClientProvider:
@property
def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_WEB_URL + "/tools"
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property
def client_metadata(self) -> OAuthClientMetadata:

Loading…
Cancel
Save