fix(api): Fix potential thread leak in MCP `BaseSession` (#22169)

The `BaseSession` class in the `core/mcp/session` package uses `ThreadPoolExecutor` 
to run the receive loop but fails to properly clean up the executor and receiver 
future, leading to potential thread leaks.

This PR addresses this issue by:
- Initializing `_executor` and `_receiver_future` attributes to `None` for proper cleanup checks
- Adding graceful shutdown with a 5-second timeout in the `__exit__` method
- Ensuring the ThreadPoolExecutor is properly shut down to prevent resource leaks

This fix prevents memory leaks and hanging threads in long-running scenarios where 
multiple MCP sessions are created and destroyed.

Signed-off-by: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
pull/22468/head
NeatGuyCoding 9 months ago committed by GitHub
parent da53bf511f
commit 7bf3d2c8bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,7 +1,7 @@
import logging import logging
import queue import queue
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from contextlib import ExitStack from contextlib import ExitStack
from datetime import timedelta from datetime import timedelta
from types import TracebackType from types import TracebackType
@ -171,23 +171,41 @@ class BaseSession(
self._session_read_timeout_seconds = read_timeout_seconds self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {} self._in_flight = {}
self._exit_stack = ExitStack() self._exit_stack = ExitStack()
# Initialize executor and future to None for proper cleanup checks
self._executor: ThreadPoolExecutor | None = None
self._receiver_future: Future | None = None
def __enter__(self) -> Self: def __enter__(self) -> Self:
self._executor = ThreadPoolExecutor() # The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1
# ensures no unnecessary threads are created.
self._executor = ThreadPoolExecutor(max_workers=1)
self._receiver_future = self._executor.submit(self._receive_loop) self._receiver_future = self._executor.submit(self._receive_loop)
return self return self
def check_receiver_status(self) -> None: def check_receiver_status(self) -> None:
if self._receiver_future.done(): """`check_receiver_status` ensures that any exceptions raised during the
execution of `_receive_loop` are retrieved and propagated."""
if self._receiver_future and self._receiver_future.done():
self._receiver_future.result() self._receiver_future.result()
def __exit__( def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None: ) -> None:
self._exit_stack.close()
self._read_stream.put(None) self._read_stream.put(None)
self._write_stream.put(None) self._write_stream.put(None)
# Wait for the receiver loop to finish
if self._receiver_future:
try:
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
except TimeoutError:
# If the receiver loop is still running after timeout, we'll force shutdown
pass
# Shutdown the executor
if self._executor:
self._executor.shutdown(wait=True)
def send_request( def send_request(
self, self,
request: SendRequestT, request: SendRequestT,

Loading…
Cancel
Save