Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3109,7 +3109,9 @@
prepared_keyspace = keyspace if keyspace else None
prepared_statement = PreparedStatement.from_message(
response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace,
self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy)
self._protocol_version, response.column_metadata, response.result_metadata_id,
response.lwt_info.is_lwt(response.flags) if response.lwt_info is not None else False,
self.cluster.column_encryption_policy)
prepared_statement.custom_payload = future.custom_payload

self.cluster.add_prepared(response.query_id, prepared_statement)
Expand Down Expand Up @@ -4316,7 +4318,7 @@
fn, args, kwargs = task
kwargs = dict(kwargs)
future = self._executor.submit(fn, *args, **kwargs)
future.add_done_callback(self._log_if_failed)

Check failure on line 4321 in cassandra/cluster.py

View workflow job for this annotation

GitHub Actions / test libev (3.12)

cannot schedule new futures after shutdown

Check failure on line 4321 in cassandra/cluster.py

View workflow job for this annotation

GitHub Actions / test libev (3.11)

cannot schedule new futures after shutdown
else:
self._queue.put_nowait((run_at, i, task))
break
Expand Down
24 changes: 24 additions & 0 deletions cassandra/lwt_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2020 ScyllaDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

class _LwtInfo:
"""
Holds LWT-related information parsed from the server's supported features.
"""

def __init__(self, lwt_meta_bit_mask):
self.lwt_meta_bit_mask = lwt_meta_bit_mask

def is_lwt(self, flags):
return (flags & self.lwt_meta_bit_mask) == self.lwt_meta_bit_mask
2 changes: 1 addition & 1 deletion cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def make_query_plan(self, working_keyspace=None, query=None):
else:
replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key)

if self.shuffle_replicas:
if self.shuffle_replicas and not query.is_lwt():
shuffle(replicas)

def yield_in_order(hosts):
Expand Down
12 changes: 8 additions & 4 deletions cassandra/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,19 +686,21 @@ class ResultMessage(_MessageType):
bind_metadata = None
pk_indexes = None
schema_change_event = None
flags = None
lwt_info = None

def __init__(self, kind):
self.kind = kind

def recv(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
def recv(self, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy):
if self.kind == RESULT_KIND_VOID:
return
elif self.kind == RESULT_KIND_ROWS:
self.recv_results_rows(f, protocol_version, user_type_map, result_metadata, column_encryption_policy)
elif self.kind == RESULT_KIND_SET_KEYSPACE:
self.new_keyspace = read_string(f)
elif self.kind == RESULT_KIND_PREPARED:
self.recv_results_prepared(f, protocol_version, user_type_map)
self.recv_results_prepared(f, protocol_version, protocol_features, user_type_map)
elif self.kind == RESULT_KIND_SCHEMA_CHANGE:
self.recv_results_schema_change(f, protocol_version)
else:
Expand All @@ -708,7 +710,7 @@ def recv(self, f, protocol_version, user_type_map, result_metadata, column_encry
def recv_body(cls, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy):
kind = read_int(f)
msg = cls(kind)
msg.recv(f, protocol_version, user_type_map, result_metadata, column_encryption_policy)
msg.recv(f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy)
return msg

def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy):
Expand Down Expand Up @@ -741,8 +743,9 @@ def decode_row(row):
col_md[3].cql_parameterized_type(),
str(e)))

def recv_results_prepared(self, f, protocol_version, user_type_map):
def recv_results_prepared(self, f, protocol_version, protocol_features, user_type_map):
self.query_id = read_binary_string(f)
self.lwt_info = protocol_features.lwt_info
if ProtocolVersion.uses_prepared_metadata(protocol_version):
self.result_metadata_id = read_binary_string(f)
else:
Expand Down Expand Up @@ -787,6 +790,7 @@ def recv_results_metadata(self, f, user_type_map):

def recv_prepared_metadata(self, f, protocol_version, user_type_map):
flags = read_int(f)
self.flags = flags
colcount = read_int(f)
pk_indexes = None
if protocol_version >= 4:
Expand Down
27 changes: 25 additions & 2 deletions cassandra/protocol_features.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging

from cassandra.shard_info import _ShardingInfo
from cassandra.lwt_info import _LwtInfo

log = logging.getLogger(__name__)


LWT_ADD_METADATA_MARK = "SCYLLA_LWT_ADD_METADATA_MARK"
LWT_OPTIMIZATION_META_BIT_MASK = "LWT_OPTIMIZATION_META_BIT_MASK"
RATE_LIMIT_ERROR_EXTENSION = "SCYLLA_RATE_LIMIT_ERROR"
TABLETS_ROUTING_V1 = "TABLETS_ROUTING_V1"

Expand All @@ -13,19 +16,22 @@ class ProtocolFeatures(object):
shard_id = 0
sharding_info = None
tablets_routing_v1 = False
lwt_info = None

def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False):
def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False, lwt_info=None):
self.rate_limit_error = rate_limit_error
self.shard_id = shard_id
self.sharding_info = sharding_info
self.tablets_routing_v1 = tablets_routing_v1
self.lwt_info = lwt_info

@staticmethod
def parse_from_supported(supported):
rate_limit_error = ProtocolFeatures.maybe_parse_rate_limit_error(supported)
shard_id, sharding_info = ProtocolFeatures.parse_sharding_info(supported)
tablets_routing_v1 = ProtocolFeatures.parse_tablets_info(supported)
return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1)
lwt_info = ProtocolFeatures.parse_lwt_info(supported)
return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1, lwt_info)

@staticmethod
def maybe_parse_rate_limit_error(supported):
Expand All @@ -49,6 +55,8 @@ def add_startup_options(self, options):
options[RATE_LIMIT_ERROR_EXTENSION] = ""
if self.tablets_routing_v1:
options[TABLETS_ROUTING_V1] = ""
if self.lwt_info is not None:
options[LWT_ADD_METADATA_MARK] = str(self.lwt_info.lwt_meta_bit_mask)

@staticmethod
def parse_sharding_info(options):
Expand All @@ -72,3 +80,18 @@ def parse_sharding_info(options):
@staticmethod
def parse_tablets_info(options):
return TABLETS_ROUTING_V1 in options

@staticmethod
def parse_lwt_info(options):
value_list = options.get(LWT_ADD_METADATA_MARK, [None])
for value in value_list:
if value is None or not value.startswith(LWT_OPTIMIZATION_META_BIT_MASK + "="):
continue
try:
lwt_meta_bit_mask = int(value[len(LWT_OPTIMIZATION_META_BIT_MASK + "="):])
return _LwtInfo(lwt_meta_bit_mask)
except Exception as e:
log.exception(f"Error while parsing {LWT_ADD_METADATA_MARK}: {e}")
return None

return None
19 changes: 15 additions & 4 deletions cassandra/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ def _set_serial_consistency_level(self, serial_consistency_level):
def _del_serial_consistency_level(self):
self._serial_consistency_level = None

def is_lwt(self):
return False

serial_consistency_level = property(
_get_serial_consistency_level,
_set_serial_consistency_level,
Expand Down Expand Up @@ -454,10 +457,11 @@ class PreparedStatement(object):
routing_key_indexes = None
_routing_key_index_set = None
serial_consistency_level = None # TODO never used?
_is_lwt = False

def __init__(self, column_metadata, query_id, routing_key_indexes, query,
keyspace, protocol_version, result_metadata, result_metadata_id,
column_encryption_policy=None):
is_lwt=None, column_encryption_policy=None):
self.column_metadata = column_metadata
self.query_id = query_id
self.routing_key_indexes = routing_key_indexes
Expand All @@ -468,15 +472,16 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query,
self.result_metadata_id = result_metadata_id
self.column_encryption_policy = column_encryption_policy
self.is_idempotent = False
self._is_lwt = is_lwt

@classmethod
def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
query, prepared_keyspace, protocol_version, result_metadata,
result_metadata_id, column_encryption_policy=None):
result_metadata_id, is_lwt, column_encryption_policy=None):
if not column_metadata:
return PreparedStatement(column_metadata, query_id, None,
query, prepared_keyspace, protocol_version, result_metadata,
result_metadata_id, column_encryption_policy)
result_metadata_id, is_lwt, column_encryption_policy)

if pk_indexes:
routing_key_indexes = pk_indexes
Expand All @@ -502,7 +507,7 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,

return PreparedStatement(column_metadata, query_id, routing_key_indexes,
query, prepared_keyspace, protocol_version, result_metadata,
result_metadata_id, column_encryption_policy)
result_metadata_id, is_lwt, column_encryption_policy)

def bind(self, values):
"""
Expand All @@ -517,6 +522,9 @@ def is_routing_key_index(self, i):
self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set()
return i in self._routing_key_index_set

def is_lwt(self):
return self._is_lwt

def __str__(self):
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
return (u'<PreparedStatement query="%s", consistency=%s>' %
Expand Down Expand Up @@ -682,6 +690,9 @@ def routing_key(self):

return self._routing_key

def is_lwt(self):
return self.prepared_statement.is_lwt()

def __str__(self):
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
return (u'<BoundStatement query="%s", values=%s, consistency=%s>' %
Expand Down
23 changes: 23 additions & 0 deletions tests/integration/standard/test_prepared_statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,29 @@ def test_raise_error_on_prepared_statement_execution_dropped_table(self):
with pytest.raises(InvalidRequest):
self.session.execute(prepared, [0])

def test_recognize_lwt_query(self):
self.session.execute("CREATE TABLE IF NOT EXISTS preparedtests.bound_statement_test (a int PRIMARY KEY, b int)")
# Prepare a non-LWT statement
statementNonLWT = self.session.prepare("UPDATE preparedtests.bound_statement_test SET b = ? WHERE a = ?")
# Prepare an LWT statement
statementLWT = self.session.prepare("UPDATE preparedtests.bound_statement_test SET b = ? WHERE a = ? IF b = ?")

boundNonLWT = statementNonLWT.bind((3, 1))
boundLWT = statementLWT.bind((3, 1, 5))

# Check LWT detection
assert not boundNonLWT.is_lwt()
assert boundLWT.is_lwt()

self.session.execute("CREATE TABLE IF NOT EXISTS preparedtests.prepared_statement_test (a int PRIMARY KEY, b int)")
# Prepare a non-LWT statement
statementNonLWT = self.session.prepare("UPDATE preparedtests.prepared_statement_test SET b = ? WHERE a = ?")
# Prepare an LWT statement
statementLWT = self.session.prepare("UPDATE preparedtests.prepared_statement_test SET b = ? WHERE a = ? IF b = ?")
# Check LWT detection
assert not statementNonLWT.is_lwt()
assert statementLWT.is_lwt()

@unittest.skipIf((CASSANDRA_VERSION >= Version('3.11.12') and CASSANDRA_VERSION < Version('4.0')) or \
CASSANDRA_VERSION >= Version('4.0.2'),
"Fixed server-side in Cassandra 3.11.12, 4.0.2")
Expand Down
Loading