fix: update mcp tool auth

pull/22036/head
Novice 11 months ago
parent 094727a16a
commit 9dd1cd9df8

@ -715,6 +715,7 @@ class ToolMCPAuthApi(Resource):
except MCPAuthError: except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id) auth_provider = OAuthClientProvider(provider_id, tenant_id)
return auth(auth_provider, server_url, args["authorization_code"]) return auth(auth_provider, server_url, args["authorization_code"])

@ -41,7 +41,7 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N
if response.status_code == 404: if response.status_code == 404:
return None return None
if not response.ok: if not response.ok:
raise Exception(f"HTTP {response.status_code} trying to load well-known OAuth metadata") raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json()) return OAuthMetadata.model_validate(response.json())
except requests.RequestException as e: except requests.RequestException as e:
if isinstance(e, requests.ConnectionError): if isinstance(e, requests.ConnectionError):
@ -49,7 +49,7 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N
if response.status_code == 404: if response.status_code == 404:
return None return None
if not response.ok: if not response.ok:
raise Exception(f"HTTP {response.status_code} trying to load well-known OAuth metadata") raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json()) return OAuthMetadata.model_validate(response.json())
raise raise
@ -68,12 +68,14 @@ def start_authorization(
if metadata: if metadata:
authorization_url = metadata.authorization_endpoint authorization_url = metadata.authorization_endpoint
if response_type not in metadata.response_types_supported: if response_type not in metadata.response_types_supported:
raise Exception(f"Incompatible auth server: does not support response type {response_type}") raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
if ( if (
not metadata.code_challenge_methods_supported not metadata.code_challenge_methods_supported
or code_challenge_method not in metadata.code_challenge_methods_supported or code_challenge_method not in metadata.code_challenge_methods_supported
): ):
raise Exception(f"Incompatible auth server: does not support code challenge method {code_challenge_method}") raise ValueError(
f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
)
else: else:
authorization_url = urljoin(server_url, "/authorize") authorization_url = urljoin(server_url, "/authorize")
@ -106,7 +108,7 @@ def exchange_authorization(
if metadata: if metadata:
token_url = metadata.token_endpoint token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported: if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise Exception(f"Incompatible auth server: does not support grant type {grant_type}") raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else: else:
token_url = urljoin(server_url, "/token") token_url = urljoin(server_url, "/token")
@ -123,7 +125,7 @@ def exchange_authorization(
response = requests.post(token_url, data=params) response = requests.post(token_url, data=params)
if not response.ok: if not response.ok:
raise Exception(f"Token exchange failed: HTTP {response.status_code}") raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json()) return OAuthTokens.model_validate(response.json())
@ -139,7 +141,7 @@ def refresh_authorization(
if metadata: if metadata:
token_url = metadata.token_endpoint token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported: if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise Exception(f"Incompatible auth server: does not support grant type {grant_type}") raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else: else:
token_url = urljoin(server_url, "/token") token_url = urljoin(server_url, "/token")
@ -154,7 +156,7 @@ def refresh_authorization(
response = requests.post(token_url, data=params) response = requests.post(token_url, data=params)
if not response.ok: if not response.ok:
raise Exception(f"Token refresh failed: HTTP {response.status_code}") raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
return OAuthTokens.parse_obj(response.json()) return OAuthTokens.parse_obj(response.json())
@ -166,7 +168,7 @@ def register_client(
"""Performs OAuth 2.0 Dynamic Client Registration.""" """Performs OAuth 2.0 Dynamic Client Registration."""
if metadata: if metadata:
if not metadata.registration_endpoint: if not metadata.registration_endpoint:
raise Exception("Incompatible auth server: does not support dynamic client registration") raise ValueError("Incompatible auth server: does not support dynamic client registration")
registration_url = metadata.registration_endpoint registration_url = metadata.registration_endpoint
else: else:
registration_url = urljoin(server_url, "/register") registration_url = urljoin(server_url, "/register")

@ -270,8 +270,6 @@ def sse_client(
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
try: try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client: with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect( with ssrf_proxy_sse_connect(
url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client

@ -436,8 +436,6 @@ def streamablehttp_client(
with ThreadPoolExecutor(max_workers=2) as executor: with ThreadPoolExecutor(max_workers=2) as executor:
try: try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
with create_ssrf_proxy_mcp_http_client( with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers, headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds), timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),

@ -1,8 +1,10 @@
import hashlib import hashlib
import json import json
from datetime import datetime
from urllib.parse import urlparse from urllib.parse import urlparse
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from core.helper import encrypter from core.helper import encrypter
from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.error import MCPAuthError, MCPConnectionError
@ -103,6 +105,7 @@ class MCPToolManageService:
raise ValueError(f"Failed to connect to MCP server: {e}") raise ValueError(f"Failed to connect to MCP server: {e}")
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
mcp_provider.authed = True mcp_provider.authed = True
mcp_provider.updated_at = datetime.now()
db.session.commit() db.session.commit()
return ToolProviderApiEntity( return ToolProviderApiEntity(
id=mcp_provider.id, id=mcp_provider.id,
@ -149,11 +152,41 @@ class MCPToolManageService:
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
mcp_provider.name = name mcp_provider.name = name
mcp_provider.server_url = encrypted_server_url mcp_provider.server_url = encrypted_server_url
mcp_provider.server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
# if the server url is changed, we need to re-auth the tool
try:
if server_url_hash != mcp_provider.server_url_hash:
try:
with MCPClient(
server_url,
provider_id,
tenant_id,
authed=False,
) as mcp_client:
tools = mcp_client.list_tools()
mcp_provider.authed = True
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
except MCPAuthError:
mcp_provider.authed = False
mcp_provider.tools = "[]"
mcp_provider.encrypted_credentials = "{}"
mcp_provider.server_url_hash = server_url_hash
mcp_provider.icon = ( mcp_provider.icon = (
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
) )
db.session.commit() db.session.commit()
except IntegrityError as e:
db.session.rollback()
# Check if the error message contains the constraint name
if "unique_mcp_provider_name" in str(e.orig):
# Raise your custom exception
raise ValueError(f"A provider with name '{name}' already exists.")
elif "unique_mcp_provider_server_url" in str(e.orig):
# You can define another custom exception for the other constraint
raise ValueError(f"A provider for server URL '{server_url}' already exists.")
else:
# Re-raise the original exception if it's not the one you're handling
raise
@classmethod @classmethod
def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False): def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False):

Loading…
Cancel
Save