diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index a974c63e35..8292161275 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -56,6 +56,7 @@ from .app import ( conversation, conversation_variables, generator, + mcp_server, message, model_config, ops_trace, diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py new file mode 100644 index 0000000000..a7a276edb4 --- /dev/null +++ b/api/controllers/console/app/mcp_server.py @@ -0,0 +1,83 @@ +import json +from enum import Enum + +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, setup_required +from extensions.ext_database import db +from fields.app_fields import app_server_fields +from libs.login import login_required +from models.model import AppMCPServer + + +class AppMCPServerStatus(str, Enum): + ACTIVE = "active" + INACTIVE = "inactive" + + +class AppMCPServerController(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model + @marshal_with(app_server_fields) + def get(self, app_model): + server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first() + return server + + @setup_required + @login_required + @account_initialization_required + @get_app_model + @marshal_with(app_server_fields) + def post(self, app_model): + # The role of the current user in the ta table must be editor, admin, or owner + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("description", type=str, required=True, location="json") + parser.add_argument("parameters", type=dict, required=True, location="json") + args = parser.parse_args() + server = AppMCPServer( + name=app_model.name, + description=args["description"], + parameters=json.dumps(args["parameters"], ensure_ascii=False), + status=AppMCPServerStatus.ACTIVE, + app_id=app_model.id, + tenant_id=current_user.current_tenant_id, + server_code=AppMCPServer.generate_server_code(16), + ) + db.session.add(server) + db.session.commit() + + return server + + @setup_required + @login_required + @account_initialization_required + @get_app_model + @marshal_with(app_server_fields) + def put(self, app_model): + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("id", type=str, required=True, location="json") + parser.add_argument("description", type=str, required=True, location="json") + parser.add_argument("parameters", type=dict, required=True, location="json") + parser.add_argument("status", type=str, required=True, location="json") + args = parser.parse_args() + server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first() + if not server: + raise Forbidden() + server.description = args["description"] + server.parameters = json.dumps(args["parameters"], ensure_ascii=False) + server.status = AppMCPServerStatus(args["status"]) + db.session.commit() + return server + + +api.add_resource(AppMCPServerController, "/apps//server") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index fd3b9aa804..c8645d5ebe 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,7 @@ import logging -from flask_restful import reqparse +from flask_restful import Resource, reqparse +from pydantic import ValidationError from werkzeug.exceptions import InternalServerError, NotFound import services @@ -24,10 +25,13 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from core.mcp.server.handler import MCPServerReuqestHandler +from core.mcp.types import ClientRequest from core.model_runtime.errors.invoke import InvokeError +from extensions.ext_database import db from libs import helper from libs.helper import uuid_value -from models.model import AppMode +from models.model import App, AppMCPServer, AppMode from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -149,7 +153,38 @@ class ChatStopApi(WebApiResource): return {"result": "success"}, 200 +class ChatMCPApi(Resource): + def post(self, server_code): + def int_or_str(value): + if isinstance(value, int): + return value + elif isinstance(value, str): + return int(value) + else: + raise ValueError("Invalid id") + + parser = reqparse.RequestParser() + parser.add_argument("jsonrpc", type=str, required=True, location="json") + parser.add_argument("method", type=str, required=True, location="json") + parser.add_argument("params", type=dict, required=True, location="json") + parser.add_argument("id", type=int_or_str, required=True, location="json") + args = parser.parse_args() + server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first() + if not server: + raise NotFound("Server Not Found") + app = db.session.query(App).filter(App.id == server.app_id).first() + if not app: + raise NotFound("App Not Found") + try: + request = ClientRequest.model_validate(args) + except ValidationError as e: + raise ValueError(f"Invalid MCP request: {str(e)}") + mcp_server_handler = MCPServerReuqestHandler(app, request) + return helper.compact_generate_response(mcp_server_handler.handle()) + + api.add_resource(CompletionApi, "/completion-messages") api.add_resource(CompletionStopApi, "/completion-messages//stop") api.add_resource(ChatApi, "/chat-messages") +api.add_resource(ChatMCPApi, "/server//mcp") api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 56e6b46a60..d8a0f45492 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -21,6 +21,7 @@ class InvokeFrom(Enum): WEB_APP = "web-app" EXPLORE = "explore" DEBUGGER = "debugger" + MCP_SERVER = "mcp-server" @classmethod def value_of(cls, value: str): @@ -49,6 +50,8 @@ class InvokeFrom(Enum): return "explore_app" elif self == InvokeFrom.SERVICE_API: return "api" + elif self == InvokeFrom.MCP_SERVER: + return "mcp_server" return "dev" diff --git a/api/core/mcp/server/handler.py b/api/core/mcp/server/handler.py new file mode 100644 index 0000000000..579e25862b --- /dev/null +++ b/api/core/mcp/server/handler.py @@ -0,0 +1,154 @@ +import json +from collections.abc import Mapping +from typing import cast + +from configs.app_config import DifyConfig +from controllers.web.passport import generate_session_id +from core.app.entities.app_invoke_entities import InvokeFrom +from core.mcp import types +from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND +from core.model_runtime.utils.encoders import jsonable_encoder +from extensions.ext_database import db +from models.model import App, EndUser +from services.app_generate_service import AppGenerateService + +""" +Apply to MCP HTTP streamable server with stateless http +""" +dify_config = DifyConfig() + + +class MCPServerReuqestHandler: + def __init__(self, app: App, request: types.ClientRequest): + self.app = app + self.request = request + if not self.app.mcp_server: + raise ValueError("MCP server not found") + self.mcp_server = self.app.mcp_server + self.end_user = self.retrieve_end_user() + + @property + def request_type(self): + return type(self.request.root) + + @property + def parameter_schema(self): + return { + "type": "object", + "properties": { + "query": {"type": "string", "description": "User Input/Question content"}, + "inputs": { + "type": "object", + "description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501 + "default": {}, + # TODO: add input parameters + }, + }, + "required": ["query"], + } + + @property + def output_parameters(self): + return self.app.output_schema + + @property + def capabilities(self): + return types.ServerCapabilities( + tools=types.ToolsCapability(listChanged=False), + ) + + def response(self, response: types.Result): + json_response = types.JSONRPCResponse( + jsonrpc="2.0", + id=(self.request.root.model_extra or {}).get("id", 1), + result=response.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + json_data = json.dumps(jsonable_encoder(json_response)) + + sse_content = f"event: message\ndata: {json_data}\n\n".encode() + + yield sse_content + + def error_response(self, code: int, message: str, data=None): + error_data = types.ErrorData(code=code, message=message, data=data) + json_response = types.JSONRPCError( + jsonrpc="2.0", + id=(self.request.root.model_extra or {}).get("id", 1), + error=error_data, + ) + json_data = json.dumps(jsonable_encoder(json_response)) + + sse_content = f"event: message\ndata: {json_data}\n\n".encode() + + yield sse_content + + def handle(self): + handle_map = { + types.InitializeRequest: self.initialize, + types.ListToolsRequest: self.list_tools, + types.CallToolRequest: self.invoke_tool, + } + try: + if self.request_type in handle_map: + return self.response(handle_map[self.request_type]()) + else: + return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}") + except ValueError as e: + return self.error_response(INVALID_PARAMS, str(e)) + except Exception as e: + return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}") + + def initialize(self): + request = cast(types.InitializeRequest, self.request.root) + client_info = request.params.clientInfo + clinet_name = f"{client_info.name}@{client_info.version}" + if not self.end_user: + end_user = EndUser( + tenant_id=self.app.tenant_id, + app_id=self.app.id, + type="mcp", + name=clinet_name, + session_id=generate_session_id(), + external_user_id=self.mcp_server.id, + ) + db.session.add(end_user) + db.session.commit() + + return types.InitializeResult( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=self.capabilities, + serverInfo=types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION), + instructions=self.mcp_server.description, + ) + + def list_tools(self): + if not self.end_user: + raise ValueError("User not found") + return types.ListToolsResult( + tools=[ + types.Tool( + name=self.mcp_server.name, + description=self.mcp_server.description, + inputSchema=self.parameter_schema, + ) + ], + ) + + def invoke_tool(self): + if not self.end_user: + raise ValueError("User not found") + request = cast(types.CallToolRequest, self.request.root) + args = request.params.arguments + if not args: + raise ValueError("No arguments provided") + response = AppGenerateService.generate(self.app, self.end_user, args, InvokeFrom.MCP_SERVER, streaming=False) + if isinstance(response, Mapping): + return types.CallToolResult(content=[types.TextContent(text=response["answer"], type="text")]) + return None + + def retrieve_end_user(self): + return ( + db.session.query(EndUser) + .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") + .first() + ) diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index 50d33d2dac..6805f4a039 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -3,11 +3,13 @@ from typing import Any, Protocol from pydantic import AnyUrl, TypeAdapter +from configs.app_config import DifyConfig from core.mcp import types from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext from core.mcp.session.base_session import BaseSession, RequestResponder -DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") +dify_config = DifyConfig() +DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION) class SamplingFnT(Protocol): diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 0b0e2a2f54..0f8408d443 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -213,3 +213,14 @@ app_import_fields = { app_import_check_dependencies_fields = { "leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)), } + +app_server_fields = { + "id": fields.String, + "name": fields.String, + "server_code": fields.String, + "description": fields.String, + "status": fields.String, + "parameters": fields.Raw, + "created_at": TimestampField, + "updated_at": TimestampField, +} diff --git a/api/migrations/versions/2025_05_22_1623-ca4c4abcc347_add_app_mcp_server.py b/api/migrations/versions/2025_05_22_1623-ca4c4abcc347_add_app_mcp_server.py new file mode 100644 index 0000000000..cb3508afed --- /dev/null +++ b/api/migrations/versions/2025_05_22_1623-ca4c4abcc347_add_app_mcp_server.py @@ -0,0 +1,41 @@ +"""add app mcp server + +Revision ID: ca4c4abcc347 +Revises: 1e67f2654a08 +Create Date: 2025-05-22 16:23:44.206102 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'ca4c4abcc347' +down_revision = '1e67f2654a08' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('app_mcp_servers', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.String(length=255), nullable=False), + sa.Column('server_code', sa.String(length=255), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False), + sa.Column('parameters', sa.Text(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('app_mcp_servers') + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index fd05d67e9a..c8df44bfec 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -294,6 +294,10 @@ class App(Base): return tags or [] + @property + def mcp_server(self): + return db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.id).first() + class AppModelConfig(Base): __tablename__ = "app_model_configs" @@ -1433,6 +1437,31 @@ class EndUser(Base, UserMixin): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) +class AppMCPServer(Base): + __tablename__ = "app_mcp_servers" + __table_args__ = (db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + app_id = db.Column(StringUUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.String(255), nullable=False) + server_code = db.Column(db.String(255), nullable=False) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + parameters = db.Column(db.Text, nullable=False) + + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + @staticmethod + def generate_server_code(n): + while True: + result = generate_string(n) + while db.session.query(AppMCPServer).filter(AppMCPServer.server_code == result).count() > 0: + result = generate_string(n) + + return result + + class Site(Base): __tablename__ = "sites" __table_args__ = ( diff --git a/api/models/tools.py b/api/models/tools.py index 3a994b4e21..5c7507759d 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -217,7 +217,7 @@ class MCPToolProvider(Base): # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # encrypted credentials - encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=False) + encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) # authed authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False) # tools diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index dedf1c5334..f8b804d93e 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -14,6 +14,7 @@ from models.model import ( ApiToken, AppAnnotationHitHistory, AppAnnotationSetting, + AppMCPServer, AppModelConfig, Conversation, EndUser, @@ -42,6 +43,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): # Delete related data _delete_app_model_configs(tenant_id, app_id) _delete_app_site(tenant_id, app_id) + _delete_app_mcp_servers(tenant_id, app_id) _delete_app_api_tokens(tenant_id, app_id) _delete_installed_apps(tenant_id, app_id) _delete_recommended_apps(tenant_id, app_id) @@ -90,6 +92,18 @@ def _delete_app_site(tenant_id: str, app_id: str): _delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site") +def _delete_app_mcp_servers(tenant_id: str, app_id: str): + def del_mcp_server(mcp_server_id: str): + db.session.query(AppMCPServer).filter(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) + + _delete_records( + """select id from app_mcp_servers where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_mcp_server, + "app mcp server", + ) + + def _delete_app_api_tokens(tenant_id: str, app_id: str): def del_api_token(api_token_id: str): db.session.query(ApiToken).filter(ApiToken.id == api_token_id).delete(synchronize_session=False)