@ -1,7 +1,5 @@
import logging
import queue
import threading
import time
from collections . abc import Callable
from concurrent . futures import ThreadPoolExecutor
from contextlib import ExitStack
@ -80,13 +78,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self . _completed = False
self . _on_complete = on_complete
self . _entered = False # Track if we're in a context manager
self . _cancel_event = threading . Event ( )
def __enter__ ( self ) - > " RequestResponder[ReceiveRequestT, SendResultT] " :
""" Enter the context manager, enabling request cancellation tracking. """
self . _entered = True
self . _cancel_event = threading . Event ( )
self . _cancel_event . clear ( )
return self
def __exit__ (
@ -101,9 +96,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self . _on_complete ( self )
finally :
self . _entered = False
if not self . _cancel_event :
raise RuntimeError ( " No active cancel scope " )
self . _cancel_event . set ( )
def respond ( self , response : SendResultT | ErrorData ) - > None :
""" Send a response for this request.
@ -117,7 +109,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
raise RuntimeError ( " RequestResponder must be used as a context manager " )
assert not self . _completed , " Request already responded to "
if not self . cancelled :
self . _completed = True
self . _session . _send_response ( request_id = self . request_id , response = response )
@ -127,7 +118,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
if not self . _entered :
raise RuntimeError ( " RequestResponder must be used as a context manager " )
self . _cancel_event . set ( )
self . _completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
self . _session . _send_response (
@ -135,14 +125,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
response = ErrorData ( code = 0 , message = " Request cancelled " , data = None ) ,
)
@property
def in_flight ( self ) - > bool :
return not self . _completed and not self . cancelled
@property
def cancelled ( self ) - > bool :
return self . _cancel_event . is_set ( )
class BaseSession (
Generic [
@ -184,11 +166,9 @@ class BaseSession(
self . _in_flight = { }
self . _exit_stack = ExitStack ( )
self . _futures = [ ]
self . _request_id_lock = threading . Lock ( )
def __enter__ ( self ) - > Self :
self . _executor = ThreadPoolExecutor ( )
self . _stop_event = threading . Event ( )
self . _receiver_future = self . _executor . submit ( self . _receive_loop )
return self
@ -196,21 +176,8 @@ class BaseSession(
self , exc_type : type [ BaseException ] | None , exc_val : BaseException | None , exc_tb : TracebackType | None
) - > None :
self . _exit_stack . close ( )
self . _stop_event . set ( )
self . _wait_for_futures ( timeout = 5 )
def _wait_for_futures ( self , timeout = None ) :
end_time = time . time ( ) + timeout if timeout else None
for future in list ( self . _futures ) :
try :
remaining = end_time - time . time ( ) if end_time else None
if remaining is not None and remaining < = 0 :
break
future . result ( timeout = remaining )
except Exception as e :
logging . exception ( f " Error waiting for task: { e } " )
self . _read_stream . put ( None )
self . _write_stream . put ( None )
def send_request (
self ,
@ -247,7 +214,7 @@ class BaseSession(
timeout = request_read_timeout_seconds . total_seconds ( )
elif self . _session_read_timeout_seconds is not None :
timeout = self . _session_read_timeout_seconds . total_seconds ( )
while not self . _stop_event . is_set ( ) :
while True :
try :
response_or_error = response_queue . get ( timeout = timeout )
break
@ -316,7 +283,7 @@ class BaseSession(
Main message processing loop .
In a real synchronous implementation , this would likely run in a separate thread .
"""
while not self . _stop_event . is_set ( ) :
while True :
try :
# Attempt to receive a message (this would be blocking in a synchronous context)
message = self . _read_stream . get ( timeout = DEFAULT_RESPONSE_READ_TIMEOUT )
@ -378,12 +345,9 @@ class BaseSession(
else :
self . _handle_incoming ( RuntimeError ( f " Received response with an unknown request ID: { message } " ) )
except queue . Empty :
if self . _stop_event . is_set ( ) :
break
continue
except Exception as e :
logging . exception ( " Error in message processing loop " )
self . _stop_event . set ( )
def _received_request ( self , responder : RequestResponder [ ReceiveRequestT , SendResultT ] ) - > None :
"""