fix(api): Fix potential thread leak in MCP `BaseSession` (#22169)

The `BaseSession` class in the `core/mcp/session` package uses `ThreadPoolExecutor` 
to run the receive loop but fails to properly clean up the executor and receiver 
future, leading to potential thread leaks.

This PR addresses this issue by:
- Initializing `_executor` and `_receiver_future` attributes to `None` for proper cleanup checks
- Adding graceful shutdown with a 5-second timeout in the `__exit__` method
- Ensuring the ThreadPoolExecutor is properly shut down to prevent resource leaks

This fix prevents memory leaks and hanging threads in long-running scenarios where 
multiple MCP sessions are created and destroyed.

Signed-off-by: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
deploy/enterprise
NeatGuyCoding 10 months ago committed by NFish
parent e48483b52d
commit 780748460d

@ -1,7 +1,7 @@
import logging import logging
import queue import queue
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from contextlib import ExitStack from contextlib import ExitStack
from datetime import timedelta from datetime import timedelta
from types import TracebackType from types import TracebackType
@ -171,23 +171,41 @@ class BaseSession(
self._session_read_timeout_seconds = read_timeout_seconds self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {} self._in_flight = {}
self._exit_stack = ExitStack() self._exit_stack = ExitStack()
# Initialize executor and future to None for proper cleanup checks
self._executor: ThreadPoolExecutor | None = None
self._receiver_future: Future | None = None
def __enter__(self) -> Self: def __enter__(self) -> Self:
self._executor = ThreadPoolExecutor() # The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1
# ensures no unnecessary threads are created.
self._executor = ThreadPoolExecutor(max_workers=1)
self._receiver_future = self._executor.submit(self._receive_loop) self._receiver_future = self._executor.submit(self._receive_loop)
return self return self
def check_receiver_status(self) -> None: def check_receiver_status(self) -> None:
if self._receiver_future.done(): """`check_receiver_status` ensures that any exceptions raised during the
execution of `_receive_loop` are retrieved and propagated."""
if self._receiver_future and self._receiver_future.done():
self._receiver_future.result() self._receiver_future.result()
def __exit__( def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None: ) -> None:
self._exit_stack.close()
self._read_stream.put(None) self._read_stream.put(None)
self._write_stream.put(None) self._write_stream.put(None)
# Wait for the receiver loop to finish
if self._receiver_future:
try:
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
except TimeoutError:
# If the receiver loop is still running after timeout, we'll force shutdown
pass
# Shutdown the executor
if self._executor:
self._executor.shutdown(wait=True)
def send_request( def send_request(
self, self,
request: SendRequestT, request: SendRequestT,

Loading…
Cancel
Save