|
|
|
|
@ -1,6 +1,8 @@
|
|
|
|
|
import logging
|
|
|
|
|
from collections.abc import Callable
|
|
|
|
|
from contextlib import ExitStack
|
|
|
|
|
from typing import Optional, cast
|
|
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
|
|
|
|
from core.mcp.client.sse_client import sse_client
|
|
|
|
|
from core.mcp.client.streamable_client import streamablehttp_client
|
|
|
|
|
@ -59,42 +61,53 @@ class MCPClient:
|
|
|
|
|
first_try: bool = True,
|
|
|
|
|
):
|
|
|
|
|
"""Initialize the client with fallback to SSE if streamable connection fails"""
|
|
|
|
|
connection_methods = [("streamablehttp_client", streamablehttp_client), ("sse_client", sse_client)]
|
|
|
|
|
from core.mcp.auth.auth_flow import auth
|
|
|
|
|
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
|
|
|
|
|
|
|
|
|
|
for method_name, client_factory in connection_methods:
|
|
|
|
|
parsed_url = urlparse(self.server_url)
|
|
|
|
|
path = parsed_url.path
|
|
|
|
|
method_name = path.rstrip("/").split("/")[-1] if path else ""
|
|
|
|
|
try:
|
|
|
|
|
client_factory = connection_methods[method_name]
|
|
|
|
|
self.connect_server(client_factory, method_name)
|
|
|
|
|
except KeyError:
|
|
|
|
|
try:
|
|
|
|
|
headers = (
|
|
|
|
|
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
|
|
|
|
if self.authed and self.token
|
|
|
|
|
else {}
|
|
|
|
|
)
|
|
|
|
|
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
|
|
|
|
if method_name == "streamablehttp_client":
|
|
|
|
|
read_stream, write_stream, _ = self._streams_context.__enter__()
|
|
|
|
|
streams = (read_stream, write_stream)
|
|
|
|
|
else: # sse_client
|
|
|
|
|
streams = self._streams_context.__enter__()
|
|
|
|
|
|
|
|
|
|
self._session_context = ClientSession(*streams)
|
|
|
|
|
self._session = self._session_context.__enter__()
|
|
|
|
|
session = cast(ClientSession, self._session)
|
|
|
|
|
session.initialize()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
except MCPAuthError:
|
|
|
|
|
if not self.authed:
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
auth(self.provider, self.server_url, self.authorization_code, self.scope)
|
|
|
|
|
if first_try:
|
|
|
|
|
return self._initialize(first_try=False)
|
|
|
|
|
|
|
|
|
|
self.connect_server(streamablehttp_client, "sse")
|
|
|
|
|
except MCPConnectionError:
|
|
|
|
|
if method_name == "streamablehttp_client":
|
|
|
|
|
continue
|
|
|
|
|
self.connect_server(sse_client, "mcp")
|
|
|
|
|
|
|
|
|
|
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
|
|
|
|
|
from core.mcp.auth.auth_flow import auth
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
headers = (
|
|
|
|
|
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
|
|
|
|
if self.authed and self.token
|
|
|
|
|
else {}
|
|
|
|
|
)
|
|
|
|
|
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
|
|
|
|
if method_name == "mcp":
|
|
|
|
|
read_stream, write_stream, _ = self._streams_context.__enter__()
|
|
|
|
|
streams = (read_stream, write_stream)
|
|
|
|
|
else: # sse_client
|
|
|
|
|
streams = self._streams_context.__enter__()
|
|
|
|
|
|
|
|
|
|
self._session_context = ClientSession(*streams)
|
|
|
|
|
self._session = self._session_context.__enter__()
|
|
|
|
|
session = cast(ClientSession, self._session)
|
|
|
|
|
session.initialize()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
except MCPAuthError:
|
|
|
|
|
if not self.authed:
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
auth(self.provider, self.server_url, self.authorization_code, self.scope)
|
|
|
|
|
if first_try:
|
|
|
|
|
return self.connect_server(client_factory, method_name, first_try=False)
|
|
|
|
|
|
|
|
|
|
except MCPConnectionError:
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def list_tools(self) -> list[Tool]:
|
|
|
|
|
"""Connect to an MCP server running with SSE transport"""
|
|
|
|
|
# List available tools to verify connection
|
|
|
|
|
|