From cc55ad65b3732df4c49a4ce49de21069bb3b73b3 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 13 Nov 2025 11:48:13 +0100 Subject: [PATCH 1/4] Perform LWT metadata protocol handshake --- cassandra/lwt_info.py | 21 +++++++++++++++++++++ cassandra/protocol_features.py | 27 +++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 cassandra/lwt_info.py diff --git a/cassandra/lwt_info.py b/cassandra/lwt_info.py new file mode 100644 index 0000000000..45750dbcec --- /dev/null +++ b/cassandra/lwt_info.py @@ -0,0 +1,21 @@ +# Copyright 2020 ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class _LwtInfo: + """ + Holds LWT-related information parsed from the server's supported features. + """ + + def __init__(self, lwt_meta_bit_mask): + self.lwt_meta_bit_mask = lwt_meta_bit_mask diff --git a/cassandra/protocol_features.py b/cassandra/protocol_features.py index 4eb7019f84..877998be7d 100644 --- a/cassandra/protocol_features.py +++ b/cassandra/protocol_features.py @@ -1,10 +1,13 @@ import logging from cassandra.shard_info import _ShardingInfo +from cassandra.lwt_info import _LwtInfo log = logging.getLogger(__name__) +LWT_ADD_METADATA_MARK = "SCYLLA_LWT_ADD_METADATA_MARK" +LWT_OPTIMIZATION_META_BIT_MASK = "LWT_OPTIMIZATION_META_BIT_MASK" RATE_LIMIT_ERROR_EXTENSION = "SCYLLA_RATE_LIMIT_ERROR" TABLETS_ROUTING_V1 = "TABLETS_ROUTING_V1" @@ -13,19 +16,22 @@ class ProtocolFeatures(object): shard_id = 0 sharding_info = None tablets_routing_v1 = False + lwt_info = None - def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False): + def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False, lwt_info=None): self.rate_limit_error = rate_limit_error self.shard_id = shard_id self.sharding_info = sharding_info self.tablets_routing_v1 = tablets_routing_v1 + self.lwt_info = lwt_info @staticmethod def parse_from_supported(supported): rate_limit_error = ProtocolFeatures.maybe_parse_rate_limit_error(supported) shard_id, sharding_info = ProtocolFeatures.parse_sharding_info(supported) tablets_routing_v1 = ProtocolFeatures.parse_tablets_info(supported) - return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1) + lwt_info = ProtocolFeatures.parse_lwt_info(supported) + return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1, lwt_info) @staticmethod def maybe_parse_rate_limit_error(supported): @@ -49,6 +55,8 @@ def add_startup_options(self, options): options[RATE_LIMIT_ERROR_EXTENSION] = "" if self.tablets_routing_v1: options[TABLETS_ROUTING_V1] = "" + if self.lwt_info is not None: + options[LWT_ADD_METADATA_MARK] = str(self.lwt_info.lwt_meta_bit_mask) @staticmethod def parse_sharding_info(options): @@ -72,3 +80,18 @@ def parse_sharding_info(options): @staticmethod def parse_tablets_info(options): return TABLETS_ROUTING_V1 in options + + @staticmethod + def parse_lwt_info(options): + value_list = options.get(LWT_ADD_METADATA_MARK, [None]) + for value in value_list: + if value is None or not value.startswith(LWT_OPTIMIZATION_META_BIT_MASK + "="): + continue + try: + lwt_meta_bit_mask = int(value[len(LWT_OPTIMIZATION_META_BIT_MASK + "="):]) + return _LwtInfo(lwt_meta_bit_mask) + except Exception as e: + log.exception(f"Error while parsing {LWT_ADD_METADATA_MARK}: {e}") + return None + + return None From 18982c66d25313eb002ee43f32569741557adc96 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 13 Nov 2025 13:26:54 +0100 Subject: [PATCH 2/4] Parse LWT flags when creating prepared statement --- cassandra/cluster.py | 4 +++- cassandra/lwt_info.py | 3 +++ cassandra/protocol.py | 12 ++++++++---- cassandra/query.py | 19 +++++++++++++++---- 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5822a23aa9..8c5cff8c99 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3109,7 +3109,9 @@ def prepare(self, query, custom_payload=None, keyspace=None): prepared_keyspace = keyspace if keyspace else None prepared_statement = PreparedStatement.from_message( response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace, - self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy) + self._protocol_version, response.column_metadata, response.result_metadata_id, + response.lwt_info.is_lwt(response.flags) if response.lwt_info is not None else False, + self.cluster.column_encryption_policy) prepared_statement.custom_payload = future.custom_payload self.cluster.add_prepared(response.query_id, prepared_statement) diff --git a/cassandra/lwt_info.py b/cassandra/lwt_info.py index 45750dbcec..561b56a082 100644 --- a/cassandra/lwt_info.py +++ b/cassandra/lwt_info.py @@ -19,3 +19,6 @@ class _LwtInfo: def __init__(self, lwt_meta_bit_mask): self.lwt_meta_bit_mask = lwt_meta_bit_mask + + def is_lwt(self, flags): + return (flags & self.lwt_meta_bit_mask) == self.lwt_meta_bit_mask diff --git a/cassandra/protocol.py b/cassandra/protocol.py index d8716f4eeb..8f7e07dde8 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -686,11 +686,13 @@ class ResultMessage(_MessageType): bind_metadata = None pk_indexes = None schema_change_event = None + flags = None + lwt_info = None def __init__(self, kind): self.kind = kind - def recv(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + def recv(self, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy): if self.kind == RESULT_KIND_VOID: return elif self.kind == RESULT_KIND_ROWS: @@ -698,7 +700,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encry elif self.kind == RESULT_KIND_SET_KEYSPACE: self.new_keyspace = read_string(f) elif self.kind == RESULT_KIND_PREPARED: - self.recv_results_prepared(f, protocol_version, user_type_map) + self.recv_results_prepared(f, protocol_version, protocol_features, user_type_map) elif self.kind == RESULT_KIND_SCHEMA_CHANGE: self.recv_results_schema_change(f, protocol_version) else: @@ -708,7 +710,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encry def recv_body(cls, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy): kind = read_int(f) msg = cls(kind) - msg.recv(f, protocol_version, user_type_map, result_metadata, column_encryption_policy) + msg.recv(f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy) return msg def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): @@ -741,8 +743,9 @@ def decode_row(row): col_md[3].cql_parameterized_type(), str(e))) - def recv_results_prepared(self, f, protocol_version, user_type_map): + def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map): self.query_id = read_binary_string(f) + self.lwt_info = protocol_features.lwt_info if ProtocolVersion.uses_prepared_metadata(protocol_version): self.result_metadata_id = read_binary_string(f) else: @@ -787,6 +790,7 @@ def recv_results_metadata(self, f, user_type_map): def recv_prepared_metadata(self, f, protocol_version, user_type_map): flags = read_int(f) + self.flags = flags colcount = read_int(f) pk_indexes = None if protocol_version >= 4: diff --git a/cassandra/query.py b/cassandra/query.py index f3922849ab..84f850233b 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -345,6 +345,9 @@ def _set_serial_consistency_level(self, serial_consistency_level): def _del_serial_consistency_level(self): self._serial_consistency_level = None + def is_lwt(self): + return False + serial_consistency_level = property( _get_serial_consistency_level, _set_serial_consistency_level, @@ -454,10 +457,11 @@ class PreparedStatement(object): routing_key_indexes = None _routing_key_index_set = None serial_consistency_level = None # TODO never used? + _is_lwt = False def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace, protocol_version, result_metadata, result_metadata_id, - column_encryption_policy=None): + is_lwt=None, column_encryption_policy=None): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes @@ -468,15 +472,16 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query, self.result_metadata_id = result_metadata_id self.column_encryption_policy = column_encryption_policy self.is_idempotent = False + self._is_lwt = is_lwt @classmethod def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, column_encryption_policy=None): + result_metadata_id, is_lwt, column_encryption_policy=None): if not column_metadata: return PreparedStatement(column_metadata, query_id, None, query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, column_encryption_policy) + result_metadata_id, is_lwt, column_encryption_policy) if pk_indexes: routing_key_indexes = pk_indexes @@ -502,7 +507,7 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, return PreparedStatement(column_metadata, query_id, routing_key_indexes, query, prepared_keyspace, protocol_version, result_metadata, - result_metadata_id, column_encryption_policy) + result_metadata_id, is_lwt, column_encryption_policy) def bind(self, values): """ @@ -517,6 +522,9 @@ def is_routing_key_index(self, i): self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set() return i in self._routing_key_index_set + def is_lwt(self): + return self._is_lwt + def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % @@ -682,6 +690,9 @@ def routing_key(self): return self._routing_key + def is_lwt(self): + return self.prepared_statement.is_lwt() + def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % From daf2780d442fb86f8a97d4e7f5c85ae986c28dad Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 13 Nov 2025 14:04:09 +0100 Subject: [PATCH 3/4] Add tests for lwt --- .../standard/test_prepared_statements.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/integration/standard/test_prepared_statements.py b/tests/integration/standard/test_prepared_statements.py index 68a704cd77..e6f86d835b 100644 --- a/tests/integration/standard/test_prepared_statements.py +++ b/tests/integration/standard/test_prepared_statements.py @@ -406,6 +406,29 @@ def test_raise_error_on_prepared_statement_execution_dropped_table(self): with pytest.raises(InvalidRequest): self.session.execute(prepared, [0]) + def test_recognize_lwt_query(self): + self.session.execute("CREATE TABLE IF NOT EXISTS preparedtests.bound_statement_test (a int PRIMARY KEY, b int)") + # Prepare a non-LWT statement + statementNonLWT = self.session.prepare("UPDATE preparedtests.bound_statement_test SET b = ? WHERE a = ?") + # Prepare an LWT statement + statementLWT = self.session.prepare("UPDATE preparedtests.bound_statement_test SET b = ? WHERE a = ? IF b = ?") + + boundNonLWT = statementNonLWT.bind((3, 1)) + boundLWT = statementLWT.bind((3, 1, 5)) + + # Check LWT detection + assert not boundNonLWT.is_lwt() + assert boundLWT.is_lwt() + + self.session.execute("CREATE TABLE IF NOT EXISTS preparedtests.prepared_statement_test (a int PRIMARY KEY, b int)") + # Prepare a non-LWT statement + statementNonLWT = self.session.prepare("UPDATE preparedtests.prepared_statement_test SET b = ? WHERE a = ?") + # Prepare an LWT statement + statementLWT = self.session.prepare("UPDATE preparedtests.prepared_statement_test SET b = ? WHERE a = ? IF b = ?") + # Check LWT detection + assert not statementNonLWT.is_lwt() + assert statementLWT.is_lwt() + @unittest.skipIf((CASSANDRA_VERSION >= Version('3.11.12') and CASSANDRA_VERSION < Version('4.0')) or \ CASSANDRA_VERSION >= Version('4.0.2'), "Fixed server-side in Cassandra 3.11.12, 4.0.2") From 0ac7a8eb847977313133f74a10a3ae4b1483c0af Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Thu, 13 Nov 2025 21:23:13 +0100 Subject: [PATCH 4/4] Don not shuffle replicas for LWT queries --- cassandra/policies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cassandra/policies.py b/cassandra/policies.py index d681980d77..8869a9ce30 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -523,7 +523,7 @@ def make_query_plan(self, working_keyspace=None, query=None): else: replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key) - if self.shuffle_replicas: + if self.shuffle_replicas and not query.is_lwt(): shuffle(replicas) def yield_in_order(hosts):