diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 8c691abffb..4f9e75c0d3 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -1,9 +1,9 @@ import json -from enum import Enum +from enum import StrEnum from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse -from werkzeug.exceptions import Forbidden +from werkzeug.exceptions import NotFound from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -14,7 +14,7 @@ from libs.login import login_required from models.model import AppMCPServer -class AppMCPServerStatus(str, Enum): +class AppMCPServerStatus(StrEnum): ACTIVE = "active" INACTIVE = "inactive" @@ -37,7 +37,7 @@ class AppMCPServerController(Resource): def post(self, app_model): # The role of the current user in the ta table must be editor, admin, or owner if not current_user.is_editor: - raise Forbidden() + raise NotFound() parser = reqparse.RequestParser() parser.add_argument("description", type=str, required=True, location="json") parser.add_argument("parameters", type=dict, required=True, location="json") @@ -62,7 +62,7 @@ class AppMCPServerController(Resource): @marshal_with(app_server_fields) def put(self, app_model): if not current_user.is_editor: - raise Forbidden() + raise NotFound() parser = reqparse.RequestParser() parser.add_argument("id", type=str, required=True, location="json") parser.add_argument("description", type=str, required=True, location="json") @@ -71,7 +71,7 @@ class AppMCPServerController(Resource): args = parser.parse_args() server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first() if not server: - raise Forbidden() + raise NotFound() server.description = args["description"] server.parameters = json.dumps(args["parameters"], ensure_ascii=False) if args["status"]: @@ -89,10 +89,10 @@ class AppMCPServerRefreshController(Resource): @marshal_with(app_server_fields) def get(self, server_id): if not current_user.is_editor: - raise Forbidden() + raise NotFound() server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first() if not server: - raise Forbidden() + raise NotFound() server.server_code = AppMCPServer.generate_server_code(16) db.session.commit() return server diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 846c5ab419..6ac3c4b20b 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,6 +1,6 @@ import io +from urllib.parse import urlparse -import validators from flask import redirect, send_file from flask_login import current_user from flask_restful import Resource, reqparse @@ -27,6 +27,17 @@ from services.tools.tools_transform_service import ToolTransformService from services.tools.workflow_tools_manage_service import WorkflowToolManageService +def is_valid_url(url: str) -> bool: + if not url: + return False + + try: + parsed = urlparse(url) + return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"] + except Exception: + return False + + class ToolProviderListApi(Resource): @setup_required @login_required @@ -634,7 +645,7 @@ class ToolProviderMCPApi(Resource): parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") args = parser.parse_args() user = current_user - if not validators.url(args["server_url"]): + if not is_valid_url(args["server_url"]): raise ValueError("Server URL is not valid.") return jsonable_encoder( MCPToolManageService.create_mcp_provider( @@ -662,7 +673,7 @@ class ToolProviderMCPApi(Resource): parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - if not validators.url(args["server_url"]): + if not is_valid_url(args["server_url"]): if "[__HIDDEN__]" in args["server_url"]: pass else: diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 7245767c9b..22b313dbfc 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -8,7 +8,7 @@ from controllers.web.error import ( AppUnavailableError, ) from core.app.app_config.entities import VariableEntity -from core.mcp.server.handler import MCPServerReuqestHandler +from core.mcp.server.handler import MCPServerRequestHandler from core.mcp.types import ClientNotification, ClientRequest from extensions.ext_database import db from libs import helper @@ -66,7 +66,7 @@ class MCPAppApi(Resource): except ValidationError as e: raise ValueError(f"Invalid MCP request: {str(e)}") - mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form) + mcp_server_handler = MCPServerRequestHandler(app, request, user_input_form) response = mcp_server_handler.handle() return helper.compact_generate_response(response) diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 1b6afd0b06..b63478e822 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -8,7 +8,7 @@ from typing import Optional from urllib.parse import urljoin import requests -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from core.mcp.auth.auth_provider import OAuthClientProvider from core.mcp.types import ( @@ -60,7 +60,7 @@ def _create_secure_redis_state(state_data: OAuthCallbackState) -> str: def _retrieve_redis_state(state_key: str) -> OAuthCallbackState: - """Retrieve and decode OAuth state data from Redis using the state key.""" + """Retrieve and decode OAuth state data from Redis using the state key, then delete it.""" redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}" # Get state data from Redis @@ -69,27 +69,23 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState: if not state_data: raise ValueError("State parameter has expired or does not exist") + # Delete the state data from Redis immediately after retrieval to prevent reuse + redis_client.delete(redis_key) + try: # Parse and validate the state data - if isinstance(state_data, bytes): - state_data = state_data.decode("utf-8") - oauth_state = OAuthCallbackState.model_validate_json(state_data) return oauth_state - except Exception as e: + except ValidationError as e: raise ValueError(f"Invalid state parameter: {str(e)}") def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState: """Handle the callback from the OAuth provider.""" - # Retrieve state data from Redis + # Retrieve state data from Redis (state is automatically deleted after retrieval) full_state_data = _retrieve_redis_state(state_key) - # Clean up the state data from Redis after successful retrieval - redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}" - redis_client.delete(redis_key) - tokens = exchange_authorization( full_state_data.server_url, full_state_data.metadata, diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 3e745d34a4..a142a3ce48 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -3,7 +3,7 @@ import queue from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from typing import Any, cast +from typing import Any, TypeAlias, final from urllib.parse import urljoin, urlparse import httpx @@ -18,10 +18,23 @@ logger = logging.getLogger(__name__) DEFAULT_QUEUE_READ_TIMEOUT = 3 + +@final +class _StatusReady: + def __init__(self, endpoint_url: str): + self._endpoint_url = endpoint_url + + +@final +class _StatusError: + def __init__(self, exc: Exception): + self._exc = exc + + # Type aliases for better readability -ReadQueue = queue.Queue[SessionMessage | Exception | None] -WriteQueue = queue.Queue[SessionMessage | Exception | None] -StatusQueue = queue.Queue[tuple[str, str | Exception]] +ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None] +WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None] +StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError] def remove_request_params(url: str) -> str: @@ -80,10 +93,10 @@ class SSETransport: if not self._validate_endpoint_url(endpoint_url): error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}" logger.error(error_msg) - status_queue.put(("error", ValueError(error_msg))) + status_queue.put(_StatusError(ValueError(error_msg))) return - status_queue.put(("ready", endpoint_url)) + status_queue.put(_StatusReady(endpoint_url)) def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None: """Handle a 'message' SSE event. @@ -197,18 +210,17 @@ class SSETransport: ValueError: If endpoint URL is not received or there's an error. """ try: - status, endpoint_url_or_error = status_queue.get(timeout=1) + status = status_queue.get(timeout=1) except queue.Empty: raise ValueError("failed to get endpoint URL") - if status != "ready": + if isinstance(status, _StatusReady): + return status._endpoint_url + elif isinstance(status, _StatusError): + raise status._exc + else: raise ValueError("failed to get endpoint URL") - if status == "error" and isinstance(endpoint_url_or_error, Exception): - raise endpoint_url_or_error - - return cast(str, endpoint_url_or_error) - def connect( self, executor: ThreadPoolExecutor, @@ -284,9 +296,9 @@ def sse_client( if exc.response.status_code == 401: raise MCPAuthError() raise MCPConnectionError() - except Exception as exc: + except Exception: logger.exception("Error connecting to SSE endpoint") - raise exc + raise finally: # Clean up queues if read_queue: diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index 3a036a0278..e9036de8c6 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -94,14 +94,15 @@ class MCPClient: if self._streams_context is None: raise MCPConnectionError("Failed to create connection context") + # Use exit_stack to manage context managers properly if method_name == "mcp": - read_stream, write_stream, _ = self._streams_context.__enter__() + read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context) streams = (read_stream, write_stream) else: # sse_client - streams = self._streams_context.__enter__() + streams = self.exit_stack.enter_context(self._streams_context) self._session_context = ClientSession(*streams) - self._session = self._session_context.__enter__() + self._session = self.exit_stack.enter_context(self._session_context) session = cast(ClientSession, self._session) session.initialize() return @@ -138,14 +139,12 @@ class MCPClient: def cleanup(self): """Clean up resources""" try: - if self._session_context: - self._session_context.__exit__(None, None, None) - - if self._streams_context: - self._streams_context.__exit__(None, None, None) + # ExitStack will handle proper cleanup of all managed context managers + self.exit_stack.close() self._session = None + self._session_context = None + self._streams_context = None self._initialized = False - self.exit_stack.close() except Exception as e: logging.exception("Error during cleanup") raise ValueError(f"Error during cleanup: {e}") diff --git a/api/core/mcp/server/handler.py b/api/core/mcp/server/handler.py index 2e9e7718e1..f23dd8adae 100644 --- a/api/core/mcp/server/handler.py +++ b/api/core/mcp/server/handler.py @@ -18,15 +18,16 @@ Apply to MCP HTTP streamable server with stateless http """ -class MCPServerReuqestHandler: +class MCPServerRequestHandler: def __init__( self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity] ): self.app = app self.request = request - if not self.app.mcp_server: + mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first() + if not mcp_server: raise ValueError("MCP server not found") - self.mcp_server: AppMCPServer = self.app.mcp_server + self.mcp_server: AppMCPServer = mcp_server self.end_user = self.retrieve_end_user() self.user_input_form = user_input_form diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index 140665edf8..a8a603b3f2 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -4,20 +4,8 @@ from configs import dify_config SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES -HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True -try: - HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY - http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower() - if http_request_node_ssl_verify_lower == "true": - HTTP_REQUEST_NODE_SSL_VERIFY = True - elif http_request_node_ssl_verify_lower == "false": - HTTP_REQUEST_NODE_SSL_VERIFY = False - else: - raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'") -except NameError: - HTTP_REQUEST_NODE_SSL_VERIFY = True +HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY -BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 77ae6a70e6..93f003effe 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -46,11 +46,11 @@ class MCPToolProviderController(ToolProviderController): tools = [] tools_data = json.loads(db_provider.tools) remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data] - + user = db_provider.load_user() tools = [ ToolEntity( identity=ToolIdentity( - author=db_provider.user.name if db_provider.user else "Anonymous", + author=user.name if user else "Anonymous", name=remote_mcp_tool.name, label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name), provider=db_provider.server_identifier, @@ -72,7 +72,7 @@ class MCPToolProviderController(ToolProviderController): return cls( entity=ToolProviderEntityWithPlugin( identity=ToolProviderIdentity( - author=db_provider.user.name if db_provider.user else "Anonymous", + author=user.name if user else "Anonymous", name=db_provider.name, label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), description=I18nObject(en_US="", zh_Hans=""), diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index c7e627a998..175e0133fb 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -1,4 +1,5 @@ import base64 +import json from collections.abc import Generator from typing import Any, Optional @@ -49,6 +50,11 @@ class MCPTool(Tool): for content in result.content: if isinstance(content, TextContent): yield self.create_text_message(content.text) + try: + yield self.create_json_message(json.loads(content.text)) + except json.JSONDecodeError: + pass + elif isinstance(content, ImageContent): yield self.create_blob_message( blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType} diff --git a/api/migrations/versions/2025_06_25_0936-58eb7bdb93fe_add_mcp_server_tool_and_app_server.py b/api/migrations/versions/2025_06_25_0936-58eb7bdb93fe_add_mcp_server_tool_and_app_server.py index 0c98614dd6..0548bf05ef 100644 --- a/api/migrations/versions/2025_06_25_0936-58eb7bdb93fe_add_mcp_server_tool_and_app_server.py +++ b/api/migrations/versions/2025_06_25_0936-58eb7bdb93fe_add_mcp_server_tool_and_app_server.py @@ -32,7 +32,7 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey'), sa.UniqueConstraint('tenant_id', 'app_id', name='unique_app_mcp_server_tenant_app_id'), - sa.UniqueConstraint('tenant_id', 'server_code', name='unique_app_mcp_server_tenant_server_code') + sa.UniqueConstraint('server_code', name='unique_app_mcp_server_server_code') ) op.create_table('tool_mcp_providers', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), diff --git a/api/models/model.py b/api/models/model.py index 15d8e011c8..b1007c4a79 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -294,10 +294,6 @@ class App(Base): return tags or [] - @property - def mcp_server(self): - return db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.id).first() - @property def author_name(self): if self.created_by: @@ -1465,7 +1461,7 @@ class AppMCPServer(Base): __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"), db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"), - db.UniqueConstraint("tenant_id", "server_code", name="unique_app_mcp_server_tenant_server_code"), + db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) diff --git a/api/models/tools.py b/api/models/tools.py index 3357d6455a..9d2c3baea5 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -234,8 +234,7 @@ class MCPToolProvider(Base): db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) - @property - def user(self) -> Account | None: + def load_user(self) -> Account | None: return db.session.query(Account).filter(Account.id == self.user_id).first() @property diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_mange_service.py index 14b06c1988..b2a88738f6 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_mange_service.py @@ -125,13 +125,14 @@ class MCPToolManageService: mcp_provider.authed = True mcp_provider.updated_at = datetime.now() db.session.commit() + user = mcp_provider.load_user() return ToolProviderApiEntity( id=mcp_provider.id, name=mcp_provider.name, tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools), type=ToolProviderType.MCP, icon=mcp_provider.icon, - author=mcp_provider.user.name if mcp_provider.user else "Anonymous", + author=user.name if user else "Anonymous", server_url=mcp_provider.masked_server_url, updated_at=int(mcp_provider.updated_at.timestamp()), description=I18nObject(en_US="", zh_Hans=""), diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b78d4581e5..8009c384b7 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -191,9 +191,10 @@ class ToolTransformService: @staticmethod def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity: + user = db_provider.load_user() return ToolProviderApiEntity( id=db_provider.server_identifier if not for_list else db_provider.id, - author=db_provider.user.name if db_provider.user else "Anonymous", + author=user.name if user else "Anonymous", name=db_provider.name, icon=db_provider.provider_icon, type=ToolProviderType.MCP, @@ -210,9 +211,10 @@ class ToolTransformService: @staticmethod def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]: + user = mcp_provider.load_user() return [ ToolApiEntity( - author=mcp_provider.user.name if mcp_provider.user else "Anonymous", + author=user.name if user else "Anonymous", name=tool.name, label=I18nObject(en_US=tool.name, zh_Hans=tool.name), description=I18nObject(en_US=tool.description, zh_Hans=tool.description),