From 0f668be41551d200ff9b7f7b1265d385b85f134d Mon Sep 17 00:00:00 2001 From: Novice Date: Thu, 12 Jun 2025 16:22:11 +0800 Subject: [PATCH] feat: add multi app mode's server support --- api/controllers/mcp/mcp.py | 6 +- api/core/helper/ssrf_proxy.py | 91 ------- api/core/mcp/client/sse_client.py | 331 +++++++++++++++++------ api/core/mcp/client/streamable_client.py | 2 +- api/core/mcp/server/handler.py | 50 ++-- api/core/mcp/session/base_session.py | 15 + api/core/mcp/utils.py | 112 ++++++++ 7 files changed, 407 insertions(+), 200 deletions(-) create mode 100644 api/core/mcp/utils.py diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 0e3175811b..f990f08680 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -1,5 +1,6 @@ from flask_restful import Resource, reqparse from pydantic import ValidationError +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from controllers.mcp import api @@ -59,8 +60,9 @@ class MCPAppApi(Resource): request = ClientRequest.model_validate(args) except ValidationError as e: raise ValueError(f"Invalid MCP request: {str(e)}") - mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form) - return helper.compact_generate_response(mcp_server_handler.handle()) + with Session(db.engine) as session: + mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form, session) + return helper.compact_generate_response(mcp_server_handler.handle()) api.add_resource(MCPAppApi, "/server//mcp") diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 9180922abb..11f245812e 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -108,94 +108,3 @@ def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): return make_request("HEAD", url, max_retries=max_retries, **kwargs) - - -def create_ssrf_proxy_mcp_http_client( - headers: dict[str, str] | None = None, - timeout: httpx.Timeout | None = None, -) -> httpx.Client: - """Create an HTTPX client with SSRF proxy configuration for MCP connections. - - Args: - headers: Optional headers to include in the client - timeout: Optional timeout configuration - - Returns: - Configured httpx.Client with proxy settings - """ - if dify_config.SSRF_PROXY_ALL_URL: - return httpx.Client( - verify=HTTP_REQUEST_NODE_SSL_VERIFY, - headers=headers or {}, - timeout=timeout, - follow_redirects=True, - proxy=dify_config.SSRF_PROXY_ALL_URL, - ) - elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: - proxy_mounts = { - "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY), - "https://": httpx.HTTPTransport( - proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY - ), - } - return httpx.Client( - verify=HTTP_REQUEST_NODE_SSL_VERIFY, - headers=headers or {}, - timeout=timeout, - follow_redirects=True, - mounts=proxy_mounts, - ) - else: - return httpx.Client( - verify=HTTP_REQUEST_NODE_SSL_VERIFY, - headers=headers or {}, - timeout=timeout, - follow_redirects=True, - ) - - -def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - """Connect to SSE endpoint with SSRF proxy protection. - - This function creates an SSE connection using the configured proxy settings - to prevent SSRF attacks when connecting to external endpoints. - - Args: - url: The SSE endpoint URL - max_retries: Maximum number of retry attempts - **kwargs: Additional arguments passed to the SSE connection - - Returns: - EventSource object for SSE streaming - """ - from httpx_sse import connect_sse - - # Extract client if provided, otherwise create one - client = kwargs.pop("client", None) - if client is None: - # Create client with SSRF proxy configuration - timeout = kwargs.pop( - "timeout", - httpx.Timeout( - timeout=dify_config.SSRF_DEFAULT_TIME_OUT, - connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT, - read=dify_config.SSRF_DEFAULT_READ_TIME_OUT, - write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, - ), - ) - headers = kwargs.pop("headers", {}) - client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout) - client_provided = False - else: - client_provided = True - - # Extract method if provided, default to GET - method = kwargs.pop("method", "GET") - - try: - return connect_sse(client, method, url, **kwargs) - except Exception as e: - # If we created the client, we need to clean it up on error - if not client_provided: - client.close() - raise diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 86cb58dda8..35744d6a8f 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -9,128 +9,276 @@ from urllib.parse import urljoin, urlparse import httpx from sseclient import SSEClient -from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect from core.mcp import types from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.types import SessionMessage +from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect logger = logging.getLogger(__name__) DEFAULT_QUEUE_READ_TIMEOUT = 3 +# Type aliases for better readability +ReadQueue = queue.Queue[SessionMessage | Exception | None] +WriteQueue = queue.Queue[SessionMessage | Exception | None] +StatusQueue = queue.Queue[tuple[str, str | Exception]] + def remove_request_params(url: str) -> str: + """Remove request parameters from URL, keeping only the path.""" return urljoin(url, urlparse(url).path) +class SSETransport: + """SSE client transport implementation.""" + + def __init__( + self, + url: str, + headers: dict[str, Any] | None = None, + timeout: float = 5.0, + sse_read_timeout: float = 5 * 60, + ) -> None: + """Initialize the SSE transport. + + Args: + url: The SSE endpoint URL. + headers: Optional headers to include in requests. + timeout: HTTP timeout for regular operations. + sse_read_timeout: Timeout for SSE read operations. + """ + self.url = url + self.headers = headers or {} + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout + self.endpoint_url: str | None = None + + def _validate_endpoint_url(self, endpoint_url: str) -> bool: + """Validate that the endpoint URL matches the connection origin. + + Args: + endpoint_url: The endpoint URL to validate. + + Returns: + True if valid, False otherwise. + """ + url_parsed = urlparse(self.url) + endpoint_parsed = urlparse(endpoint_url) + + return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme + + def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None: + """Handle an 'endpoint' SSE event. + + Args: + sse_data: The SSE event data. + status_queue: Queue to put status updates. + """ + endpoint_url = urljoin(self.url, sse_data) + logger.info(f"Received endpoint URL: {endpoint_url}") + + if not self._validate_endpoint_url(endpoint_url): + error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}" + logger.error(error_msg) + status_queue.put(("error", ValueError(error_msg))) + return + + status_queue.put(("ready", endpoint_url)) + + def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None: + """Handle a 'message' SSE event. + + Args: + sse_data: The SSE event data. + read_queue: Queue to put parsed messages. + """ + try: + message = types.JSONRPCMessage.model_validate_json(sse_data) + logger.debug(f"Received server message: {message}") + session_message = SessionMessage(message) + read_queue.put(session_message) + except Exception as exc: + logger.exception("Error parsing server message") + read_queue.put(exc) + + def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None: + """Handle a single SSE event. + + Args: + sse: The SSE event object. + read_queue: Queue for message events. + status_queue: Queue for status events. + """ + match sse.event: + case "endpoint": + self._handle_endpoint_event(sse.data, status_queue) + case "message": + self._handle_message_event(sse.data, read_queue) + case _: + logger.warning(f"Unknown SSE event: {sse.event}") + + def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None: + """Read and process SSE events. + + Args: + event_source: The SSE event source. + read_queue: Queue to put received messages. + status_queue: Queue to put status updates. + """ + try: + for sse in event_source.iter_sse(): + self._handle_sse_event(sse, read_queue, status_queue) + except httpx.ReadError as exc: + logger.debug(f"SSE reader shutting down normally: {exc}") + except Exception as exc: + read_queue.put(exc) + finally: + read_queue.put(None) + + def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None: + """Send a single message to the server. + + Args: + client: HTTP client to use. + endpoint_url: The endpoint URL to send to. + message: The message to send. + """ + response = client.post( + endpoint_url, + json=message.message.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + ), + ) + response.raise_for_status() + logger.debug(f"Client message sent successfully: {response.status_code}") + + def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None: + """Handle writing messages to the server. + + Args: + client: HTTP client to use. + endpoint_url: The endpoint URL to send messages to. + write_queue: Queue to read messages from. + """ + try: + while True: + try: + message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) + if message is None: + break + if isinstance(message, Exception): + write_queue.put(message) + continue + + self._send_message(client, endpoint_url, message) + + except queue.Empty: + continue + except httpx.ReadError as exc: + logger.debug(f"Post writer shutting down normally: {exc}") + except Exception as exc: + logger.exception("Error writing messages") + write_queue.put(exc) + finally: + write_queue.put(None) + + def _wait_for_endpoint(self, status_queue: StatusQueue) -> str: + """Wait for the endpoint URL from the status queue. + + Args: + status_queue: Queue to read status from. + + Returns: + The endpoint URL. + + Raises: + ValueError: If endpoint URL is not received or there's an error. + """ + try: + status, endpoint_url_or_error = status_queue.get(timeout=1) + except queue.Empty: + raise ValueError("failed to get endpoint URL") + + if status != "ready": + raise ValueError("failed to get endpoint URL") + + if status == "error" and isinstance(endpoint_url_or_error, Exception): + raise endpoint_url_or_error + + return cast(str, endpoint_url_or_error) + + def connect( + self, + executor: ThreadPoolExecutor, + client: httpx.Client, + event_source, + ) -> tuple[ReadQueue, WriteQueue]: + """Establish connection and start worker threads. + + Args: + executor: Thread pool executor. + client: HTTP client. + event_source: SSE event source. + + Returns: + Tuple of (read_queue, write_queue). + """ + read_queue: ReadQueue = queue.Queue() + write_queue: WriteQueue = queue.Queue() + status_queue: StatusQueue = queue.Queue() + + # Start SSE reader thread + executor.submit(self.sse_reader, event_source, read_queue, status_queue) + + # Wait for endpoint URL + endpoint_url = self._wait_for_endpoint(status_queue) + self.endpoint_url = endpoint_url + + # Start post writer thread + executor.submit(self.post_writer, client, endpoint_url, write_queue) + + return read_queue, write_queue + + @contextmanager def sse_client( url: str, headers: dict[str, Any] | None = None, timeout: float = 5.0, sse_read_timeout: float = 5 * 60, -) -> Generator[tuple[queue.Queue, queue.Queue], None, None]: +) -> Generator[tuple[ReadQueue, WriteQueue], None, None]: """ Client transport for SSE. `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + + Args: + url: The SSE endpoint URL. + headers: Optional headers to include in requests. + timeout: HTTP timeout for regular operations. + sse_read_timeout: Timeout for SSE read operations. + + Yields: + Tuple of (read_queue, write_queue) for message communication. """ - if headers is None: - headers = {} + transport = SSETransport(url, headers, timeout, sse_read_timeout) - read_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue() - write_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue() - status_queue: queue.Queue[tuple[str, str | Exception]] = queue.Queue() + read_queue: ReadQueue | None = None + write_queue: WriteQueue | None = None with ThreadPoolExecutor() as executor: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - with create_ssrf_proxy_mcp_http_client(headers=headers) as client: + + with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client: with ssrf_proxy_sse_connect( url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client ) as event_source: event_source.response.raise_for_status() - def sse_reader(status_queue: queue.Queue): - try: - for sse in event_source.iter_sse(): - match sse.event: - case "endpoint": - endpoint_url = urljoin(url, sse.data) - logger.info(f"Received endpoint URL: {endpoint_url}") - url_parsed = urlparse(url) - endpoint_parsed = urlparse(endpoint_url) - - if ( - url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme != endpoint_parsed.scheme - ): - error_msg = ( - f"Endpoint origin does not match connection origin: {endpoint_url}" - ) - logger.error(error_msg) - status_queue.put(("error", ValueError(error_msg))) - status_queue.put(("ready", endpoint_url)) - case "message": - try: - message = types.JSONRPCMessage.model_validate_json(sse.data) - logger.debug(f"Received server message: {message}") - except Exception as exc: - logger.exception("Error parsing server message") - read_queue.put(exc) - continue - session_message = SessionMessage(message) - read_queue.put(session_message) - case _: - logger.warning(f"Unknown SSE event: {sse.event}") - except httpx.ReadError as exc: - logger.debug(f"SSE reader shutting down normally: {exc}") - except Exception as exc: - read_queue.put(exc) - finally: - read_queue.put(None) - - def post_writer(endpoint_url: str): - try: - while True: - try: - message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) - if message is None: - break - if isinstance(message, Exception): - write_queue.put(message) - continue - response = client.post( - endpoint_url, - json=message.message.model_dump( - by_alias=True, - mode="json", - exclude_none=True, - ), - ) - response.raise_for_status() - logger.debug(f"Client message sent successfully: {response.status_code}") - except queue.Empty: - continue - except httpx.ReadError as exc: - logger.debug(f"SSE reader shutting down normally: {exc}") - except Exception as exc: - logger.exception("Error writing messages") - write_queue.put(exc) - finally: - write_queue.put(None) - - executor.submit(sse_reader, status_queue) - try: - status, endpoint_url_or_error = status_queue.get(timeout=1) - except queue.Empty: - raise ValueError("failed to get endpoint URL") - if status != "ready": - raise ValueError("failed to get endpoint URL") - if status == "error" and isinstance(endpoint_url_or_error, Exception): - raise endpoint_url_or_error - endpoint_url = cast(str, endpoint_url_or_error) - executor.submit(post_writer, endpoint_url) + read_queue, write_queue = transport.connect(executor, client, event_source) yield read_queue, write_queue @@ -142,8 +290,11 @@ def sse_client( logger.exception("Error connecting to SSE endpoint") raise exc finally: - read_queue.put(None) - write_queue.put(None) + # Clean up queues + if read_queue: + read_queue.put(None) + if write_queue: + write_queue.put(None) def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None: diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index 649eb32abb..bdbba6922f 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -18,7 +18,6 @@ from typing import Any, cast import httpx from httpx_sse import EventSource, ServerSentEvent -from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect from core.mcp.types import ( ClientMessageMetadata, ErrorData, @@ -30,6 +29,7 @@ from core.mcp.types import ( RequestId, SessionMessage, ) +from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect logger = logging.getLogger(__name__) diff --git a/api/core/mcp/server/handler.py b/api/core/mcp/server/handler.py index c382c08e28..8ee29055c1 100644 --- a/api/core/mcp/server/handler.py +++ b/api/core/mcp/server/handler.py @@ -2,6 +2,8 @@ import json from collections.abc import Mapping from typing import Any, cast +from sqlalchemy.orm import Session + from configs import dify_config from controllers.web.passport import generate_session_id from core.app.app_config.entities import VariableEntity, VariableEntityType @@ -9,8 +11,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.mcp import types from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND from core.model_runtime.utils.encoders import jsonable_encoder -from extensions.ext_database import db -from models.model import App, AppMCPServer, EndUser +from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService """ @@ -19,12 +20,13 @@ Apply to MCP HTTP streamable server with stateless http class MCPServerReuqestHandler: - def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity]): + def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity], session: Session): self.app = app self.request = request if not self.app.mcp_server: raise ValueError("MCP server not found") self.mcp_server: AppMCPServer = self.app.mcp_server + self._session = session self.end_user = self.retrieve_end_user() self.user_input_form = user_input_form @@ -35,19 +37,19 @@ class MCPServerReuqestHandler: @property def parameter_schema(self): parameters, required = self._convert_input_form_to_parameters(self.user_input_form) + if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}: + return { + "type": "object", + "properties": parameters, + "required": required, + } return { "type": "object", "properties": { "query": {"type": "string", "description": "User Input/Question content"}, - "inputs": { - "type": "object", - "description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501 - "default": {}, - "properties": parameters, - "required": required, - }, + **parameters, }, - "required": ["query", "inputs"], + "required": ["query", *required], } @property @@ -110,9 +112,8 @@ class MCPServerReuqestHandler: session_id=generate_session_id(), external_user_id=self.mcp_server.id, ) - db.session.add(end_user) - db.session.commit() - + self._session.add(end_user) + self._session.commit() return types.InitializeResult( protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=self.capabilities, @@ -140,14 +141,31 @@ class MCPServerReuqestHandler: args = request.params.arguments if not args: raise ValueError("No arguments provided") + if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}: + args = {"inputs": args} + else: + args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}} response = AppGenerateService.generate(self.app, self.end_user, args, InvokeFrom.MCP_SERVER, streaming=False) if isinstance(response, Mapping): - return types.CallToolResult(content=[types.TextContent(text=response["answer"], type="text")]) + answer = "" + if self.app.mode in { + AppMode.ADVANCED_CHAT.value, + AppMode.COMPLETION.value, + AppMode.CHAT.value, + AppMode.AGENT_CHAT.value, + }: + answer = response["answer"] + elif self.app.mode in {AppMode.WORKFLOW.value}: + answer = json.dumps(response["data"]["outputs"], ensure_ascii=False) + else: + raise ValueError("Invalid app mode") + # Not support image yet + return types.CallToolResult(content=[types.TextContent(text=answer, type="text")]) return None def retrieve_end_user(self): return ( - db.session.query(EndUser) + self._session.query(EndUser) .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") .first() ) diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index e78ce4f34d..7aa780d507 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -177,6 +177,15 @@ class BaseSession( self._receiver_future = self._executor.submit(self._receive_loop) return self + def check_receiver_status(self) -> None: + if self._receiver_future.done(): + try: + # 如果Future已完成,获取结果(如果有异常会在这里抛出) + self._receiver_future.result() + except Exception as e: + # 重新抛出线程中的异常 + raise e + def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: @@ -199,6 +208,7 @@ class BaseSession( Do not use this method to emit notifications! Use send_notification() instead. """ + self.check_receiver_status() request_id = self._request_id self._request_id = request_id + 1 @@ -224,6 +234,8 @@ class BaseSession( response_or_error = response_queue.get(timeout=timeout) break except queue.Empty: + # 在等待响应的过程中也检查接收线程状态 + self.check_receiver_status() continue if response_or_error is None: @@ -257,6 +269,8 @@ class BaseSession( Emits a notification, which is a one-way message that does not expect a response. """ + self.check_receiver_status() + # Some transport implementations may need to set the related_request_id # to attribute to the notifications to the request that triggered them. jsonrpc_notification = JSONRPCNotification( @@ -353,6 +367,7 @@ class BaseSession( continue except Exception as e: logging.exception("Error in message processing loop") + raise def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: """ diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py new file mode 100644 index 0000000000..140665edf8 --- /dev/null +++ b/api/core/mcp/utils.py @@ -0,0 +1,112 @@ +import httpx + +from configs import dify_config + +SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES + +HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True +try: + HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY + http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower() + if http_request_node_ssl_verify_lower == "true": + HTTP_REQUEST_NODE_SSL_VERIFY = True + elif http_request_node_ssl_verify_lower == "false": + HTTP_REQUEST_NODE_SSL_VERIFY = False + else: + raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'") +except NameError: + HTTP_REQUEST_NODE_SSL_VERIFY = True + +BACKOFF_FACTOR = 0.5 +STATUS_FORCELIST = [429, 500, 502, 503, 504] + + +def create_ssrf_proxy_mcp_http_client( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, +) -> httpx.Client: + """Create an HTTPX client with SSRF proxy configuration for MCP connections. + + Args: + headers: Optional headers to include in the client + timeout: Optional timeout configuration + + Returns: + Configured httpx.Client with proxy settings + """ + if dify_config.SSRF_PROXY_ALL_URL: + return httpx.Client( + verify=HTTP_REQUEST_NODE_SSL_VERIFY, + headers=headers or {}, + timeout=timeout, + follow_redirects=True, + proxy=dify_config.SSRF_PROXY_ALL_URL, + ) + elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: + proxy_mounts = { + "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY), + "https://": httpx.HTTPTransport( + proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY + ), + } + return httpx.Client( + verify=HTTP_REQUEST_NODE_SSL_VERIFY, + headers=headers or {}, + timeout=timeout, + follow_redirects=True, + mounts=proxy_mounts, + ) + else: + return httpx.Client( + verify=HTTP_REQUEST_NODE_SSL_VERIFY, + headers=headers or {}, + timeout=timeout, + follow_redirects=True, + ) + + +def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + """Connect to SSE endpoint with SSRF proxy protection. + + This function creates an SSE connection using the configured proxy settings + to prevent SSRF attacks when connecting to external endpoints. + + Args: + url: The SSE endpoint URL + max_retries: Maximum number of retry attempts + **kwargs: Additional arguments passed to the SSE connection + + Returns: + EventSource object for SSE streaming + """ + from httpx_sse import connect_sse + + # Extract client if provided, otherwise create one + client = kwargs.pop("client", None) + if client is None: + # Create client with SSRF proxy configuration + timeout = kwargs.pop( + "timeout", + httpx.Timeout( + timeout=dify_config.SSRF_DEFAULT_TIME_OUT, + connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT, + read=dify_config.SSRF_DEFAULT_READ_TIME_OUT, + write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, + ), + ) + headers = kwargs.pop("headers", {}) + client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout) + client_provided = False + else: + client_provided = True + + # Extract method if provided, default to GET + method = kwargs.pop("method", "GET") + + try: + return connect_sse(client, method, url, **kwargs) + except Exception: + # If we created the client, we need to clean it up on error + if not client_provided: + client.close() + raise