diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 7aa780d507..a85eee5219 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -180,10 +180,8 @@ class BaseSession( def check_receiver_status(self) -> None: if self._receiver_future.done(): try: - # 如果Future已完成,获取结果(如果有异常会在这里抛出) self._receiver_future.result() except Exception as e: - # 重新抛出线程中的异常 raise e def __exit__( @@ -234,7 +232,6 @@ class BaseSession( response_or_error = response_queue.get(timeout=timeout) break except queue.Empty: - # 在等待响应的过程中也检查接收线程状态 self.check_receiver_status() continue diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 603ab0cfb5..bf12ae2a95 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -36,7 +36,7 @@ LATEST_PROTOCOL_VERSION = "2024-11-05" ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] -RequestId = str | int +RequestId = Annotated[int | str, Field(union_mode="left_to_right")] AnyFunction: TypeAlias = Callable[..., Any] @@ -1182,6 +1182,7 @@ class OAuthClientMetadata(BaseModel): response_types: Optional[list[str]] = None token_endpoint_auth_method: Optional[str] = None client_uri: Optional[str] = None + scope: Optional[str] = None class OAuthClientInformation(BaseModel): diff --git a/api/tests/unit_tests/core/mcp/client/test_session.py b/api/tests/unit_tests/core/mcp/client/test_session.py new file mode 100644 index 0000000000..c84169bf15 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/client/test_session.py @@ -0,0 +1,471 @@ +import queue +import threading +from typing import Any + +from core.mcp import types +from core.mcp.entities import RequestContext +from core.mcp.session.base_session import RequestResponder +from core.mcp.session.client_session import DEFAULT_CLIENT_INFO, ClientSession +from core.mcp.types import ( + LATEST_PROTOCOL_VERSION, + ClientNotification, + ClientRequest, + Implementation, + InitializedNotification, + InitializeRequest, + InitializeResult, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ServerCapabilities, + ServerResult, + SessionMessage, +) + + +def test_client_session_initialize(): + # Create synchronous queues to replace async streams + client_to_server: queue.Queue[SessionMessage] = queue.Queue() + server_to_client: queue.Queue[SessionMessage] = queue.Queue() + + initialized_notification = None + + def mock_server(): + nonlocal initialized_notification + + # Receive initialization request + session_message = client_to_server.get(timeout=5.0) + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + + # Create response + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities( + logging=None, + resources=None, + tools=None, + experimental=None, + prompts=None, + ), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + instructions="The server instructions.", + ) + ) + + # Send response + server_to_client.put( + SessionMessage( + message=JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + session_notification = client_to_server.get(timeout=5.0) + jsonrpc_notification = session_notification.message + assert isinstance(jsonrpc_notification.root, JSONRPCNotification) + initialized_notification = ClientNotification.model_validate( + jsonrpc_notification.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + + # Create message handler + def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + # Start mock server thread + server_thread = threading.Thread(target=mock_server, daemon=True) + server_thread.start() + + # Create and use client session + with ClientSession( + server_to_client, + client_to_server, + message_handler=message_handler, + ) as session: + result = session.initialize() + + # Wait for server thread to complete + server_thread.join(timeout=10.0) + + # Assert results + assert isinstance(result, InitializeResult) + assert result.protocolVersion == LATEST_PROTOCOL_VERSION + assert isinstance(result.capabilities, ServerCapabilities) + assert result.serverInfo == Implementation(name="mock-server", version="0.1.0") + assert result.instructions == "The server instructions." + + # Check that client sent initialized notification + assert initialized_notification + assert isinstance(initialized_notification.root, InitializedNotification) + + +def test_client_session_custom_client_info(): + # Create synchronous queues to replace async streams + client_to_server: queue.Queue[SessionMessage] = queue.Queue() + server_to_client: queue.Queue[SessionMessage] = queue.Queue() + + custom_client_info = Implementation(name="test-client", version="1.2.3") + received_client_info = None + + def mock_server(): + nonlocal received_client_info + + session_message = client_to_server.get(timeout=5.0) + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_client_info = request.root.params.clientInfo + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + server_to_client.put( + SessionMessage( + message=JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + # Receive initialized notification + client_to_server.get(timeout=5.0) + + # Start mock server thread + server_thread = threading.Thread(target=mock_server, daemon=True) + server_thread.start() + + with ClientSession( + server_to_client, + client_to_server, + client_info=custom_client_info, + ) as session: + session.initialize() + + # Wait for server thread to complete + server_thread.join(timeout=10.0) + + # Assert that custom client info was sent + assert received_client_info == custom_client_info + + +def test_client_session_default_client_info(): + # Create synchronous queues to replace async streams + client_to_server: queue.Queue[SessionMessage] = queue.Queue() + server_to_client: queue.Queue[SessionMessage] = queue.Queue() + + received_client_info = None + + def mock_server(): + nonlocal received_client_info + + session_message = client_to_server.get(timeout=5.0) + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_client_info = request.root.params.clientInfo + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + server_to_client.put( + SessionMessage( + message=JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + # Receive initialized notification + client_to_server.get(timeout=5.0) + + # Start mock server thread + server_thread = threading.Thread(target=mock_server, daemon=True) + server_thread.start() + + with ClientSession( + server_to_client, + client_to_server, + ) as session: + session.initialize() + + # Wait for server thread to complete + server_thread.join(timeout=10.0) + + # Assert that default client info was used + assert received_client_info == DEFAULT_CLIENT_INFO + + +def test_client_session_version_negotiation_success(): + # Create synchronous queues to replace async streams + client_to_server: queue.Queue[SessionMessage] = queue.Queue() + server_to_client: queue.Queue[SessionMessage] = queue.Queue() + + def mock_server(): + session_message = client_to_server.get(timeout=5.0) + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + + # Send supported protocol version + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + server_to_client.put( + SessionMessage( + message=JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + # Receive initialized notification + client_to_server.get(timeout=5.0) + + # Start mock server thread + server_thread = threading.Thread(target=mock_server, daemon=True) + server_thread.start() + + with ClientSession( + server_to_client, + client_to_server, + ) as session: + result = session.initialize() + + # Wait for server thread to complete + server_thread.join(timeout=10.0) + + # Should successfully initialize + assert isinstance(result, InitializeResult) + assert result.protocolVersion == LATEST_PROTOCOL_VERSION + + +def test_client_session_version_negotiation_failure(): + # Create synchronous queues to replace async streams + client_to_server: queue.Queue[SessionMessage] = queue.Queue() + server_to_client: queue.Queue[SessionMessage] = queue.Queue() + + def mock_server(): + session_message = client_to_server.get(timeout=5.0) + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + + # Send unsupported protocol version + result = ServerResult( + InitializeResult( + protocolVersion="99.99.99", # Unsupported version + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + server_to_client.put( + SessionMessage( + message=JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Start mock server thread + server_thread = threading.Thread(target=mock_server, daemon=True) + server_thread.start() + + with ClientSession( + server_to_client, + client_to_server, + ) as session: + import pytest + + with pytest.raises(RuntimeError, match="Unsupported protocol version"): + session.initialize() + + # Wait for server thread to complete + server_thread.join(timeout=10.0) + + +def test_client_capabilities_default(): + # Create synchronous queues to replace async streams + client_to_server: queue.Queue[SessionMessage] = queue.Queue() + server_to_client: queue.Queue[SessionMessage] = queue.Queue() + + received_capabilities = None + + def mock_server(): + nonlocal received_capabilities + + session_message = client_to_server.get(timeout=5.0) + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + server_to_client.put( + SessionMessage( + message=JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + # Receive initialized notification + client_to_server.get(timeout=5.0) + + # Start mock server thread + server_thread = threading.Thread(target=mock_server, daemon=True) + server_thread.start() + + with ClientSession( + server_to_client, + client_to_server, + ) as session: + session.initialize() + + # Wait for server thread to complete + server_thread.join(timeout=10.0) + + # Assert default capabilities + assert received_capabilities is not None + assert received_capabilities.sampling is not None + assert received_capabilities.roots is not None + assert received_capabilities.roots.listChanged is True + + +def test_client_capabilities_with_custom_callbacks(): + # Create synchronous queues to replace async streams + client_to_server: queue.Queue[SessionMessage] = queue.Queue() + server_to_client: queue.Queue[SessionMessage] = queue.Queue() + + def custom_sampling_callback( + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, + ) -> types.CreateMessageResult | types.ErrorData: + return types.CreateMessageResult( + model="test-model", + role="assistant", + content=types.TextContent(type="text", text="Custom response"), + ) + + def custom_list_roots_callback( + context: RequestContext["ClientSession", Any], + ) -> types.ListRootsResult | types.ErrorData: + return types.ListRootsResult(roots=[]) + + def mock_server(): + session_message = client_to_server.get(timeout=5.0) + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + server_to_client.put( + SessionMessage( + message=JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + # Receive initialized notification + client_to_server.get(timeout=5.0) + + # Start mock server thread + server_thread = threading.Thread(target=mock_server, daemon=True) + server_thread.start() + + with ClientSession( + server_to_client, + client_to_server, + sampling_callback=custom_sampling_callback, + list_roots_callback=custom_list_roots_callback, + ) as session: + result = session.initialize() + + # Wait for server thread to complete + server_thread.join(timeout=10.0) + + # Verify initialization succeeded + assert isinstance(result, InitializeResult) + assert result.protocolVersion == LATEST_PROTOCOL_VERSION diff --git a/api/tests/unit_tests/core/mcp/client/test_sse.py b/api/tests/unit_tests/core/mcp/client/test_sse.py new file mode 100644 index 0000000000..8122cd08eb --- /dev/null +++ b/api/tests/unit_tests/core/mcp/client/test_sse.py @@ -0,0 +1,349 @@ +import json +import queue +import threading +import time +from typing import Any +from unittest.mock import Mock, patch + +import httpx +import pytest + +from core.mcp import types +from core.mcp.client.sse_client import sse_client +from core.mcp.error import MCPAuthError, MCPConnectionError + +SERVER_NAME = "test_server_for_SSE" + + +def test_sse_message_id_coercion(): + """Test that string message IDs that look like integers are parsed as integers. + + See for more details. + """ + json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}' + msg = types.JSONRPCMessage.model_validate_json(json_message) + expected = types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)) + + # Check if both are JSONRPCRequest instances + assert isinstance(msg.root, types.JSONRPCRequest) + assert isinstance(expected.root, types.JSONRPCRequest) + + assert msg.root.id == expected.root.id + assert msg.root.method == expected.root.method + assert msg.root.jsonrpc == expected.root.jsonrpc + + +class MockSSEClient: + """Mock SSE client for testing.""" + + def __init__(self, url: str, headers: dict[str, Any] | None = None): + self.url = url + self.headers = headers or {} + self.connected = False + self.read_queue: queue.Queue = queue.Queue() + self.write_queue: queue.Queue = queue.Queue() + + def connect(self): + """Simulate connection establishment.""" + self.connected = True + + # Send endpoint event + endpoint_data = "/messages/?session_id=test-session-123" + self.read_queue.put(("endpoint", endpoint_data)) + + return self.read_queue, self.write_queue + + def send_initialize_response(self): + """Send a mock initialize response.""" + response = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "protocolVersion": types.LATEST_PROTOCOL_VERSION, + "capabilities": { + "logging": None, + "resources": None, + "tools": None, + "experimental": None, + "prompts": None, + }, + "serverInfo": {"name": SERVER_NAME, "version": "0.1.0"}, + "instructions": "Test server instructions.", + }, + } + self.read_queue.put(("message", json.dumps(response))) + + +def test_sse_client_message_id_handling(): + """Test SSE client properly handles message ID coercion.""" + mock_client = MockSSEClient("http://test.example/sse") + read_queue, write_queue = mock_client.connect() + + # Send a message with string ID that should be coerced to int + message_data = { + "jsonrpc": "2.0", + "id": "456", # String ID + "result": {"test": "data"}, + } + read_queue.put(("message", json.dumps(message_data))) + read_queue.get(timeout=1.0) + # Get the message from queue + event_type, data = read_queue.get(timeout=1.0) + assert event_type == "message" + + # Parse the message + parsed_message = types.JSONRPCMessage.model_validate_json(data) + # Check that it's a JSONRPCResponse and verify the ID + assert isinstance(parsed_message.root, types.JSONRPCResponse) + assert parsed_message.root.id == 456 # Should be converted to int + + +def test_sse_client_connection_validation(): + """Test SSE client validates endpoint URLs properly.""" + test_url = "http://test.example/sse" + + with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: + with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect: + # Mock the HTTP client + mock_client = Mock() + mock_client_factory.return_value.__enter__.return_value = mock_client + + # Mock the SSE connection + mock_event_source = Mock() + mock_event_source.response.raise_for_status.return_value = None + mock_sse_connect.return_value.__enter__.return_value = mock_event_source + + # Mock SSE events + class MockSSEEvent: + def __init__(self, event_type: str, data: str): + self.event = event_type + self.data = data + + # Simulate endpoint event + endpoint_event = MockSSEEvent("endpoint", "/messages/?session_id=test-123") + mock_event_source.iter_sse.return_value = [endpoint_event] + + # Test connection + try: + with sse_client(test_url) as (read_queue, write_queue): + assert read_queue is not None + assert write_queue is not None + except Exception as e: + # Connection might fail due to mocking, but we're testing the validation logic + pass + + +def test_sse_client_error_handling(): + """Test SSE client properly handles various error conditions.""" + test_url = "http://test.example/sse" + + # Test 401 error handling + with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: + with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect: + # Mock 401 HTTP error + mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=Mock(status_code=401)) + mock_sse_connect.side_effect = mock_error + + with pytest.raises(MCPAuthError): + with sse_client(test_url): + pass + + # Test other HTTP errors + with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: + with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect: + # Mock other HTTP error + mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=Mock(status_code=500)) + mock_sse_connect.side_effect = mock_error + + with pytest.raises(MCPConnectionError): + with sse_client(test_url): + pass + + +def test_sse_client_timeout_configuration(): + """Test SSE client timeout configuration.""" + test_url = "http://test.example/sse" + custom_timeout = 10.0 + custom_sse_timeout = 300.0 + custom_headers = {"Authorization": "Bearer test-token"} + + with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: + with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect: + # Mock successful connection + mock_client = Mock() + mock_client_factory.return_value.__enter__.return_value = mock_client + + mock_event_source = Mock() + mock_event_source.response.raise_for_status.return_value = None + mock_event_source.iter_sse.return_value = [] + mock_sse_connect.return_value.__enter__.return_value = mock_event_source + + try: + with sse_client( + test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout + ) as (read_queue, write_queue): + # Verify the configuration was passed correctly + mock_client_factory.assert_called_with(headers=custom_headers) + + # Check that timeout was configured + call_args = mock_sse_connect.call_args + assert call_args is not None + timeout_arg = call_args[1]["timeout"] + assert timeout_arg.read == custom_sse_timeout + except Exception: + # Connection might fail due to mocking, but we tested the configuration + pass + + +def test_sse_transport_endpoint_validation(): + """Test SSE transport validates endpoint URLs correctly.""" + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + + # Valid endpoint (same origin) + valid_endpoint = "http://example.com/messages/session123" + assert transport._validate_endpoint_url(valid_endpoint) == True + + # Invalid endpoint (different origin) + invalid_endpoint = "http://malicious.com/messages/session123" + assert transport._validate_endpoint_url(invalid_endpoint) == False + + # Invalid endpoint (different scheme) + invalid_scheme = "https://example.com/messages/session123" + assert transport._validate_endpoint_url(invalid_scheme) == False + + +def test_sse_transport_message_parsing(): + """Test SSE transport properly parses different message types.""" + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + read_queue: queue.Queue = queue.Queue() + + # Test valid JSON-RPC message + valid_message = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + transport._handle_message_event(valid_message, read_queue) + + # Should have a SessionMessage in the queue + message = read_queue.get(timeout=1.0) + assert message is not None + assert hasattr(message, "message") + + # Test invalid JSON + invalid_json = '{"invalid": json}' + transport._handle_message_event(invalid_json, read_queue) + + # Should have an exception in the queue + error = read_queue.get(timeout=1.0) + assert isinstance(error, Exception) + + +def test_sse_client_queue_cleanup(): + """Test that SSE client properly cleans up queues on exit.""" + test_url = "http://test.example/sse" + + read_queue = None + write_queue = None + + with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: + with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect: + # Mock connection that raises an exception + mock_sse_connect.side_effect = Exception("Connection failed") + + try: + with sse_client(test_url) as (rq, wq): + read_queue = rq + write_queue = wq + except Exception: + pass # Expected to fail + + # Queues should be cleaned up even on exception + # Note: In real implementation, cleanup should put None to signal shutdown + + +def test_sse_client_url_processing(): + """Test SSE client URL processing functions.""" + from core.mcp.client.sse_client import remove_request_params + + # Test URL with parameters + url_with_params = "http://example.com/sse?param1=value1¶m2=value2" + cleaned_url = remove_request_params(url_with_params) + assert cleaned_url == "http://example.com/sse" + + # Test URL without parameters + url_without_params = "http://example.com/sse" + cleaned_url = remove_request_params(url_without_params) + assert cleaned_url == "http://example.com/sse" + + # Test URL with path and parameters + complex_url = "http://example.com/path/to/sse?session=123&token=abc" + cleaned_url = remove_request_params(complex_url) + assert cleaned_url == "http://example.com/path/to/sse" + + +def test_sse_client_headers_propagation(): + """Test that custom headers are properly propagated in SSE client.""" + test_url = "http://test.example/sse" + custom_headers = { + "Authorization": "Bearer test-token", + "X-Custom-Header": "test-value", + "User-Agent": "test-client/1.0", + } + + with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: + with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect: + # Mock the client factory to capture headers + mock_client = Mock() + mock_client_factory.return_value.__enter__.return_value = mock_client + + # Mock the SSE connection + mock_event_source = Mock() + mock_event_source.response.raise_for_status.return_value = None + mock_event_source.iter_sse.return_value = [] + mock_sse_connect.return_value.__enter__.return_value = mock_event_source + + try: + with sse_client(test_url, headers=custom_headers): + pass + except Exception: + pass # Expected due to mocking + + # Verify headers were passed to client factory + mock_client_factory.assert_called_with(headers=custom_headers) + + +def test_sse_client_concurrent_access(): + """Test SSE client behavior with concurrent queue access.""" + test_read_queue: queue.Queue = queue.Queue() + + # Simulate concurrent producers and consumers + def producer(): + for i in range(10): + test_read_queue.put(f"message_{i}") + time.sleep(0.01) # Small delay to simulate real conditions + + def consumer(): + received = [] + for _ in range(10): + try: + msg = test_read_queue.get(timeout=2.0) + received.append(msg) + except queue.Empty: + break + return received + + # Start producer in separate thread + producer_thread = threading.Thread(target=producer, daemon=True) + producer_thread.start() + + # Consume messages + received_messages = consumer() + + # Wait for producer to finish + producer_thread.join(timeout=5.0) + + # Verify all messages were received + assert len(received_messages) == 10 + for i in range(10): + assert f"message_{i}" in received_messages diff --git a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py new file mode 100644 index 0000000000..9a30a35a49 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py @@ -0,0 +1,450 @@ +""" +Tests for the StreamableHTTP client transport. + +Contains tests for only the client side of the StreamableHTTP transport. +""" + +import queue +import threading +import time +from typing import Any +from unittest.mock import Mock, patch + +from core.mcp import types +from core.mcp.client.streamable_client import streamablehttp_client + +# Test constants +SERVER_NAME = "test_streamable_http_server" +TEST_SESSION_ID = "test-session-id-12345" +INIT_REQUEST = { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": "2025-03-26", + "capabilities": {}, + }, + "id": "init-1", +} + + +class MockStreamableHTTPClient: + """Mock StreamableHTTP client for testing.""" + + def __init__(self, url: str, headers: dict[str, Any] | None = None): + self.url = url + self.headers = headers or {} + self.connected = False + self.read_queue: queue.Queue = queue.Queue() + self.write_queue: queue.Queue = queue.Queue() + self.session_id = TEST_SESSION_ID + + def connect(self): + """Simulate connection establishment.""" + self.connected = True + return self.read_queue, self.write_queue, lambda: self.session_id + + def send_initialize_response(self): + """Send a mock initialize response.""" + session_message = types.SessionMessage( + message=types.JSONRPCMessage( + root=types.JSONRPCResponse( + jsonrpc="2.0", + id="init-1", + result={ + "protocolVersion": types.LATEST_PROTOCOL_VERSION, + "capabilities": { + "logging": None, + "resources": None, + "tools": None, + "experimental": None, + "prompts": None, + }, + "serverInfo": {"name": SERVER_NAME, "version": "0.1.0"}, + "instructions": "Test server instructions.", + }, + ) + ) + ) + self.read_queue.put(session_message) + + def send_tools_response(self): + """Send a mock tools list response.""" + session_message = types.SessionMessage( + message=types.JSONRPCMessage( + root=types.JSONRPCResponse( + jsonrpc="2.0", + id="tools-1", + result={ + "tools": [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"type": "object", "properties": {}}, + } + ], + }, + ) + ) + ) + self.read_queue.put(session_message) + + +def test_streamablehttp_client_message_id_handling(): + """Test StreamableHTTP client properly handles message ID coercion.""" + mock_client = MockStreamableHTTPClient("http://test.example/mcp") + read_queue, write_queue, get_session_id = mock_client.connect() + + # Send a message with string ID that should be coerced to int + response_message = types.SessionMessage( + message=types.JSONRPCMessage(root=types.JSONRPCResponse(jsonrpc="2.0", id="789", result={"test": "data"})) + ) + read_queue.put(response_message) + + # Get the message from queue + message = read_queue.get(timeout=1.0) + assert message is not None + assert isinstance(message, types.SessionMessage) + + # Check that the ID was properly handled + assert isinstance(message.message.root, types.JSONRPCResponse) + assert message.message.root.id == 789 # ID should be coerced to int due to union_mode="left_to_right" + + +def test_streamablehttp_client_connection_validation(): + """Test StreamableHTTP client validates connections properly.""" + test_url = "http://test.example/mcp" + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: + # Mock the HTTP client + mock_client = Mock() + mock_client_factory.return_value.__enter__.return_value = mock_client + + # Mock successful response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status.return_value = None + mock_client.post.return_value = mock_response + + # Test connection + try: + with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id): + assert read_queue is not None + assert write_queue is not None + assert get_session_id is not None + except Exception: + # Connection might fail due to mocking, but we're testing the validation logic + pass + + +def test_streamablehttp_client_timeout_configuration(): + """Test StreamableHTTP client timeout configuration.""" + test_url = "http://test.example/mcp" + custom_headers = {"Authorization": "Bearer test-token"} + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: + # Mock successful connection + mock_client = Mock() + mock_client_factory.return_value.__enter__.return_value = mock_client + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status.return_value = None + mock_client.post.return_value = mock_response + + try: + with streamablehttp_client(test_url, headers=custom_headers) as (read_queue, write_queue, get_session_id): + # Verify the configuration was passed correctly + mock_client_factory.assert_called_with(headers=custom_headers) + except Exception: + # Connection might fail due to mocking, but we tested the configuration + pass + + +def test_streamablehttp_client_session_id_handling(): + """Test StreamableHTTP client properly handles session IDs.""" + mock_client = MockStreamableHTTPClient("http://test.example/mcp") + read_queue, write_queue, get_session_id = mock_client.connect() + + # Test that session ID is available + session_id = get_session_id() + assert session_id == TEST_SESSION_ID + + # Test that we can use the session ID in subsequent requests + assert session_id is not None + assert len(session_id) > 0 + + +def test_streamablehttp_client_message_parsing(): + """Test StreamableHTTP client properly parses different message types.""" + mock_client = MockStreamableHTTPClient("http://test.example/mcp") + read_queue, write_queue, get_session_id = mock_client.connect() + + # Test valid initialization response + mock_client.send_initialize_response() + + # Should have a SessionMessage in the queue + message = read_queue.get(timeout=1.0) + assert message is not None + assert isinstance(message, types.SessionMessage) + assert isinstance(message.message.root, types.JSONRPCResponse) + + # Test tools response + mock_client.send_tools_response() + + tools_message = read_queue.get(timeout=1.0) + assert tools_message is not None + assert isinstance(tools_message, types.SessionMessage) + + +def test_streamablehttp_client_queue_cleanup(): + """Test that StreamableHTTP client properly cleans up queues on exit.""" + test_url = "http://test.example/mcp" + + read_queue = None + write_queue = None + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: + # Mock connection that raises an exception + mock_client_factory.side_effect = Exception("Connection failed") + + try: + with streamablehttp_client(test_url) as (rq, wq, get_session_id): + read_queue = rq + write_queue = wq + except Exception: + pass # Expected to fail + + # Queues should be cleaned up even on exception + # Note: In real implementation, cleanup should put None to signal shutdown + + +def test_streamablehttp_client_headers_propagation(): + """Test that custom headers are properly propagated in StreamableHTTP client.""" + test_url = "http://test.example/mcp" + custom_headers = { + "Authorization": "Bearer test-token", + "X-Custom-Header": "test-value", + "User-Agent": "test-client/1.0", + } + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: + # Mock the client factory to capture headers + mock_client = Mock() + mock_client_factory.return_value.__enter__.return_value = mock_client + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status.return_value = None + mock_client.post.return_value = mock_response + + try: + with streamablehttp_client(test_url, headers=custom_headers): + pass + except Exception: + pass # Expected due to mocking + + # Verify headers were passed to client factory + # Check that the call was made with headers that include our custom headers + mock_client_factory.assert_called_once() + call_args = mock_client_factory.call_args + assert "headers" in call_args.kwargs + passed_headers = call_args.kwargs["headers"] + + # Verify all custom headers are present + for key, value in custom_headers.items(): + assert key in passed_headers + assert passed_headers[key] == value + + +def test_streamablehttp_client_concurrent_access(): + """Test StreamableHTTP client behavior with concurrent queue access.""" + test_read_queue: queue.Queue = queue.Queue() + test_write_queue: queue.Queue = queue.Queue() + + # Simulate concurrent producers and consumers + def producer(): + for i in range(10): + test_read_queue.put(f"message_{i}") + time.sleep(0.01) # Small delay to simulate real conditions + + def consumer(): + received = [] + for _ in range(10): + try: + msg = test_read_queue.get(timeout=2.0) + received.append(msg) + except queue.Empty: + break + return received + + # Start producer in separate thread + producer_thread = threading.Thread(target=producer, daemon=True) + producer_thread.start() + + # Consume messages + received_messages = consumer() + + # Wait for producer to finish + producer_thread.join(timeout=5.0) + + # Verify all messages were received + assert len(received_messages) == 10 + for i in range(10): + assert f"message_{i}" in received_messages + + +def test_streamablehttp_client_json_vs_sse_mode(): + """Test StreamableHTTP client handling of JSON vs SSE response modes.""" + test_url = "http://test.example/mcp" + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: + mock_client = Mock() + mock_client_factory.return_value.__enter__.return_value = mock_client + + # Mock JSON response + mock_json_response = Mock() + mock_json_response.status_code = 200 + mock_json_response.headers = {"content-type": "application/json"} + mock_json_response.json.return_value = {"result": "json_mode"} + mock_json_response.raise_for_status.return_value = None + + # Mock SSE response + mock_sse_response = Mock() + mock_sse_response.status_code = 200 + mock_sse_response.headers = {"content-type": "text/event-stream"} + mock_sse_response.raise_for_status.return_value = None + + # Test JSON mode + mock_client.post.return_value = mock_json_response + + try: + with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id): + # Should handle JSON responses + assert read_queue is not None + assert write_queue is not None + except Exception: + pass # Expected due to mocking + + # Test SSE mode + mock_client.post.return_value = mock_sse_response + + try: + with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id): + # Should handle SSE responses + assert read_queue is not None + assert write_queue is not None + except Exception: + pass # Expected due to mocking + + +def test_streamablehttp_client_terminate_on_close(): + """Test StreamableHTTP client terminate_on_close parameter.""" + test_url = "http://test.example/mcp" + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: + mock_client = Mock() + mock_client_factory.return_value.__enter__.return_value = mock_client + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status.return_value = None + mock_client.post.return_value = mock_response + mock_client.delete.return_value = mock_response + + # Test with terminate_on_close=True (default) + try: + with streamablehttp_client(test_url, terminate_on_close=True) as (read_queue, write_queue, get_session_id): + pass + except Exception: + pass # Expected due to mocking + + # Test with terminate_on_close=False + try: + with streamablehttp_client(test_url, terminate_on_close=False) as (read_queue, write_queue, get_session_id): + pass + except Exception: + pass # Expected due to mocking + + +def test_streamablehttp_client_protocol_version_handling(): + """Test StreamableHTTP client protocol version handling.""" + mock_client = MockStreamableHTTPClient("http://test.example/mcp") + read_queue, write_queue, get_session_id = mock_client.connect() + + # Send initialize response with specific protocol version + + session_message = types.SessionMessage( + message=types.JSONRPCMessage( + root=types.JSONRPCResponse( + jsonrpc="2.0", + id="init-1", + result={ + "protocolVersion": "2024-11-05", + "capabilities": {}, + "serverInfo": {"name": SERVER_NAME, "version": "0.1.0"}, + }, + ) + ) + ) + read_queue.put(session_message) + + # Get the message and verify protocol version + message = read_queue.get(timeout=1.0) + assert message is not None + assert isinstance(message.message.root, types.JSONRPCResponse) + result = message.message.root.result + assert result["protocolVersion"] == "2024-11-05" + + +def test_streamablehttp_client_error_response_handling(): + """Test StreamableHTTP client handling of error responses.""" + mock_client = MockStreamableHTTPClient("http://test.example/mcp") + read_queue, write_queue, get_session_id = mock_client.connect() + + # Send an error response + session_message = types.SessionMessage( + message=types.JSONRPCMessage( + root=types.JSONRPCError( + jsonrpc="2.0", + id="test-1", + error=types.ErrorData(code=-32601, message="Method not found", data=None), + ) + ) + ) + read_queue.put(session_message) + + # Get the error message + message = read_queue.get(timeout=1.0) + assert message is not None + assert isinstance(message.message.root, types.JSONRPCError) + assert message.message.root.error.code == -32601 + assert message.message.root.error.message == "Method not found" + + +def test_streamablehttp_client_resumption_token_handling(): + """Test StreamableHTTP client resumption token functionality.""" + test_url = "http://test.example/mcp" + test_resumption_token = "resume-token-123" + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory: + mock_client = Mock() + mock_client_factory.return_value.__enter__.return_value = mock_client + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json", "last-event-id": test_resumption_token} + mock_response.raise_for_status.return_value = None + mock_client.post.return_value = mock_response + + try: + with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id): + # Test that resumption token can be captured from headers + assert read_queue is not None + assert write_queue is not None + except Exception: + pass # Expected due to mocking