chore: change the oauth process

pull/22036/head
Novice 11 months ago
parent a467612b2b
commit b5b5d7493d

@ -1,7 +1,7 @@
import io import io
import validators import validators
from flask import send_file from flask import redirect, send_file
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -10,7 +10,7 @@ from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.mcp.auth.auth_flow import auth from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
@ -759,28 +759,16 @@ class ToolMCPUpdateApi(Resource):
return jsonable_encoder(tools) return jsonable_encoder(tools)
class ToolMCPTokenApi(Resource): class ToolMCPCallbackApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self): def get(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="args") parser.add_argument("code", type=str, required=True, nullable=False, location="args")
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="args") parser.add_argument("state", type=str, required=True, nullable=False, location="args")
args = parser.parse_args() args = parser.parse_args()
server_url = MCPToolManageService.get_mcp_provider_server_url( state_key = args["state"]
current_user.current_tenant_id, args["provider_id"] authorization_code = args["code"]
) full_state_data = handle_callback(state_key, authorization_code)
provider = MCPToolManageService.get_mcp_provider_by_provider_id( return redirect(f"{dify_config.CONSOLE_WEB_URL}/tools?mcp_provider_id={full_state_data.provider_id}")
args["provider_id"], current_user.current_tenant_id
)
if not provider:
raise ValueError("provider not found")
return auth(
OAuthClientProvider(args["provider_id"], current_user.current_tenant_id),
server_url,
authorization_code=args["authorization_code"],
)
# tool provider # tool provider
@ -822,7 +810,7 @@ api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/
api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp") api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>") api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth") api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
api.add_resource(ToolMCPTokenApi, "/workspaces/current/tool-provider/mcp/token") api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback")
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin") api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api") api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")

@ -1,11 +1,14 @@
import base64 import base64
import hashlib import hashlib
import json
import os import os
import secrets
import urllib.parse import urllib.parse
from typing import Optional from typing import Optional
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests
from pydantic import BaseModel
from core.mcp.auth.auth_provider import OAuthClientProvider from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.types import ( from core.mcp.types import (
@ -15,8 +18,21 @@ from core.mcp.types import (
OAuthMetadata, OAuthMetadata,
OAuthTokens, OAuthTokens,
) )
from extensions.ext_redis import redis_client
LATEST_PROTOCOL_VERSION = "1.0" LATEST_PROTOCOL_VERSION = "1.0"
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
class OAuthCallbackState(BaseModel):
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str
def generate_pkce_challenge() -> tuple[str, str]: def generate_pkce_challenge() -> tuple[str, str]:
@ -31,6 +47,62 @@ def generate_pkce_challenge() -> tuple[str, str]:
return code_verifier, code_challenge return code_verifier, code_challenge
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
# Generate a secure random state key
state_key = secrets.token_urlsafe(32)
# Store the state data in Redis with expiration
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
return state_key
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
"""Retrieve and decode OAuth state data from Redis using the state key."""
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
# Get state data from Redis
state_data = redis_client.get(redis_key)
if not state_data:
raise ValueError("State parameter has expired or does not exist")
try:
# Parse and validate the state data
if isinstance(state_data, bytes):
state_data = state_data.decode("utf-8")
oauth_state = OAuthCallbackState.model_validate_json(state_data)
return oauth_state
except Exception as e:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
"""Handle the callback from the OAuth provider."""
# Retrieve state data from Redis
full_state_data = _retrieve_redis_state(state_key)
# Clean up the state data from Redis after successful retrieval
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
redis_client.delete(redis_key)
tokens = exchange_authorization(
full_state_data.server_url,
full_state_data.metadata,
full_state_data.client_information,
authorization_code,
full_state_data.code_verifier,
full_state_data.redirect_uri,
)
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id)
provider.save_tokens(tokens)
return full_state_data
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]: def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata.""" """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
url = urljoin(server_url, "/.well-known/oauth-authorization-server") url = urljoin(server_url, "/.well-known/oauth-authorization-server")
@ -60,8 +132,9 @@ def start_authorization(
client_information: OAuthClientInformation, client_information: OAuthClientInformation,
redirect_url: str, redirect_url: str,
provider_id: str, provider_id: str,
tenant_id: str,
) -> tuple[str, str]: ) -> tuple[str, str]:
"""Begins the authorization flow.""" """Begins the authorization flow with secure Redis state storage."""
response_type = "code" response_type = "code"
code_challenge_method = "S256" code_challenge_method = "S256"
@ -81,13 +154,27 @@ def start_authorization(
code_verifier, code_challenge = generate_pkce_challenge() code_verifier, code_challenge = generate_pkce_challenge()
# Prepare state data with all necessary information
state_data = OAuthCallbackState(
provider_id=provider_id,
tenant_id=tenant_id,
server_url=server_url,
metadata=metadata,
client_information=client_information,
code_verifier=code_verifier,
redirect_uri=redirect_url,
)
# Store state data in Redis and generate secure state key
state_key = _create_secure_redis_state(state_data)
params = { params = {
"response_type": response_type, "response_type": response_type,
"client_id": client_information.client_id, "client_id": client_information.client_id,
"code_challenge": code_challenge, "code_challenge": code_challenge,
"code_challenge_method": code_challenge_method, "code_challenge_method": code_challenge_method,
"redirect_uri": redirect_url, "redirect_uri": redirect_url,
"state": provider_id, "state": state_key,
} }
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}" authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
@ -187,8 +274,9 @@ def auth(
provider: OAuthClientProvider, provider: OAuthClientProvider,
server_url: str, server_url: str,
authorization_code: Optional[str] = None, authorization_code: Optional[str] = None,
state_param: Optional[str] = None,
) -> dict[str, str]: ) -> dict[str, str]:
"""Orchestrates the full auth flow with a server.""" """Orchestrates the full auth flow with a server using secure Redis state storage."""
metadata = discover_oauth_metadata(server_url) metadata = discover_oauth_metadata(server_url)
# Handle client registration if needed # Handle client registration if needed
@ -205,14 +293,29 @@ def auth(
# Exchange authorization code for tokens # Exchange authorization code for tokens
if authorization_code is not None: if authorization_code is not None:
code_verifier = provider.code_verifier() if not state_param:
raise ValueError("State parameter is required when exchanging authorization code")
try:
# Retrieve state data from Redis using state key
full_state_data = _retrieve_redis_state(state_param)
code_verifier = full_state_data.code_verifier
redirect_uri = full_state_data.redirect_uri
if not code_verifier or not redirect_uri:
raise ValueError("Missing code_verifier or redirect_uri in state data")
except (json.JSONDecodeError, ValueError) as e:
raise ValueError(f"Invalid state parameter: {e}")
tokens = exchange_authorization( tokens = exchange_authorization(
server_url, server_url,
metadata, metadata,
client_information, client_information,
authorization_code, authorization_code,
code_verifier, code_verifier,
provider.redirect_url, redirect_uri,
) )
provider.save_tokens(tokens) provider.save_tokens(tokens)
return {"result": "success"} return {"result": "success"}
@ -235,6 +338,7 @@ def auth(
client_information, client_information,
provider.redirect_url, provider.redirect_url,
provider.provider_id, provider.provider_id,
provider.tenant_id,
) )
provider.save_code_verifier(code_verifier) provider.save_code_verifier(code_verifier)

@ -23,7 +23,7 @@ class OAuthClientProvider:
@property @property
def redirect_url(self) -> str: def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization.""" """The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_WEB_URL + "/tools" return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property @property
def client_metadata(self) -> OAuthClientMetadata: def client_metadata(self) -> OAuthClientMetadata:

Loading…
Cancel
Save