|
|
|
@ -1,10 +1,7 @@
|
|
|
|
import json
|
|
|
|
import json
|
|
|
|
import logging
|
|
|
|
|
|
|
|
from collections.abc import Mapping
|
|
|
|
from collections.abc import Mapping
|
|
|
|
from typing import Any, cast
|
|
|
|
from typing import Any, cast
|
|
|
|
|
|
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
from configs import dify_config
|
|
|
|
from controllers.web.passport import generate_session_id
|
|
|
|
from controllers.web.passport import generate_session_id
|
|
|
|
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
|
|
|
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
|
|
|
@ -12,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|
|
|
from core.mcp import types
|
|
|
|
from core.mcp import types
|
|
|
|
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
|
|
|
|
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
|
|
|
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
|
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
|
|
|
|
|
|
from extensions.ext_database import db
|
|
|
|
from models.model import App, AppMCPServer, AppMode, EndUser
|
|
|
|
from models.model import App, AppMCPServer, AppMode, EndUser
|
|
|
|
from services.app_generate_service import AppGenerateService
|
|
|
|
from services.app_generate_service import AppGenerateService
|
|
|
|
|
|
|
|
|
|
|
|
@ -19,17 +17,16 @@ from services.app_generate_service import AppGenerateService
|
|
|
|
Apply to MCP HTTP streamable server with stateless http
|
|
|
|
Apply to MCP HTTP streamable server with stateless http
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MCPServerReuqestHandler:
|
|
|
|
class MCPServerReuqestHandler:
|
|
|
|
def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity], session: Session):
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
|
|
|
|
|
|
|
|
):
|
|
|
|
self.app = app
|
|
|
|
self.app = app
|
|
|
|
self.request = request
|
|
|
|
self.request = request
|
|
|
|
if not self.app.mcp_server:
|
|
|
|
if not self.app.mcp_server:
|
|
|
|
raise ValueError("MCP server not found")
|
|
|
|
raise ValueError("MCP server not found")
|
|
|
|
self.mcp_server: AppMCPServer = self.app.mcp_server
|
|
|
|
self.mcp_server: AppMCPServer = self.app.mcp_server
|
|
|
|
self._session = session
|
|
|
|
|
|
|
|
self.end_user = self.retrieve_end_user()
|
|
|
|
self.end_user = self.retrieve_end_user()
|
|
|
|
self.user_input_form = user_input_form
|
|
|
|
self.user_input_form = user_input_form
|
|
|
|
|
|
|
|
|
|
|
|
@ -61,7 +58,11 @@ class MCPServerReuqestHandler:
|
|
|
|
tools=types.ToolsCapability(listChanged=False),
|
|
|
|
tools=types.ToolsCapability(listChanged=False),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def response(self, response: types.Result):
|
|
|
|
def response(self, response: types.Result | str):
|
|
|
|
|
|
|
|
if isinstance(response, str):
|
|
|
|
|
|
|
|
sse_content = f"event: ping\ndata: {response}\n\n".encode()
|
|
|
|
|
|
|
|
yield sse_content
|
|
|
|
|
|
|
|
return
|
|
|
|
json_response = types.JSONRPCResponse(
|
|
|
|
json_response = types.JSONRPCResponse(
|
|
|
|
jsonrpc="2.0",
|
|
|
|
jsonrpc="2.0",
|
|
|
|
id=(self.request.root.model_extra or {}).get("id", 1),
|
|
|
|
id=(self.request.root.model_extra or {}).get("id", 1),
|
|
|
|
@ -77,7 +78,7 @@ class MCPServerReuqestHandler:
|
|
|
|
error_data = types.ErrorData(code=code, message=message, data=data)
|
|
|
|
error_data = types.ErrorData(code=code, message=message, data=data)
|
|
|
|
json_response = types.JSONRPCError(
|
|
|
|
json_response = types.JSONRPCError(
|
|
|
|
jsonrpc="2.0",
|
|
|
|
jsonrpc="2.0",
|
|
|
|
id=(self.request.root.model_extra or {}).get("id", 1),
|
|
|
|
id=(self.request.root.model_extra or {}).get("id", 1) or 1,
|
|
|
|
error=error_data,
|
|
|
|
error=error_data,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
json_data = json.dumps(jsonable_encoder(json_response))
|
|
|
|
json_data = json.dumps(jsonable_encoder(json_response))
|
|
|
|
@ -91,6 +92,7 @@ class MCPServerReuqestHandler:
|
|
|
|
types.InitializeRequest: self.initialize,
|
|
|
|
types.InitializeRequest: self.initialize,
|
|
|
|
types.ListToolsRequest: self.list_tools,
|
|
|
|
types.ListToolsRequest: self.list_tools,
|
|
|
|
types.CallToolRequest: self.invoke_tool,
|
|
|
|
types.CallToolRequest: self.invoke_tool,
|
|
|
|
|
|
|
|
types.InitializedNotification: self.handle_notification,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
if self.request_type in handle_map:
|
|
|
|
if self.request_type in handle_map:
|
|
|
|
@ -102,8 +104,10 @@ class MCPServerReuqestHandler:
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
|
|
|
|
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_notification(self):
|
|
|
|
|
|
|
|
return "ping"
|
|
|
|
|
|
|
|
|
|
|
|
def initialize(self):
|
|
|
|
def initialize(self):
|
|
|
|
logger.info(f"Initialize: {self.request}")
|
|
|
|
|
|
|
|
request = cast(types.InitializeRequest, self.request.root)
|
|
|
|
request = cast(types.InitializeRequest, self.request.root)
|
|
|
|
client_info = request.params.clientInfo
|
|
|
|
client_info = request.params.clientInfo
|
|
|
|
clinet_name = f"{client_info.name}@{client_info.version}"
|
|
|
|
clinet_name = f"{client_info.name}@{client_info.version}"
|
|
|
|
@ -116,8 +120,8 @@ class MCPServerReuqestHandler:
|
|
|
|
session_id=generate_session_id(),
|
|
|
|
session_id=generate_session_id(),
|
|
|
|
external_user_id=self.mcp_server.id,
|
|
|
|
external_user_id=self.mcp_server.id,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self._session.add(end_user)
|
|
|
|
db.session.add(end_user)
|
|
|
|
self._session.commit()
|
|
|
|
db.session.commit()
|
|
|
|
return types.InitializeResult(
|
|
|
|
return types.InitializeResult(
|
|
|
|
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
|
|
|
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
|
|
|
capabilities=self.capabilities,
|
|
|
|
capabilities=self.capabilities,
|
|
|
|
@ -126,7 +130,6 @@ class MCPServerReuqestHandler:
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def list_tools(self):
|
|
|
|
def list_tools(self):
|
|
|
|
logger.info(f"List tools: {self.request}")
|
|
|
|
|
|
|
|
if not self.end_user:
|
|
|
|
if not self.end_user:
|
|
|
|
raise ValueError("User not found")
|
|
|
|
raise ValueError("User not found")
|
|
|
|
return types.ListToolsResult(
|
|
|
|
return types.ListToolsResult(
|
|
|
|
@ -170,7 +173,7 @@ class MCPServerReuqestHandler:
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_end_user(self):
|
|
|
|
def retrieve_end_user(self):
|
|
|
|
return (
|
|
|
|
return (
|
|
|
|
self._session.query(EndUser)
|
|
|
|
db.session.query(EndUser)
|
|
|
|
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
|
|
|
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
|
|
|
.first()
|
|
|
|
.first()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|