9292DEFAULT_MAX_RECONNECT_ATTEMPTS = 60
9393DEFAULT_PING_INTERVAL = 120 # in seconds
9494DEFAULT_MAX_OUTSTANDING_PINGS = 2
95+ DEFAULT_MAX_READ_TIMEOUTS = 3
9596DEFAULT_MAX_PAYLOAD_SIZE = 1048576
9697DEFAULT_MAX_FLUSHER_QUEUE_SIZE = 1024
9798DEFAULT_FLUSH_TIMEOUT = 10 # in seconds
@@ -341,6 +342,7 @@ async def connect(
341342 max_reconnect_attempts : int = DEFAULT_MAX_RECONNECT_ATTEMPTS ,
342343 ping_interval : int = DEFAULT_PING_INTERVAL ,
343344 max_outstanding_pings : int = DEFAULT_MAX_OUTSTANDING_PINGS ,
345+ max_read_timeouts : int = DEFAULT_MAX_READ_TIMEOUTS ,
344346 dont_randomize : bool = False ,
345347 flusher_queue_size : int = DEFAULT_MAX_FLUSHER_QUEUE_SIZE ,
346348 no_echo : bool = False ,
@@ -371,6 +373,7 @@ async def connect(
371373 :param discovered_server_cb: Callback to report when a new server joins the cluster.
372374 :param pending_size: Max size of the pending buffer for publishing commands.
373375 :param flush_timeout: Max duration to wait for a forced flush to occur.
376+ :param max_read_timeouts: Maximum number of consecutive read timeouts before considering a connection stale.
374377
375378 Connecting setting all callbacks::
376379
@@ -488,6 +491,7 @@ async def subscribe_handler(msg):
488491 self .options ["reconnect_time_wait" ] = reconnect_time_wait
489492 self .options ["max_reconnect_attempts" ] = max_reconnect_attempts
490493 self .options ["ping_interval" ] = ping_interval
494+ self .options ["max_read_timeouts" ] = max_read_timeouts
491495 self .options ["max_outstanding_pings" ] = max_outstanding_pings
492496 self .options ["no_echo" ] = no_echo
493497 self .options ["user" ] = user
@@ -1030,14 +1034,38 @@ async def request(
10301034 the responses.
10311035
10321036 """
1037+ if not self .is_connected :
1038+ await self ._check_connection_health ()
1039+
1040+ if not self .is_connected :
1041+ if self .is_closed :
1042+ raise errors .ConnectionClosedError
1043+ elif self .is_reconnecting :
1044+ raise errors .ConnectionReconnectingError
1045+ else :
1046+ raise errors .ConnectionClosedError
1047+
10331048 if old_style :
10341049 # FIXME: Support headers in old style requests.
1035- return await self ._request_old_style (subject , payload , timeout = timeout )
1050+ try :
1051+ return await self ._request_old_style (
1052+ subject , payload , timeout = timeout
1053+ )
1054+ except (errors .TimeoutError , asyncio .TimeoutError ):
1055+ await self ._check_connection_health ()
1056+ raise errors .TimeoutError
10361057 else :
1037- msg = await self ._request_new_style (subject , payload , timeout = timeout , headers = headers )
1038- if msg .headers and msg .headers .get (nats .js .api .Header .STATUS ) == NO_RESPONDERS_STATUS :
1039- raise errors .NoRespondersError
1040- return msg
1058+ try :
1059+ msg = await self ._request_new_style (
1060+ subject , payload , timeout = timeout , headers = headers
1061+ )
1062+ if msg .headers and msg .headers .get (nats .js .api .Header .STATUS
1063+ ) == NO_RESPONDERS_STATUS :
1064+ raise errors .NoRespondersError
1065+ return msg
1066+ except (errors .TimeoutError , asyncio .TimeoutError ):
1067+ await self ._check_connection_health ()
1068+ raise errors .TimeoutError
10411069
10421070 async def _request_new_style (
10431071 self ,
@@ -1049,6 +1077,9 @@ async def _request_new_style(
10491077 if self .is_draining_pubs :
10501078 raise errors .ConnectionDrainingError
10511079
1080+ if not self .is_connected :
1081+ raise errors .ConnectionClosedError
1082+
10521083 if not self ._resp_sub_prefix :
10531084 await self ._init_request_sub ()
10541085 assert self ._resp_sub_prefix
@@ -1061,17 +1092,37 @@ async def _request_new_style(
10611092
10621093 # Then use the future to get the response.
10631094 future : asyncio .Future = asyncio .Future ()
1064- future .add_done_callback (lambda f : self ._resp_map .pop (token .decode (), None ))
1065- self ._resp_map [token .decode ()] = future
1095+ token_str = token .decode ()
1096+
1097+ def cleanup_resp_map (f ):
1098+ self ._resp_map .pop (token_str , None )
10661099
1067- # Publish the request
1068- await self .publish ( subject , payload , reply = inbox . decode (), headers = headers )
1100+ future . add_done_callback ( cleanup_resp_map )
1101+ self ._resp_map [ token_str ] = future
10691102
1070- # Wait for the response or give up on timeout.
10711103 try :
1072- return await asyncio .wait_for (future , timeout )
1073- except asyncio .TimeoutError :
1074- raise errors .TimeoutError
1104+ # Publish the request
1105+ await self .publish (
1106+ subject , payload , reply = inbox .decode (), headers = headers
1107+ )
1108+
1109+ if not self .is_connected :
1110+ future .cancel ()
1111+ raise errors .ConnectionClosedError
1112+
1113+ try :
1114+ return await asyncio .wait_for (future , timeout )
1115+ except asyncio .TimeoutError :
1116+ cleanup_resp_map (future )
1117+ raise errors .TimeoutError
1118+ except asyncio .CancelledError :
1119+ cleanup_resp_map (future )
1120+ raise
1121+ except Exception :
1122+ if not future .done ():
1123+ future .cancel ()
1124+ cleanup_resp_map (future )
1125+ raise
10751126
10761127 def new_inbox (self ) -> str :
10771128 """
@@ -1397,6 +1448,35 @@ async def _process_err(self, err_msg: str) -> None:
13971448 # For now we handle similar as other clients and close.
13981449 asyncio .create_task (self ._close (Client .CLOSED , do_cbs ))
13991450
1451+ async def _check_connection_health (self ) -> bool :
1452+ """
1453+ Checks if the connection appears healthy, and if not, attempts reconnection.
1454+
1455+ Returns:
1456+ bool: True if connection is healthy or was successfully reconnected, False otherwise
1457+ """
1458+ if not self .is_connected :
1459+ if self .options [
1460+ "allow_reconnect"
1461+ ] and not self .is_reconnecting and not self .is_closed :
1462+ self ._status = Client .RECONNECTING
1463+ self ._ps .reset ()
1464+
1465+ try :
1466+ if self ._reconnection_task is not None and not self ._reconnection_task .cancelled (
1467+ ):
1468+ self ._reconnection_task .cancel ()
1469+
1470+ self ._reconnection_task = asyncio .get_running_loop (
1471+ ).create_task (self ._attempt_reconnect ())
1472+
1473+ await asyncio .sleep (self .options ["reconnect_time_wait" ])
1474+ return self .is_connected
1475+ except Exception :
1476+ return False
1477+ return False
1478+ return True
1479+
14001480 async def _process_op_err (self , e : Exception ) -> None :
14011481 """
14021482 Process errors which occurred while reading or parsing
@@ -2056,8 +2136,16 @@ async def _ping_interval(self) -> None:
20562136 await self ._send_ping ()
20572137 except (asyncio .CancelledError , RuntimeError , AttributeError ):
20582138 break
2059- # except asyncio.InvalidStateError:
2060- # pass
2139+ except asyncio .InvalidStateError :
2140+ # Handle invalid state errors that can occur when connection state changes
2141+ if self .is_connected :
2142+ await self ._process_op_err (ErrStaleConnection ())
2143+ break
2144+ except Exception as e :
2145+ if self .is_connected :
2146+ await self ._error_cb (e )
2147+ await self ._process_op_err (ErrStaleConnection ())
2148+ break
20612149
20622150 async def _read_loop (self ) -> None :
20632151 """
@@ -2066,6 +2154,8 @@ async def _read_loop(self) -> None:
20662154 In case of error while reading, it will stop running
20672155 and its task has to be rescheduled.
20682156 """
2157+ read_timeout_count = 0
2158+
20692159 while True :
20702160 try :
20712161 should_bail = self .is_closed or self .is_reconnecting
@@ -2077,21 +2167,47 @@ async def _read_loop(self) -> None:
20772167 await self ._process_op_err (err )
20782168 break
20792169
2080- b = await self ._transport .read (DEFAULT_BUFFER_SIZE )
2081- await self ._ps .parse (b )
2170+ # Use a timeout for reading to detect stalled connections
2171+ try :
2172+ read_future = self ._transport .read (DEFAULT_BUFFER_SIZE )
2173+ b = await asyncio .wait_for (
2174+ read_future , timeout = self .options ["ping_interval" ]
2175+ )
2176+ read_timeout_count = 0
2177+ await self ._ps .parse (b )
2178+ except asyncio .TimeoutError :
2179+ read_timeout_count += 1
2180+ if read_timeout_count >= self .options ["max_read_timeouts" ]:
2181+ err = ErrStaleConnection ()
2182+ await self ._error_cb (err )
2183+ await self ._process_op_err (err )
2184+ break
2185+ continue
2186+
20822187 except errors .ProtocolError :
20832188 await self ._process_op_err (errors .ProtocolError ())
20842189 break
2190+ except ConnectionResetError as e :
2191+ await self ._error_cb (e )
2192+ await self ._process_op_err (errors .ConnectionClosedError ())
2193+ break
20852194 except OSError as e :
2195+ await self ._error_cb (e )
20862196 await self ._process_op_err (e )
20872197 break
2198+ except asyncio .InvalidStateError :
2199+ if self .is_connected :
2200+ err = ErrStaleConnection ()
2201+ await self ._error_cb (err )
2202+ await self ._process_op_err (err )
2203+ break
20882204 except asyncio .CancelledError :
20892205 break
20902206 except Exception as ex :
2207+ await self ._error_cb (ex )
2208+ await self ._process_op_err (ex )
20912209 _logger .error ("nats: encountered error" , exc_info = ex )
20922210 break
2093- # except asyncio.InvalidStateError:
2094- # pass
20952211
20962212 async def __aenter__ (self ) -> "Client" :
20972213 """For when NATS client is used in a context manager"""
0 commit comments