feat: implement serveless streamable server

pull/22036/head
Novice 11 months ago
parent ac3438e187
commit b4317cd0dc

@ -1,8 +1,5 @@
import logging
from flask_restful import Resource, reqparse
from pydantic import ValidationError
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.mcp import api
@ -11,13 +8,11 @@ from controllers.web.error import (
)
from core.app.app_config.entities import VariableEntity
from core.mcp.server.handler import MCPServerReuqestHandler
from core.mcp.types import ClientRequest
from core.mcp.types import ClientNotification, ClientRequest
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMCPServer, AppMode
logger = logging.getLogger(__name__)
class MCPAppApi(Resource):
def post(self, server_code):
@ -27,15 +22,14 @@ class MCPAppApi(Resource):
elif isinstance(value, str):
return int(value)
else:
raise ValueError("Invalid id")
return None
parser = reqparse.RequestParser()
parser.add_argument("jsonrpc", type=str, required=True, location="json")
parser.add_argument("method", type=str, required=True, location="json")
parser.add_argument("params", type=dict, required=True, location="json")
parser.add_argument("id", type=int_or_str, required=True, location="json")
parser.add_argument("params", type=dict, required=False, location="json")
parser.add_argument("id", type=int_or_str, required=False, location="json")
args = parser.parse_args()
logger.info(f"MCP request: {args}")
server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
if not server:
raise NotFound("Server Not Found")
@ -62,12 +56,17 @@ class MCPAppApi(Resource):
except ValidationError as e:
raise ValueError(f"Invalid user_input_form: {str(e)}")
try:
request = ClientRequest.model_validate(args)
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
except ValidationError as e:
raise ValueError(f"Invalid MCP request: {str(e)}")
with Session(db.engine) as session:
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form, session)
return helper.compact_generate_response(mcp_server_handler.handle())
try:
notification = ClientNotification.model_validate(args)
request = notification
except ValidationError as e:
raise ValueError(f"Invalid MCP request: {str(e)}")
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form)
response = mcp_server_handler.handle()
return helper.compact_generate_response(response)
api.add_resource(MCPAppApi, "/server/<string:server_code>/mcp")

@ -1,10 +1,7 @@
import json
import logging
from collections.abc import Mapping
from typing import Any, cast
from sqlalchemy.orm import Session
from configs import dify_config
from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType
@ -12,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService
@ -19,17 +17,16 @@ from services.app_generate_service import AppGenerateService
Apply to MCP HTTP streamable server with stateless http
"""
logger = logging.getLogger(__name__)
class MCPServerReuqestHandler:
def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity], session: Session):
def __init__(
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
):
self.app = app
self.request = request
if not self.app.mcp_server:
raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = self.app.mcp_server
self._session = session
self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form
@ -61,7 +58,11 @@ class MCPServerReuqestHandler:
tools=types.ToolsCapability(listChanged=False),
)
def response(self, response: types.Result):
def response(self, response: types.Result | str):
if isinstance(response, str):
sse_content = f"event: ping\ndata: {response}\n\n".encode()
yield sse_content
return
json_response = types.JSONRPCResponse(
jsonrpc="2.0",
id=(self.request.root.model_extra or {}).get("id", 1),
@ -77,7 +78,7 @@ class MCPServerReuqestHandler:
error_data = types.ErrorData(code=code, message=message, data=data)
json_response = types.JSONRPCError(
jsonrpc="2.0",
id=(self.request.root.model_extra or {}).get("id", 1),
id=(self.request.root.model_extra or {}).get("id", 1) or 1,
error=error_data,
)
json_data = json.dumps(jsonable_encoder(json_response))
@ -91,6 +92,7 @@ class MCPServerReuqestHandler:
types.InitializeRequest: self.initialize,
types.ListToolsRequest: self.list_tools,
types.CallToolRequest: self.invoke_tool,
types.InitializedNotification: self.handle_notification,
}
try:
if self.request_type in handle_map:
@ -102,8 +104,10 @@ class MCPServerReuqestHandler:
except Exception as e:
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
def handle_notification(self):
return "ping"
def initialize(self):
logger.info(f"Initialize: {self.request}")
request = cast(types.InitializeRequest, self.request.root)
client_info = request.params.clientInfo
clinet_name = f"{client_info.name}@{client_info.version}"
@ -116,8 +120,8 @@ class MCPServerReuqestHandler:
session_id=generate_session_id(),
external_user_id=self.mcp_server.id,
)
self._session.add(end_user)
self._session.commit()
db.session.add(end_user)
db.session.commit()
return types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self.capabilities,
@ -126,7 +130,6 @@ class MCPServerReuqestHandler:
)
def list_tools(self):
logger.info(f"List tools: {self.request}")
if not self.end_user:
raise ValueError("User not found")
return types.ListToolsResult(
@ -170,7 +173,7 @@ class MCPServerReuqestHandler:
def retrieve_end_user(self):
return (
self._session.query(EndUser)
db.session.query(EndUser)
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)

@ -1161,6 +1161,7 @@ class ServerMessageMetadata:
"""Metadata specific to server messages."""
related_request_id: RequestId | None = None
request_context: object | None = None
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None

@ -367,7 +367,7 @@ class ToolTransformService:
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
"""Process properties recursively"""
TYPE_MAPPING = {"integer": "number"}
TYPE_MAPPING = {"integer": "number", "float": "number"}
COMPLEX_TYPES = ["array", "object"]
parameters = []

Loading…
Cancel
Save