feat: add mcp server
parent
c1a58ac160
commit
bdb4395319
@ -0,0 +1,83 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from libs.login import login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
|
||||
class AppMCPServerStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class AppMCPServerController(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
server = AppMCPServer(
|
||||
name=app_model.name,
|
||||
description=args["description"],
|
||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||
status=AppMCPServerStatus.ACTIVE,
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
server_code=AppMCPServer.generate_server_code(16),
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def put(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
parser.add_argument("status", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
|
||||
if not server:
|
||||
raise Forbidden()
|
||||
server.description = args["description"]
|
||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
||||
server.status = AppMCPServerStatus(args["status"])
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
|
||||
@ -0,0 +1,154 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import cast
|
||||
|
||||
from configs.app_config import DifyConfig
|
||||
from controllers.web.passport import generate_session_id
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.mcp import types
|
||||
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
"""
|
||||
Apply to MCP HTTP streamable server with stateless http
|
||||
"""
|
||||
dify_config = DifyConfig()
|
||||
|
||||
|
||||
class MCPServerReuqestHandler:
|
||||
def __init__(self, app: App, request: types.ClientRequest):
|
||||
self.app = app
|
||||
self.request = request
|
||||
if not self.app.mcp_server:
|
||||
raise ValueError("MCP server not found")
|
||||
self.mcp_server = self.app.mcp_server
|
||||
self.end_user = self.retrieve_end_user()
|
||||
|
||||
@property
|
||||
def request_type(self):
|
||||
return type(self.request.root)
|
||||
|
||||
@property
|
||||
def parameter_schema(self):
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "User Input/Question content"},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501
|
||||
"default": {},
|
||||
# TODO: add input parameters
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def output_parameters(self):
|
||||
return self.app.output_schema
|
||||
|
||||
@property
|
||||
def capabilities(self):
|
||||
return types.ServerCapabilities(
|
||||
tools=types.ToolsCapability(listChanged=False),
|
||||
)
|
||||
|
||||
def response(self, response: types.Result):
|
||||
json_response = types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=(self.request.root.model_extra or {}).get("id", 1),
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
|
||||
yield sse_content
|
||||
|
||||
def error_response(self, code: int, message: str, data=None):
|
||||
error_data = types.ErrorData(code=code, message=message, data=data)
|
||||
json_response = types.JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=(self.request.root.model_extra or {}).get("id", 1),
|
||||
error=error_data,
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
|
||||
yield sse_content
|
||||
|
||||
def handle(self):
|
||||
handle_map = {
|
||||
types.InitializeRequest: self.initialize,
|
||||
types.ListToolsRequest: self.list_tools,
|
||||
types.CallToolRequest: self.invoke_tool,
|
||||
}
|
||||
try:
|
||||
if self.request_type in handle_map:
|
||||
return self.response(handle_map[self.request_type]())
|
||||
else:
|
||||
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
|
||||
except ValueError as e:
|
||||
return self.error_response(INVALID_PARAMS, str(e))
|
||||
except Exception as e:
|
||||
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
|
||||
|
||||
def initialize(self):
|
||||
request = cast(types.InitializeRequest, self.request.root)
|
||||
client_info = request.params.clientInfo
|
||||
clinet_name = f"{client_info.name}@{client_info.version}"
|
||||
if not self.end_user:
|
||||
end_user = EndUser(
|
||||
tenant_id=self.app.tenant_id,
|
||||
app_id=self.app.id,
|
||||
type="mcp",
|
||||
name=clinet_name,
|
||||
session_id=generate_session_id(),
|
||||
external_user_id=self.mcp_server.id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return types.InitializeResult(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self.capabilities,
|
||||
serverInfo=types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION),
|
||||
instructions=self.mcp_server.description,
|
||||
)
|
||||
|
||||
def list_tools(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
return types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name=self.mcp_server.name,
|
||||
description=self.mcp_server.description,
|
||||
inputSchema=self.parameter_schema,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def invoke_tool(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
request = cast(types.CallToolRequest, self.request.root)
|
||||
args = request.params.arguments
|
||||
if not args:
|
||||
raise ValueError("No arguments provided")
|
||||
response = AppGenerateService.generate(self.app, self.end_user, args, InvokeFrom.MCP_SERVER, streaming=False)
|
||||
if isinstance(response, Mapping):
|
||||
return types.CallToolResult(content=[types.TextContent(text=response["answer"], type="text")])
|
||||
return None
|
||||
|
||||
def retrieve_end_user(self):
|
||||
return (
|
||||
db.session.query(EndUser)
|
||||
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
@ -0,0 +1,41 @@
|
||||
"""add app mcp server
|
||||
|
||||
Revision ID: ca4c4abcc347
|
||||
Revises: 1e67f2654a08
|
||||
Create Date: 2025-05-22 16:23:44.206102
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'ca4c4abcc347'
|
||||
down_revision = '1e67f2654a08'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('app_mcp_servers',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.String(length=255), nullable=False),
|
||||
sa.Column('server_code', sa.String(length=255), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False),
|
||||
sa.Column('parameters', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('app_mcp_servers')
|
||||
# ### end Alembic commands ###
|
||||
Loading…
Reference in New Issue