feat: add multi app mode's server support

pull/22036/head
Novice 11 months ago
parent 642693c79b
commit 0f668be415

@ -1,5 +1,6 @@
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.mcp import api from controllers.mcp import api
@ -59,8 +60,9 @@ class MCPAppApi(Resource):
request = ClientRequest.model_validate(args) request = ClientRequest.model_validate(args)
except ValidationError as e: except ValidationError as e:
raise ValueError(f"Invalid MCP request: {str(e)}") raise ValueError(f"Invalid MCP request: {str(e)}")
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form) with Session(db.engine) as session:
return helper.compact_generate_response(mcp_server_handler.handle()) mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form, session)
return helper.compact_generate_response(mcp_server_handler.handle())
api.add_resource(MCPAppApi, "/server/<string:server_code>/mcp") api.add_resource(MCPAppApi, "/server/<string:server_code>/mcp")

@ -108,94 +108,3 @@ def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("HEAD", url, max_retries=max_retries, **kwargs) return make_request("HEAD", url, max_retries=max_retries, **kwargs)
def create_ssrf_proxy_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
Args:
headers: Optional headers to include in the client
timeout: Optional timeout configuration
Returns:
Configured httpx.Client with proxy settings
"""
if dify_config.SSRF_PROXY_ALL_URL:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
proxy=dify_config.SSRF_PROXY_ALL_URL,
)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
}
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
mounts=proxy_mounts,
)
else:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
)
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
"""Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings
to prevent SSRF attacks when connecting to external endpoints.
Args:
url: The SSE endpoint URL
max_retries: Maximum number of retry attempts
**kwargs: Additional arguments passed to the SSE connection
Returns:
EventSource object for SSE streaming
"""
from httpx_sse import connect_sse
# Extract client if provided, otherwise create one
client = kwargs.pop("client", None)
if client is None:
# Create client with SSRF proxy configuration
timeout = kwargs.pop(
"timeout",
httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
),
)
headers = kwargs.pop("headers", {})
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
client_provided = False
else:
client_provided = True
# Extract method if provided, default to GET
method = kwargs.pop("method", "GET")
try:
return connect_sse(client, method, url, **kwargs)
except Exception as e:
# If we created the client, we need to clean it up on error
if not client_provided:
client.close()
raise

@ -9,128 +9,276 @@ from urllib.parse import urljoin, urlparse
import httpx import httpx
from sseclient import SSEClient from sseclient import SSEClient
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
from core.mcp import types from core.mcp import types
from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.types import SessionMessage from core.mcp.types import SessionMessage
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_QUEUE_READ_TIMEOUT = 3 DEFAULT_QUEUE_READ_TIMEOUT = 3
# Type aliases for better readability
ReadQueue = queue.Queue[SessionMessage | Exception | None]
WriteQueue = queue.Queue[SessionMessage | Exception | None]
StatusQueue = queue.Queue[tuple[str, str | Exception]]
def remove_request_params(url: str) -> str: def remove_request_params(url: str) -> str:
"""Remove request parameters from URL, keeping only the path."""
return urljoin(url, urlparse(url).path) return urljoin(url, urlparse(url).path)
class SSETransport:
"""SSE client transport implementation."""
def __init__(
self,
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
) -> None:
"""Initialize the SSE transport.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.endpoint_url: str | None = None
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
"""Validate that the endpoint URL matches the connection origin.
Args:
endpoint_url: The endpoint URL to validate.
Returns:
True if valid, False otherwise.
"""
url_parsed = urlparse(self.url)
endpoint_parsed = urlparse(endpoint_url)
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
"""Handle an 'endpoint' SSE event.
Args:
sse_data: The SSE event data.
status_queue: Queue to put status updates.
"""
endpoint_url = urljoin(self.url, sse_data)
logger.info(f"Received endpoint URL: {endpoint_url}")
if not self._validate_endpoint_url(endpoint_url):
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
logger.error(error_msg)
status_queue.put(("error", ValueError(error_msg)))
return
status_queue.put(("ready", endpoint_url))
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
"""Handle a 'message' SSE event.
Args:
sse_data: The SSE event data.
read_queue: Queue to put parsed messages.
"""
try:
message = types.JSONRPCMessage.model_validate_json(sse_data)
logger.debug(f"Received server message: {message}")
session_message = SessionMessage(message)
read_queue.put(session_message)
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
"""Handle a single SSE event.
Args:
sse: The SSE event object.
read_queue: Queue for message events.
status_queue: Queue for status events.
"""
match sse.event:
case "endpoint":
self._handle_endpoint_event(sse.data, status_queue)
case "message":
self._handle_message_event(sse.data, read_queue)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
"""Read and process SSE events.
Args:
event_source: The SSE event source.
read_queue: Queue to put received messages.
status_queue: Queue to put status updates.
"""
try:
for sse in event_source.iter_sse():
self._handle_sse_event(sse, read_queue, status_queue)
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
read_queue.put(exc)
finally:
read_queue.put(None)
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
"""Send a single message to the server.
Args:
client: HTTP client to use.
endpoint_url: The endpoint URL to send to.
message: The message to send.
"""
response = client.post(
endpoint_url,
json=message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
"""Handle writing messages to the server.
Args:
client: HTTP client to use.
endpoint_url: The endpoint URL to send messages to.
write_queue: Queue to read messages from.
"""
try:
while True:
try:
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None:
break
if isinstance(message, Exception):
write_queue.put(message)
continue
self._send_message(client, endpoint_url, message)
except queue.Empty:
continue
except httpx.ReadError as exc:
logger.debug(f"Post writer shutting down normally: {exc}")
except Exception as exc:
logger.exception("Error writing messages")
write_queue.put(exc)
finally:
write_queue.put(None)
def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
"""Wait for the endpoint URL from the status queue.
Args:
status_queue: Queue to read status from.
Returns:
The endpoint URL.
Raises:
ValueError: If endpoint URL is not received or there's an error.
"""
try:
status, endpoint_url_or_error = status_queue.get(timeout=1)
except queue.Empty:
raise ValueError("failed to get endpoint URL")
if status != "ready":
raise ValueError("failed to get endpoint URL")
if status == "error" and isinstance(endpoint_url_or_error, Exception):
raise endpoint_url_or_error
return cast(str, endpoint_url_or_error)
def connect(
self,
executor: ThreadPoolExecutor,
client: httpx.Client,
event_source,
) -> tuple[ReadQueue, WriteQueue]:
"""Establish connection and start worker threads.
Args:
executor: Thread pool executor.
client: HTTP client.
event_source: SSE event source.
Returns:
Tuple of (read_queue, write_queue).
"""
read_queue: ReadQueue = queue.Queue()
write_queue: WriteQueue = queue.Queue()
status_queue: StatusQueue = queue.Queue()
# Start SSE reader thread
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
# Wait for endpoint URL
endpoint_url = self._wait_for_endpoint(status_queue)
self.endpoint_url = endpoint_url
# Start post writer thread
executor.submit(self.post_writer, client, endpoint_url, write_queue)
return read_queue, write_queue
@contextmanager @contextmanager
def sse_client( def sse_client(
url: str, url: str,
headers: dict[str, Any] | None = None, headers: dict[str, Any] | None = None,
timeout: float = 5.0, timeout: float = 5.0,
sse_read_timeout: float = 5 * 60, sse_read_timeout: float = 5 * 60,
) -> Generator[tuple[queue.Queue, queue.Queue], None, None]: ) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
""" """
Client transport for SSE. Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new `sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`. event before disconnecting. All other HTTP operations are controlled by `timeout`.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
Yields:
Tuple of (read_queue, write_queue) for message communication.
""" """
if headers is None: transport = SSETransport(url, headers, timeout, sse_read_timeout)
headers = {}
read_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue() read_queue: ReadQueue | None = None
write_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue() write_queue: WriteQueue | None = None
status_queue: queue.Queue[tuple[str, str | Exception]] = queue.Queue()
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
try: try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
with create_ssrf_proxy_mcp_http_client(headers=headers) as client:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect( with ssrf_proxy_sse_connect(
url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source: ) as event_source:
event_source.response.raise_for_status() event_source.response.raise_for_status()
def sse_reader(status_queue: queue.Queue): read_queue, write_queue = transport.connect(executor, client, event_source)
try:
for sse in event_source.iter_sse():
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = (
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg)
status_queue.put(("error", ValueError(error_msg)))
status_queue.put(("ready", endpoint_url))
case "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
continue
session_message = SessionMessage(message)
read_queue.put(session_message)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
read_queue.put(exc)
finally:
read_queue.put(None)
def post_writer(endpoint_url: str):
try:
while True:
try:
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None:
break
if isinstance(message, Exception):
write_queue.put(message)
continue
response = client.post(
endpoint_url,
json=message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
except queue.Empty:
continue
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
logger.exception("Error writing messages")
write_queue.put(exc)
finally:
write_queue.put(None)
executor.submit(sse_reader, status_queue)
try:
status, endpoint_url_or_error = status_queue.get(timeout=1)
except queue.Empty:
raise ValueError("failed to get endpoint URL")
if status != "ready":
raise ValueError("failed to get endpoint URL")
if status == "error" and isinstance(endpoint_url_or_error, Exception):
raise endpoint_url_or_error
endpoint_url = cast(str, endpoint_url_or_error)
executor.submit(post_writer, endpoint_url)
yield read_queue, write_queue yield read_queue, write_queue
@ -142,8 +290,11 @@ def sse_client(
logger.exception("Error connecting to SSE endpoint") logger.exception("Error connecting to SSE endpoint")
raise exc raise exc
finally: finally:
read_queue.put(None) # Clean up queues
write_queue.put(None) if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None: def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:

@ -18,7 +18,6 @@ from typing import Any, cast
import httpx import httpx
from httpx_sse import EventSource, ServerSentEvent from httpx_sse import EventSource, ServerSentEvent
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
from core.mcp.types import ( from core.mcp.types import (
ClientMessageMetadata, ClientMessageMetadata,
ErrorData, ErrorData,
@ -30,6 +29,7 @@ from core.mcp.types import (
RequestId, RequestId,
SessionMessage, SessionMessage,
) )
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

@ -2,6 +2,8 @@ import json
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, cast from typing import Any, cast
from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from controllers.web.passport import generate_session_id from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.app_config.entities import VariableEntity, VariableEntityType
@ -9,8 +11,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.mcp import types from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db from models.model import App, AppMCPServer, AppMode, EndUser
from models.model import App, AppMCPServer, EndUser
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
""" """
@ -19,12 +20,13 @@ Apply to MCP HTTP streamable server with stateless http
class MCPServerReuqestHandler: class MCPServerReuqestHandler:
def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity]): def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity], session: Session):
self.app = app self.app = app
self.request = request self.request = request
if not self.app.mcp_server: if not self.app.mcp_server:
raise ValueError("MCP server not found") raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = self.app.mcp_server self.mcp_server: AppMCPServer = self.app.mcp_server
self._session = session
self.end_user = self.retrieve_end_user() self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form self.user_input_form = user_input_form
@ -35,19 +37,19 @@ class MCPServerReuqestHandler:
@property @property
def parameter_schema(self): def parameter_schema(self):
parameters, required = self._convert_input_form_to_parameters(self.user_input_form) parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
return {
"type": "object",
"properties": parameters,
"required": required,
}
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
"query": {"type": "string", "description": "User Input/Question content"}, "query": {"type": "string", "description": "User Input/Question content"},
"inputs": { **parameters,
"type": "object",
"description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501
"default": {},
"properties": parameters,
"required": required,
},
}, },
"required": ["query", "inputs"], "required": ["query", *required],
} }
@property @property
@ -110,9 +112,8 @@ class MCPServerReuqestHandler:
session_id=generate_session_id(), session_id=generate_session_id(),
external_user_id=self.mcp_server.id, external_user_id=self.mcp_server.id,
) )
db.session.add(end_user) self._session.add(end_user)
db.session.commit() self._session.commit()
return types.InitializeResult( return types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION, protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self.capabilities, capabilities=self.capabilities,
@ -140,14 +141,31 @@ class MCPServerReuqestHandler:
args = request.params.arguments args = request.params.arguments
if not args: if not args:
raise ValueError("No arguments provided") raise ValueError("No arguments provided")
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
args = {"inputs": args}
else:
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
response = AppGenerateService.generate(self.app, self.end_user, args, InvokeFrom.MCP_SERVER, streaming=False) response = AppGenerateService.generate(self.app, self.end_user, args, InvokeFrom.MCP_SERVER, streaming=False)
if isinstance(response, Mapping): if isinstance(response, Mapping):
return types.CallToolResult(content=[types.TextContent(text=response["answer"], type="text")]) answer = ""
if self.app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
}:
answer = response["answer"]
elif self.app.mode in {AppMode.WORKFLOW.value}:
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode")
# Not support image yet
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
return None return None
def retrieve_end_user(self): def retrieve_end_user(self):
return ( return (
db.session.query(EndUser) self._session.query(EndUser)
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first() .first()
) )

@ -177,6 +177,15 @@ class BaseSession(
self._receiver_future = self._executor.submit(self._receive_loop) self._receiver_future = self._executor.submit(self._receive_loop)
return self return self
def check_receiver_status(self) -> None:
if self._receiver_future.done():
try:
# 如果Future已完成获取结果如果有异常会在这里抛出
self._receiver_future.result()
except Exception as e:
# 重新抛出线程中的异常
raise e
def __exit__( def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None: ) -> None:
@ -199,6 +208,7 @@ class BaseSession(
Do not use this method to emit notifications! Use send_notification() Do not use this method to emit notifications! Use send_notification()
instead. instead.
""" """
self.check_receiver_status()
request_id = self._request_id request_id = self._request_id
self._request_id = request_id + 1 self._request_id = request_id + 1
@ -224,6 +234,8 @@ class BaseSession(
response_or_error = response_queue.get(timeout=timeout) response_or_error = response_queue.get(timeout=timeout)
break break
except queue.Empty: except queue.Empty:
# 在等待响应的过程中也检查接收线程状态
self.check_receiver_status()
continue continue
if response_or_error is None: if response_or_error is None:
@ -257,6 +269,8 @@ class BaseSession(
Emits a notification, which is a one-way message that does not expect Emits a notification, which is a one-way message that does not expect
a response. a response.
""" """
self.check_receiver_status()
# Some transport implementations may need to set the related_request_id # Some transport implementations may need to set the related_request_id
# to attribute to the notifications to the request that triggered them. # to attribute to the notifications to the request that triggered them.
jsonrpc_notification = JSONRPCNotification( jsonrpc_notification = JSONRPCNotification(
@ -353,6 +367,7 @@ class BaseSession(
continue continue
except Exception as e: except Exception as e:
logging.exception("Error in message processing loop") logging.exception("Error in message processing loop")
raise
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
""" """

@ -0,0 +1,112 @@
import httpx
from configs import dify_config
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True
try:
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower()
if http_request_node_ssl_verify_lower == "true":
HTTP_REQUEST_NODE_SSL_VERIFY = True
elif http_request_node_ssl_verify_lower == "false":
HTTP_REQUEST_NODE_SSL_VERIFY = False
else:
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
except NameError:
HTTP_REQUEST_NODE_SSL_VERIFY = True
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
def create_ssrf_proxy_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
Args:
headers: Optional headers to include in the client
timeout: Optional timeout configuration
Returns:
Configured httpx.Client with proxy settings
"""
if dify_config.SSRF_PROXY_ALL_URL:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
proxy=dify_config.SSRF_PROXY_ALL_URL,
)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
}
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
mounts=proxy_mounts,
)
else:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
)
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
"""Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings
to prevent SSRF attacks when connecting to external endpoints.
Args:
url: The SSE endpoint URL
max_retries: Maximum number of retry attempts
**kwargs: Additional arguments passed to the SSE connection
Returns:
EventSource object for SSE streaming
"""
from httpx_sse import connect_sse
# Extract client if provided, otherwise create one
client = kwargs.pop("client", None)
if client is None:
# Create client with SSRF proxy configuration
timeout = kwargs.pop(
"timeout",
httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
),
)
headers = kwargs.pop("headers", {})
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
client_provided = False
else:
client_provided = True
# Extract method if provided, default to GET
method = kwargs.pop("method", "GET")
try:
return connect_sse(client, method, url, **kwargs)
except Exception:
# If we created the client, we need to clean it up on error
if not client_provided:
client.close()
raise
Loading…
Cancel
Save