|
|
|
|
@ -1,7 +1,8 @@
|
|
|
|
|
import logging
|
|
|
|
|
from collections.abc import Callable
|
|
|
|
|
from contextlib import ExitStack
|
|
|
|
|
from typing import Optional, cast
|
|
|
|
|
from contextlib import AbstractContextManager, ExitStack
|
|
|
|
|
from types import TracebackType
|
|
|
|
|
from typing import Any, Optional, cast
|
|
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
|
|
|
|
from core.mcp.client.sse_client import sse_client
|
|
|
|
|
@ -39,8 +40,8 @@ class MCPClient:
|
|
|
|
|
|
|
|
|
|
# Initialize session and client objects
|
|
|
|
|
self._session: Optional[ClientSession] = None
|
|
|
|
|
self._streams_context = None
|
|
|
|
|
self._session_context = None
|
|
|
|
|
self._streams_context: Optional[AbstractContextManager[Any]] = None
|
|
|
|
|
self._session_context: Optional[ClientSession] = None
|
|
|
|
|
self.exit_stack = ExitStack()
|
|
|
|
|
|
|
|
|
|
# Whether the client has been initialized
|
|
|
|
|
@ -51,14 +52,19 @@ class MCPClient:
|
|
|
|
|
self._initialized = True
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
|
def __exit__(
|
|
|
|
|
self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
|
|
|
|
|
):
|
|
|
|
|
self.cleanup()
|
|
|
|
|
|
|
|
|
|
def _initialize(
|
|
|
|
|
self,
|
|
|
|
|
):
|
|
|
|
|
"""Initialize the client with fallback to SSE if streamable connection fails"""
|
|
|
|
|
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
|
|
|
|
|
connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
|
|
|
|
|
"mcp": streamablehttp_client,
|
|
|
|
|
"sse": sse_client,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
parsed_url = urlparse(self.server_url)
|
|
|
|
|
path = parsed_url.path
|
|
|
|
|
@ -72,7 +78,9 @@ class MCPClient:
|
|
|
|
|
except MCPConnectionError:
|
|
|
|
|
self.connect_server(streamablehttp_client, "mcp")
|
|
|
|
|
|
|
|
|
|
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
|
|
|
|
|
def connect_server(
|
|
|
|
|
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
|
|
|
|
|
):
|
|
|
|
|
from core.mcp.auth.auth_flow import auth
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
@ -82,6 +90,9 @@ class MCPClient:
|
|
|
|
|
else {}
|
|
|
|
|
)
|
|
|
|
|
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
|
|
|
|
if self._streams_context is None:
|
|
|
|
|
raise MCPConnectionError("Failed to create connection context")
|
|
|
|
|
|
|
|
|
|
if method_name == "mcp":
|
|
|
|
|
read_stream, write_stream, _ = self._streams_context.__enter__()
|
|
|
|
|
streams = (read_stream, write_stream)
|
|
|
|
|
@ -123,8 +134,8 @@ class MCPClient:
|
|
|
|
|
def cleanup(self):
|
|
|
|
|
"""Clean up resources"""
|
|
|
|
|
try:
|
|
|
|
|
if self._session:
|
|
|
|
|
self._session.__exit__(None, None, None)
|
|
|
|
|
if self._session_context:
|
|
|
|
|
self._session_context.__exit__(None, None, None)
|
|
|
|
|
|
|
|
|
|
if self._streams_context:
|
|
|
|
|
self._streams_context.__exit__(None, None, None)
|
|
|
|
|
|