|
|
|
@ -17,8 +17,9 @@ from datetime import timedelta
|
|
|
|
from typing import Any, cast
|
|
|
|
from typing import Any, cast
|
|
|
|
|
|
|
|
|
|
|
|
import httpx
|
|
|
|
import httpx
|
|
|
|
from httpx_sse import EventSource, ServerSentEvent, connect_sse
|
|
|
|
from httpx_sse import EventSource, ServerSentEvent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
|
|
|
from core.mcp.types import (
|
|
|
|
from core.mcp.types import (
|
|
|
|
ClientMessageMetadata,
|
|
|
|
ClientMessageMetadata,
|
|
|
|
ErrorData,
|
|
|
|
ErrorData,
|
|
|
|
@ -30,7 +31,6 @@ from core.mcp.types import (
|
|
|
|
RequestId,
|
|
|
|
RequestId,
|
|
|
|
SessionMessage,
|
|
|
|
SessionMessage,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from core.mcp.utils import create_mcp_http_client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
@ -50,6 +50,8 @@ ACCEPT = "Accept"
|
|
|
|
JSON = "application/json"
|
|
|
|
JSON = "application/json"
|
|
|
|
SSE = "text/event-stream"
|
|
|
|
SSE = "text/event-stream"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamableHTTPError(Exception):
|
|
|
|
class StreamableHTTPError(Exception):
|
|
|
|
"""Base exception for StreamableHTTP transport errors."""
|
|
|
|
"""Base exception for StreamableHTTP transport errors."""
|
|
|
|
@ -184,12 +186,13 @@ class StreamableHTTPTransport:
|
|
|
|
|
|
|
|
|
|
|
|
headers = self._update_headers_with_session(self.request_headers)
|
|
|
|
headers = self._update_headers_with_session(self.request_headers)
|
|
|
|
|
|
|
|
|
|
|
|
with connect_sse(
|
|
|
|
with ssrf_proxy_sse_connect(
|
|
|
|
client,
|
|
|
|
|
|
|
|
"GET",
|
|
|
|
|
|
|
|
self.url,
|
|
|
|
self.url,
|
|
|
|
|
|
|
|
2,
|
|
|
|
headers=headers,
|
|
|
|
headers=headers,
|
|
|
|
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
|
|
|
|
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
|
|
|
|
|
|
|
|
client=client,
|
|
|
|
|
|
|
|
method="GET",
|
|
|
|
) as event_source:
|
|
|
|
) as event_source:
|
|
|
|
event_source.response.raise_for_status()
|
|
|
|
event_source.response.raise_for_status()
|
|
|
|
logger.debug("GET SSE connection established")
|
|
|
|
logger.debug("GET SSE connection established")
|
|
|
|
@ -215,12 +218,13 @@ class StreamableHTTPTransport:
|
|
|
|
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
|
|
|
|
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
|
|
|
|
original_request_id = ctx.session_message.message.root.id
|
|
|
|
original_request_id = ctx.session_message.message.root.id
|
|
|
|
|
|
|
|
|
|
|
|
with connect_sse(
|
|
|
|
with ssrf_proxy_sse_connect(
|
|
|
|
ctx.client,
|
|
|
|
|
|
|
|
"GET",
|
|
|
|
|
|
|
|
self.url,
|
|
|
|
self.url,
|
|
|
|
|
|
|
|
2,
|
|
|
|
headers=headers,
|
|
|
|
headers=headers,
|
|
|
|
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
|
|
|
|
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
|
|
|
|
|
|
|
|
client=ctx.client,
|
|
|
|
|
|
|
|
method="GET",
|
|
|
|
) as event_source:
|
|
|
|
) as event_source:
|
|
|
|
event_source.response.raise_for_status()
|
|
|
|
event_source.response.raise_for_status()
|
|
|
|
logger.debug("Resumption GET SSE connection established")
|
|
|
|
logger.debug("Resumption GET SSE connection established")
|
|
|
|
@ -304,7 +308,6 @@ class StreamableHTTPTransport:
|
|
|
|
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
|
|
|
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.exception("Error reading SSE stream:")
|
|
|
|
|
|
|
|
ctx.server_to_client_queue.put(e)
|
|
|
|
ctx.server_to_client_queue.put(e)
|
|
|
|
|
|
|
|
|
|
|
|
def _handle_unexpected_content_type(
|
|
|
|
def _handle_unexpected_content_type(
|
|
|
|
@ -346,7 +349,7 @@ class StreamableHTTPTransport:
|
|
|
|
while not self.stop_event.is_set():
|
|
|
|
while not self.stop_event.is_set():
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
# Read message from client queue with timeout to check stop_event periodically
|
|
|
|
# Read message from client queue with timeout to check stop_event periodically
|
|
|
|
session_message = client_to_server_queue.get(timeout=5)
|
|
|
|
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
|
|
|
if session_message is None:
|
|
|
|
if session_message is None:
|
|
|
|
break
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
@ -444,7 +447,7 @@ def streamablehttp_client(
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
|
|
|
|
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
|
|
|
|
|
|
|
|
|
|
|
|
with create_mcp_http_client(
|
|
|
|
with create_ssrf_proxy_mcp_http_client(
|
|
|
|
headers=transport.request_headers,
|
|
|
|
headers=transport.request_headers,
|
|
|
|
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
|
|
|
|
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
|
|
|
|
) as client:
|
|
|
|
) as client:
|
|
|
|
|