diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 07aeec6fb8..0e3175811b 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -1,6 +1,6 @@ -from amqp import NotFound from flask_restful import Resource, reqparse from pydantic import ValidationError +from werkzeug.exceptions import NotFound from controllers.mcp import api from controllers.web.error import ( diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 51ddb343b3..9180922abb 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -123,16 +123,14 @@ def create_ssrf_proxy_mcp_http_client( Returns: Configured httpx.Client with proxy settings """ - client_kwargs = { - "verify": HTTP_REQUEST_NODE_SSL_VERIFY, - "headers": headers or {}, - "timeout": timeout, - "follow_redirects": True, # Enable redirect following for MCP connections - } - if dify_config.SSRF_PROXY_ALL_URL: - client_kwargs["proxy"] = dify_config.SSRF_PROXY_ALL_URL - return httpx.Client(**client_kwargs) + return httpx.Client( + verify=HTTP_REQUEST_NODE_SSL_VERIFY, + headers=headers or {}, + timeout=timeout, + follow_redirects=True, + proxy=dify_config.SSRF_PROXY_ALL_URL, + ) elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: proxy_mounts = { "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY), @@ -140,10 +138,20 @@ def create_ssrf_proxy_mcp_http_client( proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY ), } - client_kwargs["mounts"] = proxy_mounts - return httpx.Client(**client_kwargs) + return httpx.Client( + verify=HTTP_REQUEST_NODE_SSL_VERIFY, + headers=headers or {}, + timeout=timeout, + follow_redirects=True, + mounts=proxy_mounts, + ) else: - return httpx.Client(**client_kwargs) + return httpx.Client( + verify=HTTP_REQUEST_NODE_SSL_VERIFY, + headers=headers or {}, + timeout=timeout, + follow_redirects=True, + ) def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 19438799d4..0a851a5da8 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -24,8 +24,8 @@ def generate_pkce_challenge() -> tuple[str, str]: code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8") code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_") - code_challenge = hashlib.sha256(code_verifier.encode("utf-8")).digest() - code_challenge = base64.urlsafe_b64encode(code_challenge).decode("utf-8") + code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest() + code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8") code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_") return code_verifier, code_challenge @@ -213,12 +213,12 @@ def auth( provider.save_tokens(tokens) return {"result": "success"} - tokens = provider.tokens() + provider_tokens = provider.tokens() # Handle token refresh or new authorization - if tokens and tokens.refresh_token: + if provider_tokens and provider_tokens.refresh_token: try: - new_tokens = refresh_authorization(server_url, metadata, client_information, tokens.refresh_token) + new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token) provider.save_tokens(new_tokens) return {"result": "success"} except Exception as e: diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index 5c7d9e4333..556f3d7e5b 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -90,4 +90,4 @@ class OAuthClientProvider: if not mcp_provider: return "" credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id) - return credentials.get("code_verifier", "") + return str(credentials.get("code_verifier", "")) diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 955a695c29..86cb58dda8 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -3,7 +3,7 @@ import queue from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from typing import Any +from typing import Any, cast from urllib.parse import urljoin, urlparse import httpx @@ -38,9 +38,9 @@ def sse_client( if headers is None: headers = {} - read_queue = queue.Queue() - write_queue = queue.Queue() - status_queue = queue.Queue() + read_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue() + write_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue() + status_queue: queue.Queue[tuple[str, str | Exception]] = queue.Queue() with ThreadPoolExecutor() as executor: try: @@ -97,6 +97,9 @@ def sse_client( message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) if message is None: break + if isinstance(message, Exception): + write_queue.put(message) + continue response = client.post( endpoint_url, json=message.message.model_dump( @@ -119,13 +122,14 @@ def sse_client( executor.submit(sse_reader, status_queue) try: - status, endpoint_url = status_queue.get(timeout=1) + status, endpoint_url_or_error = status_queue.get(timeout=1) except queue.Empty: raise ValueError("failed to get endpoint URL") if status != "ready": raise ValueError("failed to get endpoint URL") - if status == "error": - raise endpoint_url + if status == "error" and isinstance(endpoint_url_or_error, Exception): + raise endpoint_url_or_error + endpoint_url = cast(str, endpoint_url_or_error) executor.submit(post_writer, endpoint_url) yield read_queue, write_queue diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index 2e5b8cfcac..649eb32abb 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -431,8 +431,8 @@ def streamablehttp_client( transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout) # Create queues with clear directional meaning - server_to_client_queue = queue.Queue() # For messages FROM server TO client - client_to_server_queue = queue.Queue() # For messages FROM client TO server + server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client + client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server with ThreadPoolExecutor(max_workers=2) as executor: try: diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index 1c460b0721..c5976f646d 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -1,7 +1,8 @@ import logging from collections.abc import Callable -from contextlib import ExitStack -from typing import Optional, cast +from contextlib import AbstractContextManager, ExitStack +from types import TracebackType +from typing import Any, Optional, cast from urllib.parse import urlparse from core.mcp.client.sse_client import sse_client @@ -39,8 +40,8 @@ class MCPClient: # Initialize session and client objects self._session: Optional[ClientSession] = None - self._streams_context = None - self._session_context = None + self._streams_context: Optional[AbstractContextManager[Any]] = None + self._session_context: Optional[ClientSession] = None self.exit_stack = ExitStack() # Whether the client has been initialized @@ -51,14 +52,19 @@ class MCPClient: self._initialized = True return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType] + ): self.cleanup() def _initialize( self, ): """Initialize the client with fallback to SSE if streamable connection fails""" - connection_methods = {"mcp": streamablehttp_client, "sse": sse_client} + connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = { + "mcp": streamablehttp_client, + "sse": sse_client, + } parsed_url = urlparse(self.server_url) path = parsed_url.path @@ -72,7 +78,9 @@ class MCPClient: except MCPConnectionError: self.connect_server(streamablehttp_client, "mcp") - def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True): + def connect_server( + self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True + ): from core.mcp.auth.auth_flow import auth try: @@ -82,6 +90,9 @@ class MCPClient: else {} ) self._streams_context = client_factory(url=self.server_url, headers=headers) + if self._streams_context is None: + raise MCPConnectionError("Failed to create connection context") + if method_name == "mcp": read_stream, write_stream, _ = self._streams_context.__enter__() streams = (read_stream, write_stream) @@ -123,8 +134,8 @@ class MCPClient: def cleanup(self): """Clean up resources""" try: - if self._session: - self._session.__exit__(None, None, None) + if self._session_context: + self._session_context.__exit__(None, None, None) if self._streams_context: self._streams_context.__exit__(None, None, None) diff --git a/api/core/mcp/server/handler.py b/api/core/mcp/server/handler.py index e218a6ba6e..c382c08e28 100644 --- a/api/core/mcp/server/handler.py +++ b/api/core/mcp/server/handler.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping -from typing import cast +from typing import Any, cast from configs import dify_config from controllers.web.passport import generate_session_id @@ -153,7 +153,7 @@ class MCPServerReuqestHandler: ) def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]): - parameters = {} + parameters: dict[str, dict[str, Any]] = {} required = [] for item in user_input_form: parameters[item.variable] = {} diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 445de9e2a3..e78ce4f34d 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -38,7 +38,7 @@ SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotif ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) -DEFAULT_RESPONSE_READ_TIMEOUT = 1 +DEFAULT_RESPONSE_READ_TIMEOUT = 1.0 class RequestResponder(Generic[ReceiveRequestT, SendResultT]): @@ -57,6 +57,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): 3. Cleanup of in-flight requests """ + request: ReceiveRequestT + _session: Any + _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any] + def __init__( self, request_id: RequestId, @@ -146,6 +150,8 @@ class BaseSession( _response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]] _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] + _receive_request_type: type[ReceiveRequestT] + _receive_notification_type: type[ReceiveNotificationT] def __init__( self, @@ -165,7 +171,6 @@ class BaseSession( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._exit_stack = ExitStack() - self._futures = [] def __enter__(self) -> Self: self._executor = ThreadPoolExecutor() @@ -198,7 +203,7 @@ class BaseSession( request_id = self._request_id self._request_id = request_id + 1 - response_queue = queue.Queue() + response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue() self._response_streams[request_id] = response_queue try: @@ -211,9 +216,9 @@ class BaseSession( self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) timeout = DEFAULT_RESPONSE_READ_TIMEOUT if request_read_timeout_seconds is not None: - timeout = request_read_timeout_seconds.total_seconds() + timeout = float(request_read_timeout_seconds.total_seconds()) elif self._session_read_timeout_seconds is not None: - timeout = self._session_read_timeout_seconds.total_seconds() + timeout = float(self._session_read_timeout_seconds.total_seconds()) while True: try: response_or_error = response_queue.get(timeout=timeout) diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index 7c8827a77f..518920c60b 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -340,8 +340,8 @@ class ClientSession( case types.ListRootsRequest(): with responder: - response = self._list_roots_callback(ctx) - client_response = ClientResponse.validate_python(response) + list_roots_response = self._list_roots_callback(ctx) + client_response = ClientResponse.validate_python(list_roots_response) responder.respond(client_response) case types.PingRequest(): diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 43040dd2fa..37c38507de 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -689,8 +689,8 @@ class ToolManager: result_providers[f"workflow_provider.{user_provider.name}"] = user_provider if "mcp" in filters: mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id) - for provider in mcp_providers: - result_providers[f"mcp_provider.{provider.name}"] = provider + for mcp_provider in mcp_providers: + result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider return BuiltinToolProviderSort.sort(list(result_providers.values())) diff --git a/api/models/model.py b/api/models/model.py index 61260cb8ad..05f0396424 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1467,7 +1467,7 @@ class AppMCPServer(Base): @property def parameters_dict(self) -> dict[str, Any]: - return json.loads(self.parameters) + return cast(dict[str, Any], json.loads(self.parameters)) class Site(Base): diff --git a/api/models/tools.py b/api/models/tools.py index 478de0db00..3f537eca5d 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -248,9 +248,8 @@ class MCPToolProvider(Base): return [Tool(**tool) for tool in json.loads(self.tools)] @property - def provider_icon(self) -> str: - icon_dict = json.loads(self.icon) - return icon_dict + def provider_icon(self) -> dict[str, str]: + return cast(dict[str, str], json.loads(self.icon)) class ToolModelInvoke(Base):