feat: implement serveless streamable server

pull/22036/head
Novice 11 months ago
parent ac3438e187
commit b4317cd0dc

@ -1,8 +1,5 @@
import logging
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.mcp import api from controllers.mcp import api
@ -11,13 +8,11 @@ from controllers.web.error import (
) )
from core.app.app_config.entities import VariableEntity from core.app.app_config.entities import VariableEntity
from core.mcp.server.handler import MCPServerReuqestHandler from core.mcp.server.handler import MCPServerReuqestHandler
from core.mcp.types import ClientRequest from core.mcp.types import ClientNotification, ClientRequest
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from models.model import App, AppMCPServer, AppMode from models.model import App, AppMCPServer, AppMode
logger = logging.getLogger(__name__)
class MCPAppApi(Resource): class MCPAppApi(Resource):
def post(self, server_code): def post(self, server_code):
@ -27,15 +22,14 @@ class MCPAppApi(Resource):
elif isinstance(value, str): elif isinstance(value, str):
return int(value) return int(value)
else: else:
raise ValueError("Invalid id") return None
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("jsonrpc", type=str, required=True, location="json") parser.add_argument("jsonrpc", type=str, required=True, location="json")
parser.add_argument("method", type=str, required=True, location="json") parser.add_argument("method", type=str, required=True, location="json")
parser.add_argument("params", type=dict, required=True, location="json") parser.add_argument("params", type=dict, required=False, location="json")
parser.add_argument("id", type=int_or_str, required=True, location="json") parser.add_argument("id", type=int_or_str, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
logger.info(f"MCP request: {args}")
server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first() server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
if not server: if not server:
raise NotFound("Server Not Found") raise NotFound("Server Not Found")
@ -62,12 +56,17 @@ class MCPAppApi(Resource):
except ValidationError as e: except ValidationError as e:
raise ValueError(f"Invalid user_input_form: {str(e)}") raise ValueError(f"Invalid user_input_form: {str(e)}")
try: try:
request = ClientRequest.model_validate(args) request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
except ValidationError as e: except ValidationError as e:
raise ValueError(f"Invalid MCP request: {str(e)}") try:
with Session(db.engine) as session: notification = ClientNotification.model_validate(args)
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form, session) request = notification
return helper.compact_generate_response(mcp_server_handler.handle()) except ValidationError as e:
raise ValueError(f"Invalid MCP request: {str(e)}")
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form)
response = mcp_server_handler.handle()
return helper.compact_generate_response(response)
api.add_resource(MCPAppApi, "/server/<string:server_code>/mcp") api.add_resource(MCPAppApi, "/server/<string:server_code>/mcp")

@ -1,10 +1,7 @@
import json import json
import logging
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, cast from typing import Any, cast
from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from controllers.web.passport import generate_session_id from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.app_config.entities import VariableEntity, VariableEntityType
@ -12,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.mcp import types from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.model import App, AppMCPServer, AppMode, EndUser from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
@ -19,17 +17,16 @@ from services.app_generate_service import AppGenerateService
Apply to MCP HTTP streamable server with stateless http Apply to MCP HTTP streamable server with stateless http
""" """
logger = logging.getLogger(__name__)
class MCPServerReuqestHandler: class MCPServerReuqestHandler:
def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity], session: Session): def __init__(
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
):
self.app = app self.app = app
self.request = request self.request = request
if not self.app.mcp_server: if not self.app.mcp_server:
raise ValueError("MCP server not found") raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = self.app.mcp_server self.mcp_server: AppMCPServer = self.app.mcp_server
self._session = session
self.end_user = self.retrieve_end_user() self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form self.user_input_form = user_input_form
@ -61,7 +58,11 @@ class MCPServerReuqestHandler:
tools=types.ToolsCapability(listChanged=False), tools=types.ToolsCapability(listChanged=False),
) )
def response(self, response: types.Result): def response(self, response: types.Result | str):
if isinstance(response, str):
sse_content = f"event: ping\ndata: {response}\n\n".encode()
yield sse_content
return
json_response = types.JSONRPCResponse( json_response = types.JSONRPCResponse(
jsonrpc="2.0", jsonrpc="2.0",
id=(self.request.root.model_extra or {}).get("id", 1), id=(self.request.root.model_extra or {}).get("id", 1),
@ -77,7 +78,7 @@ class MCPServerReuqestHandler:
error_data = types.ErrorData(code=code, message=message, data=data) error_data = types.ErrorData(code=code, message=message, data=data)
json_response = types.JSONRPCError( json_response = types.JSONRPCError(
jsonrpc="2.0", jsonrpc="2.0",
id=(self.request.root.model_extra or {}).get("id", 1), id=(self.request.root.model_extra or {}).get("id", 1) or 1,
error=error_data, error=error_data,
) )
json_data = json.dumps(jsonable_encoder(json_response)) json_data = json.dumps(jsonable_encoder(json_response))
@ -91,6 +92,7 @@ class MCPServerReuqestHandler:
types.InitializeRequest: self.initialize, types.InitializeRequest: self.initialize,
types.ListToolsRequest: self.list_tools, types.ListToolsRequest: self.list_tools,
types.CallToolRequest: self.invoke_tool, types.CallToolRequest: self.invoke_tool,
types.InitializedNotification: self.handle_notification,
} }
try: try:
if self.request_type in handle_map: if self.request_type in handle_map:
@ -102,8 +104,10 @@ class MCPServerReuqestHandler:
except Exception as e: except Exception as e:
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}") return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
def handle_notification(self):
return "ping"
def initialize(self): def initialize(self):
logger.info(f"Initialize: {self.request}")
request = cast(types.InitializeRequest, self.request.root) request = cast(types.InitializeRequest, self.request.root)
client_info = request.params.clientInfo client_info = request.params.clientInfo
clinet_name = f"{client_info.name}@{client_info.version}" clinet_name = f"{client_info.name}@{client_info.version}"
@ -116,8 +120,8 @@ class MCPServerReuqestHandler:
session_id=generate_session_id(), session_id=generate_session_id(),
external_user_id=self.mcp_server.id, external_user_id=self.mcp_server.id,
) )
self._session.add(end_user) db.session.add(end_user)
self._session.commit() db.session.commit()
return types.InitializeResult( return types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION, protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self.capabilities, capabilities=self.capabilities,
@ -126,7 +130,6 @@ class MCPServerReuqestHandler:
) )
def list_tools(self): def list_tools(self):
logger.info(f"List tools: {self.request}")
if not self.end_user: if not self.end_user:
raise ValueError("User not found") raise ValueError("User not found")
return types.ListToolsResult( return types.ListToolsResult(
@ -170,7 +173,7 @@ class MCPServerReuqestHandler:
def retrieve_end_user(self): def retrieve_end_user(self):
return ( return (
self._session.query(EndUser) db.session.query(EndUser)
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first() .first()
) )

@ -1161,6 +1161,7 @@ class ServerMessageMetadata:
"""Metadata specific to server messages.""" """Metadata specific to server messages."""
related_request_id: RequestId | None = None related_request_id: RequestId | None = None
request_context: object | None = None
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None

@ -367,7 +367,7 @@ class ToolTransformService:
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]: def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
"""Process properties recursively""" """Process properties recursively"""
TYPE_MAPPING = {"integer": "number"} TYPE_MAPPING = {"integer": "number", "float": "number"}
COMPLEX_TYPES = ["array", "object"] COMPLEX_TYPES = ["array", "object"]
parameters = [] parameters = []

Loading…
Cancel
Save