|
|
|
|
@ -3,7 +3,7 @@ import queue
|
|
|
|
|
from collections.abc import Generator
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from typing import Any, cast
|
|
|
|
|
from typing import Any, TypeAlias, final
|
|
|
|
|
from urllib.parse import urljoin, urlparse
|
|
|
|
|
|
|
|
|
|
import httpx
|
|
|
|
|
@ -18,10 +18,23 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@final
|
|
|
|
|
class _StatusReady:
|
|
|
|
|
def __init__(self, endpoint_url: str):
|
|
|
|
|
self._endpoint_url = endpoint_url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@final
|
|
|
|
|
class _StatusError:
|
|
|
|
|
def __init__(self, exc: Exception):
|
|
|
|
|
self._exc = exc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Type aliases for better readability
|
|
|
|
|
ReadQueue = queue.Queue[SessionMessage | Exception | None]
|
|
|
|
|
WriteQueue = queue.Queue[SessionMessage | Exception | None]
|
|
|
|
|
StatusQueue = queue.Queue[tuple[str, str | Exception]]
|
|
|
|
|
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
|
|
|
|
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
|
|
|
|
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_request_params(url: str) -> str:
|
|
|
|
|
@ -80,10 +93,10 @@ class SSETransport:
|
|
|
|
|
if not self._validate_endpoint_url(endpoint_url):
|
|
|
|
|
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
status_queue.put(("error", ValueError(error_msg)))
|
|
|
|
|
status_queue.put(_StatusError(ValueError(error_msg)))
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
status_queue.put(("ready", endpoint_url))
|
|
|
|
|
status_queue.put(_StatusReady(endpoint_url))
|
|
|
|
|
|
|
|
|
|
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
|
|
|
|
|
"""Handle a 'message' SSE event.
|
|
|
|
|
@ -197,18 +210,17 @@ class SSETransport:
|
|
|
|
|
ValueError: If endpoint URL is not received or there's an error.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
status, endpoint_url_or_error = status_queue.get(timeout=1)
|
|
|
|
|
status = status_queue.get(timeout=1)
|
|
|
|
|
except queue.Empty:
|
|
|
|
|
raise ValueError("failed to get endpoint URL")
|
|
|
|
|
|
|
|
|
|
if status != "ready":
|
|
|
|
|
if isinstance(status, _StatusReady):
|
|
|
|
|
return status._endpoint_url
|
|
|
|
|
elif isinstance(status, _StatusError):
|
|
|
|
|
raise status._exc
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("failed to get endpoint URL")
|
|
|
|
|
|
|
|
|
|
if status == "error" and isinstance(endpoint_url_or_error, Exception):
|
|
|
|
|
raise endpoint_url_or_error
|
|
|
|
|
|
|
|
|
|
return cast(str, endpoint_url_or_error)
|
|
|
|
|
|
|
|
|
|
def connect(
|
|
|
|
|
self,
|
|
|
|
|
executor: ThreadPoolExecutor,
|
|
|
|
|
@ -284,9 +296,9 @@ def sse_client(
|
|
|
|
|
if exc.response.status_code == 401:
|
|
|
|
|
raise MCPAuthError()
|
|
|
|
|
raise MCPConnectionError()
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.exception("Error connecting to SSE endpoint")
|
|
|
|
|
raise exc
|
|
|
|
|
raise
|
|
|
|
|
finally:
|
|
|
|
|
# Clean up queues
|
|
|
|
|
if read_queue:
|
|
|
|
|
|