diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 1c0f582501..dd2ab25527 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -171,6 +171,9 @@ class BaseSession( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._exit_stack = ExitStack() + # Initialize executor and future to None for proper cleanup checks + self._executor = None + self._receiver_future = None def __enter__(self) -> Self: self._executor = ThreadPoolExecutor() @@ -184,10 +187,21 @@ class BaseSession( def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: - self._exit_stack.close() self._read_stream.put(None) self._write_stream.put(None) + # Wait for the receiver loop to finish + if self._receiver_future: + try: + self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds + except Exception: + # If the receiver loop is still running after timeout, we'll force shutdown + pass + + # Shutdown the executor + if self._executor: + self._executor.shutdown(wait=True, timeout=5.0) + def send_request( self, request: SendRequestT,