From b5b5d7493daca3eb9988b621f57316026db2c60d Mon Sep 17 00:00:00 2001 From: Novice Date: Mon, 23 Jun 2025 13:53:26 +0800 Subject: [PATCH] chore: change the oauth process --- .../console/workspace/tool_providers.py | 32 ++--- api/core/mcp/auth/auth_flow.py | 114 +++++++++++++++++- api/core/mcp/auth/auth_provider.py | 2 +- 3 files changed, 120 insertions(+), 28 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 6e8b3ab603..194b01fc6a 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,7 +1,7 @@ import io import validators -from flask import send_file +from flask import redirect, send_file from flask_login import current_user from flask_restful import Resource, reqparse from sqlalchemy.orm import Session @@ -10,7 +10,7 @@ from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required -from core.mcp.auth.auth_flow import auth +from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.auth.auth_provider import OAuthClientProvider from core.mcp.error import MCPAuthError from core.mcp.mcp_client import MCPClient @@ -759,28 +759,16 @@ class ToolMCPUpdateApi(Resource): return jsonable_encoder(tools) -class ToolMCPTokenApi(Resource): - @setup_required - @login_required - @account_initialization_required +class ToolMCPCallbackApi(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument("provider_id", type=str, required=True, nullable=False, location="args") - parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="args") + parser.add_argument("code", type=str, required=True, nullable=False, location="args") + parser.add_argument("state", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - server_url = MCPToolManageService.get_mcp_provider_server_url( - current_user.current_tenant_id, args["provider_id"] - ) - provider = MCPToolManageService.get_mcp_provider_by_provider_id( - args["provider_id"], current_user.current_tenant_id - ) - if not provider: - raise ValueError("provider not found") - return auth( - OAuthClientProvider(args["provider_id"], current_user.current_tenant_id), - server_url, - authorization_code=args["authorization_code"], - ) + state_key = args["state"] + authorization_code = args["code"] + full_state_data = handle_callback(state_key, authorization_code) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/tools?mcp_provider_id={full_state_data.provider_id}") # tool provider @@ -822,7 +810,7 @@ api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/ api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp") api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/") api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth") -api.add_resource(ToolMCPTokenApi, "/workspaces/current/tool-provider/mcp/token") +api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback") api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin") api.add_resource(ToolApiListApi, "/workspaces/current/tools/api") diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 0ad4b73acd..d9917a3fbb 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -1,11 +1,14 @@ import base64 import hashlib +import json import os +import secrets import urllib.parse from typing import Optional from urllib.parse import urljoin import requests +from pydantic import BaseModel from core.mcp.auth.auth_provider import OAuthClientProvider from core.mcp.types import ( @@ -15,8 +18,21 @@ from core.mcp.types import ( OAuthMetadata, OAuthTokens, ) +from extensions.ext_redis import redis_client LATEST_PROTOCOL_VERSION = "1.0" +OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry +OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:" + + +class OAuthCallbackState(BaseModel): + provider_id: str + tenant_id: str + server_url: str + metadata: OAuthMetadata | None = None + client_information: OAuthClientInformation + code_verifier: str + redirect_uri: str def generate_pkce_challenge() -> tuple[str, str]: @@ -31,6 +47,62 @@ def generate_pkce_challenge() -> tuple[str, str]: return code_verifier, code_challenge +def _create_secure_redis_state(state_data: OAuthCallbackState) -> str: + """Create a secure state parameter by storing state data in Redis and returning a random state key.""" + # Generate a secure random state key + state_key = secrets.token_urlsafe(32) + + # Store the state data in Redis with expiration + redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}" + redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json()) + + return state_key + + +def _retrieve_redis_state(state_key: str) -> OAuthCallbackState: + """Retrieve and decode OAuth state data from Redis using the state key.""" + redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}" + + # Get state data from Redis + state_data = redis_client.get(redis_key) + + if not state_data: + raise ValueError("State parameter has expired or does not exist") + + try: + # Parse and validate the state data + if isinstance(state_data, bytes): + state_data = state_data.decode("utf-8") + + oauth_state = OAuthCallbackState.model_validate_json(state_data) + + return oauth_state + except Exception as e: + raise ValueError(f"Invalid state parameter: {str(e)}") + + +def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState: + """Handle the callback from the OAuth provider.""" + # Retrieve state data from Redis + full_state_data = _retrieve_redis_state(state_key) + + # Clean up the state data from Redis after successful retrieval + redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}" + redis_client.delete(redis_key) + + tokens = exchange_authorization( + full_state_data.server_url, + full_state_data.metadata, + full_state_data.client_information, + authorization_code, + full_state_data.code_verifier, + full_state_data.redirect_uri, + ) + provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id) + provider.save_tokens(tokens) + return full_state_data + + def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]: """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata.""" url = urljoin(server_url, "/.well-known/oauth-authorization-server") @@ -60,8 +132,9 @@ def start_authorization( client_information: OAuthClientInformation, redirect_url: str, provider_id: str, + tenant_id: str, ) -> tuple[str, str]: - """Begins the authorization flow.""" + """Begins the authorization flow with secure Redis state storage.""" response_type = "code" code_challenge_method = "S256" @@ -81,13 +154,27 @@ def start_authorization( code_verifier, code_challenge = generate_pkce_challenge() + # Prepare state data with all necessary information + state_data = OAuthCallbackState( + provider_id=provider_id, + tenant_id=tenant_id, + server_url=server_url, + metadata=metadata, + client_information=client_information, + code_verifier=code_verifier, + redirect_uri=redirect_url, + ) + + # Store state data in Redis and generate secure state key + state_key = _create_secure_redis_state(state_data) + params = { "response_type": response_type, "client_id": client_information.client_id, "code_challenge": code_challenge, "code_challenge_method": code_challenge_method, "redirect_uri": redirect_url, - "state": provider_id, + "state": state_key, } authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}" @@ -187,8 +274,9 @@ def auth( provider: OAuthClientProvider, server_url: str, authorization_code: Optional[str] = None, + state_param: Optional[str] = None, ) -> dict[str, str]: - """Orchestrates the full auth flow with a server.""" + """Orchestrates the full auth flow with a server using secure Redis state storage.""" metadata = discover_oauth_metadata(server_url) # Handle client registration if needed @@ -205,14 +293,29 @@ def auth( # Exchange authorization code for tokens if authorization_code is not None: - code_verifier = provider.code_verifier() + if not state_param: + raise ValueError("State parameter is required when exchanging authorization code") + + try: + # Retrieve state data from Redis using state key + full_state_data = _retrieve_redis_state(state_param) + + code_verifier = full_state_data.code_verifier + redirect_uri = full_state_data.redirect_uri + + if not code_verifier or not redirect_uri: + raise ValueError("Missing code_verifier or redirect_uri in state data") + + except (json.JSONDecodeError, ValueError) as e: + raise ValueError(f"Invalid state parameter: {e}") + tokens = exchange_authorization( server_url, metadata, client_information, authorization_code, code_verifier, - provider.redirect_url, + redirect_uri, ) provider.save_tokens(tokens) return {"result": "success"} @@ -235,6 +338,7 @@ def auth( client_information, provider.redirect_url, provider.provider_id, + provider.tenant_id, ) provider.save_code_verifier(code_verifier) diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index 556f3d7e5b..80e165f10d 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -23,7 +23,7 @@ class OAuthClientProvider: @property def redirect_url(self) -> str: """The URL to redirect the user agent to after authorization.""" - return dify_config.CONSOLE_WEB_URL + "/tools" + return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback" @property def client_metadata(self) -> OAuthClientMetadata: