diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 0869a29add..6e8b3ab603 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -715,6 +715,7 @@ class ToolMCPAuthApi(Resource): except MCPAuthError: auth_provider = OAuthClientProvider(provider_id, tenant_id) + return auth(auth_provider, server_url, args["authorization_code"]) diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 03c607fc18..0ad4b73acd 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -41,7 +41,7 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N if response.status_code == 404: return None if not response.ok: - raise Exception(f"HTTP {response.status_code} trying to load well-known OAuth metadata") + raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") return OAuthMetadata.model_validate(response.json()) except requests.RequestException as e: if isinstance(e, requests.ConnectionError): @@ -49,7 +49,7 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N if response.status_code == 404: return None if not response.ok: - raise Exception(f"HTTP {response.status_code} trying to load well-known OAuth metadata") + raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") return OAuthMetadata.model_validate(response.json()) raise @@ -68,12 +68,14 @@ def start_authorization( if metadata: authorization_url = metadata.authorization_endpoint if response_type not in metadata.response_types_supported: - raise Exception(f"Incompatible auth server: does not support response type {response_type}") + raise ValueError(f"Incompatible auth server: does not support response type {response_type}") if ( not metadata.code_challenge_methods_supported or code_challenge_method not in metadata.code_challenge_methods_supported ): - raise Exception(f"Incompatible auth server: does not support code challenge method {code_challenge_method}") + raise ValueError( + f"Incompatible auth server: does not support code challenge method {code_challenge_method}" + ) else: authorization_url = urljoin(server_url, "/authorize") @@ -106,7 +108,7 @@ def exchange_authorization( if metadata: token_url = metadata.token_endpoint if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported: - raise Exception(f"Incompatible auth server: does not support grant type {grant_type}") + raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}") else: token_url = urljoin(server_url, "/token") @@ -123,7 +125,7 @@ def exchange_authorization( response = requests.post(token_url, data=params) if not response.ok: - raise Exception(f"Token exchange failed: HTTP {response.status_code}") + raise ValueError(f"Token exchange failed: HTTP {response.status_code}") return OAuthTokens.model_validate(response.json()) @@ -139,7 +141,7 @@ def refresh_authorization( if metadata: token_url = metadata.token_endpoint if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported: - raise Exception(f"Incompatible auth server: does not support grant type {grant_type}") + raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}") else: token_url = urljoin(server_url, "/token") @@ -154,7 +156,7 @@ def refresh_authorization( response = requests.post(token_url, data=params) if not response.ok: - raise Exception(f"Token refresh failed: HTTP {response.status_code}") + raise ValueError(f"Token refresh failed: HTTP {response.status_code}") return OAuthTokens.parse_obj(response.json()) @@ -166,7 +168,7 @@ def register_client( """Performs OAuth 2.0 Dynamic Client Registration.""" if metadata: if not metadata.registration_endpoint: - raise Exception("Incompatible auth server: does not support dynamic client registration") + raise ValueError("Incompatible auth server: does not support dynamic client registration") registration_url = metadata.registration_endpoint else: registration_url = urljoin(server_url, "/register") diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 35744d6a8f..3e745d34a4 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -270,8 +270,6 @@ def sse_client( with ThreadPoolExecutor() as executor: try: - logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client: with ssrf_proxy_sse_connect( url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index bdbba6922f..495b56201b 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -436,8 +436,6 @@ def streamablehttp_client( with ThreadPoolExecutor(max_workers=2) as executor: try: - logger.info(f"Connecting to StreamableHTTP endpoint: {url}") - with create_ssrf_proxy_mcp_http_client( headers=transport.request_headers, timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds), diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_mange_service.py index 578cfea95e..7c4362913d 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_mange_service.py @@ -1,8 +1,10 @@ import hashlib import json +from datetime import datetime from urllib.parse import urlparse from sqlalchemy import or_ +from sqlalchemy.exc import IntegrityError from core.helper import encrypter from core.mcp.error import MCPAuthError, MCPConnectionError @@ -103,6 +105,7 @@ class MCPToolManageService: raise ValueError(f"Failed to connect to MCP server: {e}") mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) mcp_provider.authed = True + mcp_provider.updated_at = datetime.now() db.session.commit() return ToolProviderApiEntity( id=mcp_provider.id, @@ -149,11 +152,41 @@ class MCPToolManageService: encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) mcp_provider.name = name mcp_provider.server_url = encrypted_server_url - mcp_provider.server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() - mcp_provider.icon = ( - json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon - ) - db.session.commit() + server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() + # if the server url is changed, we need to re-auth the tool + try: + if server_url_hash != mcp_provider.server_url_hash: + try: + with MCPClient( + server_url, + provider_id, + tenant_id, + authed=False, + ) as mcp_client: + tools = mcp_client.list_tools() + mcp_provider.authed = True + mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) + except MCPAuthError: + mcp_provider.authed = False + mcp_provider.tools = "[]" + mcp_provider.encrypted_credentials = "{}" + mcp_provider.server_url_hash = server_url_hash + mcp_provider.icon = ( + json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon + ) + db.session.commit() + except IntegrityError as e: + db.session.rollback() + # Check if the error message contains the constraint name + if "unique_mcp_provider_name" in str(e.orig): + # Raise your custom exception + raise ValueError(f"A provider with name '{name}' already exists.") + elif "unique_mcp_provider_server_url" in str(e.orig): + # You can define another custom exception for the other constraint + raise ValueError(f"A provider for server URL '{server_url}' already exists.") + else: + # Re-raise the original exception if it's not the one you're handling + raise @classmethod def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False):