From b4317cd0dcf926ca6f50739031c3d4676b20805c Mon Sep 17 00:00:00 2001 From: Novice Date: Fri, 13 Jun 2025 16:53:18 +0800 Subject: [PATCH] feat: implement serveless streamable server --- api/controllers/mcp/mcp.py | 29 +++++++++-------- api/core/mcp/server/handler.py | 31 ++++++++++--------- api/core/mcp/types.py | 1 + api/services/tools/tools_transform_service.py | 2 +- 4 files changed, 33 insertions(+), 30 deletions(-) diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index cacd8197de..78cd4f66ac 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -1,8 +1,5 @@ -import logging - from flask_restful import Resource, reqparse from pydantic import ValidationError -from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from controllers.mcp import api @@ -11,13 +8,11 @@ from controllers.web.error import ( ) from core.app.app_config.entities import VariableEntity from core.mcp.server.handler import MCPServerReuqestHandler -from core.mcp.types import ClientRequest +from core.mcp.types import ClientNotification, ClientRequest from extensions.ext_database import db from libs import helper from models.model import App, AppMCPServer, AppMode -logger = logging.getLogger(__name__) - class MCPAppApi(Resource): def post(self, server_code): @@ -27,15 +22,14 @@ class MCPAppApi(Resource): elif isinstance(value, str): return int(value) else: - raise ValueError("Invalid id") + return None parser = reqparse.RequestParser() parser.add_argument("jsonrpc", type=str, required=True, location="json") parser.add_argument("method", type=str, required=True, location="json") - parser.add_argument("params", type=dict, required=True, location="json") - parser.add_argument("id", type=int_or_str, required=True, location="json") + parser.add_argument("params", type=dict, required=False, location="json") + parser.add_argument("id", type=int_or_str, required=False, location="json") args = parser.parse_args() - logger.info(f"MCP request: {args}") server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first() if not server: raise NotFound("Server Not Found") @@ -62,12 +56,17 @@ class MCPAppApi(Resource): except ValidationError as e: raise ValueError(f"Invalid user_input_form: {str(e)}") try: - request = ClientRequest.model_validate(args) + request: ClientRequest | ClientNotification = ClientRequest.model_validate(args) except ValidationError as e: - raise ValueError(f"Invalid MCP request: {str(e)}") - with Session(db.engine) as session: - mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form, session) - return helper.compact_generate_response(mcp_server_handler.handle()) + try: + notification = ClientNotification.model_validate(args) + request = notification + except ValidationError as e: + raise ValueError(f"Invalid MCP request: {str(e)}") + + mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form) + response = mcp_server_handler.handle() + return helper.compact_generate_response(response) api.add_resource(MCPAppApi, "/server//mcp") diff --git a/api/core/mcp/server/handler.py b/api/core/mcp/server/handler.py index 065a7da2b0..6f4ee9adf1 100644 --- a/api/core/mcp/server/handler.py +++ b/api/core/mcp/server/handler.py @@ -1,10 +1,7 @@ import json -import logging from collections.abc import Mapping from typing import Any, cast -from sqlalchemy.orm import Session - from configs import dify_config from controllers.web.passport import generate_session_id from core.app.app_config.entities import VariableEntity, VariableEntityType @@ -12,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.mcp import types from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND from core.model_runtime.utils.encoders import jsonable_encoder +from extensions.ext_database import db from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService @@ -19,17 +17,16 @@ from services.app_generate_service import AppGenerateService Apply to MCP HTTP streamable server with stateless http """ -logger = logging.getLogger(__name__) - class MCPServerReuqestHandler: - def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity], session: Session): + def __init__( + self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity] + ): self.app = app self.request = request if not self.app.mcp_server: raise ValueError("MCP server not found") self.mcp_server: AppMCPServer = self.app.mcp_server - self._session = session self.end_user = self.retrieve_end_user() self.user_input_form = user_input_form @@ -61,7 +58,11 @@ class MCPServerReuqestHandler: tools=types.ToolsCapability(listChanged=False), ) - def response(self, response: types.Result): + def response(self, response: types.Result | str): + if isinstance(response, str): + sse_content = f"event: ping\ndata: {response}\n\n".encode() + yield sse_content + return json_response = types.JSONRPCResponse( jsonrpc="2.0", id=(self.request.root.model_extra or {}).get("id", 1), @@ -77,7 +78,7 @@ class MCPServerReuqestHandler: error_data = types.ErrorData(code=code, message=message, data=data) json_response = types.JSONRPCError( jsonrpc="2.0", - id=(self.request.root.model_extra or {}).get("id", 1), + id=(self.request.root.model_extra or {}).get("id", 1) or 1, error=error_data, ) json_data = json.dumps(jsonable_encoder(json_response)) @@ -91,6 +92,7 @@ class MCPServerReuqestHandler: types.InitializeRequest: self.initialize, types.ListToolsRequest: self.list_tools, types.CallToolRequest: self.invoke_tool, + types.InitializedNotification: self.handle_notification, } try: if self.request_type in handle_map: @@ -102,8 +104,10 @@ class MCPServerReuqestHandler: except Exception as e: return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}") + def handle_notification(self): + return "ping" + def initialize(self): - logger.info(f"Initialize: {self.request}") request = cast(types.InitializeRequest, self.request.root) client_info = request.params.clientInfo clinet_name = f"{client_info.name}@{client_info.version}" @@ -116,8 +120,8 @@ class MCPServerReuqestHandler: session_id=generate_session_id(), external_user_id=self.mcp_server.id, ) - self._session.add(end_user) - self._session.commit() + db.session.add(end_user) + db.session.commit() return types.InitializeResult( protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=self.capabilities, @@ -126,7 +130,6 @@ class MCPServerReuqestHandler: ) def list_tools(self): - logger.info(f"List tools: {self.request}") if not self.end_user: raise ValueError("User not found") return types.ListToolsResult( @@ -170,7 +173,7 @@ class MCPServerReuqestHandler: def retrieve_end_user(self): return ( - self._session.query(EndUser) + db.session.query(EndUser) .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") .first() ) diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 7f26133211..603ab0cfb5 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -1161,6 +1161,7 @@ class ServerMessageMetadata: """Metadata specific to server messages.""" related_request_id: RequestId | None = None + request_context: object | None = None MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 06e042f139..dbd042cb9d 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -367,7 +367,7 @@ class ToolTransformService: def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]: """Process properties recursively""" - TYPE_MAPPING = {"integer": "number"} + TYPE_MAPPING = {"integer": "number", "float": "number"} COMPLEX_TYPES = ["array", "object"] parameters = []