merge main
commit
90f800408d
@ -0,0 +1,102 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from libs.login import login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
|
||||
class AppMCPServerStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class AppMCPServerController(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
server = AppMCPServer(
|
||||
name=app_model.name,
|
||||
description=args["description"],
|
||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||
status=AppMCPServerStatus.ACTIVE,
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
server_code=AppMCPServer.generate_server_code(16),
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def put(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
parser.add_argument("status", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
|
||||
if not server:
|
||||
raise NotFound()
|
||||
server.description = args["description"]
|
||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
||||
if args["status"]:
|
||||
if args["status"] not in [status.value for status in AppMCPServerStatus]:
|
||||
raise ValueError("Invalid status")
|
||||
server.status = args["status"]
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
class AppMCPServerRefreshController(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, server_id):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first()
|
||||
if not server:
|
||||
raise NotFound()
|
||||
server.server_code = AppMCPServer.generate_server_code(16)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
|
||||
api.add_resource(AppMCPServerRefreshController, "/apps/<uuid:server_id>/server/refresh")
|
||||
@ -0,0 +1,8 @@
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("mcp", __name__, url_prefix="/mcp")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from . import mcp
|
||||
@ -0,0 +1,104 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import api
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types
|
||||
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
|
||||
from core.mcp.types import ClientNotification, ClientRequest
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode
|
||||
|
||||
|
||||
class MCPAppApi(Resource):
|
||||
def post(self, server_code):
|
||||
def int_or_str(value):
|
||||
if isinstance(value, (int, str)):
|
||||
return value
|
||||
else:
|
||||
return None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("jsonrpc", type=str, required=True, location="json")
|
||||
parser.add_argument("method", type=str, required=True, location="json")
|
||||
parser.add_argument("params", type=dict, required=False, location="json")
|
||||
parser.add_argument("id", type=int_or_str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
request_id = args.get("id")
|
||||
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
|
||||
if not server:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
|
||||
)
|
||||
|
||||
if server.status != AppMCPServerStatus.ACTIVE:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
|
||||
)
|
||||
|
||||
app = db.session.query(App).filter(App.id == server.app_id).first()
|
||||
if not app:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
|
||||
)
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
converted_user_input_form: list[VariableEntity] = []
|
||||
try:
|
||||
for item in user_input_form:
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
converted_user_input_form.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
)
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
)
|
||||
|
||||
try:
|
||||
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
|
||||
except ValidationError as e:
|
||||
try:
|
||||
notification = ClientNotification.model_validate(args)
|
||||
request = notification
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
)
|
||||
|
||||
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
||||
response = mcp_server_handler.handle()
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
|
||||
api.add_resource(MCPAppApi, "/server/<string:server_code>/mcp")
|
||||
@ -0,0 +1,342 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import urllib.parse
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.types import (
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
def generate_pkce_challenge() -> tuple[str, str]:
|
||||
"""Generate PKCE challenge and verifier."""
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
||||
code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
|
||||
|
||||
code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
|
||||
code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
|
||||
|
||||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
|
||||
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
|
||||
# Generate a secure random state key
|
||||
state_key = secrets.token_urlsafe(32)
|
||||
|
||||
# Store the state data in Redis with expiration
|
||||
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
||||
redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
|
||||
|
||||
return state_key
|
||||
|
||||
|
||||
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
||||
"""Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
|
||||
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
||||
|
||||
# Get state data from Redis
|
||||
state_data = redis_client.get(redis_key)
|
||||
|
||||
if not state_data:
|
||||
raise ValueError("State parameter has expired or does not exist")
|
||||
|
||||
# Delete the state data from Redis immediately after retrieval to prevent reuse
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
try:
|
||||
# Parse and validate the state data
|
||||
oauth_state = OAuthCallbackState.model_validate_json(state_data)
|
||||
|
||||
return oauth_state
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||
|
||||
|
||||
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
|
||||
"""Handle the callback from the OAuth provider."""
|
||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||
full_state_data = _retrieve_redis_state(state_key)
|
||||
|
||||
tokens = exchange_authorization(
|
||||
full_state_data.server_url,
|
||||
full_state_data.metadata,
|
||||
full_state_data.client_information,
|
||||
authorization_code,
|
||||
full_state_data.code_verifier,
|
||||
full_state_data.redirect_uri,
|
||||
)
|
||||
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
|
||||
provider.save_tokens(tokens)
|
||||
return full_state_data
|
||||
|
||||
|
||||
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
|
||||
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
|
||||
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
||||
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
except requests.RequestException as e:
|
||||
if isinstance(e, requests.ConnectionError):
|
||||
response = requests.get(url)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
raise
|
||||
|
||||
|
||||
def start_authorization(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_information: OAuthClientInformation,
|
||||
redirect_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Begins the authorization flow with secure Redis state storage."""
|
||||
response_type = "code"
|
||||
code_challenge_method = "S256"
|
||||
|
||||
if metadata:
|
||||
authorization_url = metadata.authorization_endpoint
|
||||
if response_type not in metadata.response_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
|
||||
if (
|
||||
not metadata.code_challenge_methods_supported
|
||||
or code_challenge_method not in metadata.code_challenge_methods_supported
|
||||
):
|
||||
raise ValueError(
|
||||
f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
|
||||
)
|
||||
else:
|
||||
authorization_url = urljoin(server_url, "/authorize")
|
||||
|
||||
code_verifier, code_challenge = generate_pkce_challenge()
|
||||
|
||||
# Prepare state data with all necessary information
|
||||
state_data = OAuthCallbackState(
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
server_url=server_url,
|
||||
metadata=metadata,
|
||||
client_information=client_information,
|
||||
code_verifier=code_verifier,
|
||||
redirect_uri=redirect_url,
|
||||
)
|
||||
|
||||
# Store state data in Redis and generate secure state key
|
||||
state_key = _create_secure_redis_state(state_data)
|
||||
|
||||
params = {
|
||||
"response_type": response_type,
|
||||
"client_id": client_information.client_id,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"redirect_uri": redirect_url,
|
||||
"state": state_key,
|
||||
}
|
||||
|
||||
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
return authorization_url, code_verifier
|
||||
|
||||
|
||||
def exchange_authorization(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_information: OAuthClientInformation,
|
||||
authorization_code: str,
|
||||
code_verifier: str,
|
||||
redirect_uri: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchanges an authorization code for an access token."""
|
||||
grant_type = "authorization_code"
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
params = {
|
||||
"grant_type": grant_type,
|
||||
"client_id": client_information.client_id,
|
||||
"code": authorization_code,
|
||||
"code_verifier": code_verifier,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = requests.post(token_url, data=params)
|
||||
if not response.ok:
|
||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
def refresh_authorization(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_information: OAuthClientInformation,
|
||||
refresh_token: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange a refresh token for an updated access token."""
|
||||
grant_type = "refresh_token"
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
params = {
|
||||
"grant_type": grant_type,
|
||||
"client_id": client_information.client_id,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = requests.post(token_url, data=params)
|
||||
if not response.ok:
|
||||
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.parse_obj(response.json())
|
||||
|
||||
|
||||
def register_client(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_metadata: OAuthClientMetadata,
|
||||
) -> OAuthClientInformationFull:
|
||||
"""Performs OAuth 2.0 Dynamic Client Registration."""
|
||||
if metadata:
|
||||
if not metadata.registration_endpoint:
|
||||
raise ValueError("Incompatible auth server: does not support dynamic client registration")
|
||||
registration_url = metadata.registration_endpoint
|
||||
else:
|
||||
registration_url = urljoin(server_url, "/register")
|
||||
|
||||
response = requests.post(
|
||||
registration_url,
|
||||
json=client_metadata.model_dump(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
if not response.ok:
|
||||
response.raise_for_status()
|
||||
return OAuthClientInformationFull.model_validate(response.json())
|
||||
|
||||
|
||||
def auth(
|
||||
provider: OAuthClientProvider,
|
||||
server_url: str,
|
||||
authorization_code: Optional[str] = None,
|
||||
state_param: Optional[str] = None,
|
||||
for_list: bool = False,
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||
metadata = discover_oauth_metadata(server_url)
|
||||
|
||||
# Handle client registration if needed
|
||||
client_information = provider.client_information()
|
||||
if not client_information:
|
||||
if authorization_code is not None:
|
||||
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
||||
try:
|
||||
full_information = register_client(server_url, metadata, provider.client_metadata)
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"Could not register OAuth client: {e}")
|
||||
provider.save_client_information(full_information)
|
||||
client_information = full_information
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
if authorization_code is not None:
|
||||
if not state_param:
|
||||
raise ValueError("State parameter is required when exchanging authorization code")
|
||||
|
||||
try:
|
||||
# Retrieve state data from Redis using state key
|
||||
full_state_data = _retrieve_redis_state(state_param)
|
||||
|
||||
code_verifier = full_state_data.code_verifier
|
||||
redirect_uri = full_state_data.redirect_uri
|
||||
|
||||
if not code_verifier or not redirect_uri:
|
||||
raise ValueError("Missing code_verifier or redirect_uri in state data")
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid state parameter: {e}")
|
||||
|
||||
tokens = exchange_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
client_information,
|
||||
authorization_code,
|
||||
code_verifier,
|
||||
redirect_uri,
|
||||
)
|
||||
provider.save_tokens(tokens)
|
||||
return {"result": "success"}
|
||||
|
||||
provider_tokens = provider.tokens()
|
||||
|
||||
# Handle token refresh or new authorization
|
||||
if provider_tokens and provider_tokens.refresh_token:
|
||||
try:
|
||||
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
|
||||
provider.save_tokens(new_tokens)
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
||||
|
||||
# Start new authorization flow
|
||||
authorization_url, code_verifier = start_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
client_information,
|
||||
provider.redirect_url,
|
||||
provider.mcp_provider.id,
|
||||
provider.mcp_provider.tenant_id,
|
||||
)
|
||||
|
||||
provider.save_code_verifier(code_verifier)
|
||||
return {"authorization_url": authorization_url}
|
||||
@ -0,0 +1,81 @@
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp.types import (
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
|
||||
|
||||
class OAuthClientProvider:
|
||||
mcp_provider: MCPToolProvider
|
||||
|
||||
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
|
||||
if for_list:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
else:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
|
||||
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
"""The URL to redirect the user agent to after authorization."""
|
||||
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
||||
|
||||
@property
|
||||
def client_metadata(self) -> OAuthClientMetadata:
|
||||
"""Metadata about this OAuth client."""
|
||||
return OAuthClientMetadata(
|
||||
redirect_uris=[self.redirect_url],
|
||||
token_endpoint_auth_method="none",
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
client_name="Dify",
|
||||
client_uri="https://github.com/langgenius/dify",
|
||||
)
|
||||
|
||||
def client_information(self) -> Optional[OAuthClientInformation]:
|
||||
"""Loads information about this OAuth client."""
|
||||
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
|
||||
if not client_information:
|
||||
return None
|
||||
return OAuthClientInformation.model_validate(client_information)
|
||||
|
||||
def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
|
||||
"""Saves client information after dynamic registration."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
self.mcp_provider,
|
||||
{"client_information": client_information.model_dump()},
|
||||
)
|
||||
|
||||
def tokens(self) -> Optional[OAuthTokens]:
|
||||
"""Loads any existing OAuth tokens for the current session."""
|
||||
credentials = self.mcp_provider.decrypted_credentials
|
||||
if not credentials:
|
||||
return None
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
token_type=credentials.get("token_type", "Bearer"),
|
||||
expires_in=int(credentials.get("expires_in", "3600") or 3600),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
)
|
||||
|
||||
def save_tokens(self, tokens: OAuthTokens) -> None:
|
||||
"""Stores new OAuth tokens for the current session."""
|
||||
# update mcp provider credentials
|
||||
token_dict = tokens.model_dump()
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
||||
|
||||
def save_code_verifier(self, code_verifier: str) -> None:
|
||||
"""Saves a PKCE code verifier for the current session."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
||||
|
||||
def code_verifier(self) -> str:
|
||||
"""Loads the PKCE code verifier for the current session."""
|
||||
# get code verifier from mcp provider credentials
|
||||
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))
|
||||
@ -0,0 +1,361 @@
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, TypeAlias, final
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from sseclient import SSEClient
|
||||
|
||||
from core.mcp import types
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.types import SessionMessage
|
||||
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
|
||||
@final
|
||||
class _StatusReady:
|
||||
def __init__(self, endpoint_url: str):
|
||||
self._endpoint_url = endpoint_url
|
||||
|
||||
|
||||
@final
|
||||
class _StatusError:
|
||||
def __init__(self, exc: Exception):
|
||||
self._exc = exc
|
||||
|
||||
|
||||
# Type aliases for better readability
|
||||
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
"""Remove request parameters from URL, keeping only the path."""
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
class SSETransport:
|
||||
"""SSE client transport implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
) -> None:
|
||||
"""Initialize the SSE transport.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.endpoint_url: str | None = None
|
||||
|
||||
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
|
||||
"""Validate that the endpoint URL matches the connection origin.
|
||||
|
||||
Args:
|
||||
endpoint_url: The endpoint URL to validate.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise.
|
||||
"""
|
||||
url_parsed = urlparse(self.url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
|
||||
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
|
||||
|
||||
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
|
||||
"""Handle an 'endpoint' SSE event.
|
||||
|
||||
Args:
|
||||
sse_data: The SSE event data.
|
||||
status_queue: Queue to put status updates.
|
||||
"""
|
||||
endpoint_url = urljoin(self.url, sse_data)
|
||||
logger.info(f"Received endpoint URL: {endpoint_url}")
|
||||
|
||||
if not self._validate_endpoint_url(endpoint_url):
|
||||
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
logger.error(error_msg)
|
||||
status_queue.put(_StatusError(ValueError(error_msg)))
|
||||
return
|
||||
|
||||
status_queue.put(_StatusReady(endpoint_url))
|
||||
|
||||
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
|
||||
"""Handle a 'message' SSE event.
|
||||
|
||||
Args:
|
||||
sse_data: The SSE event data.
|
||||
read_queue: Queue to put parsed messages.
|
||||
"""
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(sse_data)
|
||||
logger.debug(f"Received server message: {message}")
|
||||
session_message = SessionMessage(message)
|
||||
read_queue.put(session_message)
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing server message")
|
||||
read_queue.put(exc)
|
||||
|
||||
def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Handle a single SSE event.
|
||||
|
||||
Args:
|
||||
sse: The SSE event object.
|
||||
read_queue: Queue for message events.
|
||||
status_queue: Queue for status events.
|
||||
"""
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
self._handle_endpoint_event(sse.data, status_queue)
|
||||
case "message":
|
||||
self._handle_message_event(sse.data, read_queue)
|
||||
case _:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
|
||||
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Read and process SSE events.
|
||||
|
||||
Args:
|
||||
event_source: The SSE event source.
|
||||
read_queue: Queue to put received messages.
|
||||
status_queue: Queue to put status updates.
|
||||
"""
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
self._handle_sse_event(sse, read_queue, status_queue)
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"SSE reader shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
read_queue.put(exc)
|
||||
finally:
|
||||
read_queue.put(None)
|
||||
|
||||
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
|
||||
"""Send a single message to the server.
|
||||
|
||||
Args:
|
||||
client: HTTP client to use.
|
||||
endpoint_url: The endpoint URL to send to.
|
||||
message: The message to send.
|
||||
"""
|
||||
response = client.post(
|
||||
endpoint_url,
|
||||
json=message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
|
||||
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
|
||||
"""Handle writing messages to the server.
|
||||
|
||||
Args:
|
||||
client: HTTP client to use.
|
||||
endpoint_url: The endpoint URL to send messages to.
|
||||
write_queue: Queue to read messages from.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, Exception):
|
||||
write_queue.put(message)
|
||||
continue
|
||||
|
||||
self._send_message(client, endpoint_url, message)
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"Post writer shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error writing messages")
|
||||
write_queue.put(exc)
|
||||
finally:
|
||||
write_queue.put(None)
|
||||
|
||||
def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
|
||||
"""Wait for the endpoint URL from the status queue.
|
||||
|
||||
Args:
|
||||
status_queue: Queue to read status from.
|
||||
|
||||
Returns:
|
||||
The endpoint URL.
|
||||
|
||||
Raises:
|
||||
ValueError: If endpoint URL is not received or there's an error.
|
||||
"""
|
||||
try:
|
||||
status = status_queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
if isinstance(status, _StatusReady):
|
||||
return status._endpoint_url
|
||||
elif isinstance(status, _StatusError):
|
||||
raise status._exc
|
||||
else:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
def connect(
|
||||
self,
|
||||
executor: ThreadPoolExecutor,
|
||||
client: httpx.Client,
|
||||
event_source,
|
||||
) -> tuple[ReadQueue, WriteQueue]:
|
||||
"""Establish connection and start worker threads.
|
||||
|
||||
Args:
|
||||
executor: Thread pool executor.
|
||||
client: HTTP client.
|
||||
event_source: SSE event source.
|
||||
|
||||
Returns:
|
||||
Tuple of (read_queue, write_queue).
|
||||
"""
|
||||
read_queue: ReadQueue = queue.Queue()
|
||||
write_queue: WriteQueue = queue.Queue()
|
||||
status_queue: StatusQueue = queue.Queue()
|
||||
|
||||
# Start SSE reader thread
|
||||
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
|
||||
|
||||
# Wait for endpoint URL
|
||||
endpoint_url = self._wait_for_endpoint(status_queue)
|
||||
self.endpoint_url = endpoint_url
|
||||
|
||||
# Start post writer thread
|
||||
executor.submit(self.post_writer, client, endpoint_url, write_queue)
|
||||
|
||||
return read_queue, write_queue
|
||||
|
||||
|
||||
@contextmanager
|
||||
def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
|
||||
"""
|
||||
Client transport for SSE.
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
|
||||
Yields:
|
||||
Tuple of (read_queue, write_queue) for message communication.
|
||||
"""
|
||||
transport = SSETransport(url, headers, timeout, sse_read_timeout)
|
||||
|
||||
read_queue: ReadQueue | None = None
|
||||
write_queue: WriteQueue | None = None
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
|
||||
yield read_queue, write_queue
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
|
||||
|
||||
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
|
||||
"""
|
||||
Send a message to the server using the provided HTTP client.
|
||||
|
||||
Args:
|
||||
http_client: The HTTP client to use for sending
|
||||
endpoint_url: The endpoint URL to send the message to
|
||||
session_message: The message to send
|
||||
"""
|
||||
try:
|
||||
response = http_client.post(
|
||||
endpoint_url,
|
||||
json=session_message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error sending message")
|
||||
raise
|
||||
|
||||
|
||||
def read_messages(
|
||||
sse_client: SSEClient,
|
||||
) -> Generator[SessionMessage | Exception, None, None]:
|
||||
"""
|
||||
Read messages from the SSE client.
|
||||
|
||||
Args:
|
||||
sse_client: The SSE client to read from
|
||||
|
||||
Yields:
|
||||
SessionMessage or Exception for each event received
|
||||
"""
|
||||
try:
|
||||
for sse in sse_client.events():
|
||||
if sse.event == "message":
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"Received server message: {message}")
|
||||
yield SessionMessage(message)
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing server message")
|
||||
yield exc
|
||||
else:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error reading SSE messages")
|
||||
yield exc
|
||||
@ -0,0 +1,476 @@
|
||||
"""
|
||||
StreamableHTTP Client Transport Module
|
||||
|
||||
This module implements the StreamableHTTP transport for MCP clients,
|
||||
providing support for HTTP POST requests with optional SSE streaming responses
|
||||
and session management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from httpx_sse import EventSource, ServerSentEvent
|
||||
|
||||
from core.mcp.types import (
|
||||
ClientMessageMetadata,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestId,
|
||||
SessionMessage,
|
||||
)
|
||||
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SessionMessageOrError = SessionMessage | Exception | None
|
||||
# Queue types with clearer names for their roles
|
||||
ServerToClientQueue = queue.Queue[SessionMessageOrError] # Server to client messages
|
||||
ClientToServerQueue = queue.Queue[SessionMessage | None] # Client to server messages
|
||||
GetSessionIdCallback = Callable[[], str | None]
|
||||
|
||||
MCP_SESSION_ID = "mcp-session-id"
|
||||
LAST_EVENT_ID = "last-event-id"
|
||||
CONTENT_TYPE = "content-type"
|
||||
ACCEPT = "Accept"
|
||||
|
||||
|
||||
JSON = "application/json"
|
||||
SSE = "text/event-stream"
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
|
||||
class StreamableHTTPError(Exception):
|
||||
"""Base exception for StreamableHTTP transport errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResumptionError(StreamableHTTPError):
|
||||
"""Raised when resumption request is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
"""Context for a request operation."""
|
||||
|
||||
client: httpx.Client
|
||||
headers: dict[str, str]
|
||||
session_id: str | None
|
||||
session_message: SessionMessage
|
||||
metadata: ClientMessageMetadata | None
|
||||
server_to_client_queue: ServerToClientQueue # Renamed for clarity
|
||||
sse_read_timeout: timedelta
|
||||
|
||||
|
||||
class StreamableHTTPTransport:
|
||||
"""StreamableHTTP client transport implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
) -> None:
|
||||
"""Initialize the StreamableHTTP transport.
|
||||
|
||||
Args:
|
||||
url: The endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.session_id: str | None = None
|
||||
self.request_headers = {
|
||||
ACCEPT: f"{JSON}, {SSE}",
|
||||
CONTENT_TYPE: JSON,
|
||||
**self.headers,
|
||||
}
|
||||
|
||||
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Update headers with session ID if available."""
|
||||
headers = base_headers.copy()
|
||||
if self.session_id:
|
||||
headers[MCP_SESSION_ID] = self.session_id
|
||||
return headers
|
||||
|
||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialization request."""
|
||||
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||
|
||||
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialized notification."""
|
||||
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
|
||||
|
||||
def _maybe_extract_session_id_from_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
) -> None:
|
||||
"""Extract and store session ID from response headers."""
|
||||
new_session_id = response.headers.get(MCP_SESSION_ID)
|
||||
if new_session_id:
|
||||
self.session_id = new_session_id
|
||||
logger.info(f"Received session ID: {self.session_id}")
|
||||
|
||||
def _handle_sse_event(
|
||||
self,
|
||||
sse: ServerSentEvent,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
original_request_id: RequestId | None = None,
|
||||
resumption_callback: Callable[[str], None] | None = None,
|
||||
) -> bool:
|
||||
"""Handle an SSE event, returning True if the response is complete."""
|
||||
if sse.event == "message":
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"SSE message: {message}")
|
||||
|
||||
# If this is a response and we have original_request_id, replace it
|
||||
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
||||
message.root.id = original_request_id
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
# Put message in queue that goes to client
|
||||
server_to_client_queue.put(session_message)
|
||||
|
||||
# Call resumption token callback if we have an ID
|
||||
if sse.id and resumption_callback:
|
||||
resumption_callback(sse.id)
|
||||
|
||||
# If this is a response or error return True indicating completion
|
||||
# Otherwise, return False to continue listening
|
||||
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
|
||||
|
||||
except Exception as exc:
|
||||
# Put exception in queue that goes to client
|
||||
server_to_client_queue.put(exc)
|
||||
return False
|
||||
elif sse.event == "ping":
|
||||
logger.debug("Received ping event")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
return False
|
||||
|
||||
def handle_get_stream(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
) -> None:
|
||||
"""Handle GET stream for server-initiated messages."""
|
||||
try:
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
headers = self._update_headers_with_session(self.request_headers)
|
||||
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
|
||||
client=client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
self._handle_sse_event(sse, server_to_client_queue)
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(f"GET stream error (non-fatal): {exc}")
|
||||
|
||||
def _handle_resumption_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a resumption request using GET with SSE."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
if ctx.metadata and ctx.metadata.resumption_token:
|
||||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
|
||||
else:
|
||||
raise ResumptionError("Resumption request requires a resumption token")
|
||||
|
||||
# Extract original request ID to map responses
|
||||
original_request_id = None
|
||||
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
|
||||
original_request_id = ctx.session_message.message.root.id
|
||||
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
|
||||
client=ctx.client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("Resumption GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
original_request_id,
|
||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
|
||||
def _handle_post_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a POST request with response processing."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
message = ctx.session_message.message
|
||||
is_initialization = self._is_initialization_request(message)
|
||||
|
||||
with ctx.client.stream(
|
||||
"POST",
|
||||
self.url,
|
||||
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status_code == 202:
|
||||
logger.debug("Received 202 Accepted")
|
||||
return
|
||||
|
||||
if response.status_code == 404:
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
self._send_session_terminated_error(
|
||||
ctx.server_to_client_queue,
|
||||
message.root.id,
|
||||
)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
if is_initialization:
|
||||
self._maybe_extract_session_id_from_response(response)
|
||||
|
||||
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
|
||||
|
||||
if content_type.startswith(JSON):
|
||||
self._handle_json_response(response, ctx.server_to_client_queue)
|
||||
elif content_type.startswith(SSE):
|
||||
self._handle_sse_response(response, ctx)
|
||||
else:
|
||||
self._handle_unexpected_content_type(
|
||||
content_type,
|
||||
ctx.server_to_client_queue,
|
||||
)
|
||||
|
||||
def _handle_json_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
) -> None:
|
||||
"""Handle JSON response from the server."""
|
||||
try:
|
||||
content = response.read()
|
||||
message = JSONRPCMessage.model_validate_json(content)
|
||||
session_message = SessionMessage(message)
|
||||
server_to_client_queue.put(session_message)
|
||||
except Exception as exc:
|
||||
server_to_client_queue.put(exc)
|
||||
|
||||
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
|
||||
"""Handle SSE response from the server."""
|
||||
try:
|
||||
event_source = EventSource(response)
|
||||
for sse in event_source.iter_sse():
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
except Exception as e:
|
||||
ctx.server_to_client_queue.put(e)
|
||||
|
||||
def _handle_unexpected_content_type(
|
||||
self,
|
||||
content_type: str,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
) -> None:
|
||||
"""Handle unexpected content type in response."""
|
||||
error_msg = f"Unexpected content type: {content_type}"
|
||||
logger.error(error_msg)
|
||||
server_to_client_queue.put(ValueError(error_msg))
|
||||
|
||||
def _send_session_terminated_error(
|
||||
self,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
request_id: RequestId,
|
||||
) -> None:
|
||||
"""Send a session terminated error response."""
|
||||
jsonrpc_error = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
error=ErrorData(code=32600, message="Session terminated by server"),
|
||||
)
|
||||
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
|
||||
server_to_client_queue.put(session_message)
|
||||
|
||||
def post_writer(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
client_to_server_queue: ClientToServerQueue,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
start_get_stream: Callable[[], None],
|
||||
) -> None:
|
||||
"""Handle writing requests to the server.
|
||||
|
||||
This method processes messages from the client_to_server_queue and sends them to the server.
|
||||
Responses are written to the server_to_client_queue.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Read message from client queue with timeout to check stop_event periodically
|
||||
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if session_message is None:
|
||||
break
|
||||
|
||||
message = session_message.message
|
||||
metadata = (
|
||||
session_message.metadata if isinstance(session_message.metadata, ClientMessageMetadata) else None
|
||||
)
|
||||
|
||||
# Check if this is a resumption request
|
||||
is_resumption = bool(metadata and metadata.resumption_token)
|
||||
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
|
||||
# Handle initialized notification
|
||||
if self._is_initialized_notification(message):
|
||||
start_get_stream()
|
||||
|
||||
ctx = RequestContext(
|
||||
client=client,
|
||||
headers=self.request_headers,
|
||||
session_id=self.session_id,
|
||||
session_message=session_message,
|
||||
metadata=metadata,
|
||||
server_to_client_queue=server_to_client_queue, # Queue to write responses to client
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
if is_resumption:
|
||||
self._handle_resumption_request(ctx)
|
||||
else:
|
||||
self._handle_post_request(ctx)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as exc:
|
||||
server_to_client_queue.put(exc)
|
||||
|
||||
def terminate_session(self, client: httpx.Client) -> None:
|
||||
"""Terminate the session by sending a DELETE request."""
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
try:
|
||||
headers = self._update_headers_with_session(self.request_headers)
|
||||
response = client.delete(self.url, headers=headers)
|
||||
|
||||
if response.status_code == 405:
|
||||
logger.debug("Server does not allow session termination")
|
||||
elif response.status_code != 200:
|
||||
logger.warning(f"Session termination failed: {response.status_code}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Session termination failed: {exc}")
|
||||
|
||||
def get_session_id(self) -> str | None:
|
||||
"""Get the current session ID."""
|
||||
return self.session_id
|
||||
|
||||
|
||||
@contextmanager
|
||||
def streamablehttp_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
terminate_on_close: bool = True,
|
||||
) -> Generator[
|
||||
tuple[
|
||||
ServerToClientQueue, # Queue for receiving messages FROM server
|
||||
ClientToServerQueue, # Queue for sending messages TO server
|
||||
GetSessionIdCallback,
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Client transport for StreamableHTTP.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Yields:
|
||||
Tuple containing:
|
||||
- server_to_client_queue: Queue for reading messages FROM the server
|
||||
- client_to_server_queue: Queue for sending messages TO the server
|
||||
- get_session_id_callback: Function to retrieve the current session ID
|
||||
"""
|
||||
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
|
||||
|
||||
# Create queues with clear directional meaning
|
||||
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
|
||||
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream() -> None:
|
||||
"""Start a worker thread to handle server-initiated messages."""
|
||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||
|
||||
# Start the post_writer worker thread
|
||||
executor.submit(
|
||||
transport.post_writer,
|
||||
client,
|
||||
client_to_server_queue, # Queue for messages FROM client TO server
|
||||
server_to_client_queue, # Queue for messages FROM server TO client
|
||||
start_get_stream,
|
||||
)
|
||||
|
||||
try:
|
||||
yield (
|
||||
server_to_client_queue, # Queue for receiving messages FROM server
|
||||
client_to_server_queue, # Queue for sending messages TO server
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
client_to_server_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
client_to_server_queue.put(None)
|
||||
server_to_client_queue.put(None)
|
||||
@ -0,0 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
@ -0,0 +1,10 @@
|
||||
class MCPError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MCPConnectionError(MCPError):
|
||||
pass
|
||||
|
||||
|
||||
class MCPAuthError(MCPConnectionError):
|
||||
pass
|
||||
@ -0,0 +1,150 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
from core.mcp.client.streamable_client import streamablehttp_client
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.session.client_session import ClientSession
|
||||
from core.mcp.types import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
authed: bool = True,
|
||||
authorization_code: Optional[str] = None,
|
||||
for_list: bool = False,
|
||||
):
|
||||
# Initialize info
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
self.client_type = "streamable"
|
||||
self.server_url = server_url
|
||||
|
||||
# Authentication info
|
||||
self.authed = authed
|
||||
self.authorization_code = authorization_code
|
||||
if authed:
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
|
||||
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
|
||||
self.token = self.provider.tokens()
|
||||
|
||||
# Initialize session and client objects
|
||||
self._session: Optional[ClientSession] = None
|
||||
self._streams_context: Optional[AbstractContextManager[Any]] = None
|
||||
self._session_context: Optional[ClientSession] = None
|
||||
self.exit_stack = ExitStack()
|
||||
|
||||
# Whether the client has been initialized
|
||||
self._initialized = False
|
||||
|
||||
def __enter__(self):
|
||||
self._initialize()
|
||||
self._initialized = True
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
|
||||
):
|
||||
self.cleanup()
|
||||
|
||||
def _initialize(
|
||||
self,
|
||||
):
|
||||
"""Initialize the client with fallback to SSE if streamable connection fails"""
|
||||
connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
|
||||
"mcp": streamablehttp_client,
|
||||
"sse": sse_client,
|
||||
}
|
||||
|
||||
parsed_url = urlparse(self.server_url)
|
||||
path = parsed_url.path
|
||||
method_name = path.rstrip("/").split("/")[-1] if path else ""
|
||||
try:
|
||||
client_factory = connection_methods[method_name]
|
||||
self.connect_server(client_factory, method_name)
|
||||
except KeyError:
|
||||
try:
|
||||
self.connect_server(sse_client, "sse")
|
||||
except MCPConnectionError:
|
||||
self.connect_server(streamablehttp_client, "mcp")
|
||||
|
||||
def connect_server(
|
||||
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
|
||||
):
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
|
||||
try:
|
||||
headers = (
|
||||
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
||||
if self.authed and self.token
|
||||
else {}
|
||||
)
|
||||
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
||||
if self._streams_context is None:
|
||||
raise MCPConnectionError("Failed to create connection context")
|
||||
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self.exit_stack.enter_context(self._streams_context)
|
||||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self.exit_stack.enter_context(self._session_context)
|
||||
session = cast(ClientSession, self._session)
|
||||
session.initialize()
|
||||
return
|
||||
|
||||
except MCPAuthError:
|
||||
if not self.authed:
|
||||
raise
|
||||
try:
|
||||
auth(self.provider, self.server_url, self.authorization_code)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to authenticate: {e}")
|
||||
self.token = self.provider.tokens()
|
||||
if first_try:
|
||||
return self.connect_server(client_factory, method_name, first_try=False)
|
||||
|
||||
except MCPConnectionError:
|
||||
raise
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""Connect to an MCP server running with SSE transport"""
|
||||
# List available tools to verify connection
|
||||
if not self._initialized or not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
response = self._session.list_tools()
|
||||
tools = response.tools
|
||||
return tools
|
||||
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict):
|
||||
"""Call a tool"""
|
||||
if not self._initialized or not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
return self._session.call_tool(tool_name, tool_args)
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
try:
|
||||
# ExitStack will handle proper cleanup of all managed context managers
|
||||
self.exit_stack.close()
|
||||
self._session = None
|
||||
self._session_context = None
|
||||
self._streams_context = None
|
||||
self._initialized = False
|
||||
except Exception as e:
|
||||
logging.exception("Error during cleanup")
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
@ -0,0 +1,228 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web.passport import generate_session_id
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.mcp import types
|
||||
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
"""
|
||||
Apply to MCP HTTP streamable server with stateless http
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPServerStreamableHTTPRequestHandler:
|
||||
def __init__(
|
||||
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
|
||||
):
|
||||
self.app = app
|
||||
self.request = request
|
||||
mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first()
|
||||
if not mcp_server:
|
||||
raise ValueError("MCP server not found")
|
||||
self.mcp_server: AppMCPServer = mcp_server
|
||||
self.end_user = self.retrieve_end_user()
|
||||
self.user_input_form = user_input_form
|
||||
|
||||
@property
|
||||
def request_type(self):
|
||||
return type(self.request.root)
|
||||
|
||||
@property
|
||||
def parameter_schema(self):
|
||||
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
|
||||
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "User Input/Question content"},
|
||||
**parameters,
|
||||
},
|
||||
"required": ["query", *required],
|
||||
}
|
||||
|
||||
@property
|
||||
def capabilities(self):
|
||||
return types.ServerCapabilities(
|
||||
tools=types.ToolsCapability(listChanged=False),
|
||||
)
|
||||
|
||||
def response(self, response: types.Result | str):
|
||||
if isinstance(response, str):
|
||||
sse_content = f"event: ping\ndata: {response}\n\n".encode()
|
||||
yield sse_content
|
||||
return
|
||||
json_response = types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=(self.request.root.model_extra or {}).get("id", 1),
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
|
||||
yield sse_content
|
||||
|
||||
def error_response(self, code: int, message: str, data=None):
|
||||
request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
|
||||
return create_mcp_error_response(request_id, code, message, data)
|
||||
|
||||
def handle(self):
|
||||
handle_map = {
|
||||
types.InitializeRequest: self.initialize,
|
||||
types.ListToolsRequest: self.list_tools,
|
||||
types.CallToolRequest: self.invoke_tool,
|
||||
types.InitializedNotification: self.handle_notification,
|
||||
types.PingRequest: self.handle_ping,
|
||||
}
|
||||
try:
|
||||
if self.request_type in handle_map:
|
||||
return self.response(handle_map[self.request_type]())
|
||||
else:
|
||||
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
|
||||
except ValueError as e:
|
||||
logger.exception("Invalid params")
|
||||
return self.error_response(INVALID_PARAMS, str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Internal server error")
|
||||
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
|
||||
|
||||
def handle_notification(self):
|
||||
return "ping"
|
||||
|
||||
def handle_ping(self):
|
||||
return types.EmptyResult()
|
||||
|
||||
def initialize(self):
|
||||
request = cast(types.InitializeRequest, self.request.root)
|
||||
client_info = request.params.clientInfo
|
||||
clinet_name = f"{client_info.name}@{client_info.version}"
|
||||
if not self.end_user:
|
||||
end_user = EndUser(
|
||||
tenant_id=self.app.tenant_id,
|
||||
app_id=self.app.id,
|
||||
type="mcp",
|
||||
name=clinet_name,
|
||||
session_id=generate_session_id(),
|
||||
external_user_id=self.mcp_server.id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
return types.InitializeResult(
|
||||
protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self.capabilities,
|
||||
serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
|
||||
instructions=self.mcp_server.description,
|
||||
)
|
||||
|
||||
def list_tools(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
return types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name=self.app.name,
|
||||
description=self.mcp_server.description,
|
||||
inputSchema=self.parameter_schema,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def invoke_tool(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
request = cast(types.CallToolRequest, self.request.root)
|
||||
args = request.params.arguments
|
||||
if not args:
|
||||
raise ValueError("No arguments provided")
|
||||
if self.app.mode in {AppMode.WORKFLOW.value}:
|
||||
args = {"inputs": args}
|
||||
elif self.app.mode in {AppMode.COMPLETION.value}:
|
||||
args = {"query": "", "inputs": args}
|
||||
else:
|
||||
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
|
||||
response = AppGenerateService.generate(
|
||||
self.app,
|
||||
self.end_user,
|
||||
args,
|
||||
InvokeFrom.SERVICE_API,
|
||||
streaming=self.app.mode == AppMode.AGENT_CHAT.value,
|
||||
)
|
||||
answer = ""
|
||||
if isinstance(response, RateLimitGenerator):
|
||||
for item in response.generator:
|
||||
data = item
|
||||
if isinstance(data, str) and data.startswith("data: "):
|
||||
try:
|
||||
json_str = data[6:].strip()
|
||||
parsed_data = json.loads(json_str)
|
||||
if parsed_data.get("event") == "agent_thought":
|
||||
answer += parsed_data.get("thought", "")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(response, Mapping):
|
||||
if self.app.mode in {
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
AppMode.COMPLETION.value,
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
}:
|
||||
answer = response["answer"]
|
||||
elif self.app.mode in {AppMode.WORKFLOW.value}:
|
||||
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
# Not support image yet
|
||||
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
|
||||
|
||||
def retrieve_end_user(self):
|
||||
return (
|
||||
db.session.query(EndUser)
|
||||
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
|
||||
parameters: dict[str, dict[str, Any]] = {}
|
||||
required = []
|
||||
for item in user_input_form:
|
||||
parameters[item.variable] = {}
|
||||
if item.type in (
|
||||
VariableEntityType.FILE,
|
||||
VariableEntityType.FILE_LIST,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
):
|
||||
continue
|
||||
if item.required:
|
||||
required.append(item.variable)
|
||||
# if the workflow republished, the parameters not changed
|
||||
# we should not raise error here
|
||||
try:
|
||||
description = self.mcp_server.parameters_dict[item.variable]
|
||||
except KeyError:
|
||||
description = ""
|
||||
parameters[item.variable]["description"] = description
|
||||
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
parameters[item.variable]["type"] = "string"
|
||||
elif item.type == VariableEntityType.SELECT:
|
||||
parameters[item.variable]["type"] = "string"
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "float"
|
||||
return parameters, required
|
||||
@ -0,0 +1,397 @@
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import ExitStack
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, Self, TypeVar
|
||||
|
||||
from httpx import HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.types import (
|
||||
CancelledNotification,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
MessageMetadata,
|
||||
RequestId,
|
||||
RequestParams,
|
||||
ServerMessageMetadata,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
SessionMessage,
|
||||
)
|
||||
|
||||
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
||||
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
||||
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
|
||||
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
"""Handles responding to MCP requests and manages request lifecycle.
|
||||
|
||||
This class MUST be used as a context manager to ensure proper cleanup and
|
||||
cancellation handling:
|
||||
|
||||
Example:
|
||||
with request_responder as resp:
|
||||
resp.respond(result)
|
||||
|
||||
The context manager ensures:
|
||||
1. Proper cancellation scope setup and cleanup
|
||||
2. Request completion tracking
|
||||
3. Cleanup of in-flight requests
|
||||
"""
|
||||
|
||||
request: ReceiveRequestT
|
||||
_session: Any
|
||||
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: RequestId,
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: """BaseSession[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT
|
||||
]""",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
self.request = request
|
||||
self._session = session
|
||||
self._completed = False
|
||||
self._on_complete = on_complete
|
||||
self._entered = False # Track if we're in a context manager
|
||||
|
||||
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
||||
"""Enter the context manager, enabling request cancellation tracking."""
|
||||
self._entered = True
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit the context manager, performing cleanup and notifying completion."""
|
||||
try:
|
||||
if self._completed:
|
||||
self._on_complete(self)
|
||||
finally:
|
||||
self._entered = False
|
||||
|
||||
def respond(self, response: SendResultT | ErrorData) -> None:
|
||||
"""Send a response for this request.
|
||||
|
||||
Must be called within a context manager block.
|
||||
Raises:
|
||||
RuntimeError: If not used within a context manager
|
||||
AssertionError: If request was already responded to
|
||||
"""
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
assert not self._completed, "Request already responded to"
|
||||
|
||||
self._completed = True
|
||||
|
||||
self._session._send_response(request_id=self.request_id, response=response)
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel this request and mark it as completed."""
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
|
||||
self._completed = True # Mark as completed so it's removed from in_flight
|
||||
# Send an error response to indicate cancellation
|
||||
self._session._send_response(
|
||||
request_id=self.request_id,
|
||||
response=ErrorData(code=0, message="Request cancelled", data=None),
|
||||
)
|
||||
|
||||
|
||||
class BaseSession(
|
||||
Generic[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT,
|
||||
],
|
||||
):
|
||||
"""
|
||||
Implements an MCP "session" on top of read/write streams, including features
|
||||
like request/response linking, notifications, and progress.
|
||||
|
||||
This class is a context manager that automatically starts processing
|
||||
messages when entered.
|
||||
"""
|
||||
|
||||
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
_receive_request_type: type[ReceiveRequestT]
|
||||
_receive_notification_type: type[ReceiveNotificationT]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: queue.Queue,
|
||||
write_stream: queue.Queue,
|
||||
receive_request_type: type[ReceiveRequestT],
|
||||
receive_notification_type: type[ReceiveNotificationT],
|
||||
# If none, reading will never time out
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> None:
|
||||
self._read_stream = read_stream
|
||||
self._write_stream = write_stream
|
||||
self._response_streams = {}
|
||||
self._request_id = 0
|
||||
self._receive_request_type = receive_request_type
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._session_read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
self._exit_stack = ExitStack()
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self._executor = ThreadPoolExecutor()
|
||||
self._receiver_future = self._executor.submit(self._receive_loop)
|
||||
return self
|
||||
|
||||
def check_receiver_status(self) -> None:
|
||||
if self._receiver_future.done():
|
||||
self._receiver_future.result()
|
||||
|
||||
def __exit__(
|
||||
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
self._exit_stack.close()
|
||||
self._read_stream.put(None)
|
||||
self._write_stream.put(None)
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
request: SendRequestT,
|
||||
result_type: type[ReceiveResultT],
|
||||
request_read_timeout_seconds: timedelta | None = None,
|
||||
metadata: MessageMetadata = None,
|
||||
) -> ReceiveResultT:
|
||||
"""
|
||||
Sends a request and wait for a response. Raises an McpError if the
|
||||
response contains an error. If a request read timeout is provided, it
|
||||
will take precedence over the session read timeout.
|
||||
|
||||
Do not use this method to emit notifications! Use send_notification()
|
||||
instead.
|
||||
"""
|
||||
self.check_receiver_status()
|
||||
|
||||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
|
||||
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
|
||||
self._response_streams[request_id] = response_queue
|
||||
|
||||
try:
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
|
||||
timeout = DEFAULT_RESPONSE_READ_TIMEOUT
|
||||
if request_read_timeout_seconds is not None:
|
||||
timeout = float(request_read_timeout_seconds.total_seconds())
|
||||
elif self._session_read_timeout_seconds is not None:
|
||||
timeout = float(self._session_read_timeout_seconds.total_seconds())
|
||||
while True:
|
||||
try:
|
||||
response_or_error = response_queue.get(timeout=timeout)
|
||||
break
|
||||
except queue.Empty:
|
||||
self.check_receiver_status()
|
||||
continue
|
||||
|
||||
if response_or_error is None:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(
|
||||
code=500,
|
||||
message="No response received",
|
||||
)
|
||||
)
|
||||
elif isinstance(response_or_error, JSONRPCError):
|
||||
if response_or_error.error.code == 401:
|
||||
raise MCPAuthError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
)
|
||||
else:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
|
||||
finally:
|
||||
self._response_streams.pop(request_id, None)
|
||||
|
||||
def send_notification(
|
||||
self,
|
||||
notification: SendNotificationT,
|
||||
related_request_id: RequestId | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Emits a notification, which is a one-way message that does not expect
|
||||
a response.
|
||||
"""
|
||||
self.check_receiver_status()
|
||||
|
||||
# Some transport implementations may need to set the related_request_id
|
||||
# to attribute to the notifications to the request that triggered them.
|
||||
jsonrpc_notification = JSONRPCNotification(
|
||||
jsonrpc="2.0",
|
||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
session_message = SessionMessage(
|
||||
message=JSONRPCMessage(jsonrpc_notification),
|
||||
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
|
||||
)
|
||||
self._write_stream.put(session_message)
|
||||
|
||||
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
|
||||
if isinstance(response, ErrorData):
|
||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
||||
self._write_stream.put(session_message)
|
||||
else:
|
||||
jsonrpc_response = JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
||||
self._write_stream.put(session_message)
|
||||
|
||||
def _receive_loop(self) -> None:
|
||||
"""
|
||||
Main message processing loop.
|
||||
In a real synchronous implementation, this would likely run in a separate thread.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Attempt to receive a message (this would be blocking in a synchronous context)
|
||||
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, HTTPStatusError):
|
||||
response_queue = self._response_streams.get(self._request_id - 1)
|
||||
if response_queue is not None:
|
||||
response_queue.put(
|
||||
JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=self._request_id - 1,
|
||||
error=ErrorData(code=message.response.status_code, message=message.args[0]),
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
|
||||
elif isinstance(message, Exception):
|
||||
self._handle_incoming(message)
|
||||
elif isinstance(message.message.root, JSONRPCRequest):
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
||||
request=validated_request,
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
)
|
||||
|
||||
self._in_flight[responder.request_id] = responder
|
||||
self._received_request(responder)
|
||||
|
||||
if not responder._completed:
|
||||
self._handle_incoming(responder)
|
||||
|
||||
elif isinstance(message.message.root, JSONRPCNotification):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
cancelled_id = notification.root.params.requestId
|
||||
if cancelled_id in self._in_flight:
|
||||
self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
self._received_notification(notification)
|
||||
self._handle_incoming(notification)
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logging.warning(f"Failed to validate notification: {e}. Message was: {message.message.root}")
|
||||
else: # Response or error
|
||||
response_queue = self._response_streams.get(message.message.root.id)
|
||||
if response_queue is not None:
|
||||
response_queue.put(message.message.root)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logging.exception("Error in message processing loop")
|
||||
raise
|
||||
|
||||
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a request without needing to
|
||||
listen on the message stream.
|
||||
|
||||
If the request is responded to within this method, it will not be
|
||||
forwarded on to the message stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a notification without needing
|
||||
to listen on the message stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Sends a progress notification for a request that is currently being
|
||||
processed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
||||
) -> None:
|
||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||
pass
|
||||
@ -0,0 +1,365 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp import types
|
||||
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
|
||||
from core.mcp.session.base_session import BaseSession, RequestResponder
|
||||
|
||||
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.project.version)
|
||||
|
||||
|
||||
class SamplingFnT(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class ListRootsFnT(Protocol):
|
||||
def __call__(self, context: RequestContext["ClientSession", Any]) -> types.ListRootsResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class LoggingFnT(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class MessageHandlerFnT(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
def _default_message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise ValueError(str(message))
|
||||
elif isinstance(message, (types.ServerNotification | RequestResponder)):
|
||||
pass
|
||||
|
||||
|
||||
def _default_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
def _default_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="List roots not supported",
|
||||
)
|
||||
|
||||
|
||||
def _default_logging_callback(
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
types.ClientRequest,
|
||||
types.ClientNotification,
|
||||
types.ClientResult,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream,
|
||||
write_stream,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
client_info: types.Implementation | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream,
|
||||
write_stream,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
self._client_info = client_info or DEFAULT_CLIENT_INFO
|
||||
self._sampling_callback = sampling_callback or _default_sampling_callback
|
||||
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
|
||||
self._logging_callback = logging_callback or _default_logging_callback
|
||||
self._message_handler = message_handler or _default_message_handler
|
||||
|
||||
def initialize(self) -> types.InitializeResult:
|
||||
sampling = types.SamplingCapability()
|
||||
roots = types.RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True,
|
||||
)
|
||||
|
||||
result = self.send_request(
|
||||
types.ClientRequest(
|
||||
types.InitializeRequest(
|
||||
method="initialize",
|
||||
params=types.InitializeRequestParams(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ClientCapabilities(
|
||||
sampling=sampling,
|
||||
experimental=None,
|
||||
roots=roots,
|
||||
),
|
||||
clientInfo=self._client_info,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
|
||||
|
||||
self.send_notification(
|
||||
types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def send_ping(self) -> types.EmptyResult:
|
||||
"""Send a ping request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=types.ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
|
||||
"""Send a logging/setLevel request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.SetLevelRequest(
|
||||
method="logging/setLevel",
|
||||
params=types.SetLevelRequestParams(level=level),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def list_resources(self) -> types.ListResourcesResult:
|
||||
"""Send a resources/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListResourcesRequest(
|
||||
method="resources/list",
|
||||
)
|
||||
),
|
||||
types.ListResourcesResult,
|
||||
)
|
||||
|
||||
def list_resource_templates(self) -> types.ListResourceTemplatesResult:
|
||||
"""Send a resources/templates/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListResourceTemplatesRequest(
|
||||
method="resources/templates/list",
|
||||
)
|
||||
),
|
||||
types.ListResourceTemplatesResult,
|
||||
)
|
||||
|
||||
def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ReadResourceRequest(
|
||||
method="resources/read",
|
||||
params=types.ReadResourceRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.ReadResourceResult,
|
||||
)
|
||||
|
||||
def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/subscribe request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.SubscribeRequest(
|
||||
method="resources/subscribe",
|
||||
params=types.SubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/unsubscribe request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.UnsubscribeRequest(
|
||||
method="resources/unsubscribe",
|
||||
params=types.UnsubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> types.CallToolResult:
|
||||
"""Send a tools/call request."""
|
||||
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
request_read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
|
||||
def list_prompts(self) -> types.ListPromptsResult:
|
||||
"""Send a prompts/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListPromptsRequest(
|
||||
method="prompts/list",
|
||||
)
|
||||
),
|
||||
types.ListPromptsResult,
|
||||
)
|
||||
|
||||
def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
|
||||
"""Send a prompts/get request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetPromptRequest(
|
||||
method="prompts/get",
|
||||
params=types.GetPromptRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.GetPromptResult,
|
||||
)
|
||||
|
||||
def complete(
|
||||
self,
|
||||
ref: types.ResourceReference | types.PromptReference,
|
||||
argument: dict[str, str],
|
||||
) -> types.CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CompleteRequest(
|
||||
method="completion/complete",
|
||||
params=types.CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=types.CompletionArgument(**argument),
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CompleteResult,
|
||||
)
|
||||
|
||||
def list_tools(self) -> types.ListToolsResult:
|
||||
"""Send a tools/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListToolsRequest(
|
||||
method="tools/list",
|
||||
)
|
||||
),
|
||||
types.ListToolsResult,
|
||||
)
|
||||
|
||||
def send_roots_list_changed(self) -> None:
|
||||
"""Send a roots/list_changed notification."""
|
||||
self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.RootsListChangedNotification(
|
||||
method="notifications/roots/list_changed",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
meta=responder.request_meta,
|
||||
session=self,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
match responder.request.root:
|
||||
case types.CreateMessageRequest(params=params):
|
||||
with responder:
|
||||
response = self._sampling_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
responder.respond(client_response)
|
||||
|
||||
case types.ListRootsRequest():
|
||||
with responder:
|
||||
list_roots_response = self._list_roots_callback(ctx)
|
||||
client_response = ClientResponse.validate_python(list_roots_response)
|
||||
responder.respond(client_response)
|
||||
|
||||
case types.PingRequest():
|
||||
with responder:
|
||||
return responder.respond(types.ClientResult(root=types.EmptyResult()))
|
||||
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
self._message_handler(req)
|
||||
|
||||
def _received_notification(self, notification: types.ServerNotification) -> None:
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
match notification.root:
|
||||
case types.LoggingMessageNotification(params=params):
|
||||
self._logging_callback(params)
|
||||
case _:
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,114 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp.types import ErrorData, JSONRPCError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
def create_ssrf_proxy_mcp_http_client(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
) -> httpx.Client:
|
||||
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
|
||||
|
||||
Args:
|
||||
headers: Optional headers to include in the client
|
||||
timeout: Optional timeout configuration
|
||||
|
||||
Returns:
|
||||
Configured httpx.Client with proxy settings
|
||||
"""
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
proxy=dify_config.SSRF_PROXY_ALL_URL,
|
||||
)
|
||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxy_mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
|
||||
"https://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
),
|
||||
}
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
mounts=proxy_mounts,
|
||||
)
|
||||
else:
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
|
||||
def ssrf_proxy_sse_connect(url, **kwargs):
|
||||
"""Connect to SSE endpoint with SSRF proxy protection.
|
||||
|
||||
This function creates an SSE connection using the configured proxy settings
|
||||
to prevent SSRF attacks when connecting to external endpoints.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL
|
||||
**kwargs: Additional arguments passed to the SSE connection
|
||||
|
||||
Returns:
|
||||
EventSource object for SSE streaming
|
||||
"""
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
# Extract client if provided, otherwise create one
|
||||
client = kwargs.pop("client", None)
|
||||
if client is None:
|
||||
# Create client with SSRF proxy configuration
|
||||
timeout = kwargs.pop(
|
||||
"timeout",
|
||||
httpx.Timeout(
|
||||
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
|
||||
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
|
||||
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
|
||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||
),
|
||||
)
|
||||
headers = kwargs.pop("headers", {})
|
||||
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
|
||||
client_provided = False
|
||||
else:
|
||||
client_provided = True
|
||||
|
||||
# Extract method if provided, default to GET
|
||||
method = kwargs.pop("method", "GET")
|
||||
|
||||
try:
|
||||
return connect_sse(client, method, url, **kwargs)
|
||||
except Exception:
|
||||
# If we created the client, we need to clean it up on error
|
||||
if not client_provided:
|
||||
client.close()
|
||||
raise
|
||||
|
||||
|
||||
def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None):
|
||||
"""Create MCP error response"""
|
||||
error_data = ErrorData(code=code, message=message, data=data)
|
||||
json_response = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id or 1,
|
||||
error=error_data,
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
yield sse_content
|
||||
@ -0,0 +1,130 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.types import Tool as RemoteMCPTool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolProviderEntityWithPlugin,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class MCPToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
entity: ToolProviderEntityWithPlugin
|
||||
|
||||
def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None:
|
||||
super().__init__(entity)
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_id = provider_id
|
||||
self.server_url = server_url
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.MCP
|
||||
|
||||
@classmethod
|
||||
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
|
||||
"""
|
||||
from db provider
|
||||
"""
|
||||
tools = []
|
||||
tools_data = json.loads(db_provider.tools)
|
||||
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
|
||||
user = db_provider.load_user()
|
||||
tools = [
|
||||
ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author=user.name if user else "Anonymous",
|
||||
name=remote_mcp_tool.name,
|
||||
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
||||
provider=db_provider.server_identifier,
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(
|
||||
en_US=remote_mcp_tool.description or "", zh_Hans=remote_mcp_tool.description or ""
|
||||
),
|
||||
llm=remote_mcp_tool.description or "",
|
||||
),
|
||||
output_schema=None,
|
||||
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
|
||||
)
|
||||
for remote_mcp_tool in remote_mcp_tools
|
||||
]
|
||||
|
||||
return cls(
|
||||
entity=ToolProviderEntityWithPlugin(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user.name if user else "Anonymous",
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
plugin_id=None,
|
||||
credentials_schema=[],
|
||||
tools=tools,
|
||||
),
|
||||
provider_id=db_provider.server_identifier or "",
|
||||
tenant_id=db_provider.tenant_id or "",
|
||||
server_url=db_provider.decrypted_server_url,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
tool_entity = next(
|
||||
(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
|
||||
)
|
||||
|
||||
if not tool_entity:
|
||||
raise ValueError(f"Tool with name {tool_name} not found")
|
||||
|
||||
return MCPTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[MCPTool]: # type: ignore
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
return [
|
||||
MCPTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
)
|
||||
for tool_entity in self.entity.tools
|
||||
]
|
||||
@ -0,0 +1,92 @@
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
runtime_parameters: Optional[list[ToolParameter]]
|
||||
server_url: str
|
||||
provider_id: str
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.runtime_parameters = None
|
||||
self.server_url = server_url
|
||||
self.provider_id = provider_id
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.MCP
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
try:
|
||||
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPAuthError as e:
|
||||
raise ToolInvokeError("Please auth the tool first") from e
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
try:
|
||||
content_json = json.loads(content.text)
|
||||
if isinstance(content_json, dict):
|
||||
yield self.create_json_message(content_json)
|
||||
elif isinstance(content_json, list):
|
||||
for item in content_json:
|
||||
yield self.create_json_message(item)
|
||||
else:
|
||||
yield self.create_text_message(content.text)
|
||||
except json.JSONDecodeError:
|
||||
yield self.create_text_message(content.text)
|
||||
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self.create_blob_message(
|
||||
blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
|
||||
)
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
server_url=self.server_url,
|
||||
provider_id=self.provider_id,
|
||||
)
|
||||
|
||||
def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
in mcp tool invoke, if the parameter is empty, it will be set to None
|
||||
"""
|
||||
return {
|
||||
key: value
|
||||
for key, value in parameter.items()
|
||||
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
||||
}
|
||||
@ -0,0 +1,64 @@
|
||||
"""add mcp server tool and app server
|
||||
|
||||
Revision ID: 58eb7bdb93fe
|
||||
Revises: 0ab65e1cc7fa
|
||||
Create Date: 2025-06-25 09:36:07.510570
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '58eb7bdb93fe'
|
||||
down_revision = '0ab65e1cc7fa'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('app_mcp_servers',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.String(length=255), nullable=False),
|
||||
sa.Column('server_code', sa.String(length=255), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False),
|
||||
sa.Column('parameters', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'app_id', name='unique_app_mcp_server_tenant_app_id'),
|
||||
sa.UniqueConstraint('server_code', name='unique_app_mcp_server_server_code')
|
||||
)
|
||||
op.create_table('tool_mcp_providers',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('name', sa.String(length=40), nullable=False),
|
||||
sa.Column('server_identifier', sa.String(length=24), nullable=False),
|
||||
sa.Column('server_url', sa.Text(), nullable=False),
|
||||
sa.Column('server_url_hash', sa.String(length=64), nullable=False),
|
||||
sa.Column('icon', sa.String(length=255), nullable=True),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('user_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
|
||||
sa.Column('authed', sa.Boolean(), nullable=False),
|
||||
sa.Column('tools', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_mcp_provider_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'name', name='unique_mcp_provider_name'),
|
||||
sa.UniqueConstraint('tenant_id', 'server_identifier', name='unique_mcp_provider_server_identifier'),
|
||||
sa.UniqueConstraint('tenant_id', 'server_url_hash', name='unique_mcp_provider_server_url')
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('tool_mcp_providers')
|
||||
op.drop_table('app_mcp_servers')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,232 @@
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from extensions.ext_database import db
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
|
||||
|
||||
|
||||
class MCPToolManageService:
|
||||
"""
|
||||
Service class for managing mcp tools.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
|
||||
res = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
if not res:
|
||||
raise ValueError("MCP tool not found")
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
|
||||
res = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
|
||||
.first()
|
||||
)
|
||||
if not res:
|
||||
raise ValueError("MCP tool not found")
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def create_mcp_provider(
|
||||
tenant_id: str,
|
||||
name: str,
|
||||
server_url: str,
|
||||
user_id: str,
|
||||
icon: str,
|
||||
icon_type: str,
|
||||
icon_background: str,
|
||||
server_identifier: str,
|
||||
) -> ToolProviderApiEntity:
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
existing_provider = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
or_(
|
||||
MCPToolProvider.name == name,
|
||||
MCPToolProvider.server_url_hash == server_url_hash,
|
||||
MCPToolProvider.server_identifier == server_identifier,
|
||||
),
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing_provider:
|
||||
if existing_provider.name == name:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
elif existing_provider.server_url_hash == server_url_hash:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
elif existing_provider.server_identifier == server_identifier:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
mcp_tool = MCPToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
server_url=encrypted_server_url,
|
||||
server_url_hash=server_url_hash,
|
||||
user_id=user_id,
|
||||
authed=False,
|
||||
tools="[]",
|
||||
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
|
||||
server_identifier=server_identifier,
|
||||
)
|
||||
db.session.add(mcp_tool)
|
||||
db.session.commit()
|
||||
return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||
|
||||
@staticmethod
|
||||
def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
|
||||
mcp_providers = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(MCPToolProvider.tenant_id == tenant_id)
|
||||
.order_by(MCPToolProvider.name)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list)
|
||||
for mcp_provider in mcp_providers
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
except MCPAuthError as e:
|
||||
raise ValueError("Please auth the tool first")
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
mcp_provider.authed = True
|
||||
mcp_provider.updated_at = datetime.now()
|
||||
db.session.commit()
|
||||
user = mcp_provider.load_user()
|
||||
return ToolProviderApiEntity(
|
||||
id=mcp_provider.id,
|
||||
name=mcp_provider.name,
|
||||
tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
|
||||
type=ToolProviderType.MCP,
|
||||
icon=mcp_provider.icon,
|
||||
author=user.name if user else "Anonymous",
|
||||
server_url=mcp_provider.masked_server_url,
|
||||
updated_at=int(mcp_provider.updated_at.timestamp()),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
|
||||
plugin_unique_identifier=mcp_provider.server_identifier,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
|
||||
mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
|
||||
db.session.delete(mcp_tool)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def update_mcp_provider(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: str,
|
||||
name: str,
|
||||
server_url: str,
|
||||
icon: str,
|
||||
icon_type: str,
|
||||
icon_background: str,
|
||||
server_identifier: str,
|
||||
):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
mcp_provider.updated_at = datetime.now()
|
||||
mcp_provider.name = name
|
||||
mcp_provider.icon = (
|
||||
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
|
||||
)
|
||||
mcp_provider.server_identifier = server_identifier
|
||||
|
||||
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
mcp_provider.server_url = encrypted_server_url
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
|
||||
if server_url_hash != mcp_provider.server_url_hash:
|
||||
cls._re_connect_mcp_provider(mcp_provider, provider_id, tenant_id)
|
||||
mcp_provider.server_url_hash = server_url_hash
|
||||
try:
|
||||
db.session.commit()
|
||||
except IntegrityError as e:
|
||||
db.session.rollback()
|
||||
error_msg = str(e.orig)
|
||||
if "unique_mcp_provider_name" in error_msg:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
elif "unique_mcp_provider_server_url" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
elif "unique_mcp_provider_server_identifier" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
else:
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def update_mcp_provider_credentials(
|
||||
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
|
||||
):
|
||||
provider_controller = MCPToolProviderController._from_db(mcp_provider)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=mcp_provider.tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.provider_id,
|
||||
)
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
mcp_provider.updated_at = datetime.now()
|
||||
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
|
||||
mcp_provider.authed = authed
|
||||
if not authed:
|
||||
mcp_provider.tools = "[]"
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def _re_connect_mcp_provider(cls, mcp_provider: MCPToolProvider, provider_id: str, tenant_id: str):
|
||||
"""re-connect mcp provider"""
|
||||
try:
|
||||
with MCPClient(
|
||||
mcp_provider.decrypted_server_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
authed=False,
|
||||
for_list=True,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
mcp_provider.authed = True
|
||||
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
except MCPAuthError:
|
||||
mcp_provider.authed = False
|
||||
mcp_provider.tools = "[]"
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
# reset credentials
|
||||
mcp_provider.encrypted_credentials = "{}"
|
||||
@ -0,0 +1,280 @@
|
||||
import base64
|
||||
import binascii
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.helper.encrypter import (
|
||||
batch_decrypt_token,
|
||||
decrypt_token,
|
||||
encrypt_token,
|
||||
get_decrypt_decoding,
|
||||
obfuscated_token,
|
||||
)
|
||||
from libs.rsa import PrivkeyNotFoundError
|
||||
|
||||
|
||||
class TestObfuscatedToken:
|
||||
@pytest.mark.parametrize(
|
||||
("token", "expected"),
|
||||
[
|
||||
("", ""), # Empty token
|
||||
("1234567", "*" * 20), # Short token (<8 chars)
|
||||
("12345678", "*" * 20), # Boundary case (8 chars)
|
||||
("123456789abcdef", "123456" + "*" * 12 + "ef"), # Long token
|
||||
("abc!@#$%^&*()def", "abc!@#" + "*" * 12 + "ef"), # Special chars
|
||||
],
|
||||
)
|
||||
def test_obfuscation_logic(self, token, expected):
|
||||
"""Test core obfuscation logic for various token lengths"""
|
||||
assert obfuscated_token(token) == expected
|
||||
|
||||
def test_sensitive_data_protection(self):
|
||||
"""Ensure obfuscation never reveals full sensitive data"""
|
||||
token = "api_key_secret_12345"
|
||||
obfuscated = obfuscated_token(token)
|
||||
assert token not in obfuscated
|
||||
assert "*" * 12 in obfuscated
|
||||
|
||||
|
||||
class TestEncryptToken:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_successful_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test successful token encryption"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
|
||||
result = encrypt_token("tenant-123", "test_token")
|
||||
|
||||
assert result == base64.b64encode(b"encrypted_data").decode()
|
||||
mock_encrypt.assert_called_with("test_token", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
def test_tenant_not_found(self, mock_query):
|
||||
"""Test error when tenant doesn't exist"""
|
||||
mock_query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypt_token("invalid-tenant", "test_token")
|
||||
|
||||
assert "Tenant with id invalid-tenant not found" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestDecryptToken:
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_successful_decryption(self, mock_decrypt):
|
||||
"""Test successful token decryption"""
|
||||
mock_decrypt.return_value = "decrypted_token"
|
||||
encrypted_data = base64.b64encode(b"encrypted_data").decode()
|
||||
|
||||
result = decrypt_token("tenant-123", encrypted_data)
|
||||
|
||||
assert result == "decrypted_token"
|
||||
mock_decrypt.assert_called_once_with(b"encrypted_data", "tenant-123")
|
||||
|
||||
def test_invalid_base64(self):
|
||||
"""Test handling of invalid base64 input"""
|
||||
with pytest.raises(binascii.Error):
|
||||
decrypt_token("tenant-123", "invalid_base64!!!")
|
||||
|
||||
|
||||
class TestBatchDecryptToken:
|
||||
@patch("libs.rsa.get_decrypt_decoding")
|
||||
@patch("libs.rsa.decrypt_token_with_decoding")
|
||||
def test_batch_decryption(self, mock_decrypt_with_decoding, mock_get_decoding):
|
||||
"""Test batch decryption functionality"""
|
||||
mock_rsa_key = MagicMock()
|
||||
mock_cipher_rsa = MagicMock()
|
||||
mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
|
||||
|
||||
# Test multiple tokens
|
||||
mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3"]
|
||||
tokens = [
|
||||
base64.b64encode(b"encrypted1").decode(),
|
||||
base64.b64encode(b"encrypted2").decode(),
|
||||
base64.b64encode(b"encrypted3").decode(),
|
||||
]
|
||||
result = batch_decrypt_token("tenant-123", tokens)
|
||||
|
||||
assert result == ["token1", "token2", "token3"]
|
||||
# Key should only be loaded once
|
||||
mock_get_decoding.assert_called_once_with("tenant-123")
|
||||
|
||||
|
||||
class TestGetDecryptDecoding:
|
||||
@patch("extensions.ext_redis.redis_client.get")
|
||||
@patch("extensions.ext_storage.storage.load")
|
||||
def test_private_key_not_found(self, mock_storage_load, mock_redis_get):
|
||||
"""Test error when private key file doesn't exist"""
|
||||
mock_redis_get.return_value = None
|
||||
mock_storage_load.side_effect = FileNotFoundError()
|
||||
|
||||
with pytest.raises(PrivkeyNotFoundError) as exc_info:
|
||||
get_decrypt_decoding("tenant-123")
|
||||
|
||||
assert "Private key not found, tenant_id: tenant-123" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestEncryptDecryptIntegration:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
|
||||
"""Test that encryption and decryption are consistent"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
# Setup mock encryption/decryption
|
||||
original_token = "test_token_123"
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
mock_decrypt.return_value = original_token
|
||||
|
||||
# Test encryption
|
||||
encrypted = encrypt_token("tenant-123", original_token)
|
||||
|
||||
# Test decryption
|
||||
decrypted = decrypt_token("tenant-123", encrypted)
|
||||
|
||||
assert decrypted == original_token
|
||||
|
||||
|
||||
class TestSecurity:
|
||||
"""Critical security tests for encryption system"""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
|
||||
"""Ensure tokens encrypted for one tenant cannot be used by another"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "tenant1_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_for_tenant1"
|
||||
|
||||
# Encrypt token for tenant1
|
||||
encrypted = encrypt_token("tenant-123", "sensitive_data")
|
||||
|
||||
# Attempt to decrypt with different tenant should fail
|
||||
with patch("libs.rsa.decrypt") as mock_decrypt:
|
||||
mock_decrypt.side_effect = Exception("Invalid tenant key")
|
||||
|
||||
with pytest.raises(Exception, match="Invalid tenant key"):
|
||||
decrypt_token("different-tenant", encrypted)
|
||||
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_tampered_ciphertext_rejection(self, mock_decrypt):
|
||||
"""Detect and reject tampered ciphertext"""
|
||||
valid_encrypted = base64.b64encode(b"valid_data").decode()
|
||||
|
||||
# Tamper with ciphertext
|
||||
tampered_bytes = bytearray(base64.b64decode(valid_encrypted))
|
||||
tampered_bytes[0] ^= 0xFF
|
||||
tampered = base64.b64encode(bytes(tampered_bytes)).decode()
|
||||
|
||||
mock_decrypt.side_effect = Exception("Decryption error")
|
||||
|
||||
with pytest.raises(Exception, match="Decryption error"):
|
||||
decrypt_token("tenant-123", tampered)
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_encryption_randomness(self, mock_encrypt, mock_query):
|
||||
"""Ensure same plaintext produces different ciphertext"""
|
||||
mock_tenant = MagicMock(encrypt_public_key="key")
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
# Different outputs for same input
|
||||
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
|
||||
|
||||
results = [encrypt_token("tenant-123", "token") for _ in range(3)]
|
||||
|
||||
# All results should be different
|
||||
assert len(set(results)) == 3
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Additional security-focused edge case tests"""
|
||||
|
||||
def test_should_handle_empty_string_in_obfuscation(self):
|
||||
"""Test handling of empty string in obfuscation"""
|
||||
# Test empty string (which is a valid str type)
|
||||
assert obfuscated_token("") == ""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test encryption of empty token"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_empty"
|
||||
|
||||
result = encrypt_token("tenant-123", "")
|
||||
|
||||
assert result == base64.b64encode(b"encrypted_empty").decode()
|
||||
mock_encrypt.assert_called_with("", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
|
||||
"""Test tokens containing special/unicode characters"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_special"
|
||||
|
||||
# Test various special characters
|
||||
special_tokens = [
|
||||
"token\x00with\x00null", # Null bytes
|
||||
"token_with_emoji_😀🎉", # Unicode emoji
|
||||
"token\nwith\nnewlines", # Newlines
|
||||
"token\twith\ttabs", # Tabs
|
||||
"token_with_中文字符", # Chinese characters
|
||||
]
|
||||
|
||||
for token in special_tokens:
|
||||
result = encrypt_token("tenant-123", token)
|
||||
assert result == base64.b64encode(b"encrypted_special").decode()
|
||||
mock_encrypt.assert_called_with(token, "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
|
||||
"""Test behavior when token exceeds RSA encryption limits"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
# RSA 2048-bit can only encrypt ~245 bytes
|
||||
# The actual limit depends on padding scheme
|
||||
mock_encrypt.side_effect = ValueError("Message too long for RSA key size")
|
||||
|
||||
# Create a token that would exceed RSA limits
|
||||
long_token = "x" * 300
|
||||
|
||||
with pytest.raises(ValueError, match="Message too long for RSA key size"):
|
||||
encrypt_token("tenant-123", long_token)
|
||||
|
||||
@patch("libs.rsa.get_decrypt_decoding")
|
||||
@patch("libs.rsa.decrypt_token_with_decoding")
|
||||
def test_batch_decrypt_loads_key_only_once(self, mock_decrypt_with_decoding, mock_get_decoding):
|
||||
"""Verify batch decryption optimization - loads key only once"""
|
||||
mock_rsa_key = MagicMock()
|
||||
mock_cipher_rsa = MagicMock()
|
||||
mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
|
||||
|
||||
# Test with multiple tokens
|
||||
mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3", "token4", "token5"]
|
||||
tokens = [base64.b64encode(f"encrypted{i}".encode()).decode() for i in range(5)]
|
||||
|
||||
result = batch_decrypt_token("tenant-123", tokens)
|
||||
|
||||
assert result == ["token1", "token2", "token3", "token4", "token5"]
|
||||
# Key should only be loaded once regardless of token count
|
||||
mock_get_decoding.assert_called_once_with("tenant-123")
|
||||
assert mock_decrypt_with_decoding.call_count == 5
|
||||
@ -0,0 +1,471 @@
|
||||
import queue
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from core.mcp import types
|
||||
from core.mcp.entities import RequestContext
|
||||
from core.mcp.session.base_session import RequestResponder
|
||||
from core.mcp.session.client_session import DEFAULT_CLIENT_INFO, ClientSession
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
Implementation,
|
||||
InitializedNotification,
|
||||
InitializeRequest,
|
||||
InitializeResult,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
ServerCapabilities,
|
||||
ServerResult,
|
||||
SessionMessage,
|
||||
)
|
||||
|
||||
|
||||
def test_client_session_initialize():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
initialized_notification = None
|
||||
|
||||
def mock_server():
|
||||
nonlocal initialized_notification
|
||||
|
||||
# Receive initialization request
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
|
||||
# Create response
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(
|
||||
logging=None,
|
||||
resources=None,
|
||||
tools=None,
|
||||
experimental=None,
|
||||
prompts=None,
|
||||
),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
instructions="The server instructions.",
|
||||
)
|
||||
)
|
||||
|
||||
# Send response
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Receive initialized notification
|
||||
session_notification = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_notification = session_notification.message
|
||||
assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
|
||||
initialized_notification = ClientNotification.model_validate(
|
||||
jsonrpc_notification.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
|
||||
# Create message handler
|
||||
def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
# Create and use client session
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
message_handler=message_handler,
|
||||
) as session:
|
||||
result = session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
# Assert results
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
|
||||
assert isinstance(result.capabilities, ServerCapabilities)
|
||||
assert result.serverInfo == Implementation(name="mock-server", version="0.1.0")
|
||||
assert result.instructions == "The server instructions."
|
||||
|
||||
# Check that client sent initialized notification
|
||||
assert initialized_notification
|
||||
assert isinstance(initialized_notification.root, InitializedNotification)
|
||||
|
||||
|
||||
def test_client_session_custom_client_info():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
custom_client_info = Implementation(name="test-client", version="1.2.3")
|
||||
received_client_info = None
|
||||
|
||||
def mock_server():
|
||||
nonlocal received_client_info
|
||||
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
received_client_info = request.root.params.clientInfo
|
||||
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Receive initialized notification
|
||||
client_to_server.get(timeout=5.0)
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
client_info=custom_client_info,
|
||||
) as session:
|
||||
session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
# Assert that custom client info was sent
|
||||
assert received_client_info == custom_client_info
|
||||
|
||||
|
||||
def test_client_session_default_client_info():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
received_client_info = None
|
||||
|
||||
def mock_server():
|
||||
nonlocal received_client_info
|
||||
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
received_client_info = request.root.params.clientInfo
|
||||
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Receive initialized notification
|
||||
client_to_server.get(timeout=5.0)
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
) as session:
|
||||
session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
# Assert that default client info was used
|
||||
assert received_client_info == DEFAULT_CLIENT_INFO
|
||||
|
||||
|
||||
def test_client_session_version_negotiation_success():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
def mock_server():
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
|
||||
# Send supported protocol version
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Receive initialized notification
|
||||
client_to_server.get(timeout=5.0)
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
) as session:
|
||||
result = session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
# Should successfully initialize
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
|
||||
|
||||
|
||||
def test_client_session_version_negotiation_failure():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
def mock_server():
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
|
||||
# Send unsupported protocol version
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion="99.99.99", # Unsupported version
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
) as session:
|
||||
import pytest
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
|
||||
session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
|
||||
def test_client_capabilities_default():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
received_capabilities = None
|
||||
|
||||
def mock_server():
|
||||
nonlocal received_capabilities
|
||||
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
received_capabilities = request.root.params.capabilities
|
||||
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Receive initialized notification
|
||||
client_to_server.get(timeout=5.0)
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
) as session:
|
||||
session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
# Assert default capabilities
|
||||
assert received_capabilities is not None
|
||||
assert received_capabilities.sampling is not None
|
||||
assert received_capabilities.roots is not None
|
||||
assert received_capabilities.roots.listChanged is True
|
||||
|
||||
|
||||
def test_client_capabilities_with_custom_callbacks():
|
||||
# Create synchronous queues to replace async streams
|
||||
client_to_server: queue.Queue[SessionMessage] = queue.Queue()
|
||||
server_to_client: queue.Queue[SessionMessage] = queue.Queue()
|
||||
|
||||
def custom_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData:
|
||||
return types.CreateMessageResult(
|
||||
model="test-model",
|
||||
role="assistant",
|
||||
content=types.TextContent(type="text", text="Custom response"),
|
||||
)
|
||||
|
||||
def custom_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ListRootsResult(roots=[])
|
||||
|
||||
def mock_server():
|
||||
session_message = client_to_server.get(timeout=5.0)
|
||||
jsonrpc_request = session_message.message
|
||||
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
|
||||
request = ClientRequest.model_validate(
|
||||
jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
assert isinstance(request.root, InitializeRequest)
|
||||
|
||||
result = ServerResult(
|
||||
InitializeResult(
|
||||
protocolVersion=LATEST_PROTOCOL_VERSION,
|
||||
capabilities=ServerCapabilities(),
|
||||
serverInfo=Implementation(name="mock-server", version="0.1.0"),
|
||||
)
|
||||
)
|
||||
|
||||
server_to_client.put(
|
||||
SessionMessage(
|
||||
message=JSONRPCMessage(
|
||||
JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=jsonrpc_request.root.id,
|
||||
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
# Receive initialized notification
|
||||
client_to_server.get(timeout=5.0)
|
||||
|
||||
# Start mock server thread
|
||||
server_thread = threading.Thread(target=mock_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
with ClientSession(
|
||||
server_to_client,
|
||||
client_to_server,
|
||||
sampling_callback=custom_sampling_callback,
|
||||
list_roots_callback=custom_list_roots_callback,
|
||||
) as session:
|
||||
result = session.initialize()
|
||||
|
||||
# Wait for server thread to complete
|
||||
server_thread.join(timeout=10.0)
|
||||
|
||||
# Verify initialization succeeded
|
||||
assert isinstance(result, InitializeResult)
|
||||
assert result.protocolVersion == LATEST_PROTOCOL_VERSION
|
||||
@ -0,0 +1,349 @@
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.mcp import types
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
|
||||
SERVER_NAME = "test_server_for_SSE"
|
||||
|
||||
|
||||
def test_sse_message_id_coercion():
|
||||
"""Test that string message IDs that look like integers are parsed as integers.
|
||||
|
||||
See <https://github.com/modelcontextprotocol/python-sdk/pull/851> for more details.
|
||||
"""
|
||||
json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}'
|
||||
msg = types.JSONRPCMessage.model_validate_json(json_message)
|
||||
expected = types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123))
|
||||
|
||||
# Check if both are JSONRPCRequest instances
|
||||
assert isinstance(msg.root, types.JSONRPCRequest)
|
||||
assert isinstance(expected.root, types.JSONRPCRequest)
|
||||
|
||||
assert msg.root.id == expected.root.id
|
||||
assert msg.root.method == expected.root.method
|
||||
assert msg.root.jsonrpc == expected.root.jsonrpc
|
||||
|
||||
|
||||
class MockSSEClient:
|
||||
"""Mock SSE client for testing."""
|
||||
|
||||
def __init__(self, url: str, headers: dict[str, Any] | None = None):
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.connected = False
|
||||
self.read_queue: queue.Queue = queue.Queue()
|
||||
self.write_queue: queue.Queue = queue.Queue()
|
||||
|
||||
def connect(self):
|
||||
"""Simulate connection establishment."""
|
||||
self.connected = True
|
||||
|
||||
# Send endpoint event
|
||||
endpoint_data = "/messages/?session_id=test-session-123"
|
||||
self.read_queue.put(("endpoint", endpoint_data))
|
||||
|
||||
return self.read_queue, self.write_queue
|
||||
|
||||
def send_initialize_response(self):
|
||||
"""Send a mock initialize response."""
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {
|
||||
"protocolVersion": types.LATEST_PROTOCOL_VERSION,
|
||||
"capabilities": {
|
||||
"logging": None,
|
||||
"resources": None,
|
||||
"tools": None,
|
||||
"experimental": None,
|
||||
"prompts": None,
|
||||
},
|
||||
"serverInfo": {"name": SERVER_NAME, "version": "0.1.0"},
|
||||
"instructions": "Test server instructions.",
|
||||
},
|
||||
}
|
||||
self.read_queue.put(("message", json.dumps(response)))
|
||||
|
||||
|
||||
def test_sse_client_message_id_handling():
|
||||
"""Test SSE client properly handles message ID coercion."""
|
||||
mock_client = MockSSEClient("http://test.example/sse")
|
||||
read_queue, write_queue = mock_client.connect()
|
||||
|
||||
# Send a message with string ID that should be coerced to int
|
||||
message_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "456", # String ID
|
||||
"result": {"test": "data"},
|
||||
}
|
||||
read_queue.put(("message", json.dumps(message_data)))
|
||||
read_queue.get(timeout=1.0)
|
||||
# Get the message from queue
|
||||
event_type, data = read_queue.get(timeout=1.0)
|
||||
assert event_type == "message"
|
||||
|
||||
# Parse the message
|
||||
parsed_message = types.JSONRPCMessage.model_validate_json(data)
|
||||
# Check that it's a JSONRPCResponse and verify the ID
|
||||
assert isinstance(parsed_message.root, types.JSONRPCResponse)
|
||||
assert parsed_message.root.id == 456 # Should be converted to int
|
||||
|
||||
|
||||
def test_sse_client_connection_validation():
|
||||
"""Test SSE client validates endpoint URLs properly."""
|
||||
test_url = "http://test.example/sse"
|
||||
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock the HTTP client
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
# Mock the SSE connection
|
||||
mock_event_source = Mock()
|
||||
mock_event_source.response.raise_for_status.return_value = None
|
||||
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
|
||||
|
||||
# Mock SSE events
|
||||
class MockSSEEvent:
|
||||
def __init__(self, event_type: str, data: str):
|
||||
self.event = event_type
|
||||
self.data = data
|
||||
|
||||
# Simulate endpoint event
|
||||
endpoint_event = MockSSEEvent("endpoint", "/messages/?session_id=test-123")
|
||||
mock_event_source.iter_sse.return_value = [endpoint_event]
|
||||
|
||||
# Test connection
|
||||
try:
|
||||
with sse_client(test_url) as (read_queue, write_queue):
|
||||
assert read_queue is not None
|
||||
assert write_queue is not None
|
||||
except Exception as e:
|
||||
# Connection might fail due to mocking, but we're testing the validation logic
|
||||
pass
|
||||
|
||||
|
||||
def test_sse_client_error_handling():
|
||||
"""Test SSE client properly handles various error conditions."""
|
||||
test_url = "http://test.example/sse"
|
||||
|
||||
# Test 401 error handling
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock 401 HTTP error
|
||||
mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=Mock(status_code=401))
|
||||
mock_sse_connect.side_effect = mock_error
|
||||
|
||||
with pytest.raises(MCPAuthError):
|
||||
with sse_client(test_url):
|
||||
pass
|
||||
|
||||
# Test other HTTP errors
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock other HTTP error
|
||||
mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=Mock(status_code=500))
|
||||
mock_sse_connect.side_effect = mock_error
|
||||
|
||||
with pytest.raises(MCPConnectionError):
|
||||
with sse_client(test_url):
|
||||
pass
|
||||
|
||||
|
||||
def test_sse_client_timeout_configuration():
|
||||
"""Test SSE client timeout configuration."""
|
||||
test_url = "http://test.example/sse"
|
||||
custom_timeout = 10.0
|
||||
custom_sse_timeout = 300.0
|
||||
custom_headers = {"Authorization": "Bearer test-token"}
|
||||
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock successful connection
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
mock_event_source = Mock()
|
||||
mock_event_source.response.raise_for_status.return_value = None
|
||||
mock_event_source.iter_sse.return_value = []
|
||||
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
|
||||
|
||||
try:
|
||||
with sse_client(
|
||||
test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout
|
||||
) as (read_queue, write_queue):
|
||||
# Verify the configuration was passed correctly
|
||||
mock_client_factory.assert_called_with(headers=custom_headers)
|
||||
|
||||
# Check that timeout was configured
|
||||
call_args = mock_sse_connect.call_args
|
||||
assert call_args is not None
|
||||
timeout_arg = call_args[1]["timeout"]
|
||||
assert timeout_arg.read == custom_sse_timeout
|
||||
except Exception:
|
||||
# Connection might fail due to mocking, but we tested the configuration
|
||||
pass
|
||||
|
||||
|
||||
def test_sse_transport_endpoint_validation():
|
||||
"""Test SSE transport validates endpoint URLs correctly."""
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
|
||||
# Valid endpoint (same origin)
|
||||
valid_endpoint = "http://example.com/messages/session123"
|
||||
assert transport._validate_endpoint_url(valid_endpoint) == True
|
||||
|
||||
# Invalid endpoint (different origin)
|
||||
invalid_endpoint = "http://malicious.com/messages/session123"
|
||||
assert transport._validate_endpoint_url(invalid_endpoint) == False
|
||||
|
||||
# Invalid endpoint (different scheme)
|
||||
invalid_scheme = "https://example.com/messages/session123"
|
||||
assert transport._validate_endpoint_url(invalid_scheme) == False
|
||||
|
||||
|
||||
def test_sse_transport_message_parsing():
|
||||
"""Test SSE transport properly parses different message types."""
|
||||
from core.mcp.client.sse_client import SSETransport
|
||||
|
||||
transport = SSETransport("http://example.com/sse")
|
||||
read_queue: queue.Queue = queue.Queue()
|
||||
|
||||
# Test valid JSON-RPC message
|
||||
valid_message = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
|
||||
transport._handle_message_event(valid_message, read_queue)
|
||||
|
||||
# Should have a SessionMessage in the queue
|
||||
message = read_queue.get(timeout=1.0)
|
||||
assert message is not None
|
||||
assert hasattr(message, "message")
|
||||
|
||||
# Test invalid JSON
|
||||
invalid_json = '{"invalid": json}'
|
||||
transport._handle_message_event(invalid_json, read_queue)
|
||||
|
||||
# Should have an exception in the queue
|
||||
error = read_queue.get(timeout=1.0)
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
|
||||
def test_sse_client_queue_cleanup():
|
||||
"""Test that SSE client properly cleans up queues on exit."""
|
||||
test_url = "http://test.example/sse"
|
||||
|
||||
read_queue = None
|
||||
write_queue = None
|
||||
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock connection that raises an exception
|
||||
mock_sse_connect.side_effect = Exception("Connection failed")
|
||||
|
||||
try:
|
||||
with sse_client(test_url) as (rq, wq):
|
||||
read_queue = rq
|
||||
write_queue = wq
|
||||
except Exception:
|
||||
pass # Expected to fail
|
||||
|
||||
# Queues should be cleaned up even on exception
|
||||
# Note: In real implementation, cleanup should put None to signal shutdown
|
||||
|
||||
|
||||
def test_sse_client_url_processing():
|
||||
"""Test SSE client URL processing functions."""
|
||||
from core.mcp.client.sse_client import remove_request_params
|
||||
|
||||
# Test URL with parameters
|
||||
url_with_params = "http://example.com/sse?param1=value1¶m2=value2"
|
||||
cleaned_url = remove_request_params(url_with_params)
|
||||
assert cleaned_url == "http://example.com/sse"
|
||||
|
||||
# Test URL without parameters
|
||||
url_without_params = "http://example.com/sse"
|
||||
cleaned_url = remove_request_params(url_without_params)
|
||||
assert cleaned_url == "http://example.com/sse"
|
||||
|
||||
# Test URL with path and parameters
|
||||
complex_url = "http://example.com/path/to/sse?session=123&token=abc"
|
||||
cleaned_url = remove_request_params(complex_url)
|
||||
assert cleaned_url == "http://example.com/path/to/sse"
|
||||
|
||||
|
||||
def test_sse_client_headers_propagation():
|
||||
"""Test that custom headers are properly propagated in SSE client."""
|
||||
test_url = "http://test.example/sse"
|
||||
custom_headers = {
|
||||
"Authorization": "Bearer test-token",
|
||||
"X-Custom-Header": "test-value",
|
||||
"User-Agent": "test-client/1.0",
|
||||
}
|
||||
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock the client factory to capture headers
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
# Mock the SSE connection
|
||||
mock_event_source = Mock()
|
||||
mock_event_source.response.raise_for_status.return_value = None
|
||||
mock_event_source.iter_sse.return_value = []
|
||||
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
|
||||
|
||||
try:
|
||||
with sse_client(test_url, headers=custom_headers):
|
||||
pass
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
|
||||
# Verify headers were passed to client factory
|
||||
mock_client_factory.assert_called_with(headers=custom_headers)
|
||||
|
||||
|
||||
def test_sse_client_concurrent_access():
|
||||
"""Test SSE client behavior with concurrent queue access."""
|
||||
test_read_queue: queue.Queue = queue.Queue()
|
||||
|
||||
# Simulate concurrent producers and consumers
|
||||
def producer():
|
||||
for i in range(10):
|
||||
test_read_queue.put(f"message_{i}")
|
||||
time.sleep(0.01) # Small delay to simulate real conditions
|
||||
|
||||
def consumer():
|
||||
received = []
|
||||
for _ in range(10):
|
||||
try:
|
||||
msg = test_read_queue.get(timeout=2.0)
|
||||
received.append(msg)
|
||||
except queue.Empty:
|
||||
break
|
||||
return received
|
||||
|
||||
# Start producer in separate thread
|
||||
producer_thread = threading.Thread(target=producer, daemon=True)
|
||||
producer_thread.start()
|
||||
|
||||
# Consume messages
|
||||
received_messages = consumer()
|
||||
|
||||
# Wait for producer to finish
|
||||
producer_thread.join(timeout=5.0)
|
||||
|
||||
# Verify all messages were received
|
||||
assert len(received_messages) == 10
|
||||
for i in range(10):
|
||||
assert f"message_{i}" in received_messages
|
||||
@ -0,0 +1,450 @@
|
||||
"""
|
||||
Tests for the StreamableHTTP client transport.
|
||||
|
||||
Contains tests for only the client side of the StreamableHTTP transport.
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.mcp import types
|
||||
from core.mcp.client.streamable_client import streamablehttp_client
|
||||
|
||||
# Test constants
|
||||
SERVER_NAME = "test_streamable_http_server"
|
||||
TEST_SESSION_ID = "test-session-id-12345"
|
||||
INIT_REQUEST = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"clientInfo": {"name": "test-client", "version": "1.0"},
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {},
|
||||
},
|
||||
"id": "init-1",
|
||||
}
|
||||
|
||||
|
||||
class MockStreamableHTTPClient:
|
||||
"""Mock StreamableHTTP client for testing."""
|
||||
|
||||
def __init__(self, url: str, headers: dict[str, Any] | None = None):
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.connected = False
|
||||
self.read_queue: queue.Queue = queue.Queue()
|
||||
self.write_queue: queue.Queue = queue.Queue()
|
||||
self.session_id = TEST_SESSION_ID
|
||||
|
||||
def connect(self):
|
||||
"""Simulate connection establishment."""
|
||||
self.connected = True
|
||||
return self.read_queue, self.write_queue, lambda: self.session_id
|
||||
|
||||
def send_initialize_response(self):
|
||||
"""Send a mock initialize response."""
|
||||
session_message = types.SessionMessage(
|
||||
message=types.JSONRPCMessage(
|
||||
root=types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id="init-1",
|
||||
result={
|
||||
"protocolVersion": types.LATEST_PROTOCOL_VERSION,
|
||||
"capabilities": {
|
||||
"logging": None,
|
||||
"resources": None,
|
||||
"tools": None,
|
||||
"experimental": None,
|
||||
"prompts": None,
|
||||
},
|
||||
"serverInfo": {"name": SERVER_NAME, "version": "0.1.0"},
|
||||
"instructions": "Test server instructions.",
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
self.read_queue.put(session_message)
|
||||
|
||||
def send_tools_response(self):
|
||||
"""Send a mock tools list response."""
|
||||
session_message = types.SessionMessage(
|
||||
message=types.JSONRPCMessage(
|
||||
root=types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id="tools-1",
|
||||
result={
|
||||
"tools": [
|
||||
{
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"inputSchema": {"type": "object", "properties": {}},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
self.read_queue.put(session_message)
|
||||
|
||||
|
||||
def test_streamablehttp_client_message_id_handling():
|
||||
"""Test StreamableHTTP client properly handles message ID coercion."""
|
||||
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
|
||||
read_queue, write_queue, get_session_id = mock_client.connect()
|
||||
|
||||
# Send a message with string ID that should be coerced to int
|
||||
response_message = types.SessionMessage(
|
||||
message=types.JSONRPCMessage(root=types.JSONRPCResponse(jsonrpc="2.0", id="789", result={"test": "data"}))
|
||||
)
|
||||
read_queue.put(response_message)
|
||||
|
||||
# Get the message from queue
|
||||
message = read_queue.get(timeout=1.0)
|
||||
assert message is not None
|
||||
assert isinstance(message, types.SessionMessage)
|
||||
|
||||
# Check that the ID was properly handled
|
||||
assert isinstance(message.message.root, types.JSONRPCResponse)
|
||||
assert message.message.root.id == 789 # ID should be coerced to int due to union_mode="left_to_right"
|
||||
|
||||
|
||||
def test_streamablehttp_client_connection_validation():
|
||||
"""Test StreamableHTTP client validates connections properly."""
|
||||
test_url = "http://test.example/mcp"
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
# Mock the HTTP client
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
# Mock successful response
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Test connection
|
||||
try:
|
||||
with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
|
||||
assert read_queue is not None
|
||||
assert write_queue is not None
|
||||
assert get_session_id is not None
|
||||
except Exception:
|
||||
# Connection might fail due to mocking, but we're testing the validation logic
|
||||
pass
|
||||
|
||||
|
||||
def test_streamablehttp_client_timeout_configuration():
|
||||
"""Test StreamableHTTP client timeout configuration."""
|
||||
test_url = "http://test.example/mcp"
|
||||
custom_headers = {"Authorization": "Bearer test-token"}
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
# Mock successful connection
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
try:
|
||||
with streamablehttp_client(test_url, headers=custom_headers) as (read_queue, write_queue, get_session_id):
|
||||
# Verify the configuration was passed correctly
|
||||
mock_client_factory.assert_called_with(headers=custom_headers)
|
||||
except Exception:
|
||||
# Connection might fail due to mocking, but we tested the configuration
|
||||
pass
|
||||
|
||||
|
||||
def test_streamablehttp_client_session_id_handling():
|
||||
"""Test StreamableHTTP client properly handles session IDs."""
|
||||
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
|
||||
read_queue, write_queue, get_session_id = mock_client.connect()
|
||||
|
||||
# Test that session ID is available
|
||||
session_id = get_session_id()
|
||||
assert session_id == TEST_SESSION_ID
|
||||
|
||||
# Test that we can use the session ID in subsequent requests
|
||||
assert session_id is not None
|
||||
assert len(session_id) > 0
|
||||
|
||||
|
||||
def test_streamablehttp_client_message_parsing():
|
||||
"""Test StreamableHTTP client properly parses different message types."""
|
||||
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
|
||||
read_queue, write_queue, get_session_id = mock_client.connect()
|
||||
|
||||
# Test valid initialization response
|
||||
mock_client.send_initialize_response()
|
||||
|
||||
# Should have a SessionMessage in the queue
|
||||
message = read_queue.get(timeout=1.0)
|
||||
assert message is not None
|
||||
assert isinstance(message, types.SessionMessage)
|
||||
assert isinstance(message.message.root, types.JSONRPCResponse)
|
||||
|
||||
# Test tools response
|
||||
mock_client.send_tools_response()
|
||||
|
||||
tools_message = read_queue.get(timeout=1.0)
|
||||
assert tools_message is not None
|
||||
assert isinstance(tools_message, types.SessionMessage)
|
||||
|
||||
|
||||
def test_streamablehttp_client_queue_cleanup():
|
||||
"""Test that StreamableHTTP client properly cleans up queues on exit."""
|
||||
test_url = "http://test.example/mcp"
|
||||
|
||||
read_queue = None
|
||||
write_queue = None
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
# Mock connection that raises an exception
|
||||
mock_client_factory.side_effect = Exception("Connection failed")
|
||||
|
||||
try:
|
||||
with streamablehttp_client(test_url) as (rq, wq, get_session_id):
|
||||
read_queue = rq
|
||||
write_queue = wq
|
||||
except Exception:
|
||||
pass # Expected to fail
|
||||
|
||||
# Queues should be cleaned up even on exception
|
||||
# Note: In real implementation, cleanup should put None to signal shutdown
|
||||
|
||||
|
||||
def test_streamablehttp_client_headers_propagation():
|
||||
"""Test that custom headers are properly propagated in StreamableHTTP client."""
|
||||
test_url = "http://test.example/mcp"
|
||||
custom_headers = {
|
||||
"Authorization": "Bearer test-token",
|
||||
"X-Custom-Header": "test-value",
|
||||
"User-Agent": "test-client/1.0",
|
||||
}
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
# Mock the client factory to capture headers
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
try:
|
||||
with streamablehttp_client(test_url, headers=custom_headers):
|
||||
pass
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
|
||||
# Verify headers were passed to client factory
|
||||
# Check that the call was made with headers that include our custom headers
|
||||
mock_client_factory.assert_called_once()
|
||||
call_args = mock_client_factory.call_args
|
||||
assert "headers" in call_args.kwargs
|
||||
passed_headers = call_args.kwargs["headers"]
|
||||
|
||||
# Verify all custom headers are present
|
||||
for key, value in custom_headers.items():
|
||||
assert key in passed_headers
|
||||
assert passed_headers[key] == value
|
||||
|
||||
|
||||
def test_streamablehttp_client_concurrent_access():
|
||||
"""Test StreamableHTTP client behavior with concurrent queue access."""
|
||||
test_read_queue: queue.Queue = queue.Queue()
|
||||
test_write_queue: queue.Queue = queue.Queue()
|
||||
|
||||
# Simulate concurrent producers and consumers
|
||||
def producer():
|
||||
for i in range(10):
|
||||
test_read_queue.put(f"message_{i}")
|
||||
time.sleep(0.01) # Small delay to simulate real conditions
|
||||
|
||||
def consumer():
|
||||
received = []
|
||||
for _ in range(10):
|
||||
try:
|
||||
msg = test_read_queue.get(timeout=2.0)
|
||||
received.append(msg)
|
||||
except queue.Empty:
|
||||
break
|
||||
return received
|
||||
|
||||
# Start producer in separate thread
|
||||
producer_thread = threading.Thread(target=producer, daemon=True)
|
||||
producer_thread.start()
|
||||
|
||||
# Consume messages
|
||||
received_messages = consumer()
|
||||
|
||||
# Wait for producer to finish
|
||||
producer_thread.join(timeout=5.0)
|
||||
|
||||
# Verify all messages were received
|
||||
assert len(received_messages) == 10
|
||||
for i in range(10):
|
||||
assert f"message_{i}" in received_messages
|
||||
|
||||
|
||||
def test_streamablehttp_client_json_vs_sse_mode():
|
||||
"""Test StreamableHTTP client handling of JSON vs SSE response modes."""
|
||||
test_url = "http://test.example/mcp"
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
# Mock JSON response
|
||||
mock_json_response = Mock()
|
||||
mock_json_response.status_code = 200
|
||||
mock_json_response.headers = {"content-type": "application/json"}
|
||||
mock_json_response.json.return_value = {"result": "json_mode"}
|
||||
mock_json_response.raise_for_status.return_value = None
|
||||
|
||||
# Mock SSE response
|
||||
mock_sse_response = Mock()
|
||||
mock_sse_response.status_code = 200
|
||||
mock_sse_response.headers = {"content-type": "text/event-stream"}
|
||||
mock_sse_response.raise_for_status.return_value = None
|
||||
|
||||
# Test JSON mode
|
||||
mock_client.post.return_value = mock_json_response
|
||||
|
||||
try:
|
||||
with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
|
||||
# Should handle JSON responses
|
||||
assert read_queue is not None
|
||||
assert write_queue is not None
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
|
||||
# Test SSE mode
|
||||
mock_client.post.return_value = mock_sse_response
|
||||
|
||||
try:
|
||||
with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
|
||||
# Should handle SSE responses
|
||||
assert read_queue is not None
|
||||
assert write_queue is not None
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
|
||||
|
||||
def test_streamablehttp_client_terminate_on_close():
|
||||
"""Test StreamableHTTP client terminate_on_close parameter."""
|
||||
test_url = "http://test.example/mcp"
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.delete.return_value = mock_response
|
||||
|
||||
# Test with terminate_on_close=True (default)
|
||||
try:
|
||||
with streamablehttp_client(test_url, terminate_on_close=True) as (read_queue, write_queue, get_session_id):
|
||||
pass
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
|
||||
# Test with terminate_on_close=False
|
||||
try:
|
||||
with streamablehttp_client(test_url, terminate_on_close=False) as (read_queue, write_queue, get_session_id):
|
||||
pass
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
|
||||
|
||||
def test_streamablehttp_client_protocol_version_handling():
|
||||
"""Test StreamableHTTP client protocol version handling."""
|
||||
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
|
||||
read_queue, write_queue, get_session_id = mock_client.connect()
|
||||
|
||||
# Send initialize response with specific protocol version
|
||||
|
||||
session_message = types.SessionMessage(
|
||||
message=types.JSONRPCMessage(
|
||||
root=types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id="init-1",
|
||||
result={
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"serverInfo": {"name": SERVER_NAME, "version": "0.1.0"},
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
read_queue.put(session_message)
|
||||
|
||||
# Get the message and verify protocol version
|
||||
message = read_queue.get(timeout=1.0)
|
||||
assert message is not None
|
||||
assert isinstance(message.message.root, types.JSONRPCResponse)
|
||||
result = message.message.root.result
|
||||
assert result["protocolVersion"] == "2024-11-05"
|
||||
|
||||
|
||||
def test_streamablehttp_client_error_response_handling():
|
||||
"""Test StreamableHTTP client handling of error responses."""
|
||||
mock_client = MockStreamableHTTPClient("http://test.example/mcp")
|
||||
read_queue, write_queue, get_session_id = mock_client.connect()
|
||||
|
||||
# Send an error response
|
||||
session_message = types.SessionMessage(
|
||||
message=types.JSONRPCMessage(
|
||||
root=types.JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id="test-1",
|
||||
error=types.ErrorData(code=-32601, message="Method not found", data=None),
|
||||
)
|
||||
)
|
||||
)
|
||||
read_queue.put(session_message)
|
||||
|
||||
# Get the error message
|
||||
message = read_queue.get(timeout=1.0)
|
||||
assert message is not None
|
||||
assert isinstance(message.message.root, types.JSONRPCError)
|
||||
assert message.message.root.error.code == -32601
|
||||
assert message.message.root.error.message == "Method not found"
|
||||
|
||||
|
||||
def test_streamablehttp_client_resumption_token_handling():
|
||||
"""Test StreamableHTTP client resumption token functionality."""
|
||||
test_url = "http://test.example/mcp"
|
||||
test_resumption_token = "resume-token-123"
|
||||
|
||||
with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
mock_client = Mock()
|
||||
mock_client_factory.return_value.__enter__.return_value = mock_client
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json", "last-event-id": test_resumption_token}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
try:
|
||||
with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
|
||||
# Test that resumption token can be captured from headers
|
||||
assert read_queue is not None
|
||||
assert write_queue is not None
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
@ -0,0 +1,53 @@
|
||||
from redis import RedisError
|
||||
|
||||
from extensions.ext_redis import redis_fallback
|
||||
|
||||
|
||||
def test_redis_fallback_success():
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
assert test_func() == "success"
|
||||
|
||||
|
||||
def test_redis_fallback_error():
|
||||
@redis_fallback(default_return="fallback")
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func() == "fallback"
|
||||
|
||||
|
||||
def test_redis_fallback_none_default():
|
||||
@redis_fallback()
|
||||
def test_func():
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func() is None
|
||||
|
||||
|
||||
def test_redis_fallback_with_args():
|
||||
@redis_fallback(default_return=0)
|
||||
def test_func(x, y):
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func(1, 2) == 0
|
||||
|
||||
|
||||
def test_redis_fallback_with_kwargs():
|
||||
@redis_fallback(default_return={})
|
||||
def test_func(x=None, y=None):
|
||||
raise RedisError("Redis error")
|
||||
|
||||
assert test_func(x=1, y=2) == {}
|
||||
|
||||
|
||||
def test_redis_fallback_preserves_function_metadata():
|
||||
@redis_fallback(default_return=None)
|
||||
def test_func():
|
||||
"""Test function docstring"""
|
||||
pass
|
||||
|
||||
assert test_func.__name__ == "test_func"
|
||||
assert test_func.__doc__ == "Test function docstring"
|
||||
@ -0,0 +1,7 @@
|
||||
<svg width="32" height="32" viewBox="0 0 32 32" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M28.0049 16C28.0049 20.4183 24.4231 24 20.0049 24C15.5866 24 12.0049 20.4183 12.0049 16C12.0049 11.5817 15.5866 8 20.0049 8C24.4231 8 28.0049 11.5817 28.0049 16Z" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M4.00488 16H6.67155" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M4.00488 9.33334H8.00488" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M4.00488 22.6667H8.00488" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M26 22L29.3333 25.3333" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 823 B |
@ -0,0 +1,4 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M9.20626 1.68651C9.61828 1.68651 10.014 1.8473 10.3093 2.13466C10.4536 2.27516 10.5684 2.44313 10.6468 2.62868C10.7252 2.81422 10.7657 3.01358 10.7659 3.21501C10.7661 3.41645 10.7259 3.61588 10.6478 3.80156C10.5697 3.98723 10.4552 4.1554 10.3111 4.29614L5.86656 8.65516C5.81837 8.70203 5.78006 8.75808 5.7539 8.82001C5.72775 8.88194 5.71427 8.94848 5.71427 9.01571C5.71427 9.08294 5.72775 9.14948 5.7539 9.21141C5.78006 9.27334 5.81837 9.32939 5.86656 9.37626C5.96503 9.47212 6.09703 9.52576 6.23445 9.52576C6.37187 9.52576 6.50387 9.47212 6.60234 9.37626L6.66222 9.31698L6.66345 9.31576L11.0463 5.01725C11.3417 4.73067 11.7372 4.57056 12.1488 4.5709C12.5604 4.57124 12.9556 4.73202 13.2506 5.01908L13.2811 5.04902C13.4256 5.18967 13.5405 5.35786 13.6189 5.54363C13.6973 5.72941 13.7377 5.92903 13.7377 6.13068C13.7377 6.33233 13.6973 6.53195 13.6189 6.71773C13.5405 6.9035 13.4256 7.07169 13.2811 7.21234L7.96082 12.43C7.84828 12.5393 7.75882 12.6701 7.69773 12.8147C7.63664 12.9592 7.60517 13.1145 7.60517 13.2714C7.60517 13.4284 7.63664 13.5837 7.69773 13.7282C7.75882 13.8728 7.84828 14.0036 7.96082 14.1129L9.05348 15.1842C9.15192 15.2799 9.28378 15.3334 9.42106 15.3334C9.55834 15.3334 9.6902 15.2799 9.78864 15.1842C9.83683 15.1373 9.87514 15.0813 9.9013 15.0194C9.92746 14.9574 9.94094 14.8909 9.94094 14.8237C9.94094 14.7564 9.92746 14.6899 9.9013 14.628C9.87514 14.566 9.83683 14.51 9.78864 14.4631L8.69598 13.3912C8.67992 13.3756 8.66716 13.357 8.65844 13.3363C8.64973 13.3157 8.64523 13.2935 8.64523 13.2711C8.64523 13.2488 8.64973 13.2266 8.65844 13.206C8.66716 13.1853 8.67992 13.1667 8.69598 13.1511L14.0163 7.93405C14.2572 7.69971 14.4488 7.41943 14.5796 7.10979C14.7104 6.80014 14.7778 6.46742 14.7778 6.13129C14.7778 5.79516 14.7104 5.46244 14.5796 5.1528C14.4488 4.84315 14.2572 4.56288 14.0163 4.32853L13.9857 4.29797C13.6978 4.01697 13.3493 3.80582 12.9669 3.6808C12.5845 3.55578 12.1785 3.52022 11.7802 3.57687C11.8371 3.1838 11.8001 2.78285 11.6722 2.40684C11.5443 2.03083 11.3292 1.69045 11.0445 1.41356C10.5524 0.93469 9.89288 0.666748 9.20626 0.666748C8.51964 0.666748 7.86012 0.93469 7.36805 1.41356L1.48555 7.18239C1.43735 7.22926 1.39905 7.28532 1.37289 7.34725C1.34673 7.40917 1.33325 7.47572 1.33325 7.54294C1.33325 7.61017 1.34673 7.67672 1.37289 7.73864C1.39905 7.80057 1.43735 7.85663 1.48555 7.9035C1.58399 7.99918 1.71585 8.0527 1.85313 8.0527C1.9904 8.0527 2.12227 7.99918 2.22071 7.9035L8.10321 2.13466C8.39848 1.8473 8.79424 1.68651 9.20626 1.68651Z" fill="white"/>
|
||||
<path d="M9.68688 3.41201C9.66072 3.47394 9.62241 3.52999 9.57422 3.57686L5.22314 7.8436C5.07864 7.98425 4.96378 8.15243 4.88535 8.33821C4.80693 8.52399 4.76652 8.7236 4.76652 8.92526C4.76652 9.12691 4.80693 9.32652 4.88535 9.5123C4.96378 9.69808 5.07864 9.86626 5.22314 10.0069C5.51841 10.2943 5.91417 10.4551 6.32619 10.4551C6.73821 10.4551 7.13397 10.2943 7.42924 10.0069L11.7797 5.74017C11.8782 5.64431 12.0102 5.59067 12.1476 5.59067C12.285 5.59067 12.417 5.64431 12.5155 5.74017C12.5637 5.78704 12.602 5.8431 12.6281 5.90503C12.6543 5.96696 12.6678 6.0335 12.6678 6.10073C12.6678 6.16795 12.6543 6.2345 12.6281 6.29643C12.602 6.35835 12.5637 6.41441 12.5155 6.46128L8.1644 10.728C7.67225 11.2067 7.01276 11.4746 6.32619 11.4746C5.63962 11.4746 4.98013 11.2067 4.48798 10.728C4.24701 10.4937 4.05547 10.2134 3.92468 9.90375C3.79389 9.59411 3.7265 9.26139 3.7265 8.92526C3.7265 8.58912 3.79389 8.2564 3.92468 7.94676C4.05547 7.63712 4.24701 7.35684 4.48798 7.1225L8.83845 2.85576C8.93691 2.75989 9.06891 2.70625 9.20633 2.70625C9.34375 2.70625 9.47575 2.75989 9.57422 2.85576C9.62241 2.90263 9.66072 2.95868 9.68688 3.02061C9.71304 3.08254 9.72651 3.14908 9.72651 3.21631C9.72651 3.28353 9.71304 3.35008 9.68688 3.41201Z" fill="white"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.8 KiB |
@ -0,0 +1,36 @@
|
||||
<svg width="204" height="36" viewBox="0 0 204 36" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g opacity="0.1">
|
||||
<path d="M3.33333 18.3333C3.33333 18.5423 3.64067 18.9056 4.35365 19.2621C5.27603 19.7233 6.58451 20 8 20C9.41547 20 10.7239 19.7233 11.6463 19.2621C12.3593 18.9056 12.6667 18.5423 12.6667 18.3333V16.8858C11.5667 17.5655 9.88487 18 8 18C6.11515 18 4.43331 17.5655 3.33333 16.8858V18.3333ZM12.6667 20.2191C11.5667 20.8988 9.88487 21.3333 8 21.3333C6.11515 21.3333 4.43331 20.8988 3.33333 20.2191V21.6667C3.33333 21.8756 3.64067 22.2389 4.35365 22.5954C5.27603 23.0566 6.58451 23.3333 8 23.3333C9.41547 23.3333 10.7239 23.0566 11.6463 22.5954C12.3593 22.2389 12.6667 21.8756 12.6667 21.6667V20.2191ZM2 21.6667V15C2 13.3431 4.68629 12 8 12C11.3137 12 14 13.3431 14 15V21.6667C14 23.3235 11.3137 24.6667 8 24.6667C4.68629 24.6667 2 23.3235 2 21.6667ZM8 16.6667C9.41547 16.6667 10.7239 16.3899 11.6463 15.9288C12.3593 15.5723 12.6667 15.2089 12.6667 15C12.6667 14.7911 12.3593 14.4277 11.6463 14.0712C10.7239 13.6101 9.41547 13.3333 8 13.3333C6.58451 13.3333 5.27603 13.6101 4.35365 14.0712C3.64067 14.4277 3.33333 14.7911 3.33333 15C3.33333 15.2089 3.64067 15.5723 4.35365 15.9288C5.27603 16.3899 6.58451 16.6667 8 16.6667Z" fill="#101828" fill-opacity="0.3"/>
|
||||
</g>
|
||||
<g opacity="0.3">
|
||||
<path d="M41.3337 11.3333C41.7019 11.3333 42.0003 11.6318 42.0003 12V15.3333C42.0003 15.7015 41.7019 16 41.3337 16H38.0003V24.6667C38.0003 25.0349 37.7019 25.3333 37.3337 25.3333H34.667C34.2988 25.3333 34.0003 25.0349 34.0003 24.6667V16H30.3337C29.9655 16 29.667 15.7015 29.667 15.3333V13.7454C29.667 13.4929 29.8097 13.262 30.0355 13.1491L33.667 11.3333H41.3337ZM38.0003 12.6667H33.9818L31.0003 14.1574V14.6667H35.3337V24H36.667V14.6667H38.0003V12.6667ZM40.667 12.6667H39.3337V14.6667H40.667V12.6667Z" fill="#101828" fill-opacity="0.3"/>
|
||||
</g>
|
||||
<g opacity="0.6">
|
||||
<path d="M60.6667 13.3333C60.6667 11.8606 61.8606 10.6667 63.3333 10.6667C64.8061 10.6667 66 11.8606 66 13.3333H69.3333C69.7015 13.3333 70 13.6318 70 14V16.7805C70 16.9969 69.8949 17.1998 69.7183 17.3248C69.5415 17.4497 69.3152 17.4811 69.1112 17.409C68.973 17.3602 68.8237 17.3333 68.6667 17.3333C67.9303 17.3333 67.3333 17.9303 67.3333 18.6667C67.3333 19.4031 67.9303 20 68.6667 20C68.8237 20 68.973 19.9731 69.1112 19.9243C69.3152 19.8522 69.5415 19.8836 69.7183 20.0085C69.8949 20.1335 70 20.3365 70 20.5529V23.3333C70 23.7015 69.7015 24 69.3333 24H58.6667C58.2985 24 58 23.7015 58 23.3333V14C58 13.6318 58.2985 13.3333 58.6667 13.3333H60.6667ZM63.3333 12C62.597 12 62 12.5969 62 13.3333C62 13.4903 62.0269 13.6397 62.0757 13.7778C62.1478 13.9819 62.1164 14.2082 61.9915 14.3849C61.8665 14.5616 61.6635 14.6667 61.4471 14.6667H59.3333V22.6667H68.6667V21.3333C67.1939 21.3333 66 20.1394 66 18.6667C66 17.1939 67.1939 16 68.6667 16V14.6667H65.2195C65.0031 14.6667 64.8002 14.5616 64.6752 14.3849C64.5503 14.2082 64.5189 13.9819 64.591 13.7778C64.6398 13.6397 64.6667 13.4904 64.6667 13.3333C64.6667 12.5969 64.0697 12 63.3333 12Z" fill="#101828" fill-opacity="0.3"/>
|
||||
</g>
|
||||
<rect x="84.5" y="0.5" width="35" height="35" rx="9.5" stroke="#101828" stroke-opacity="0.04"/>
|
||||
<path d="M96.167 16.3333H107.834V25.5H96.167V16.3333Z" stroke="#676F83" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M100.333 19.6667H103.666" stroke="#676F83" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M107.833 12.1667L105.583 15.9167" stroke="#676F83" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M97.4888 11.9517C97.5447 11.9238 97.5901 11.8784 97.6181 11.8224L97.9911 11.0765C98.0976 10.8634 98.4017 10.8634 98.5083 11.0765L98.8813 11.8224C98.9092 11.8784 98.9546 11.9238 99.0106 11.9517L99.7565 12.3247C99.9696 12.4313 99.9696 12.7354 99.7565 12.842L99.0106 13.2149C98.9546 13.2429 98.9092 13.2883 98.8813 13.3442L98.5083 14.0902C98.4017 14.3033 98.0976 14.3033 97.9911 14.0902L97.6181 13.3442C97.5901 13.2883 97.5447 13.2429 97.4888 13.2149L96.7429 12.842C96.5297 12.7354 96.5297 12.4313 96.7429 12.3247L97.4888 11.9517Z" fill="#676F83"/>
|
||||
<path d="M101.882 10.5438C101.952 10.5089 102.009 10.4521 102.044 10.3822L102.51 9.4498C102.643 9.1834 103.023 9.1834 103.157 9.4498L103.623 10.3822C103.658 10.4521 103.714 10.5089 103.784 10.5438L104.717 11.0101C104.983 11.1432 104.983 11.5234 104.717 11.6566L103.784 12.1228C103.714 12.1578 103.658 12.2145 103.623 12.2845L103.157 13.2169C103.023 13.4833 102.643 13.4833 102.51 13.2169L102.044 12.2845C102.009 12.2145 101.952 12.1578 101.882 12.1228L100.95 11.6566C100.683 11.5234 100.683 11.1432 100.95 11.0101L101.882 10.5438Z" fill="#676F83"/>
|
||||
<g opacity="0.6" clip-path="url(#clip0_9296_51042)">
|
||||
<path d="M145.809 14.7521L145.645 15.1292C145.525 15.4053 145.143 15.4053 145.022 15.1292L144.858 14.7521C144.565 14.0796 144.037 13.5443 143.379 13.2514L142.872 13.0261C142.599 12.9044 142.599 12.5059 142.872 12.3841L143.35 12.1714C144.026 11.871 144.563 11.3158 144.851 10.6206L145.02 10.213C145.138 9.92899 145.53 9.92899 145.647 10.213L145.816 10.6206C146.104 11.3158 146.641 11.871 147.317 12.1714L147.795 12.3841C148.069 12.5059 148.069 12.9044 147.795 13.0261L147.289 13.2514C146.63 13.5443 146.102 14.0796 145.809 14.7521ZM138 11.3333C140.712 11.3333 142.951 13.3571 143.289 15.9766L144.79 18.3358C144.889 18.4911 144.869 18.7231 144.64 18.8211L143.334 19.3807V21.3333C143.334 22.0697 142.737 22.6667 142 22.6667H140.668L140.667 24.6667H134.667L134.667 22.2041C134.667 21.4168 134.376 20.6725 133.837 20.0007C133.105 19.0875 132.667 17.9283 132.667 16.6667C132.667 13.7211 135.055 11.3333 138 11.3333ZM138 12.6667C135.791 12.6667 134 14.4575 134 16.6667C134 17.5899 134.312 18.4619 134.878 19.1666C135.607 20.076 136.001 21.1115 136 22.2042L136 23.3333H139.334L139.335 21.3333H142V18.5013L143.033 18.0587L142.005 16.4417L141.967 16.1475C141.711 14.1676 140.017 12.6667 138 12.6667ZM144.993 21.3286L146.103 22.0683C146.88 20.9041 147.334 19.5051 147.334 18.0001C147.334 17.5447 147.292 17.0991 147.213 16.6667L145.917 17C145.972 17.3252 146 17.6593 146 18.0001C146 19.2314 145.629 20.3761 144.993 21.3286Z" fill="#101828" fill-opacity="0.3"/>
|
||||
</g>
|
||||
<g opacity="0.3">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M167.149 16.3522L167.251 16.3822L174.224 19.0391L174.317 19.082C174.722 19.308 174.777 19.8792 174.424 20.179L174.341 20.2389L171.817 21.8171L170.239 24.3418C169.962 24.784 169.324 24.7511 169.082 24.3171L169.039 24.2246L166.382 17.2513C166.188 16.742 166.644 16.2421 167.149 16.3522ZM169.812 22.5085L170.767 20.9811L170.811 20.9186C170.858 20.859 170.916 20.8076 170.981 20.7669L172.508 19.8119L168.152 18.1524L169.812 22.5085Z" fill="#101828" fill-opacity="0.3"/>
|
||||
<path d="M165.212 20.3978L163.562 22.0475L162.619 21.1048L164.269 19.4551L165.212 20.3978Z" fill="#101828" fill-opacity="0.3"/>
|
||||
<path d="M163.666 18H161.333V16.6667H163.666V18Z" fill="#101828" fill-opacity="0.3"/>
|
||||
<path d="M165.212 14.2689L164.269 15.2116L162.619 13.5619L163.562 12.6192L165.212 14.2689Z" fill="#101828" fill-opacity="0.3"/>
|
||||
<path d="M172.047 13.5619L170.397 15.2116L169.455 14.2689L171.104 12.6192L172.047 13.5619Z" fill="#101828" fill-opacity="0.3"/>
|
||||
<path d="M168 13.6667H166.666V11.3333H168V13.6667Z" fill="#101828" fill-opacity="0.3"/>
|
||||
</g>
|
||||
<g opacity="0.1">
|
||||
<path d="M202.666 23.3333V14.6667L201.333 12H190.666L189.333 14.669V23.3333C189.333 23.7015 189.631 24 190 24H202C202.368 24 202.666 23.7015 202.666 23.3333ZM190.666 16H201.333V22.6667H190.666V16ZM191.49 13.3333H200.509L201.176 14.6667H190.824L191.49 13.3333ZM198 17.3333H194V18.6667H198V17.3333Z" fill="#101828" fill-opacity="0.3"/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_9296_51042">
|
||||
<rect width="16" height="16" fill="white" transform="translate(132 10)"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 7.5 KiB |
@ -0,0 +1,7 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M1.33325 4.66663C1.33325 3.56206 2.22869 2.66663 3.33325 2.66663H12.6666C13.7712 2.66663 14.6666 3.56206 14.6666 4.66663V8.16663C14.6666 8.53483 14.3681 8.83329 13.9999 8.83329C13.6317 8.83329 13.3333 8.53483 13.3333 8.16663V4.66663C13.3333 4.29844 13.0348 3.99996 12.6666 3.99996H3.33325C2.96507 3.99996 2.66659 4.29844 2.66659 4.66663V12C2.66659 12.3682 2.96507 12.6666 3.33325 12.6666H7.99992C8.36812 12.6666 8.66658 12.9651 8.66658 13.3333C8.66658 13.7015 8.36812 14 7.99992 14H3.33325C2.22869 14 1.33325 13.1046 1.33325 12V4.66663Z" fill="white"/>
|
||||
<path d="M3.66659 5.83329C3.66659 6.29353 4.03968 6.66663 4.49992 6.66663C4.96016 6.66663 5.33325 6.29353 5.33325 5.83329C5.33325 5.37305 4.96016 4.99996 4.49992 4.99996C4.03968 4.99996 3.66659 5.37305 3.66659 5.83329Z" fill="white"/>
|
||||
<path d="M5.99992 5.83329C5.99992 6.29353 6.37301 6.66663 6.83325 6.66663C7.29352 6.66663 7.66658 6.29353 7.66658 5.83329C7.66658 5.37305 7.29352 4.99996 6.83325 4.99996C6.37301 4.99996 5.99992 5.37305 5.99992 5.83329Z" fill="white"/>
|
||||
<path d="M8.33325 5.83329C8.33325 6.29353 8.70632 6.66663 9.16658 6.66663C9.62685 6.66663 9.99992 6.29353 9.99992 5.83329C9.99992 5.37305 9.62685 4.99996 9.16658 4.99996C8.70632 4.99996 8.33325 5.37305 8.33325 5.83329Z" fill="white"/>
|
||||
<path d="M10.5293 9.69609C10.2933 9.62349 10.0365 9.68729 9.86185 9.86189C9.68725 10.0365 9.62345 10.2934 9.69605 10.5294L11.0294 14.8627C11.1095 15.1231 11.3401 15.3086 11.6116 15.331C11.8832 15.3535 12.1411 15.2085 12.2629 14.9648L13.1635 13.1636L14.9647 12.263C15.2085 12.1411 15.3535 11.8832 15.331 11.6116C15.3085 11.3401 15.1231 11.1096 14.8627 11.0294L10.5293 9.69609Z" fill="white"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.7 KiB |
@ -0,0 +1,77 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "32",
|
||||
"height": "32",
|
||||
"viewBox": "0 0 32 32",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M28.0049 16C28.0049 20.4183 24.4231 24 20.0049 24C15.5866 24 12.0049 20.4183 12.0049 16C12.0049 11.5817 15.5866 8 20.0049 8C24.4231 8 28.0049 11.5817 28.0049 16Z",
|
||||
"stroke": "currentColor",
|
||||
"stroke-width": "2",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M4.00488 16H6.67155",
|
||||
"stroke": "currentColor",
|
||||
"stroke-width": "2",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M4.00488 9.33334H8.00488",
|
||||
"stroke": "currentColor",
|
||||
"stroke-width": "2",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M4.00488 22.6667H8.00488",
|
||||
"stroke": "currentColor",
|
||||
"stroke-width": "2",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M26 22L29.3333 25.3333",
|
||||
"stroke": "currentColor",
|
||||
"stroke-width": "2",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "SearchMenu"
|
||||
}
|
||||
@ -0,0 +1,20 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import * as React from 'react'
|
||||
import data from './SearchMenu.json'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.MutableRefObject<HTMLOrSVGElement>>;
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'SearchMenu'
|
||||
|
||||
export default Icon
|
||||
@ -0,0 +1,35 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "16",
|
||||
"height": "16",
|
||||
"viewBox": "0 0 16 16",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M9.20626 1.68651C9.61828 1.68651 10.014 1.8473 10.3093 2.13466C10.4536 2.27516 10.5684 2.44313 10.6468 2.62868C10.7252 2.81422 10.7657 3.01358 10.7659 3.21501C10.7661 3.41645 10.7259 3.61588 10.6478 3.80156C10.5697 3.98723 10.4552 4.1554 10.3111 4.29614L5.86656 8.65516C5.81837 8.70203 5.78006 8.75808 5.7539 8.82001C5.72775 8.88194 5.71427 8.94848 5.71427 9.01571C5.71427 9.08294 5.72775 9.14948 5.7539 9.21141C5.78006 9.27334 5.81837 9.32939 5.86656 9.37626C5.96503 9.47212 6.09703 9.52576 6.23445 9.52576C6.37187 9.52576 6.50387 9.47212 6.60234 9.37626L6.66222 9.31698L6.66345 9.31576L11.0463 5.01725C11.3417 4.73067 11.7372 4.57056 12.1488 4.5709C12.5604 4.57124 12.9556 4.73202 13.2506 5.01908L13.2811 5.04902C13.4256 5.18967 13.5405 5.35786 13.6189 5.54363C13.6973 5.72941 13.7377 5.92903 13.7377 6.13068C13.7377 6.33233 13.6973 6.53195 13.6189 6.71773C13.5405 6.9035 13.4256 7.07169 13.2811 7.21234L7.96082 12.43C7.84828 12.5393 7.75882 12.6701 7.69773 12.8147C7.63664 12.9592 7.60517 13.1145 7.60517 13.2714C7.60517 13.4284 7.63664 13.5837 7.69773 13.7282C7.75882 13.8728 7.84828 14.0036 7.96082 14.1129L9.05348 15.1842C9.15192 15.2799 9.28378 15.3334 9.42106 15.3334C9.55834 15.3334 9.6902 15.2799 9.78864 15.1842C9.83683 15.1373 9.87514 15.0813 9.9013 15.0194C9.92746 14.9574 9.94094 14.8909 9.94094 14.8237C9.94094 14.7564 9.92746 14.6899 9.9013 14.628C9.87514 14.566 9.83683 14.51 9.78864 14.4631L8.69598 13.3912C8.67992 13.3756 8.66716 13.357 8.65844 13.3363C8.64973 13.3157 8.64523 13.2935 8.64523 13.2711C8.64523 13.2488 8.64973 13.2266 8.65844 13.206C8.66716 13.1853 8.67992 13.1667 8.69598 13.1511L14.0163 7.93405C14.2572 7.69971 14.4488 7.41943 14.5796 7.10979C14.7104 6.80014 14.7778 6.46742 14.7778 6.13129C14.7778 5.79516 14.7104 5.46244 14.5796 5.1528C14.4488 4.84315 14.2572 4.56288 14.0163 4.32853L13.9857 4.29797C13.6978 4.01697 13.3493 3.80582 12.9669 3.6808C12.5845 3.55578 12.1785 3.52022 11.7802 3.57687C11.8371 3.1838 11.8001 2.78285 11.6722 2.40684C11.5443 2.03083 11.3292 1.69045 11.0445 1.41356C10.5524 0.93469 9.89288 0.666748 9.20626 0.666748C8.51964 0.666748 7.86012 0.93469 7.36805 1.41356L1.48555 7.18239C1.43735 7.22926 1.39905 7.28532 1.37289 7.34725C1.34673 7.40917 1.33325 7.47572 1.33325 7.54294C1.33325 7.61017 1.34673 7.67672 1.37289 7.73864C1.39905 7.80057 1.43735 7.85663 1.48555 7.9035C1.58399 7.99918 1.71585 8.0527 1.85313 8.0527C1.9904 8.0527 2.12227 7.99918 2.22071 7.9035L8.10321 2.13466C8.39848 1.8473 8.79424 1.68651 9.20626 1.68651Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M9.68688 3.41201C9.66072 3.47394 9.62241 3.52999 9.57422 3.57686L5.22314 7.8436C5.07864 7.98425 4.96378 8.15243 4.88535 8.33821C4.80693 8.52399 4.76652 8.7236 4.76652 8.92526C4.76652 9.12691 4.80693 9.32652 4.88535 9.5123C4.96378 9.69808 5.07864 9.86626 5.22314 10.0069C5.51841 10.2943 5.91417 10.4551 6.32619 10.4551C6.73821 10.4551 7.13397 10.2943 7.42924 10.0069L11.7797 5.74017C11.8782 5.64431 12.0102 5.59067 12.1476 5.59067C12.285 5.59067 12.417 5.64431 12.5155 5.74017C12.5637 5.78704 12.602 5.8431 12.6281 5.90503C12.6543 5.96696 12.6678 6.0335 12.6678 6.10073C12.6678 6.16795 12.6543 6.2345 12.6281 6.29643C12.602 6.35835 12.5637 6.41441 12.5155 6.46128L8.1644 10.728C7.67225 11.2067 7.01276 11.4746 6.32619 11.4746C5.63962 11.4746 4.98013 11.2067 4.48798 10.728C4.24701 10.4937 4.05547 10.2134 3.92468 9.90375C3.79389 9.59411 3.7265 9.26139 3.7265 8.92526C3.7265 8.58912 3.79389 8.2564 3.92468 7.94676C4.05547 7.63712 4.24701 7.35684 4.48798 7.1225L8.83845 2.85576C8.93691 2.75989 9.06891 2.70625 9.20633 2.70625C9.34375 2.70625 9.47575 2.75989 9.57422 2.85576C9.62241 2.90263 9.66072 2.95868 9.68688 3.02061C9.71304 3.08254 9.72651 3.14908 9.72651 3.21631C9.72651 3.28353 9.71304 3.35008 9.68688 3.41201Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "Mcp"
|
||||
}
|
||||
@ -0,0 +1,20 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import * as React from 'react'
|
||||
import data from './Mcp.json'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.MutableRefObject<HTMLOrSVGElement>>;
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'Mcp'
|
||||
|
||||
export default Icon
|
||||
@ -0,0 +1,279 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "204",
|
||||
"height": "36",
|
||||
"viewBox": "0 0 204 36",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "g",
|
||||
"attributes": {
|
||||
"opacity": "0.1"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M3.33333 18.3333C3.33333 18.5423 3.64067 18.9056 4.35365 19.2621C5.27603 19.7233 6.58451 20 8 20C9.41547 20 10.7239 19.7233 11.6463 19.2621C12.3593 18.9056 12.6667 18.5423 12.6667 18.3333V16.8858C11.5667 17.5655 9.88487 18 8 18C6.11515 18 4.43331 17.5655 3.33333 16.8858V18.3333ZM12.6667 20.2191C11.5667 20.8988 9.88487 21.3333 8 21.3333C6.11515 21.3333 4.43331 20.8988 3.33333 20.2191V21.6667C3.33333 21.8756 3.64067 22.2389 4.35365 22.5954C5.27603 23.0566 6.58451 23.3333 8 23.3333C9.41547 23.3333 10.7239 23.0566 11.6463 22.5954C12.3593 22.2389 12.6667 21.8756 12.6667 21.6667V20.2191ZM2 21.6667V15C2 13.3431 4.68629 12 8 12C11.3137 12 14 13.3431 14 15V21.6667C14 23.3235 11.3137 24.6667 8 24.6667C4.68629 24.6667 2 23.3235 2 21.6667ZM8 16.6667C9.41547 16.6667 10.7239 16.3899 11.6463 15.9288C12.3593 15.5723 12.6667 15.2089 12.6667 15C12.6667 14.7911 12.3593 14.4277 11.6463 14.0712C10.7239 13.6101 9.41547 13.3333 8 13.3333C6.58451 13.3333 5.27603 13.6101 4.35365 14.0712C3.64067 14.4277 3.33333 14.7911 3.33333 15C3.33333 15.2089 3.64067 15.5723 4.35365 15.9288C5.27603 16.3899 6.58451 16.6667 8 16.6667Z",
|
||||
"fill": "currentColor",
|
||||
"fill-opacity": "0.3"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "g",
|
||||
"attributes": {
|
||||
"opacity": "0.3"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M41.3337 11.3333C41.7019 11.3333 42.0003 11.6318 42.0003 12V15.3333C42.0003 15.7015 41.7019 16 41.3337 16H38.0003V24.6667C38.0003 25.0349 37.7019 25.3333 37.3337 25.3333H34.667C34.2988 25.3333 34.0003 25.0349 34.0003 24.6667V16H30.3337C29.9655 16 29.667 15.7015 29.667 15.3333V13.7454C29.667 13.4929 29.8097 13.262 30.0355 13.1491L33.667 11.3333H41.3337ZM38.0003 12.6667H33.9818L31.0003 14.1574V14.6667H35.3337V24H36.667V14.6667H38.0003V12.6667ZM40.667 12.6667H39.3337V14.6667H40.667V12.6667Z",
|
||||
"fill": "currentColor",
|
||||
"fill-opacity": "0.3"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "g",
|
||||
"attributes": {
|
||||
"opacity": "0.6"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M60.6667 13.3333C60.6667 11.8606 61.8606 10.6667 63.3333 10.6667C64.8061 10.6667 66 11.8606 66 13.3333H69.3333C69.7015 13.3333 70 13.6318 70 14V16.7805C70 16.9969 69.8949 17.1998 69.7183 17.3248C69.5415 17.4497 69.3152 17.4811 69.1112 17.409C68.973 17.3602 68.8237 17.3333 68.6667 17.3333C67.9303 17.3333 67.3333 17.9303 67.3333 18.6667C67.3333 19.4031 67.9303 20 68.6667 20C68.8237 20 68.973 19.9731 69.1112 19.9243C69.3152 19.8522 69.5415 19.8836 69.7183 20.0085C69.8949 20.1335 70 20.3365 70 20.5529V23.3333C70 23.7015 69.7015 24 69.3333 24H58.6667C58.2985 24 58 23.7015 58 23.3333V14C58 13.6318 58.2985 13.3333 58.6667 13.3333H60.6667ZM63.3333 12C62.597 12 62 12.5969 62 13.3333C62 13.4903 62.0269 13.6397 62.0757 13.7778C62.1478 13.9819 62.1164 14.2082 61.9915 14.3849C61.8665 14.5616 61.6635 14.6667 61.4471 14.6667H59.3333V22.6667H68.6667V21.3333C67.1939 21.3333 66 20.1394 66 18.6667C66 17.1939 67.1939 16 68.6667 16V14.6667H65.2195C65.0031 14.6667 64.8002 14.5616 64.6752 14.3849C64.5503 14.2082 64.5189 13.9819 64.591 13.7778C64.6398 13.6397 64.6667 13.4904 64.6667 13.3333C64.6667 12.5969 64.0697 12 63.3333 12Z",
|
||||
"fill": "currentColor",
|
||||
"fill-opacity": "0.3"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "rect",
|
||||
"attributes": {
|
||||
"x": "84.5",
|
||||
"y": "0.5",
|
||||
"width": "35",
|
||||
"height": "35",
|
||||
"rx": "9.5",
|
||||
"stroke": "#101828",
|
||||
"stroke-opacity": "0.04"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M96.167 16.3333H107.834V25.5H96.167V16.3333Z",
|
||||
"stroke": "currentColor",
|
||||
"stroke-width": "1.5",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M100.333 19.6667H103.666",
|
||||
"stroke": "currentColor",
|
||||
"stroke-width": "1.5",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M107.833 12.1667L105.583 15.9167",
|
||||
"stroke": "currentColor",
|
||||
"stroke-width": "1.5",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M97.4888 11.9517C97.5447 11.9238 97.5901 11.8784 97.6181 11.8224L97.9911 11.0765C98.0976 10.8634 98.4017 10.8634 98.5083 11.0765L98.8813 11.8224C98.9092 11.8784 98.9546 11.9238 99.0106 11.9517L99.7565 12.3247C99.9696 12.4313 99.9696 12.7354 99.7565 12.842L99.0106 13.2149C98.9546 13.2429 98.9092 13.2883 98.8813 13.3442L98.5083 14.0902C98.4017 14.3033 98.0976 14.3033 97.9911 14.0902L97.6181 13.3442C97.5901 13.2883 97.5447 13.2429 97.4888 13.2149L96.7429 12.842C96.5297 12.7354 96.5297 12.4313 96.7429 12.3247L97.4888 11.9517Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M101.882 10.5438C101.952 10.5089 102.009 10.4521 102.044 10.3822L102.51 9.4498C102.643 9.1834 103.023 9.1834 103.157 9.4498L103.623 10.3822C103.658 10.4521 103.714 10.5089 103.784 10.5438L104.717 11.0101C104.983 11.1432 104.983 11.5234 104.717 11.6566L103.784 12.1228C103.714 12.1578 103.658 12.2145 103.623 12.2845L103.157 13.2169C103.023 13.4833 102.643 13.4833 102.51 13.2169L102.044 12.2845C102.009 12.2145 101.952 12.1578 101.882 12.1228L100.95 11.6566C100.683 11.5234 100.683 11.1432 100.95 11.0101L101.882 10.5438Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "g",
|
||||
"attributes": {
|
||||
"opacity": "0.6",
|
||||
"clip-path": "url(#clip0_9296_51042)"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M145.809 14.7521L145.645 15.1292C145.525 15.4053 145.143 15.4053 145.022 15.1292L144.858 14.7521C144.565 14.0796 144.037 13.5443 143.379 13.2514L142.872 13.0261C142.599 12.9044 142.599 12.5059 142.872 12.3841L143.35 12.1714C144.026 11.871 144.563 11.3158 144.851 10.6206L145.02 10.213C145.138 9.92899 145.53 9.92899 145.647 10.213L145.816 10.6206C146.104 11.3158 146.641 11.871 147.317 12.1714L147.795 12.3841C148.069 12.5059 148.069 12.9044 147.795 13.0261L147.289 13.2514C146.63 13.5443 146.102 14.0796 145.809 14.7521ZM138 11.3333C140.712 11.3333 142.951 13.3571 143.289 15.9766L144.79 18.3358C144.889 18.4911 144.869 18.7231 144.64 18.8211L143.334 19.3807V21.3333C143.334 22.0697 142.737 22.6667 142 22.6667H140.668L140.667 24.6667H134.667L134.667 22.2041C134.667 21.4168 134.376 20.6725 133.837 20.0007C133.105 19.0875 132.667 17.9283 132.667 16.6667C132.667 13.7211 135.055 11.3333 138 11.3333ZM138 12.6667C135.791 12.6667 134 14.4575 134 16.6667C134 17.5899 134.312 18.4619 134.878 19.1666C135.607 20.076 136.001 21.1115 136 22.2042L136 23.3333H139.334L139.335 21.3333H142V18.5013L143.033 18.0587L142.005 16.4417L141.967 16.1475C141.711 14.1676 140.017 12.6667 138 12.6667ZM144.993 21.3286L146.103 22.0683C146.88 20.9041 147.334 19.5051 147.334 18.0001C147.334 17.5447 147.292 17.0991 147.213 16.6667L145.917 17C145.972 17.3252 146 17.6593 146 18.0001C146 19.2314 145.629 20.3761 144.993 21.3286Z",
|
||||
"fill": "currentColor",
|
||||
"fill-opacity": "0.3"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "g",
|
||||
"attributes": {
|
||||
"opacity": "0.3"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"fill-rule": "evenodd",
|
||||
"clip-rule": "evenodd",
|
||||
"d": "M167.149 16.3522L167.251 16.3822L174.224 19.0391L174.317 19.082C174.722 19.308 174.777 19.8792 174.424 20.179L174.341 20.2389L171.817 21.8171L170.239 24.3418C169.962 24.784 169.324 24.7511 169.082 24.3171L169.039 24.2246L166.382 17.2513C166.188 16.742 166.644 16.2421 167.149 16.3522ZM169.812 22.5085L170.767 20.9811L170.811 20.9186C170.858 20.859 170.916 20.8076 170.981 20.7669L172.508 19.8119L168.152 18.1524L169.812 22.5085Z",
|
||||
"fill": "currentColor",
|
||||
"fill-opacity": "0.3"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M165.212 20.3978L163.562 22.0475L162.619 21.1048L164.269 19.4551L165.212 20.3978Z",
|
||||
"fill": "currentColor",
|
||||
"fill-opacity": "0.3"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M163.666 18H161.333V16.6667H163.666V18Z",
|
||||
"fill": "currentColor",
|
||||
"fill-opacity": "0.3"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M165.212 14.2689L164.269 15.2116L162.619 13.5619L163.562 12.6192L165.212 14.2689Z",
|
||||
"fill": "currentColor",
|
||||
"fill-opacity": "0.3"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M172.047 13.5619L170.397 15.2116L169.455 14.2689L171.104 12.6192L172.047 13.5619Z",
|
||||
"fill": "currentColor",
|
||||
"fill-opacity": "0.3"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M168 13.6667H166.666V11.3333H168V13.6667Z",
|
||||
"fill": "currentColor",
|
||||
"fill-opacity": "0.3"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "g",
|
||||
"attributes": {
|
||||
"opacity": "0.1"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M202.666 23.3333V14.6667L201.333 12H190.666L189.333 14.669V23.3333C189.333 23.7015 189.631 24 190 24H202C202.368 24 202.666 23.7015 202.666 23.3333ZM190.666 16H201.333V22.6667H190.666V16ZM191.49 13.3333H200.509L201.176 14.6667H190.824L191.49 13.3333ZM198 17.3333H194V18.6667H198V17.3333Z",
|
||||
"fill": "currentColor",
|
||||
"fill-opacity": "0.3"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "defs",
|
||||
"attributes": {},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "clipPath",
|
||||
"attributes": {
|
||||
"id": "clip0_9296_51042"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "rect",
|
||||
"attributes": {
|
||||
"width": "16",
|
||||
"height": "16",
|
||||
"fill": "white",
|
||||
"transform": "translate(132 10)"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "NoToolPlaceholder"
|
||||
}
|
||||
@ -0,0 +1,20 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import * as React from 'react'
|
||||
import data from './NoToolPlaceholder.json'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.MutableRefObject<HTMLOrSVGElement>>;
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'NoToolPlaceholder'
|
||||
|
||||
export default Icon
|
||||
@ -1,5 +1,7 @@
|
||||
export { default as AnthropicText } from './AnthropicText'
|
||||
export { default as Generator } from './Generator'
|
||||
export { default as Group } from './Group'
|
||||
export { default as Mcp } from './Mcp'
|
||||
export { default as NoToolPlaceholder } from './NoToolPlaceholder'
|
||||
export { default as Openai } from './Openai'
|
||||
export { default as ReplayLine } from './ReplayLine'
|
||||
|
||||
@ -0,0 +1,62 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "16",
|
||||
"height": "16",
|
||||
"viewBox": "0 0 16 16",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M1.33325 4.66663C1.33325 3.56206 2.22869 2.66663 3.33325 2.66663H12.6666C13.7712 2.66663 14.6666 3.56206 14.6666 4.66663V8.16663C14.6666 8.53483 14.3681 8.83329 13.9999 8.83329C13.6317 8.83329 13.3333 8.53483 13.3333 8.16663V4.66663C13.3333 4.29844 13.0348 3.99996 12.6666 3.99996H3.33325C2.96507 3.99996 2.66659 4.29844 2.66659 4.66663V12C2.66659 12.3682 2.96507 12.6666 3.33325 12.6666H7.99992C8.36812 12.6666 8.66658 12.9651 8.66658 13.3333C8.66658 13.7015 8.36812 14 7.99992 14H3.33325C2.22869 14 1.33325 13.1046 1.33325 12V4.66663Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M3.66659 5.83329C3.66659 6.29353 4.03968 6.66663 4.49992 6.66663C4.96016 6.66663 5.33325 6.29353 5.33325 5.83329C5.33325 5.37305 4.96016 4.99996 4.49992 4.99996C4.03968 4.99996 3.66659 5.37305 3.66659 5.83329Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M5.99992 5.83329C5.99992 6.29353 6.37301 6.66663 6.83325 6.66663C7.29352 6.66663 7.66658 6.29353 7.66658 5.83329C7.66658 5.37305 7.29352 4.99996 6.83325 4.99996C6.37301 4.99996 5.99992 5.37305 5.99992 5.83329Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M8.33325 5.83329C8.33325 6.29353 8.70632 6.66663 9.16658 6.66663C9.62685 6.66663 9.99992 6.29353 9.99992 5.83329C9.99992 5.37305 9.62685 4.99996 9.16658 4.99996C8.70632 4.99996 8.33325 5.37305 8.33325 5.83329Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M10.5293 9.69609C10.2933 9.62349 10.0365 9.68729 9.86185 9.86189C9.68725 10.0365 9.62345 10.2934 9.69605 10.5294L11.0294 14.8627C11.1095 15.1231 11.3401 15.3086 11.6116 15.331C11.8832 15.3535 12.1411 15.2085 12.2629 14.9648L13.1635 13.1636L14.9647 12.263C15.2085 12.1411 15.3535 11.8832 15.331 11.6116C15.3085 11.3401 15.1231 11.1096 14.8627 11.0294L10.5293 9.69609Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "WindowCursor"
|
||||
}
|
||||
@ -0,0 +1,20 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import * as React from 'react'
|
||||
import data from './WindowCursor.json'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.MutableRefObject<HTMLOrSVGElement>>;
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'WindowCursor'
|
||||
|
||||
export default Icon
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue