Skip to content

Commit 632f8c4

Browse files
committed
improve long running connection stability
1 parent 1d38f9b commit 632f8c4

File tree

1 file changed

+135
-19
lines changed

1 file changed

+135
-19
lines changed

nats/src/nats/aio/client.py

Lines changed: 135 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
DEFAULT_MAX_RECONNECT_ATTEMPTS = 60
9393
DEFAULT_PING_INTERVAL = 120 # in seconds
9494
DEFAULT_MAX_OUTSTANDING_PINGS = 2
95+
DEFAULT_MAX_READ_TIMEOUTS = 3
9596
DEFAULT_MAX_PAYLOAD_SIZE = 1048576
9697
DEFAULT_MAX_FLUSHER_QUEUE_SIZE = 1024
9798
DEFAULT_FLUSH_TIMEOUT = 10 # in seconds
@@ -341,6 +342,7 @@ async def connect(
341342
max_reconnect_attempts: int = DEFAULT_MAX_RECONNECT_ATTEMPTS,
342343
ping_interval: int = DEFAULT_PING_INTERVAL,
343344
max_outstanding_pings: int = DEFAULT_MAX_OUTSTANDING_PINGS,
345+
max_read_timeouts: int = DEFAULT_MAX_READ_TIMEOUTS,
344346
dont_randomize: bool = False,
345347
flusher_queue_size: int = DEFAULT_MAX_FLUSHER_QUEUE_SIZE,
346348
no_echo: bool = False,
@@ -371,6 +373,7 @@ async def connect(
371373
:param discovered_server_cb: Callback to report when a new server joins the cluster.
372374
:param pending_size: Max size of the pending buffer for publishing commands.
373375
:param flush_timeout: Max duration to wait for a forced flush to occur.
376+
:param max_read_timeouts: Maximum number of consecutive read timeouts before considering a connection stale.
374377
375378
Connecting setting all callbacks::
376379
@@ -488,6 +491,7 @@ async def subscribe_handler(msg):
488491
self.options["reconnect_time_wait"] = reconnect_time_wait
489492
self.options["max_reconnect_attempts"] = max_reconnect_attempts
490493
self.options["ping_interval"] = ping_interval
494+
self.options["max_read_timeouts"] = max_read_timeouts
491495
self.options["max_outstanding_pings"] = max_outstanding_pings
492496
self.options["no_echo"] = no_echo
493497
self.options["user"] = user
@@ -1030,14 +1034,38 @@ async def request(
10301034
the responses.
10311035
10321036
"""
1037+
if not self.is_connected:
1038+
await self._check_connection_health()
1039+
1040+
if not self.is_connected:
1041+
if self.is_closed:
1042+
raise errors.ConnectionClosedError
1043+
elif self.is_reconnecting:
1044+
raise errors.ConnectionReconnectingError
1045+
else:
1046+
raise errors.ConnectionClosedError
1047+
10331048
if old_style:
10341049
# FIXME: Support headers in old style requests.
1035-
return await self._request_old_style(subject, payload, timeout=timeout)
1050+
try:
1051+
return await self._request_old_style(
1052+
subject, payload, timeout=timeout
1053+
)
1054+
except (errors.TimeoutError, asyncio.TimeoutError):
1055+
await self._check_connection_health()
1056+
raise errors.TimeoutError
10361057
else:
1037-
msg = await self._request_new_style(subject, payload, timeout=timeout, headers=headers)
1038-
if msg.headers and msg.headers.get(nats.js.api.Header.STATUS) == NO_RESPONDERS_STATUS:
1039-
raise errors.NoRespondersError
1040-
return msg
1058+
try:
1059+
msg = await self._request_new_style(
1060+
subject, payload, timeout=timeout, headers=headers
1061+
)
1062+
if msg.headers and msg.headers.get(nats.js.api.Header.STATUS
1063+
) == NO_RESPONDERS_STATUS:
1064+
raise errors.NoRespondersError
1065+
return msg
1066+
except (errors.TimeoutError, asyncio.TimeoutError):
1067+
await self._check_connection_health()
1068+
raise errors.TimeoutError
10411069

10421070
async def _request_new_style(
10431071
self,
@@ -1049,6 +1077,9 @@ async def _request_new_style(
10491077
if self.is_draining_pubs:
10501078
raise errors.ConnectionDrainingError
10511079

1080+
if not self.is_connected:
1081+
raise errors.ConnectionClosedError
1082+
10521083
if not self._resp_sub_prefix:
10531084
await self._init_request_sub()
10541085
assert self._resp_sub_prefix
@@ -1061,17 +1092,37 @@ async def _request_new_style(
10611092

10621093
# Then use the future to get the response.
10631094
future: asyncio.Future = asyncio.Future()
1064-
future.add_done_callback(lambda f: self._resp_map.pop(token.decode(), None))
1065-
self._resp_map[token.decode()] = future
1095+
token_str = token.decode()
1096+
1097+
def cleanup_resp_map(f):
1098+
self._resp_map.pop(token_str, None)
10661099

1067-
# Publish the request
1068-
await self.publish(subject, payload, reply=inbox.decode(), headers=headers)
1100+
future.add_done_callback(cleanup_resp_map)
1101+
self._resp_map[token_str] = future
10691102

1070-
# Wait for the response or give up on timeout.
10711103
try:
1072-
return await asyncio.wait_for(future, timeout)
1073-
except asyncio.TimeoutError:
1074-
raise errors.TimeoutError
1104+
# Publish the request
1105+
await self.publish(
1106+
subject, payload, reply=inbox.decode(), headers=headers
1107+
)
1108+
1109+
if not self.is_connected:
1110+
future.cancel()
1111+
raise errors.ConnectionClosedError
1112+
1113+
try:
1114+
return await asyncio.wait_for(future, timeout)
1115+
except asyncio.TimeoutError:
1116+
cleanup_resp_map(future)
1117+
raise errors.TimeoutError
1118+
except asyncio.CancelledError:
1119+
cleanup_resp_map(future)
1120+
raise
1121+
except Exception:
1122+
if not future.done():
1123+
future.cancel()
1124+
cleanup_resp_map(future)
1125+
raise
10751126

10761127
def new_inbox(self) -> str:
10771128
"""
@@ -1397,6 +1448,35 @@ async def _process_err(self, err_msg: str) -> None:
13971448
# For now we handle similar as other clients and close.
13981449
asyncio.create_task(self._close(Client.CLOSED, do_cbs))
13991450

1451+
async def _check_connection_health(self) -> bool:
1452+
"""
1453+
Checks if the connection appears healthy, and if not, attempts reconnection.
1454+
1455+
Returns:
1456+
bool: True if connection is healthy or was successfully reconnected, False otherwise
1457+
"""
1458+
if not self.is_connected:
1459+
if self.options[
1460+
"allow_reconnect"
1461+
] and not self.is_reconnecting and not self.is_closed:
1462+
self._status = Client.RECONNECTING
1463+
self._ps.reset()
1464+
1465+
try:
1466+
if self._reconnection_task is not None and not self._reconnection_task.cancelled(
1467+
):
1468+
self._reconnection_task.cancel()
1469+
1470+
self._reconnection_task = asyncio.get_running_loop(
1471+
).create_task(self._attempt_reconnect())
1472+
1473+
await asyncio.sleep(self.options["reconnect_time_wait"])
1474+
return self.is_connected
1475+
except Exception:
1476+
return False
1477+
return False
1478+
return True
1479+
14001480
async def _process_op_err(self, e: Exception) -> None:
14011481
"""
14021482
Process errors which occurred while reading or parsing
@@ -2056,8 +2136,16 @@ async def _ping_interval(self) -> None:
20562136
await self._send_ping()
20572137
except (asyncio.CancelledError, RuntimeError, AttributeError):
20582138
break
2059-
# except asyncio.InvalidStateError:
2060-
# pass
2139+
except asyncio.InvalidStateError:
2140+
# Handle invalid state errors that can occur when connection state changes
2141+
if self.is_connected:
2142+
await self._process_op_err(ErrStaleConnection())
2143+
break
2144+
except Exception as e:
2145+
if self.is_connected:
2146+
await self._error_cb(e)
2147+
await self._process_op_err(ErrStaleConnection())
2148+
break
20612149

20622150
async def _read_loop(self) -> None:
20632151
"""
@@ -2066,6 +2154,8 @@ async def _read_loop(self) -> None:
20662154
In case of error while reading, it will stop running
20672155
and its task has to be rescheduled.
20682156
"""
2157+
read_timeout_count = 0
2158+
20692159
while True:
20702160
try:
20712161
should_bail = self.is_closed or self.is_reconnecting
@@ -2077,21 +2167,47 @@ async def _read_loop(self) -> None:
20772167
await self._process_op_err(err)
20782168
break
20792169

2080-
b = await self._transport.read(DEFAULT_BUFFER_SIZE)
2081-
await self._ps.parse(b)
2170+
# Use a timeout for reading to detect stalled connections
2171+
try:
2172+
read_future = self._transport.read(DEFAULT_BUFFER_SIZE)
2173+
b = await asyncio.wait_for(
2174+
read_future, timeout=self.options["ping_interval"]
2175+
)
2176+
read_timeout_count = 0
2177+
await self._ps.parse(b)
2178+
except asyncio.TimeoutError:
2179+
read_timeout_count += 1
2180+
if read_timeout_count >= self.options["max_read_timeouts"]:
2181+
err = ErrStaleConnection()
2182+
await self._error_cb(err)
2183+
await self._process_op_err(err)
2184+
break
2185+
continue
2186+
20822187
except errors.ProtocolError:
20832188
await self._process_op_err(errors.ProtocolError())
20842189
break
2190+
except ConnectionResetError as e:
2191+
await self._error_cb(e)
2192+
await self._process_op_err(errors.ConnectionClosedError())
2193+
break
20852194
except OSError as e:
2195+
await self._error_cb(e)
20862196
await self._process_op_err(e)
20872197
break
2198+
except asyncio.InvalidStateError:
2199+
if self.is_connected:
2200+
err = ErrStaleConnection()
2201+
await self._error_cb(err)
2202+
await self._process_op_err(err)
2203+
break
20882204
except asyncio.CancelledError:
20892205
break
20902206
except Exception as ex:
2207+
await self._error_cb(ex)
2208+
await self._process_op_err(ex)
20912209
_logger.error("nats: encountered error", exc_info=ex)
20922210
break
2093-
# except asyncio.InvalidStateError:
2094-
# pass
20952211

20962212
async def __aenter__(self) -> "Client":
20972213
"""For when NATS client is used in a context manager"""

0 commit comments

Comments
 (0)