From f9cc550d91b919addd358d37ffb27bc7a988753e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roberto=20Santamar=C3=ADa?= Date: Tue, 8 Aug 2023 13:19:03 +0000 Subject: [PATCH 1/7] feat(cluster): support for transactions on cluster-aware client Adds support for transactions based on multi/watch/exec on clusters. Transactions in this mode are limited to a single hash slot. Contributed-by: Scopely --- .gitignore | 1 + CHANGES | 1 + docs/advanced_features.rst | 56 ++- redis/__init__.py | 6 + redis/client.py | 15 +- redis/cluster.py | 475 ++++++++++++++++++++++--- redis/exceptions.py | 18 + tests/test_cluster.py | 34 +- tests/test_cluster_transaction.py | 572 ++++++++++++++++++++++++++++++ tests/test_pipeline.py | 1 - 10 files changed, 1081 insertions(+), 98 deletions(-) create mode 100644 tests/test_cluster_transaction.py diff --git a/.gitignore b/.gitignore index 5f77dcfde4..7184ad4e20 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ vagrant/.vagrant .cache .eggs .idea +.vscode .coverage env venv diff --git a/CHANGES b/CHANGES index 24b52c54db..7f4e08fb31 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Support transactions in ClusterPipeline (originally developed by Scopely and contributed under the MIT License) * Removing support for RedisGraph module. RedisGraph support is deprecated since Redis Stack 7.2 (https://redis.com/blog/redisgraph-eol/) * Fix lock.extend() typedef to accept float TTL extension * Update URL in the readme linking to Redis University diff --git a/docs/advanced_features.rst b/docs/advanced_features.rst index 603e728e84..89ec3fcd43 100644 --- a/docs/advanced_features.rst +++ b/docs/advanced_features.rst @@ -177,20 +177,48 @@ the server. ... pipe.set('foo1', 'bar1').get('foo1').execute() [True, b'bar1'] -Please note: - RedisCluster pipelines currently only support key-based -commands. - The pipeline gets its ‘read_from_replicas’ value from the -cluster’s parameter. Thus, if read from replications is enabled in the -cluster instance, the pipeline will also direct read commands to -replicas. - The ‘transaction’ option is NOT supported in cluster-mode. -In non-cluster mode, the ‘transaction’ option is available when -executing pipelines. This wraps the pipeline commands with MULTI/EXEC -commands, and effectively turns the pipeline commands into a single -transaction block. This means that all commands are executed -sequentially without any interruptions from other clients. However, in -cluster-mode this is not possible, because commands are partitioned -according to their respective destination nodes. This means that we can -not turn the pipeline commands into one transaction block, because in -most cases they are split up into several smaller pipelines. +Please note: + +- RedisCluster pipelines currently only support key-based commands. +- The pipeline gets its ‘read_from_replicas’ value from the + cluster’s parameter. Thus, if read from replications is enabled in + the cluster instance, the pipeline will also direct read commands to + replicas. + + +Transactions in clusters +~~~~~~~~~~~~~~~~~~~~~~~~ + +Transactions are supported in cluster-mode with one caveat: all keys of +all commands issued on a transaction pipeline must reside on the +same slot. This is similar to the limitation of multikey commands in +cluster. The reason behind this is that the Redis engine does not offer +a mechanism to block or exchange key data across nodes on the fly. A +client may add some logic to abstract engine limitations when running +on a cluster, such as the pipeline behavior explained on the previous +block, but there is no simple way that a client can enforce atomicity +across nodes on a distributed system. + +The compromise of limiting the transaction pipeline to same-slot keys +is exactly that: a compromise. While this behavior is differnet from +non-transactional cluster pipelines, it simplifies migration of clients +from standalone to cluster under some circumstances. Note that application +code that issues multi/exec commands on a standalone client without +embedding them within a pipeline would eventually get ‘AttributeError’s. +With this approach, if the application uses ‘client.pipeline(transaction=True)’, +then switching the client with a cluster-aware instance would simplify +code changes (to some extent). This may be true for application code that +makes use of hash keys, since its transactions may are already be +mapping all commands to the same slot. + +An alternative is some kind of two-step commit solution, where a slot +validation is run before the actual commands are run. This could work +with controlled node maintenance but does not cover single node failures. + +Cluster transaction support (pipeline/multi/exec) was originally developed by +Scopely and contributed to redis-py under the MIT License. + + Publish / Subscribe ------------------- diff --git a/redis/__init__.py b/redis/__init__.py index f82a876b2d..14030205e3 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -16,11 +16,14 @@ BusyLoadingError, ChildDeadlockedError, ConnectionError, + CrossSlotTransactionError, DataError, + InvalidPipelineStack, InvalidResponse, OutOfMemoryError, PubSubError, ReadOnlyError, + RedisClusterException, RedisError, ResponseError, TimeoutError, @@ -56,15 +59,18 @@ def int_or_str(value): "ConnectionError", "ConnectionPool", "CredentialProvider", + "CrossSlotTransactionError", "DataError", "from_url", "default_backoff", + "InvalidPipelineStack", "InvalidResponse", "OutOfMemoryError", "PubSubError", "ReadOnlyError", "Redis", "RedisCluster", + "RedisClusterException", "RedisError", "ResponseError", "Sentinel", diff --git a/redis/client.py b/redis/client.py index fda927507a..04be405c53 100755 --- a/redis/client.py +++ b/redis/client.py @@ -31,6 +31,7 @@ ) from redis.connection import ( AbstractConnection, + Connection, ConnectionPool, SSLConnection, UnixDomainSocketConnection, @@ -1279,9 +1280,15 @@ class Pipeline(Redis): UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} - def __init__(self, connection_pool, response_callbacks, transaction, shard_hint): + def __init__( + self, + connection_pool: ConnectionPool, + response_callbacks, + transaction, + shard_hint, + ): self.connection_pool = connection_pool - self.connection = None + self.connection: Optional[Connection] = None self.response_callbacks = response_callbacks self.transaction = transaction self.shard_hint = shard_hint @@ -1414,7 +1421,9 @@ def pipeline_execute_command(self, *args, **options) -> "Pipeline": self.command_stack.append((args, options)) return self - def _execute_transaction(self, connection, commands, raise_on_error) -> List: + def _execute_transaction( + self, connection: Connection, commands, raise_on_error + ) -> List: cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})]) all_cmds = connection.pack_commands( [args for args, options in cmds if EMPTY_RESPONSE not in options] diff --git a/redis/cluster.py b/redis/cluster.py index 39b454babe..e4388885fa 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -5,16 +5,21 @@ import time from collections import OrderedDict from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from itertools import chain +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff from redis.cache import CacheConfig, CacheFactory, CacheFactoryInterface, CacheInterface -from redis.client import CaseInsensitiveDict, PubSub, Redis +from redis.client import EMPTY_RESPONSE, CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args -from redis.connection import ConnectionPool, parse_url +from redis.connection import ( + Connection, + ConnectionPool, + parse_url, +) from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.event import ( AfterPooledConnectionsInstantiationEvent, @@ -28,7 +33,10 @@ ClusterDownError, ClusterError, ConnectionError, + CrossSlotTransactionError, DataError, + ExecAbortError, + InvalidPipelineStack, MovedError, RedisClusterException, RedisError, @@ -36,6 +44,7 @@ SlotNotCoveredError, TimeoutError, TryAgainError, + WatchError, ) from redis.lock import Lock from redis.retry import Retry @@ -60,7 +69,7 @@ def get_node_name(host: str, port: Union[str, int]) -> str: reason="Use get_connection(redis_node) instead", version="5.0.3", ) -def get_connection(redis_node, *args, **options): +def get_connection(redis_node: Redis, *args, **options) -> Connection: return redis_node.connection or redis_node.connection_pool.get_connection() @@ -708,7 +717,7 @@ def on_connect(self, connection): if self.user_on_connect_func is not None: self.user_on_connect_func(connection) - def get_redis_connection(self, node): + def get_redis_connection(self, node: "ClusterNode") -> Redis: if not node.redis_connection: with self._lock: if not node.redis_connection: @@ -811,9 +820,6 @@ def pipeline(self, transaction=None, shard_hint=None): if shard_hint: raise RedisClusterException("shard_hint is deprecated in cluster mode") - if transaction: - raise RedisClusterException("transaction is deprecated in cluster mode") - return ClusterPipeline( nodes_manager=self.nodes_manager, commands_parser=self.commands_parser, @@ -825,6 +831,7 @@ def pipeline(self, transaction=None, shard_hint=None): load_balancing_strategy=self.load_balancing_strategy, reinitialize_steps=self.reinitialize_steps, lock=self._lock, + transaction=transaction, ) def lock( @@ -986,7 +993,7 @@ def _get_command_keys(self, *args): redis_conn = self.get_default_node().redis_connection return self.commands_parser.get_keys(redis_conn, *args) - def determine_slot(self, *args): + def determine_slot(self, *args) -> int: """ Figure out what slot to use based on args. @@ -1297,6 +1304,28 @@ def load_external_module(self, funcname, func): """ setattr(self, funcname, func) + def transaction(self, func, *watches, **kwargs): + """ + Convenience method for executing the callable `func` as a transaction + while watching all keys specified in `watches`. The 'func' callable + should expect a single argument which is a Pipeline object. + """ + shard_hint = kwargs.pop("shard_hint", None) + value_from_callable = kwargs.pop("value_from_callable", False) + watch_delay = kwargs.pop("watch_delay", None) + with self.pipeline(True, shard_hint) as pipe: + while True: + try: + if watches: + pipe.watch(*watches) + func_value = func(pipe) + exec_value = pipe.execute() + return func_value if value_from_callable else exec_value + except WatchError: + if watch_delay is not None and watch_delay > 0: + time.sleep(watch_delay) + continue + class ClusterNode: def __init__(self, host, port, server_type=None, redis_connection=None): @@ -1390,7 +1419,7 @@ def __init__( event_dispatcher: Optional[EventDispatcher] = None, **kwargs, ): - self.nodes_cache = {} + self.nodes_cache: Dict[str, Redis] = {} self.slots_cache = {} self.startup_nodes = {} self.default_node = None @@ -1490,7 +1519,7 @@ def get_node_from_slot( read_from_replicas=False, load_balancing_strategy=None, server_type=None, - ): + ) -> ClusterNode: """ Gets a node that servers this hash slot """ @@ -1769,6 +1798,16 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: return self.address_remap((host, port)) return host, port + def find_connection_owner(self, connection: Connection) -> Optional[Redis]: + node_name = get_node_name(connection.host, connection.port) + for node in self.nodes_cache.values(): + if node.redis_connection: + conn_args = node.redis_connection.connection_pool.connection_kwargs + if node_name == get_node_name( + conn_args.get("host"), conn_args.get("port") + ): + return node + class ClusterPubSub(PubSub): """ @@ -2028,6 +2067,10 @@ class ClusterPipeline(RedisCluster): TryAgainError, ) + NO_SLOTS_COMMANDS = {"UNWATCH"} + IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} + UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} + def __init__( self, nodes_manager: "NodesManager", @@ -2040,10 +2083,11 @@ def __init__( cluster_error_retry_attempts: int = 3, reinitialize_steps: int = 5, lock=None, + transaction=False, **kwargs, ): """ """ - self.command_stack = [] + self.command_stack: List[PipelineCommand] = [] self.nodes_manager = nodes_manager self.commands_parser = commands_parser self.refresh_table_asap = False @@ -2066,6 +2110,14 @@ def __init__( if lock is None: lock = threading.Lock() self._lock = lock + self.transaction = transaction + self.explicit_transaction = False + self.watching = False + self.transaction_connection: Optional[Connection] = None + self.pipeline_slots: Set[int] = set() + self.slot_migrating = False + self.cluster_error = False + self.executing = False def __repr__(self): """ """ @@ -2093,21 +2145,107 @@ def __bool__(self): "Pipeline instances should always evaluate to True on Python 3+" return True + def _get_client_and_connection_for_transaction(self) -> Tuple[Redis, Connection]: + """ + Find a connection for a pipeline transaction. + + For running an atomic transaction, watch keys ensure that contents have not been + altered as long as the watch commands for those keys were sent over the same + connection. So once we start watching a key, we fetch a connection to the + node that owns that slot and reuse it. + """ + if not self.pipeline_slots: + raise RedisClusterException( + "At least a command with a key is needed to identify a node" + ) + + node: ClusterNode = self.nodes_manager.get_node_from_slot( + list(self.pipeline_slots)[0], False + ) + redis_node: Redis = self.get_redis_connection(node) + if self.transaction_connection: + if not redis_node.connection_pool.owns_connection( + self.transaction_connection + ): + previous_node = self.nodes_manager.find_connection_owner( + self.transaction_connection + ) + previous_node.connection_pool.release(self.transaction_connection) + self.transaction_connection = None + + if not self.transaction_connection: + self.transaction_connection = get_connection(redis_node, ("INFO",)) + + return redis_node, self.transaction_connection + def execute_command(self, *args, **kwargs): """ Wrapper function for pipeline_execute_command """ - return self.pipeline_execute_command(*args, **kwargs) + slot_number: Optional[int] = None + if args[0] not in self.NO_SLOTS_COMMANDS: + slot_number = self.determine_slot(*args) + + if ( + self.watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS + ) and not self.explicit_transaction: + if args[0] == "WATCH": + self._validate_watch() + + if slot_number is not None: + if self.pipeline_slots and slot_number not in self.pipeline_slots: + raise CrossSlotTransactionError( + "Cannot watch or send commands on different slots" + ) + + self.pipeline_slots.add(slot_number) + elif args[0] not in self.NO_SLOTS_COMMANDS: + raise RedisClusterException( + f"Cannot identify slot number for command: {args[0]}," + "it cannot be triggered in a transaction" + ) + + return self.immediate_execute_command(*args, **kwargs) + else: + if slot_number is not None: + self.pipeline_slots.add(slot_number) + + return self.pipeline_execute_command(*args, **kwargs) def pipeline_execute_command(self, *args, **options): """ - Appends the executed command to the pipeline's command stack + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. """ self.command_stack.append( PipelineCommand(args, options, len(self.command_stack)) ) return self + def _get_connection_and_send_command(self, *args, **options): + redis_node, connection = self._get_client_and_connection_for_transaction() + return self._send_command_parse_response( + connection, redis_node, args[0], *args, **options + ) + + def immediate_execute_command(self, *args, **options): + retry = Retry( + default_backoff(), + self.cluster_error_retry_attempts, + ) + retry.update_supported_errors([AskError, MovedError]) + return retry.call_with_retry( + lambda: self._get_connection_and_send_command(*args, **options), + self._reinitialize_on_error, + ) + def raise_first_error(self, stack): """ Raise the first exception on the stack @@ -2118,6 +2256,15 @@ def raise_first_error(self, stack): self.annotate_exception(r, c.position + 1, c.args) raise r + def raise_first_transaction_error(self, responses, stack): + """ + Raise the first exception on the stack + """ + for r, cmd in zip(responses, stack): + if isinstance(r, Exception): + self.annotate_exception(r, cmd.position + 1, cmd.args) + raise r + def annotate_exception(self, exception, number, command): """ Provides extra context to the exception prior to it being handled @@ -2134,8 +2281,14 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: Execute all the commands in the current pipeline """ stack = self.command_stack + if not stack and (not self.watching or not self.pipeline_slots): + return [] + try: - return self.send_cluster_commands(stack, raise_on_error) + if self.transaction or self.explicit_transaction: + return self._execute_transaction_with_retries(stack, raise_on_error) + else: + return self.send_cluster_commands(stack, raise_on_error) finally: self.reset() @@ -2147,29 +2300,35 @@ def reset(self): self.scripts = set() - # TODO: Implement # make sure to reset the connection state in the event that we were # watching something - # if self.watching and self.connection: - # try: - # # call this manually since our unwatch or - # # immediate_execute_command methods can call reset() - # self.connection.send_command('UNWATCH') - # self.connection.read_response() - # except ConnectionError: - # # disconnect will also remove any previous WATCHes - # self.connection.disconnect() + if self.transaction_connection: + try: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + self.transaction_connection.send_command("UNWATCH") + self.transaction_connection.read_response() + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + node = self.nodes_manager.find_connection_owner( + self.transaction_connection + ) + node.redis_connection.connection_pool.release( + self.transaction_connection + ) + self.transaction_connection = None + except ConnectionError: + # disconnect will also remove any previous WATCHes + if self.transaction_connection: + self.transaction_connection.disconnect() # clean up the other instance attributes self.watching = False self.explicit_transaction = False - - # TODO: Implement - # we can safely return the connection to the pool here since we're - # sure we're no longer WATCHing anything - # if self.connection: - # self.connection_pool.release(self.connection) - # self.connection = None + self.pipeline_slots = set() + self.slot_migrating = False + self.cluster_error = False + self.executing = False def send_cluster_commands( self, stack, raise_on_error=True, allow_redirections=True @@ -2397,30 +2556,242 @@ def eval(self): raise RedisClusterException("method eval() is not implemented") def multi(self): - """ """ - raise RedisClusterException("method multi() is not implemented") + """ + Start a transactional block of the pipeline after WATCH commands + are issued. End the transactional block with `execute`. + """ - def immediate_execute_command(self, *args, **options): - """ """ - raise RedisClusterException( - "method immediate_execute_command() is not implemented" + # Cluster transaction support (pipeline/multi/exec) originally developed + # by Scopely and contributed to redis-py under the MIT License. + + if self.explicit_transaction: + raise RedisError("Cannot issue nested calls to MULTI") + if self.command_stack: + raise RedisError( + "Commands without an initial WATCH have already been issued" + ) + self.explicit_transaction = True + + def _send_command_parse_response( + self, conn, redis_node: Redis, command_name, *args, **options + ): + """ + Send a command and parse the response + """ + + self.slot_migrating = False + try: + conn.send_command(*args) + output = redis_node.parse_response(conn, command_name, **options) + + except (AskError, MovedError) as slot_error: + self.slot_migrating = True + raise slot_error + + if command_name in self.UNWATCH_COMMANDS: + self.watching = False + return output + + def _disconnect_reset_raise(self, conn, error): + """ + Close the connection, reset watching state and + raise an exception if we were watching, + retry_on_timeout is not set, + or the error is not a TimeoutError + """ + if not conn: + conn = self.transaction_connection + + if conn: + conn.disconnect() + + # if we were already watching a variable, the watch is no longer + # valid since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + self.reset() + raise WatchError( + "A ConnectionError occurred on while watching one or more keys" + ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn and conn.retry_on_timeout and isinstance(error, TimeoutError)): + self.reset() + raise + + def _reinitialize_on_error(self, error): + if self.watching: + if self.slot_migrating and self.executing: + raise WatchError("Slot rebalancing ocurred while watching keys") + if self.cluster_error: + raise RedisClusterException("Cluster error ocurred while watching keys") + + if self.slot_migrating or self.cluster_error: + if self.transaction_connection: + self.transaction_connection = None + + self.reinitialize_counter += 1 + if self._should_reinitialized(): + self.nodes_manager.initialize() + self.reinitialize_counter = 0 + else: + self.nodes_manager.update_moved_exception(error) + + self.slot_migrating = False + self.cluster_error = False + self.executing = False + + def _execute_transaction_with_retries( + self, stack: List["PipelineCommand"], raise_on_error: bool + ): + retry = Retry( + default_backoff(), + self.cluster_error_retry_attempts, + ) + retry.update_supported_errors([AskError, MovedError]) + return retry.call_with_retry( + lambda: self._execute_transaction(stack, raise_on_error), + self._reinitialize_on_error, ) - def _execute_transaction(self, *args, **kwargs): - """ """ - raise RedisClusterException("method _execute_transaction() is not implemented") + def _execute_transaction( + self, stack: List["PipelineCommand"], raise_on_error: bool + ): + if len(self.pipeline_slots) > 1: + raise CrossSlotTransactionError( + "All keys involved in a cluster transaction must map to the same slot" + ) + + self.executing = True + self.slot_migrating = False + self.cluster_error = False + + redis_node, connection = self._get_client_and_connection_for_transaction() + + stack = chain( + [PipelineCommand(("MULTI",))], + stack, + [PipelineCommand(("EXEC",))], + ) + commands = [c.args for c in stack if EMPTY_RESPONSE not in c.options] + packed_commands = connection.pack_commands(commands) + connection.send_packed_command(packed_commands) + errors = [] + + # parse off the response for MULTI + # NOTE: we need to handle ResponseErrors here and continue + # so that we read all the additional command messages from + # the socket + try: + redis_node.parse_response(connection, "MULTI") + except ResponseError as e: + self.annotate_exception(e, 0, "MULTI") + errors.append(e) + + # and all the other commands + for i, command in enumerate(self.command_stack): + if EMPTY_RESPONSE in command.options: + errors.append((i, command.options[EMPTY_RESPONSE])) + else: + try: + _ = redis_node.parse_response(connection, "_") + except (AskError, MovedError) as slot_error: + self.slot_migrating = True + self.annotate_exception(slot_error, i + 1, command.args) + errors.append(slot_error) + except (ClusterDownError, ConnectionError) as cluster_error: + self.cluster_error = True + self.annotate_exception(cluster_error, i + 1, command.args) + raise + except ResponseError as e: + self.annotate_exception(e, i + 1, command.args) + errors.append(e) + + response = None + # parse the EXEC. + try: + response = redis_node.parse_response(connection, "EXEC") + except ExecAbortError: + if errors: + raise errors[0] + raise + + self.executing = False + + # EXEC clears any watched keys + self.watching = False + + if response is None: + raise WatchError("Watched variable changed.") + + # put any parse errors into the response + for i, e in errors: + response.insert(i, e) + + if len(response) != len(self.command_stack): + raise InvalidPipelineStack( + "Unexpected response length for cluster pipeline EXEC." + " Command stack was {} but response had length {}".format( + [c.args[0] for c in self.command_stack], len(response) + ) + ) + + # find any errors in the response and raise if necessary + if raise_on_error or self.slot_migrating: + self.raise_first_transaction_error( + response, + self.command_stack, + ) + + # We have to run response callbacks manually + data = [] + for r, cmd in zip(response, self.command_stack): + if not isinstance(r, Exception): + command_name = cmd.args[0] + if command_name in self.cluster_response_callbacks: + r = self.response_callbacks[command_name](r, **cmd.options) + data.append(r) + return data def load_scripts(self): """ """ raise RedisClusterException("method load_scripts() is not implemented") + def discard(self): + if self.transaction or self.explicit_transaction: + self.reset() + return + + if not self.explicit_transaction: + raise RedisClusterException("DISCARD triggered without MULTI") + + def _validate_watch(self): + if self.explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + + self.watching = True + def watch(self, *names): - """ """ - raise RedisClusterException("method watch() is not implemented") + """Watches the values at keys ``names``""" + + # Cluster transaction support (pipeline/multi/exec) originally developed + # by Scopely and contributed to redis-py under the MIT License. + + if self.explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + + return self.execute_command("WATCH", *names) def unwatch(self): - """ """ - raise RedisClusterException("method unwatch() is not implemented") + """Unwatches all previously specified keys""" + + # Cluster transaction support (pipeline/multi/exec) originally developed + # by Scopely and contributed to redis-py under the MIT License. + + if self.watching: + return self.execute_command("UNWATCH") + + return True def script_load_for_pipeline(self, *args, **kwargs): """ """ @@ -2432,23 +2803,13 @@ def delete(self, *names): """ "Delete a key specified by ``names``" """ - if len(names) != 1: - raise RedisClusterException( - "deleting multiple keys is not implemented in pipeline command" - ) - - return self.execute_command("DEL", names[0]) + return self.execute_command("DEL", *names) def unlink(self, *names): """ "Unlink a key specified by ``names``" """ - if len(names) != 1: - raise RedisClusterException( - "unlinking multiple keys is not implemented in pipeline command" - ) - - return self.execute_command("UNLINK", names[0]) + return self.execute_command("UNLINK", *names) def block_pipeline_command(name: str) -> Callable[..., Any]: diff --git a/redis/exceptions.py b/redis/exceptions.py index bad447a086..a00ac65ac1 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -221,3 +221,21 @@ class SlotNotCoveredError(RedisClusterException): class MaxConnectionsError(ConnectionError): ... + + +class CrossSlotTransactionError(RedisClusterException): + """ + Raised when a transaction or watch is triggered in a pipeline + and not all keys or all commands belong to the same slot. + """ + + pass + + +class InvalidPipelineStack(RedisClusterException): + """ + Raised on unexpected response length on pipelines. This is + most likely a handling error on the stack. + """ + + pass diff --git a/tests/test_cluster.py b/tests/test_cluster.py index d96342f87a..622a42d7bc 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -3009,24 +3009,10 @@ def test_blocked_methods(self, r): They maybe implemented in the future. """ pipe = r.pipeline() - with pytest.raises(RedisClusterException): - pipe.multi() - - with pytest.raises(RedisClusterException): - pipe.immediate_execute_command() - - with pytest.raises(RedisClusterException): - pipe._execute_transaction(None, None, None) with pytest.raises(RedisClusterException): pipe.load_scripts() - with pytest.raises(RedisClusterException): - pipe.watch() - - with pytest.raises(RedisClusterException): - pipe.unwatch() - with pytest.raises(RedisClusterException): pipe.script_load_for_pipeline(None) @@ -3038,14 +3024,6 @@ def test_blocked_arguments(self, r): Currently some arguments is blocked when using in cluster mode. They maybe implemented in the future. """ - with pytest.raises(RedisClusterException) as ex: - r.pipeline(transaction=True) - - assert ( - str(ex.value).startswith("transaction is deprecated in cluster mode") - is True - ) - with pytest.raises(RedisClusterException) as ex: r.pipeline(shard_hint=True) @@ -3103,7 +3081,7 @@ def test_delete_single(self, r): pipe.delete("a") assert pipe.execute() == [1] - def test_multi_delete_unsupported(self, r): + def test_multi_delete_unsupported_cross_slot(self, r): """ Test that multi delete operation is unsupported """ @@ -3113,6 +3091,16 @@ def test_multi_delete_unsupported(self, r): with pytest.raises(RedisClusterException): pipe.delete("a", "b") + def test_multi_delete_supported_single_slot(self, r): + """ + Test that multi delete operation is unsupported + """ + with r.pipeline(transaction=False) as pipe: + r["{key}:a"] = 1 + r["{key}:b"] = 2 + pipe.delete("{key}:a", "{key}:b") + assert pipe.execute() + def test_unlink_single(self, r): """ Test a single unlink operation diff --git a/tests/test_cluster_transaction.py b/tests/test_cluster_transaction.py new file mode 100644 index 0000000000..f2bb209131 --- /dev/null +++ b/tests/test_cluster_transaction.py @@ -0,0 +1,572 @@ +from typing import Tuple +from unittest.mock import patch + +import pytest + +import redis +from redis.client import Redis +from redis.cluster import PRIMARY, ClusterNode, NodesManager, RedisCluster + +from .conftest import skip_if_server_version_lt, wait_for_command + + +# Cluster transaction support (pipeline/multi/exec) originally developed +# by Scopely and contributed to redis-py under the MIT License. + + +def _find_source_and_target_node_for_slot( + r: RedisCluster, slot: int +) -> Tuple[ClusterNode, ClusterNode]: + """Returns a pair of ClusterNodes, where the first node is the + one that owns the slot and the second is a possible target + for that slot, i.e. a primary node different from the first + one. + """ + node_migrating = r.nodes_manager.get_node_from_slot(slot) + assert node_migrating, f"No node could be found that owns slot #{slot}" + + available_targets = [ + n + for n in r.nodes_manager.startup_nodes.values() + if node_migrating.name != n.name and n.server_type == PRIMARY + ] + + assert available_targets, f"No possible target nodes for slot #{slot}" + return node_migrating, available_targets[0] + + +class TestClusterTransaction: + @pytest.mark.onlycluster + def test_pipeline_is_true(self, r): + "Ensure pipeline instances are not false-y" + with r.pipeline(transaction=True) as pipe: + assert pipe + + @pytest.mark.onlycluster + def test_pipeline_no_transaction_watch(self, r): + r["a"] = 0 + + with r.pipeline(transaction=False) as pipe: + pipe.watch("a") + a = pipe.get("a") + pipe.multi() + pipe.set("a", int(a) + 1) + assert pipe.execute() == [b"OK"] + + @pytest.mark.onlycluster + def test_pipeline_no_transaction_watch_failure(self, r): + r["a"] = 0 + + with r.pipeline(transaction=False) as pipe: + pipe.watch("a") + a = pipe.get("a") + + r["a"] = "bad" + + pipe.multi() + pipe.set("a", int(a) + 1) + + with pytest.raises(redis.WatchError): + pipe.execute() + + assert r["a"] == b"bad" + + @pytest.mark.onlycluster + def test_pipeline_empty_transaction(self, r): + r["a"] = 0 + + with r.pipeline(transaction=True) as pipe: + assert pipe.execute() == [] + + @pytest.mark.onlycluster + def test_exec_error_in_response(self, r): + """ + an invalid pipeline command at exec time adds the exception instance + to the list of returned values + """ + hashkey = "{key}" + r[f"{hashkey}:c"] = "a" + with r.pipeline() as pipe: + pipe.set(f"{hashkey}:a", 1).set(f"{hashkey}:b", 2) + pipe.lpush(f"{hashkey}:c", 3).set(f"{hashkey}:d", 4) + result = pipe.execute(raise_on_error=False) + + assert result[0] + assert r[f"{hashkey}:a"] == b"1" + assert result[1] + assert r[f"{hashkey}:b"] == b"2" + + # we can't lpush to a key that's a string value, so this should + # be a ResponseError exception + assert isinstance(result[2], redis.ResponseError) + assert r[f"{hashkey}:c"] == b"a" + + # since this isn't a transaction, the other commands after the + # error are still executed + assert result[3] + assert r[f"{hashkey}:d"] == b"4" + + # make sure the pipe was restored to a working state + assert pipe.set(f"{hashkey}:z", "zzz").execute() == [True] + assert r[f"{hashkey}:z"] == b"zzz" + + @pytest.mark.onlycluster + def test_exec_error_raised(self, r): + hashkey = "{key}" + r[f"{hashkey}:c"] = "a" + with r.pipeline(transaction=True) as pipe: + pipe.set(f"{hashkey}:a", 1).set(f"{hashkey}:b", 2) + pipe.lpush(f"{hashkey}:c", 3).set(f"{hashkey}:d", 4) + with pytest.raises(redis.ResponseError) as ex: + pipe.execute() + assert str(ex.value).startswith( + "Command # 3 (LPUSH {key}:c 3) of pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert pipe.set(f"{hashkey}:z", "zzz").execute() == [b"OK"] + assert r[f"{hashkey}:z"] == b"zzz" + + @pytest.mark.onlycluster + def test_parse_error_raised(self, r): + hashkey = "{key}" + with r.pipeline(transaction=True) as pipe: + # the zrem is invalid because we don't pass any keys to it + pipe.set(f"{hashkey}:a", 1).zrem(f"{hashkey}:b").set(f"{hashkey}:b", 2) + with pytest.raises(redis.ResponseError) as ex: + pipe.execute() + + assert str(ex.value).startswith( + "Command # 2 (ZREM {key}:b) of pipeline caused error: wrong number" + ) + + # make sure the pipe was restored to a working state + assert pipe.set(f"{hashkey}:z", "zzz").execute() == [b"OK"] + assert r[f"{hashkey}:z"] == b"zzz" + + @pytest.mark.onlycluster + def test_parse_error_raised_transaction(self, r): + hashkey = "{key}" + with r.pipeline() as pipe: + pipe.multi() + # the zrem is invalid because we don't pass any keys to it + pipe.set(f"{hashkey}:a", 1).zrem(f"{hashkey}:b").set(f"{hashkey}:b", 2) + with pytest.raises(redis.ResponseError) as ex: + pipe.execute() + + assert str(ex.value).startswith( + "Command # 2 (ZREM {key}:b) of pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert pipe.set(f"{hashkey}:z", "zzz").execute() == [True] + assert r[f"{hashkey}:z"] == b"zzz" + + @pytest.mark.onlycluster + def test_parse_error_raised_invalid_response_length_transaction(self, r): + hashkey = "{key}" + with r.pipeline() as pipe: + pipe.multi() + pipe.set(f"{hashkey}:a", 1).set(f"{hashkey}:b", 1) + with patch("redis.client.Redis.parse_response") as parse_response_mock: + parse_response_mock.return_value = ["OK"] + with pytest.raises(redis.InvalidPipelineStack) as ex: + pipe.execute() + + assert str(ex.value).startswith( + "Unexpected response length for cluster pipeline EXEC" + ) + + # make sure the pipe was restored to a working state + assert pipe.set(f"{hashkey}:z", "zzz").execute() == [True] + assert r[f"{hashkey}:z"] == b"zzz" + + @pytest.mark.onlycluster + def test_watch_succeed(self, r): + hashkey = "{key}" + r[f"{hashkey}:a"] = 1 + r[f"{hashkey}:b"] = 2 + + with r.pipeline() as pipe: + pipe.watch(f"{hashkey}:a", f"{hashkey}:b") + assert pipe.watching + a_value = pipe.get(f"{hashkey}:a") + b_value = pipe.get(f"{hashkey}:b") + assert a_value == b"1" + assert b_value == b"2" + pipe.multi() + + pipe.set(f"{hashkey}:c", 3) + assert pipe.execute() == [b"OK"] + assert not pipe.watching + + @pytest.mark.onlycluster + def test_watch_failure(self, r): + hashkey = "{key}" + r[f"{hashkey}:a"] = 1 + r[f"{hashkey}:b"] = 2 + + with r.pipeline() as pipe: + pipe.watch(f"{hashkey}:a", f"{hashkey}:b") + r[f"{hashkey}:b"] = 3 + pipe.multi() + pipe.get(f"{hashkey}:a") + with pytest.raises(redis.WatchError): + pipe.execute() + + assert not pipe.watching + + @pytest.mark.onlycluster + def test_cross_slot_watch_single_call_failure(self, r): + with r.pipeline() as pipe: + with pytest.raises(redis.RedisClusterException) as ex: + pipe.watch("a", "b") + + assert str(ex.value).startswith( + "WATCH - all keys must map to the same key slot" + ) + + assert not pipe.watching + + @pytest.mark.onlycluster + def test_cross_slot_watch_multiple_calls_failure(self, r): + with r.pipeline() as pipe: + with pytest.raises(redis.CrossSlotTransactionError) as ex: + pipe.watch("a") + pipe.watch("b") + + assert str(ex.value).startswith( + "Cannot watch or send commands on different slots" + ) + + assert pipe.watching + + @pytest.mark.onlycluster + def test_watch_failure_in_empty_transaction(self, r): + hashkey = "{key}" + r[f"{hashkey}:a"] = 1 + r[f"{hashkey}:b"] = 2 + + with r.pipeline() as pipe: + pipe.watch(f"{hashkey}:a", f"{hashkey}:b") + r[f"{hashkey}:b"] = 3 + pipe.multi() + with pytest.raises(redis.WatchError): + pipe.execute() + + assert not pipe.watching + + @pytest.mark.onlycluster + def test_unwatch(self, r): + hashkey = "{key}" + r[f"{hashkey}:a"] = 1 + r[f"{hashkey}:b"] = 2 + + with r.pipeline() as pipe: + pipe.watch(f"{hashkey}:a", f"{hashkey}:b") + r[f"{hashkey}:b"] = 3 + pipe.unwatch() + assert not pipe.watching + pipe.get(f"{hashkey}:a") + assert pipe.execute() == [b"1"] + + @pytest.mark.onlycluster + def test_watch_exec_auto_unwatch(self, r): + hashkey = "{key}" + r[f"{hashkey}:a"] = 1 + r[f"{hashkey}:b"] = 2 + + target_slot = r.determine_slot("GET", f"{hashkey}:a") + target_node = r.nodes_manager.get_node_from_slot(target_slot) + with r.monitor(target_node=target_node) as m: + with r.pipeline() as pipe: + pipe.watch(f"{hashkey}:a", f"{hashkey}:b") + assert pipe.watching + a_value = pipe.get(f"{hashkey}:a") + b_value = pipe.get(f"{hashkey}:b") + assert a_value == b"1" + assert b_value == b"2" + pipe.multi() + pipe.set(f"{hashkey}:c", 3) + assert pipe.execute() == [b"OK"] + assert not pipe.watching + + unwatch_command = wait_for_command( + r, m, "UNWATCH", key=f"{hashkey}:test_watch_exec_auto_unwatch" + ) + assert unwatch_command is not None, ( + "execute should reset and send UNWATCH automatically" + ) + + @pytest.mark.onlycluster + def test_watch_reset_unwatch(self, r): + hashkey = "{key}" + r[f"{hashkey}:a"] = 1 + + target_slot = r.determine_slot("GET", f"{hashkey}:a") + target_node = r.nodes_manager.get_node_from_slot(target_slot) + with r.monitor(target_node=target_node) as m: + with r.pipeline() as pipe: + pipe.watch(f"{hashkey}:a") + assert pipe.watching + pipe.reset() + assert not pipe.watching + + unwatch_command = wait_for_command( + r, m, "UNWATCH", key=f"{hashkey}:test_watch_reset_unwatch" + ) + assert unwatch_command is not None + assert unwatch_command["command"] == "UNWATCH" + + @pytest.mark.onlycluster + def test_transaction_callable(self, r): + hashkey = "{key}" + r[f"{hashkey}:a"] = 1 + r[f"{hashkey}:b"] = 2 + has_run = [] + + def my_transaction(pipe): + a_value = pipe.get(f"{hashkey}:a") + assert a_value in (b"1", b"2") + b_value = pipe.get(f"{hashkey}:b") + assert b_value == b"2" + + # silly run-once code... incr's "a" so WatchError should be raised + # forcing this all to run again. this should incr "a" once to "2" + if not has_run: + r.incr(f"{hashkey}:a") + has_run.append("it has") + + pipe.multi() + pipe.set(f"{hashkey}:c", int(a_value) + int(b_value)) + + result = r.transaction(my_transaction, f"{hashkey}:a", f"{hashkey}:b") + assert result == [b"OK"] + assert r[f"{hashkey}:c"] == b"4" + + def test_exec_error_in_no_transaction_pipeline(self, r): + r["a"] = 1 + with r.pipeline(transaction=False) as pipe: + pipe.llen("a") + pipe.expire("a", 100) + + with pytest.raises(redis.ResponseError) as ex: + pipe.execute() + + assert str(ex.value).startswith( + "Command # 1 (LLEN a) of pipeline caused error: " + ) + + assert r["a"] == b"1" + + @pytest.mark.onlycluster + @skip_if_server_version_lt("2.0.0") + def test_pipeline_discard(self, r): + hashkey = "{key}" + + # empty pipeline should raise an error + with r.pipeline() as pipe: + pipe.set(f"{hashkey}:key", "someval") + with pytest.raises(redis.exceptions.RedisClusterException) as ex: + pipe.discard() + + assert str(ex.value).startswith("DISCARD triggered without MULTI") + + # setting a pipeline and discarding should do the same + with r.pipeline() as pipe: + pipe.set(f"{hashkey}:key", "someval") + pipe.set(f"{hashkey}:someotherkey", "val") + response = pipe.execute() + pipe.set(f"{hashkey}:key", "another value!") + with pytest.raises(redis.exceptions.RedisClusterException) as ex: + pipe.discard() + + assert str(ex.value).startswith("DISCARD triggered without MULTI") + + pipe.set(f"{hashkey}:foo", "bar") + response = pipe.execute() + + assert response[0] + assert r.get(f"{hashkey}:foo") == b"bar" + + @pytest.mark.onlycluster + @skip_if_server_version_lt("2.0.0") + def test_transaction_discard(self, r): + hashkey = "{key}" + + # pipelines enabled as transactions can be discarded at any point + with r.pipeline(transaction=True) as pipe: + pipe.watch(f"{hashkey}:key") + pipe.set(f"{hashkey}:key", "someval") + pipe.discard() + + assert not pipe.watching + assert not pipe.command_stack + + # pipelines with multi can be discarded + with r.pipeline() as pipe: + pipe.watch(f"{hashkey}:key") + pipe.multi() + pipe.set(f"{hashkey}:key", "someval") + pipe.discard() + + assert not pipe.watching + assert not pipe.command_stack + + @pytest.mark.onlycluster + def test_retry_transaction_during_unfinished_slot_migration(self, r): + """ + When a transaction is triggered during a migration, MovedError + or AskError may appear (depends on the key being already migrated + or the key not existing already). The patch on parse_response + simulates such an error, but the slot cache is not updated + (meaning the migration is still ongogin) so the pipeline eventually + fails as if it was retried but the migration is not yet complete. + """ + key = "book" + slot = r.keyslot(key) + node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + + with patch.object(Redis, "parse_response") as parse_response, patch.object( + NodesManager, "_update_moved_slots" + ) as manager_update_moved_slots: + + def ask_redirect_effect(connection, *args, **options): + if "MULTI" in args: + return + elif "EXEC" in args: + raise redis.exceptions.ExecAbortError() + + raise redis.exceptions.AskError(f"{slot} {node_importing.name}") + + parse_response.side_effect = ask_redirect_effect + + with r.pipeline(transaction=True) as pipe: + pipe.multi() + pipe.set(key, "val") + with pytest.raises(redis.exceptions.AskError) as ex: + pipe.execute() + + assert str(ex.value).startswith( + "Command # 1 (SET book val) of pipeline caused error:" + f" {slot} {node_importing.name}" + ) + + manager_update_moved_slots.assert_called() + + @pytest.mark.onlycluster + def test_retry_transaction_during_slot_migration_successful(self, r): + """ + If a MovedError or AskError appears when calling EXEC and no key is watched, + the pipeline is retried after updating the node manager slot table. If the + migration was completed, the transaction may then complete successfully. + """ + key = "book" + slot = r.keyslot(key) + node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + + with patch.object(Redis, "parse_response") as parse_response, patch.object( + NodesManager, "_update_moved_slots" + ) as manager_update_moved_slots: + + def ask_redirect_effect(conn, *args, **options): + # first call should go here, we trigger an AskError + if f"{conn.host}:{conn.port}" == node_migrating.name: + if "MULTI" in args: + return + elif "EXEC" in args: + raise redis.exceptions.ExecAbortError() + + raise redis.exceptions.AskError(f"{slot} {node_importing.name}") + # if the slot table is updated, the next call will go here + elif f"{conn.host}:{conn.port}" == node_importing.name: + if "EXEC" in args: + return [ + "MOCK_OK" + ] # mock value to validate this section was called + return + else: + assert False, f"unexpected node {conn.host}:{conn.port} was called" + + def update_moved_slot(): # simulate slot table update + ask_error = r.nodes_manager._moved_exception + assert ask_error is not None, "No AskError was previously triggered" + assert f"{ask_error.host}:{ask_error.port}" == node_importing.name + r.nodes_manager._moved_exception = None + r.nodes_manager.slots_cache[slot] = [node_importing] + + parse_response.side_effect = ask_redirect_effect + manager_update_moved_slots.side_effect = update_moved_slot + + result = None + with r.pipeline(transaction=True) as pipe: + pipe.multi() + pipe.set(key, "val") + result = pipe.execute() + + assert result and "MOCK_OK" in result, "Target node was not called" + + @pytest.mark.onlycluster + def test_retry_transaction_with_watch_during_slot_migration(self, r): + """ + If a MovedError or AskError appears when calling EXEC and keys were + being watched before the migration started, a WatchError should appear. + These errors imply resetting the connection and connecting to a new node, + so watches are lost anyway and the client code must be notified. + """ + key = "book" + slot = r.keyslot(key) + node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + + with patch.object(Redis, "parse_response") as parse_response: + + def ask_redirect_effect(conn, *args, **options): + if f"{conn.host}:{conn.port}" == node_migrating.name: + # we simulate the watch was sent before the migration started + if "WATCH" in args: + return b"OK" + # but the pipeline was triggered after the migration started + elif "MULTI" in args: + return + elif "EXEC" in args: + raise redis.exceptions.ExecAbortError() + + raise redis.exceptions.AskError(f"{slot} {node_importing.name}") + # we should not try to connect to any other node + else: + assert False, f"unexpected node {conn.host}:{conn.port} was called" + + parse_response.side_effect = ask_redirect_effect + + with r.pipeline(transaction=True) as pipe: + pipe.watch(key) + pipe.multi() + pipe.set(key, "val") + with pytest.raises(redis.exceptions.WatchError) as ex: + pipe.execute() + + assert str(ex.value).startswith( + "Slot rebalancing ocurred while watching keys" + ) + + @pytest.mark.onlycluster + def test_retry_transaction_with_watch_after_slot_migration(self, r): + """ + If a MovedError or AskError appears when calling WATCH, the client + must attempt to recover itself before proceeding and no WatchError + should appear. + """ + key = "book" + slot = r.keyslot(key) + r.reinitialize_steps = 1 + + # force a MovedError on the first call to pipe.watch() + # by switching the node that owns the slot to another one + _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + r.nodes_manager.slots_cache[slot] = [node_importing] + + with r.pipeline(transaction=True) as pipe: + pipe.watch(key) + pipe.multi() + pipe.set(key, "val") + pipe.execute() diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index bbf1ec9eb5..2e9b21ada4 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -330,7 +330,6 @@ def my_transaction(pipe): assert result == [True] assert r["c"] == b"4" - @pytest.mark.onlynoncluster def test_transaction_callable_returns_value_from_callable(self, r): def callback(pipe): # No need to do anything here since we only want the return value From 9244b6764e5b01a47a3bd960dd0a80beeb1ce852 Mon Sep 17 00:00:00 2001 From: Roberto Santamaria Date: Fri, 25 Apr 2025 10:30:26 +0200 Subject: [PATCH 2/7] fix: remove deprecated argument --- redis/cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/cluster.py b/redis/cluster.py index e4388885fa..7b478022f8 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2174,7 +2174,7 @@ def _get_client_and_connection_for_transaction(self) -> Tuple[Redis, Connection] self.transaction_connection = None if not self.transaction_connection: - self.transaction_connection = get_connection(redis_node, ("INFO",)) + self.transaction_connection = get_connection(redis_node) return redis_node, self.transaction_connection From ea3fc59b7a4a74dfffe9ecc6a92ae91df2648f4a Mon Sep 17 00:00:00 2001 From: Roberto Santamaria Date: Fri, 25 Apr 2025 14:35:03 +0200 Subject: [PATCH 3/7] remove attributions from code --- CHANGES | 2 +- docs/advanced_features.rst | 3 --- redis/cluster.py | 9 --------- tests/test_cluster_transaction.py | 4 ---- 4 files changed, 1 insertion(+), 17 deletions(-) diff --git a/CHANGES b/CHANGES index 7f4e08fb31..33a82bd4f1 100644 --- a/CHANGES +++ b/CHANGES @@ -1,4 +1,4 @@ - * Support transactions in ClusterPipeline (originally developed by Scopely and contributed under the MIT License) + * Support transactions in ClusterPipeline * Removing support for RedisGraph module. RedisGraph support is deprecated since Redis Stack 7.2 (https://redis.com/blog/redisgraph-eol/) * Fix lock.extend() typedef to accept float TTL extension * Update URL in the readme linking to Redis University diff --git a/docs/advanced_features.rst b/docs/advanced_features.rst index 89ec3fcd43..5e8542fdcb 100644 --- a/docs/advanced_features.rst +++ b/docs/advanced_features.rst @@ -215,9 +215,6 @@ An alternative is some kind of two-step commit solution, where a slot validation is run before the actual commands are run. This could work with controlled node maintenance but does not cover single node failures. -Cluster transaction support (pipeline/multi/exec) was originally developed by -Scopely and contributed to redis-py under the MIT License. - Publish / Subscribe diff --git a/redis/cluster.py b/redis/cluster.py index 7b478022f8..f87be95494 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2561,9 +2561,6 @@ def multi(self): are issued. End the transactional block with `execute`. """ - # Cluster transaction support (pipeline/multi/exec) originally developed - # by Scopely and contributed to redis-py under the MIT License. - if self.explicit_transaction: raise RedisError("Cannot issue nested calls to MULTI") if self.command_stack: @@ -2774,9 +2771,6 @@ def _validate_watch(self): def watch(self, *names): """Watches the values at keys ``names``""" - # Cluster transaction support (pipeline/multi/exec) originally developed - # by Scopely and contributed to redis-py under the MIT License. - if self.explicit_transaction: raise RedisError("Cannot issue a WATCH after a MULTI") @@ -2785,9 +2779,6 @@ def watch(self, *names): def unwatch(self): """Unwatches all previously specified keys""" - # Cluster transaction support (pipeline/multi/exec) originally developed - # by Scopely and contributed to redis-py under the MIT License. - if self.watching: return self.execute_command("UNWATCH") diff --git a/tests/test_cluster_transaction.py b/tests/test_cluster_transaction.py index f2bb209131..c70d2ac9bb 100644 --- a/tests/test_cluster_transaction.py +++ b/tests/test_cluster_transaction.py @@ -10,10 +10,6 @@ from .conftest import skip_if_server_version_lt, wait_for_command -# Cluster transaction support (pipeline/multi/exec) originally developed -# by Scopely and contributed to redis-py under the MIT License. - - def _find_source_and_target_node_for_slot( r: RedisCluster, slot: int ) -> Tuple[ClusterNode, ClusterNode]: From a68359bb6a851fb09b5416966df3ff7928bad51f Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 30 Apr 2025 16:40:29 +0300 Subject: [PATCH 4/7] Refactor ClusterPipeline to use execution strategies --- redis/cluster.py | 1163 ++++++++++++++++++----------- tests/test_cluster.py | 38 +- tests/test_cluster_transaction.py | 38 +- tests/test_pipeline.py | 18 +- 4 files changed, 781 insertions(+), 476 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index f87be95494..bf49c18b6c 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -3,6 +3,7 @@ import sys import threading import time +from abc import ABC, abstractmethod from collections import OrderedDict from enum import Enum from itertools import chain @@ -2111,6 +2112,15 @@ def __init__( lock = threading.Lock() self._lock = lock self.transaction = transaction + self._execution_strategy: ExecutionStrategy = PipelineStrategy( + super(), + cluster_response_callbacks=self.cluster_response_callbacks, + cluster_error_retry_attempts=self.cluster_error_retry_attempts, + ) if not self.transaction else TransactionStrategy( + super(), + cluster_response_callbacks=self.cluster_response_callbacks, + cluster_error_retry_attempts=self.cluster_error_retry_attempts, + ) self.explicit_transaction = False self.watching = False self.transaction_connection: Optional[Connection] = None @@ -2145,125 +2155,437 @@ def __bool__(self): "Pipeline instances should always evaluate to True on Python 3+" return True - def _get_client_and_connection_for_transaction(self) -> Tuple[Redis, Connection]: + def execute_command(self, *args, **kwargs): """ - Find a connection for a pipeline transaction. + Wrapper function for pipeline_execute_command + """ + return self._execution_strategy.execute_command(*args, **kwargs) - For running an atomic transaction, watch keys ensure that contents have not been - altered as long as the watch commands for those keys were sent over the same - connection. So once we start watching a key, we fetch a connection to the - node that owns that slot and reuse it. + def pipeline_execute_command(self, *args, **options): + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self._execution_strategy.pipeline_execute_command(*args, **options) + return self + + def annotate_exception(self, exception, number, command): + """ + Provides extra context to the exception prior to it being handled + """ + self._execution_strategy.annotate_exception(exception, number, command) + + def execute(self, raise_on_error: bool = True) -> List[Any]: + """ + Execute all the commands in the current pipeline """ - if not self.pipeline_slots: - raise RedisClusterException( - "At least a command with a key is needed to identify a node" - ) - node: ClusterNode = self.nodes_manager.get_node_from_slot( - list(self.pipeline_slots)[0], False + try: + return self._execution_strategy.execute(raise_on_error) + finally: + self.reset() + + def reset(self): + """ + Reset back to empty pipeline. + """ + self._execution_strategy.reset() + + def send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): + return self._execution_strategy.send_cluster_commands( + stack, raise_on_error=raise_on_error, allow_redirections=allow_redirections ) - redis_node: Redis = self.get_redis_connection(node) - if self.transaction_connection: - if not redis_node.connection_pool.owns_connection( - self.transaction_connection - ): - previous_node = self.nodes_manager.find_connection_owner( - self.transaction_connection - ) - previous_node.connection_pool.release(self.transaction_connection) - self.transaction_connection = None - if not self.transaction_connection: - self.transaction_connection = get_connection(redis_node) + def exists(self, *keys): + return self._execution_strategy.exists(*keys) + + def eval(self): + """ """ + return self._execution_strategy.eval() + + def multi(self): + """ + Start a transactional block of the pipeline after WATCH commands + are issued. End the transactional block with `execute`. + """ + self._execution_strategy.multi() + + def load_scripts(self): + """ """ + self._execution_strategy.load_scripts() - return redis_node, self.transaction_connection + def discard(self): + """ """ + self._execution_strategy.discard() + + def watch(self, *names): + """Watches the values at keys ``names``""" + self._execution_strategy.watch(*names) + + def unwatch(self): + """Unwatches all previously specified keys""" + self._execution_strategy.unwatch() + + def script_load_for_pipeline(self, *args, **kwargs): + self._execution_strategy.script_load_for_pipeline(*args, **kwargs) + def delete(self, *names): + self._execution_strategy.delete(*names) + + def unlink(self, *names): + self._execution_strategy.unlink(*names) + + +def block_pipeline_command(name: str) -> Callable[..., Any]: + """ + Prints error because some pipelined commands should + be blocked when running in cluster-mode + """ + + def inner(*args, **kwargs): + raise RedisClusterException( + f"ERROR: Calling pipelined function {name} is blocked " + f"when running redis in cluster mode..." + ) + + return inner + + +# Blocked pipeline commands +PIPELINE_BLOCKED_COMMANDS = ( + "BGREWRITEAOF", + "BGSAVE", + "BITOP", + "BRPOPLPUSH", + "CLIENT GETNAME", + "CLIENT KILL", + "CLIENT LIST", + "CLIENT SETNAME", + "CLIENT", + "CONFIG GET", + "CONFIG RESETSTAT", + "CONFIG REWRITE", + "CONFIG SET", + "CONFIG", + "DBSIZE", + "ECHO", + "EVALSHA", + "FLUSHALL", + "FLUSHDB", + "INFO", + "KEYS", + "LASTSAVE", + "MGET", + "MGET NONATOMIC", + "MOVE", + "MSET", + "MSET NONATOMIC", + "MSETNX", + "PFCOUNT", + "PFMERGE", + "PING", + "PUBLISH", + "RANDOMKEY", + "READONLY", + "READWRITE", + "RENAME", + "RENAMENX", + "RPOPLPUSH", + "SAVE", + "SCAN", + "SCRIPT EXISTS", + "SCRIPT FLUSH", + "SCRIPT KILL", + "SCRIPT LOAD", + "SCRIPT", + "SDIFF", + "SDIFFSTORE", + "SENTINEL GET MASTER ADDR BY NAME", + "SENTINEL MASTER", + "SENTINEL MASTERS", + "SENTINEL MONITOR", + "SENTINEL REMOVE", + "SENTINEL SENTINELS", + "SENTINEL SET", + "SENTINEL SLAVES", + "SENTINEL", + "SHUTDOWN", + "SINTER", + "SINTERSTORE", + "SLAVEOF", + "SLOWLOG GET", + "SLOWLOG LEN", + "SLOWLOG RESET", + "SLOWLOG", + "SMOVE", + "SORT", + "SUNION", + "SUNIONSTORE", + "TIME", +) +for command in PIPELINE_BLOCKED_COMMANDS: + command = command.replace(" ", "_").lower() + + setattr(ClusterPipeline, command, block_pipeline_command(command)) + + +class PipelineCommand: + """ """ + + def __init__(self, args, options=None, position=None): + self.args = args + if options is None: + options = {} + self.options = options + self.position = position + self.result = None + self.node = None + self.asking = False + + +class NodeCommands: + """ """ + + def __init__(self, parse_response, connection_pool, connection): + """ """ + self.parse_response = parse_response + self.connection_pool = connection_pool + self.connection = connection + self.commands = [] + + def append(self, c): + """ """ + self.commands.append(c) + + def write(self): + """ + Code borrowed from Redis so it can be fixed + """ + connection = self.connection + commands = self.commands + + # We are going to clobber the commands with the write, so go ahead + # and ensure that nothing is sitting there from a previous run. + for c in commands: + c.result = None + + # build up all commands into a single request to increase network perf + # send all the commands and catch connection and timeout errors. + try: + connection.send_packed_command( + connection.pack_commands([c.args for c in commands]) + ) + except (ConnectionError, TimeoutError) as e: + for c in commands: + c.result = e + + def read(self): + """ """ + connection = self.connection + for c in self.commands: + # if there is a result on this command, + # it means we ran into an exception + # like a connection error. Trying to parse + # a response on a connection that + # is no longer open will result in a + # connection error raised by redis-py. + # but redis-py doesn't check in parse_response + # that the sock object is + # still set and if you try to + # read from a closed connection, it will + # result in an AttributeError because + # it will do a readline() call on None. + # This can have all kinds of nasty side-effects. + # Treating this case as a connection error + # is fine because it will dump + # the connection object back into the + # pool and on the next write, it will + # explicitly open the connection and all will be well. + if c.result is None: + try: + c.result = self.parse_response(connection, c.args[0], **c.options) + except (ConnectionError, TimeoutError) as e: + for c in self.commands: + c.result = e + return + except RedisError: + c.result = sys.exc_info()[1] + + +class ExecutionStrategy(ABC): + + @abstractmethod def execute_command(self, *args, **kwargs): """ - Wrapper function for pipeline_execute_command + Execution flow for current execution strategy. + + See: ClusterPipeline.execute_command() """ - slot_number: Optional[int] = None - if args[0] not in self.NO_SLOTS_COMMANDS: - slot_number = self.determine_slot(*args) + pass + + @abstractmethod + def annotate_exception(self, exception, number, command): + """ + Annotate exception according to current execution strategy. + + See: ClusterPipeline.annotate_exception() + """ + pass + + @abstractmethod + def pipeline_execute_command(self, *args, **options): + """ + Pipeline execution flow for current execution strategy. + + See: ClusterPipeline.pipeline_execute_command() + """ + pass + + @abstractmethod + def execute(self, raise_on_error: bool = True) -> List[Any]: + """ + Executes current execution strategy. + + See: ClusterPipeline.execute() + """ + pass + + @abstractmethod + def send_cluster_commands(self, stack, raise_on_error=True, allow_redirections=True): + """ + Sends commands according to current execution strategy. + + See: ClusterPipeline.send_cluster_commands() + """ + pass + + @abstractmethod + def reset(self): + """ + Resets current execution strategy. + + See: ClusterPipeline.reset() + """ + pass + + @abstractmethod + def exists(self, *keys): + pass + + @abstractmethod + def eval(self): + pass + + @abstractmethod + def multi(self): + """ + Starts transactional context. + + See: ClusterPipeline.reset() + """ + pass + + @abstractmethod + def load_scripts(self): + pass + + @abstractmethod + def watch(self, *names): + pass + + @abstractmethod + def unwatch(self): + pass + + @abstractmethod + def script_load_for_pipeline(self, *args, **kwargs): + pass + + @abstractmethod + def delete(self, *names): + pass - if ( - self.watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS - ) and not self.explicit_transaction: - if args[0] == "WATCH": - self._validate_watch() + @abstractmethod + def unlink(self, *names): + pass - if slot_number is not None: - if self.pipeline_slots and slot_number not in self.pipeline_slots: - raise CrossSlotTransactionError( - "Cannot watch or send commands on different slots" - ) + @abstractmethod + def discard(self): + pass - self.pipeline_slots.add(slot_number) - elif args[0] not in self.NO_SLOTS_COMMANDS: - raise RedisClusterException( - f"Cannot identify slot number for command: {args[0]}," - "it cannot be triggered in a transaction" - ) - return self.immediate_execute_command(*args, **kwargs) - else: - if slot_number is not None: - self.pipeline_slots.add(slot_number) +class AbstractStrategy(ExecutionStrategy): - return self.pipeline_execute_command(*args, **kwargs) + def __init__( + self, + cluster: RedisCluster, + cluster_response_callbacks: Optional[Dict[str, Callable]] = None, + cluster_error_retry_attempts: int = 3, + ): + self._command_queue: List[PipelineCommand] = [] + self._cluster = cluster + self._nodes_manager = self._cluster.nodes_manager + self._cluster_response_callbacks = cluster_response_callbacks + self._cluster_error_retry_attempts = cluster_error_retry_attempts + + @abstractmethod + def execute_command(self, *args, **kwargs): + pass def pipeline_execute_command(self, *args, **options): - """ - Stage a command to be executed when execute() is next called + self._command_queue.append( + PipelineCommand(args, options, len(self._command_queue)) + ) - Returns the current Pipeline object back so commands can be - chained together, such as: + @abstractmethod + def execute(self, raise_on_error: bool = True) -> List[Any]: + pass - pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + @abstractmethod + def send_cluster_commands(self, stack, raise_on_error=True, allow_redirections=True): + pass - At some other point, you can then run: pipe.execute(), - which will execute all commands queued in the pipe. - """ - self.command_stack.append( - PipelineCommand(args, options, len(self.command_stack)) - ) - return self + @abstractmethod + def reset(self): + pass - def _get_connection_and_send_command(self, *args, **options): - redis_node, connection = self._get_client_and_connection_for_transaction() - return self._send_command_parse_response( - connection, redis_node, args[0], *args, **options - ) + def exists(self, *keys): + return self.execute_command("EXISTS", *keys) - def immediate_execute_command(self, *args, **options): - retry = Retry( - default_backoff(), - self.cluster_error_retry_attempts, - ) - retry.update_supported_errors([AskError, MovedError]) - return retry.call_with_retry( - lambda: self._get_connection_and_send_command(*args, **options), - self._reinitialize_on_error, + def eval(self): + """ """ + raise RedisClusterException("method eval() is not implemented") + + def load_scripts(self): + """ """ + raise RedisClusterException("method load_scripts() is not implemented") + + def script_load_for_pipeline(self, *args, **kwargs): + """ """ + raise RedisClusterException( + "method script_load_for_pipeline() is not implemented" ) - def raise_first_error(self, stack): + def delete(self, *names): """ - Raise the first exception on the stack + "Delete a key specified by ``names``" """ - for c in stack: - r = c.result - if isinstance(r, Exception): - self.annotate_exception(r, c.position + 1, c.args) - raise r + return self.execute_command("DEL", *names) - def raise_first_transaction_error(self, responses, stack): + def unlink(self, *names): """ - Raise the first exception on the stack + "Unlink a key specified by ``names``" """ - for r, cmd in zip(responses, stack): - if isinstance(r, Exception): - self.annotate_exception(r, cmd.position + 1, cmd.args) - raise r + return self.execute_command("UNLINK", *names) def annotate_exception(self, exception, number, command): """ @@ -2276,19 +2598,42 @@ def annotate_exception(self, exception, number, command): ) exception.args = (msg,) + exception.args[1:] - def execute(self, raise_on_error: bool = True) -> List[Any]: +class PipelineStrategy(AbstractStrategy): + + def __init__( + self, + cluster: RedisCluster, + cluster_response_callbacks: Optional[Dict[str, Callable]] = None, + cluster_error_retry_attempts: int = 3, + ): + super().__init__( + cluster, + cluster_response_callbacks, + cluster_error_retry_attempts + ) + self.node_flags = cluster.NODE_FLAGS.copy() + self.command_flags = cluster.COMMAND_FLAGS.copy() + + def execute_command(self, *args, **kwargs): + self.pipeline_execute_command(*args, **kwargs) + + def _raise_first_error(self, stack): """ - Execute all the commands in the current pipeline + Raise the first exception on the stack """ - stack = self.command_stack - if not stack and (not self.watching or not self.pipeline_slots): + for c in stack: + r = c.result + if isinstance(r, Exception): + self.annotate_exception(r, c.position + 1, c.args) + raise r + + def execute(self, raise_on_error: bool = True) -> List[Any]: + stack = self._command_queue + if not stack: return [] try: - if self.transaction or self.explicit_transaction: - return self._execute_transaction_with_retries(stack, raise_on_error) - else: - return self.send_cluster_commands(stack, raise_on_error) + return self.send_cluster_commands(stack, raise_on_error) finally: self.reset() @@ -2296,42 +2641,13 @@ def reset(self): """ Reset back to empty pipeline. """ - self.command_stack = [] - - self.scripts = set() - - # make sure to reset the connection state in the event that we were - # watching something - if self.transaction_connection: - try: - # call this manually since our unwatch or - # immediate_execute_command methods can call reset() - self.transaction_connection.send_command("UNWATCH") - self.transaction_connection.read_response() - # we can safely return the connection to the pool here since we're - # sure we're no longer WATCHing anything - node = self.nodes_manager.find_connection_owner( - self.transaction_connection - ) - node.redis_connection.connection_pool.release( - self.transaction_connection - ) - self.transaction_connection = None - except ConnectionError: - # disconnect will also remove any previous WATCHes - if self.transaction_connection: - self.transaction_connection.disconnect() - - # clean up the other instance attributes - self.watching = False - self.explicit_transaction = False - self.pipeline_slots = set() - self.slot_migrating = False - self.cluster_error = False - self.executing = False + self._command_queue = [] def send_cluster_commands( - self, stack, raise_on_error=True, allow_redirections=True + self, + stack, + raise_on_error=True, + allow_redirections=True ): """ Wrapper for CLUSTERDOWN error handling. @@ -2350,7 +2666,7 @@ def send_cluster_commands( """ if not stack: return [] - retry_attempts = self.cluster_error_retry_attempts + retry_attempts = self._cluster_error_retry_attempts while True: try: return self._send_cluster_commands( @@ -2368,7 +2684,10 @@ def send_cluster_commands( raise e def _send_cluster_commands( - self, stack, raise_on_error=True, allow_redirections=True + self, + stack, + raise_on_error=True, + allow_redirections=True ): """ Send a bunch of cluster commands to the redis cluster. @@ -2412,7 +2731,7 @@ def _send_cluster_commands( ) node = target_nodes[0] - if node == self.get_default_node(): + if node == self._cluster.get_default_node(): is_default_node = True # now that we know the name of the node @@ -2420,7 +2739,7 @@ def _send_cluster_commands( # we can build a list of commands for each node. node_name = node.name if node_name not in nodes: - redis_node = self.get_redis_connection(node) + redis_node = self._cluster.get_redis_connection(node) try: connection = get_connection(redis_node) except (ConnectionError, TimeoutError): @@ -2428,9 +2747,9 @@ def _send_cluster_commands( n.connection_pool.release(n.connection) # Connection retries are being handled in the node's # Retry object. Reinitialize the node -> slot table. - self.nodes_manager.initialize() + self._nodes_manager.initialize() if is_default_node: - self.replace_default_node() + self._cluster.replace_default_node() raise nodes[node_name] = NodeCommands( redis_node.parse_response, @@ -2511,11 +2830,11 @@ def _send_cluster_commands( # If a lot of commands have failed, we'll be setting the # flag to rebuild the slots table from scratch. # So MOVED errors should correct themselves fairly quickly. - self.reinitialize_counter += 1 - if self._should_reinitialized(): - self.nodes_manager.initialize() + self._cluster.reinitialize_counter += 1 + if self._cluster._should_reinitialized(): + self._nodes_manager.initialize() if is_default_node: - self.replace_default_node() + self._cluster.replace_default_node() for c in attempt: try: # send each command individually like we @@ -2528,46 +2847,207 @@ def _send_cluster_commands( # to the sequence of commands issued in the stack in pipeline.execute() response = [] for c in sorted(stack, key=lambda x: x.position): - if c.args[0] in self.cluster_response_callbacks: + if c.args[0] in self._cluster_response_callbacks: # Remove keys entry, it needs only for cache. c.options.pop("keys", None) - c.result = self.cluster_response_callbacks[c.args[0]]( + c.result = self._cluster_response_callbacks[c.args[0]]( c.result, **c.options ) response.append(c.result) if raise_on_error: - self.raise_first_error(stack) + self._raise_first_error(stack) return response - def _fail_on_redirect(self, allow_redirections): - """ """ - if not allow_redirections: - raise RedisClusterException( - "ASK & MOVED redirection not allowed in this pipeline" + def _is_nodes_flag(self, target_nodes): + return isinstance(target_nodes, str) and target_nodes in self.node_flags + + def _parse_target_nodes(self, target_nodes): + if isinstance(target_nodes, list): + nodes = target_nodes + elif isinstance(target_nodes, ClusterNode): + # Supports passing a single ClusterNode as a variable + nodes = [target_nodes] + elif isinstance(target_nodes, dict): + # Supports dictionaries of the format {node_name: node}. + # It enables to execute commands with multi nodes as follows: + # rc.cluster_save_config(rc.get_primaries()) + nodes = target_nodes.values() + else: + raise TypeError( + "target_nodes type can be one of the following: " + "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," + "ClusterNode, list, or dict. " + f"The passed type is {type(target_nodes)}" ) + return nodes - def exists(self, *keys): - return self.execute_command("EXISTS", *keys) + def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: + # Determine which nodes should be executed the command on. + # Returns a list of target nodes. + command = args[0].upper() + if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags: + command = f"{args[0]} {args[1]}".upper() - def eval(self): - """ """ - raise RedisClusterException("method eval() is not implemented") + nodes_flag = kwargs.pop("nodes_flag", None) + if nodes_flag is not None: + # nodes flag passed by the user + command_flag = nodes_flag + else: + # get the nodes group for this command if it was predefined + command_flag = self.command_flags.get(command) + if command_flag == self._cluster.RANDOM: + # return a random node + return [self._cluster.get_random_node()] + elif command_flag == self._cluster.PRIMARIES: + # return all primaries + return self._cluster.get_primaries() + elif command_flag == self._cluster.REPLICAS: + # return all replicas + return self._cluster.get_replicas() + elif command_flag == self._cluster.ALL_NODES: + # return all nodes + return self._cluster.get_nodes() + elif command_flag == self._cluster.DEFAULT_NODE: + # return the cluster's default node + return [self._nodes_manager.default_node] + elif command in self._cluster.SEARCH_COMMANDS[0]: + return [self._nodes_manager.default_node] + else: + # get the node that holds the key's slot + slot = self._cluster.determine_slot(*args) + node = self._nodes_manager.get_node_from_slot( + slot, + self._cluster.read_from_replicas and command in READ_COMMANDS, + self._cluster.load_balancing_strategy if command in READ_COMMANDS else None, + ) + return [node] def multi(self): + raise RedisClusterException("method multi() is not implemented") + + def discard(self): + raise RedisClusterException("method discard() is not implemented") + + def watch(self, *names): + raise RedisClusterException("method watch() is not implemented") + + def unwatch(self, *names): + raise RedisClusterException("method unwatch() is not implemented") + + +class TransactionStrategy(AbstractStrategy): + + NO_SLOTS_COMMANDS = {"UNWATCH"} + IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} + UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} + + def __init__( + self, + cluster: RedisCluster, + cluster_response_callbacks: Optional[Dict[str, Callable]] = None, + cluster_error_retry_attempts: int = 3, + ): + super().__init__( + cluster, + cluster_response_callbacks, + cluster_error_retry_attempts + ) + self._explicit_transaction = False + self._watching = False + self._pipeline_slots: Set[int] = set() + self._transaction_connection: Optional[Connection] = None + self._cluster_error = False + self._slot_migrating = False + self._executing = False + + def _get_client_and_connection_for_transaction(self) -> Tuple[Redis, Connection]: """ - Start a transactional block of the pipeline after WATCH commands - are issued. End the transactional block with `execute`. - """ + Find a connection for a pipeline transaction. - if self.explicit_transaction: - raise RedisError("Cannot issue nested calls to MULTI") - if self.command_stack: - raise RedisError( - "Commands without an initial WATCH have already been issued" + For running an atomic transaction, watch keys ensure that contents have not been + altered as long as the watch commands for those keys were sent over the same + connection. So once we start watching a key, we fetch a connection to the + node that owns that slot and reuse it. + """ + if not self._pipeline_slots: + raise RedisClusterException( + "At least a command with a key is needed to identify a node" ) - self.explicit_transaction = True + + node: ClusterNode = self._nodes_manager.get_node_from_slot( + list(self._pipeline_slots)[0], False + ) + redis_node: Redis = self._cluster.get_redis_connection(node) + if self._transaction_connection: + if not redis_node.connection_pool.owns_connection( + self._transaction_connection + ): + previous_node = self._nodes_manager.find_connection_owner( + self._transaction_connection + ) + previous_node.connection_pool.release(self._transaction_connection) + self._transaction_connection = None + + if not self._transaction_connection: + self._transaction_connection = get_connection(redis_node) + + return redis_node, self._transaction_connection + + def execute_command(self, *args, **kwargs): + slot_number: Optional[int] = None + if args[0] not in ClusterPipeline.NO_SLOTS_COMMANDS: + slot_number = self._cluster.determine_slot(*args) + + if ( + self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS + ) and not self._explicit_transaction: + if args[0] == "WATCH": + self._validate_watch() + + if slot_number is not None: + if self._pipeline_slots and slot_number not in self._pipeline_slots: + raise CrossSlotTransactionError( + "Cannot watch or send commands on different slots" + ) + + self._pipeline_slots.add(slot_number) + elif args[0] not in self.NO_SLOTS_COMMANDS: + raise RedisClusterException( + f"Cannot identify slot number for command: {args[0]}," + "it cannot be triggered in a transaction" + ) + + return self._immediate_execute_command(*args, **kwargs) + else: + if slot_number is not None: + self._pipeline_slots.add(slot_number) + + return self.pipeline_execute_command(*args, **kwargs) + + def _validate_watch(self): + if self._explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + + self._watching = True + + def _immediate_execute_command(self, *args, **options): + retry = Retry( + default_backoff(), + self._cluster.cluster_error_retry_attempts, + ) + retry.update_supported_errors([AskError, MovedError]) + return retry.call_with_retry( + lambda: self._get_connection_and_send_command(*args, **options), + self._reinitialize_on_error, + ) + + def _get_connection_and_send_command(self, *args, **options): + redis_node, connection = self._get_client_and_connection_for_transaction() + return self._send_command_parse_response( + connection, redis_node, args[0], *args, **options + ) def _send_command_parse_response( self, conn, redis_node: Redis, command_name, *args, **options @@ -2586,64 +3066,53 @@ def _send_command_parse_response( raise slot_error if command_name in self.UNWATCH_COMMANDS: - self.watching = False + self._watching = False return output - def _disconnect_reset_raise(self, conn, error): - """ - Close the connection, reset watching state and - raise an exception if we were watching, - retry_on_timeout is not set, - or the error is not a TimeoutError - """ - if not conn: - conn = self.transaction_connection - - if conn: - conn.disconnect() - - # if we were already watching a variable, the watch is no longer - # valid since this connection has died. raise a WatchError, which - # indicates the user should retry this transaction. - if self.watching: - self.reset() - raise WatchError( - "A ConnectionError occurred on while watching one or more keys" - ) - # if retry_on_timeout is not set, or the error is not - # a TimeoutError, raise it - if not (conn and conn.retry_on_timeout and isinstance(error, TimeoutError)): - self.reset() - raise - def _reinitialize_on_error(self, error): - if self.watching: - if self.slot_migrating and self.executing: + if self._watching: + if self.slot_migrating and self._executing: raise WatchError("Slot rebalancing ocurred while watching keys") - if self.cluster_error: + if self._cluster_error: raise RedisClusterException("Cluster error ocurred while watching keys") - if self.slot_migrating or self.cluster_error: - if self.transaction_connection: - self.transaction_connection = None + if self.slot_migrating or self._cluster_error: + if self._transaction_connection: + self._transaction_connection = None - self.reinitialize_counter += 1 - if self._should_reinitialized(): - self.nodes_manager.initialize() + self._cluster.reinitialize_counter += 1 + if self._cluster._should_reinitialized(): + self._nodes_manager.initialize() self.reinitialize_counter = 0 else: - self.nodes_manager.update_moved_exception(error) + self._nodes_manager.update_moved_exception(error) self.slot_migrating = False - self.cluster_error = False - self.executing = False + self._cluster_error = False + self._executing = False + + def _raise_first_error(self, responses, stack): + """ + Raise the first exception on the stack + """ + for r, cmd in zip(responses, stack): + if isinstance(r, Exception): + self.annotate_exception(r, cmd.position + 1, cmd.args) + raise r + + def execute(self, raise_on_error: bool = True) -> List[Any]: + stack = self._command_queue + if not stack and (not self._watching or not self._pipeline_slots): + return [] + + return self._execute_transaction_with_retries(stack, raise_on_error) def _execute_transaction_with_retries( self, stack: List["PipelineCommand"], raise_on_error: bool ): retry = Retry( default_backoff(), - self.cluster_error_retry_attempts, + self._cluster.cluster_error_retry_attempts, ) retry.update_supported_errors([AskError, MovedError]) return retry.call_with_retry( @@ -2654,14 +3123,14 @@ def _execute_transaction_with_retries( def _execute_transaction( self, stack: List["PipelineCommand"], raise_on_error: bool ): - if len(self.pipeline_slots) > 1: + if len(self._pipeline_slots) > 1: raise CrossSlotTransactionError( "All keys involved in a cluster transaction must map to the same slot" ) - self.executing = True + self._executing = True self.slot_migrating = False - self.cluster_error = False + self._cluster_error = False redis_node, connection = self._get_client_and_connection_for_transaction() @@ -2686,7 +3155,7 @@ def _execute_transaction( errors.append(e) # and all the other commands - for i, command in enumerate(self.command_stack): + for i, command in enumerate(self._command_queue): if EMPTY_RESPONSE in command.options: errors.append((i, command.options[EMPTY_RESPONSE])) else: @@ -2697,7 +3166,7 @@ def _execute_transaction( self.annotate_exception(slot_error, i + 1, command.args) errors.append(slot_error) except (ClusterDownError, ConnectionError) as cluster_error: - self.cluster_error = True + self._cluster_error = True self.annotate_exception(cluster_error, i + 1, command.args) raise except ResponseError as e: @@ -2713,10 +3182,10 @@ def _execute_transaction( raise errors[0] raise - self.executing = False + self._executing = False # EXEC clears any watched keys - self.watching = False + self._watching = False if response is None: raise WatchError("Watched variable changed.") @@ -2725,255 +3194,91 @@ def _execute_transaction( for i, e in errors: response.insert(i, e) - if len(response) != len(self.command_stack): + if len(response) != len(self._command_queue): raise InvalidPipelineStack( "Unexpected response length for cluster pipeline EXEC." " Command stack was {} but response had length {}".format( - [c.args[0] for c in self.command_stack], len(response) + [c.args[0] for c in self._command_queue], len(response) ) ) # find any errors in the response and raise if necessary if raise_on_error or self.slot_migrating: - self.raise_first_transaction_error( + self._raise_first_error( response, - self.command_stack, + self._command_queue, ) # We have to run response callbacks manually data = [] - for r, cmd in zip(response, self.command_stack): + for r, cmd in zip(response, self._command_queue): if not isinstance(r, Exception): command_name = cmd.args[0] - if command_name in self.cluster_response_callbacks: - r = self.response_callbacks[command_name](r, **cmd.options) + if command_name in self._cluster.cluster_response_callbacks: + r = self._cluster_response_callbacks[command_name](r, **cmd.options) data.append(r) return data - def load_scripts(self): - """ """ - raise RedisClusterException("method load_scripts() is not implemented") + def reset(self): + self._command_queue = [] - def discard(self): - if self.transaction or self.explicit_transaction: - self.reset() - return + # make sure to reset the connection state in the event that we were + # watching something + if self._transaction_connection: + try: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + self._transaction_connection.send_command("UNWATCH") + self._transaction_connection.read_response() + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + node = self._nodes_manager.find_connection_owner( + self._transaction_connection + ) + node.redis_connection.connection_pool.release( + self._transaction_connection + ) + self._transaction_connection = None + except ConnectionError: + # disconnect will also remove any previous WATCHes + if self._transaction_connection: + self._transaction_connection.disconnect() - if not self.explicit_transaction: - raise RedisClusterException("DISCARD triggered without MULTI") + # clean up the other instance attributes + self._watching = False + self._explicit_transaction = False + self._pipeline_slots = set() + self._slot_migrating = False + self._cluster_error = False + self._executing = False - def _validate_watch(self): - if self.explicit_transaction: - raise RedisError("Cannot issue a WATCH after a MULTI") + def send_cluster_commands(self, stack, raise_on_error=True, allow_redirections=True): + raise NotImplementedError("send_cluster_commands cannot be executed in transactional context.") - self.watching = True + def multi(self): + if self._explicit_transaction: + raise RedisError("Cannot issue nested calls to MULTI") + if self._command_queue: + raise RedisError( + "Commands without an initial WATCH have already been issued" + ) + self._explicit_transaction = True def watch(self, *names): - """Watches the values at keys ``names``""" - - if self.explicit_transaction: + if self._explicit_transaction: raise RedisError("Cannot issue a WATCH after a MULTI") return self.execute_command("WATCH", *names) def unwatch(self): - """Unwatches all previously specified keys""" - - if self.watching: + if self._watching: return self.execute_command("UNWATCH") return True - def script_load_for_pipeline(self, *args, **kwargs): - """ """ - raise RedisClusterException( - "method script_load_for_pipeline() is not implemented" - ) - - def delete(self, *names): - """ - "Delete a key specified by ``names``" - """ - return self.execute_command("DEL", *names) - - def unlink(self, *names): - """ - "Unlink a key specified by ``names``" - """ - return self.execute_command("UNLINK", *names) - - -def block_pipeline_command(name: str) -> Callable[..., Any]: - """ - Prints error because some pipelined commands should - be blocked when running in cluster-mode - """ - - def inner(*args, **kwargs): - raise RedisClusterException( - f"ERROR: Calling pipelined function {name} is blocked " - f"when running redis in cluster mode..." - ) - - return inner - - -# Blocked pipeline commands -PIPELINE_BLOCKED_COMMANDS = ( - "BGREWRITEAOF", - "BGSAVE", - "BITOP", - "BRPOPLPUSH", - "CLIENT GETNAME", - "CLIENT KILL", - "CLIENT LIST", - "CLIENT SETNAME", - "CLIENT", - "CONFIG GET", - "CONFIG RESETSTAT", - "CONFIG REWRITE", - "CONFIG SET", - "CONFIG", - "DBSIZE", - "ECHO", - "EVALSHA", - "FLUSHALL", - "FLUSHDB", - "INFO", - "KEYS", - "LASTSAVE", - "MGET", - "MGET NONATOMIC", - "MOVE", - "MSET", - "MSET NONATOMIC", - "MSETNX", - "PFCOUNT", - "PFMERGE", - "PING", - "PUBLISH", - "RANDOMKEY", - "READONLY", - "READWRITE", - "RENAME", - "RENAMENX", - "RPOPLPUSH", - "SAVE", - "SCAN", - "SCRIPT EXISTS", - "SCRIPT FLUSH", - "SCRIPT KILL", - "SCRIPT LOAD", - "SCRIPT", - "SDIFF", - "SDIFFSTORE", - "SENTINEL GET MASTER ADDR BY NAME", - "SENTINEL MASTER", - "SENTINEL MASTERS", - "SENTINEL MONITOR", - "SENTINEL REMOVE", - "SENTINEL SENTINELS", - "SENTINEL SET", - "SENTINEL SLAVES", - "SENTINEL", - "SHUTDOWN", - "SINTER", - "SINTERSTORE", - "SLAVEOF", - "SLOWLOG GET", - "SLOWLOG LEN", - "SLOWLOG RESET", - "SLOWLOG", - "SMOVE", - "SORT", - "SUNION", - "SUNIONSTORE", - "TIME", -) -for command in PIPELINE_BLOCKED_COMMANDS: - command = command.replace(" ", "_").lower() - - setattr(ClusterPipeline, command, block_pipeline_command(command)) - - -class PipelineCommand: - """ """ - - def __init__(self, args, options=None, position=None): - self.args = args - if options is None: - options = {} - self.options = options - self.position = position - self.result = None - self.node = None - self.asking = False - - -class NodeCommands: - """ """ - - def __init__(self, parse_response, connection_pool, connection): - """ """ - self.parse_response = parse_response - self.connection_pool = connection_pool - self.connection = connection - self.commands = [] - - def append(self, c): - """ """ - self.commands.append(c) - - def write(self): - """ - Code borrowed from Redis so it can be fixed - """ - connection = self.connection - commands = self.commands - - # We are going to clobber the commands with the write, so go ahead - # and ensure that nothing is sitting there from a previous run. - for c in commands: - c.result = None - - # build up all commands into a single request to increase network perf - # send all the commands and catch connection and timeout errors. - try: - connection.send_packed_command( - connection.pack_commands([c.args for c in commands]) - ) - except (ConnectionError, TimeoutError) as e: - for c in commands: - c.result = e + def discard(self): + if self._explicit_transaction: + self.reset() + return - def read(self): - """ """ - connection = self.connection - for c in self.commands: - # if there is a result on this command, - # it means we ran into an exception - # like a connection error. Trying to parse - # a response on a connection that - # is no longer open will result in a - # connection error raised by redis-py. - # but redis-py doesn't check in parse_response - # that the sock object is - # still set and if you try to - # read from a closed connection, it will - # result in an AttributeError because - # it will do a readline() call on None. - # This can have all kinds of nasty side-effects. - # Treating this case as a connection error - # is fine because it will dump - # the connection object back into the - # pool and on the next write, it will - # explicitly open the connection and all will be well. - if c.result is None: - try: - c.result = self.parse_response(connection, c.args[0], **c.options) - except (ConnectionError, TimeoutError) as e: - for c in self.commands: - c.result = e - return - except RedisError: - c.result = sys.exc_info()[1] + raise RedisClusterException("DISCARD triggered without MULTI") \ No newline at end of file diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 622a42d7bc..1f79ee9ef3 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -265,12 +265,12 @@ def moved_redirection_helper(request, failover=False): slot = 12182 redirect_node = None # Get the current primary that holds this slot - prev_primary = rc.nodes_manager.get_node_from_slot(slot) + prev_primary = rc._nodes_manager.get_node_from_slot(slot) if failover: - if len(rc.nodes_manager.slots_cache[slot]) < 2: + if len(rc._nodes_manager.slots_cache[slot]) < 2: warnings.warn("Skipping this test since it requires to have a replica") return - redirect_node = rc.nodes_manager.slots_cache[slot][1] + redirect_node = rc._nodes_manager.slots_cache[slot][1] else: # Use one of the primaries to be the redirected node redirect_node = rc.get_primaries()[0] @@ -290,7 +290,7 @@ def ok_response(connection, *args, **options): parse_response.side_effect = moved_redirect_effect assert rc.execute_command("SET", "foo", "bar") == "MOCK_OK" - slot_primary = rc.nodes_manager.slots_cache[slot][0] + slot_primary = rc._nodes_manager.slots_cache[slot][0] assert slot_primary == redirect_node if failover: assert rc.get_node(host=r_host, port=r_port).server_type == PRIMARY @@ -722,7 +722,7 @@ def test_all_nodes(self, r): """ Set a list of nodes and it should be possible to iterate over all """ - nodes = [node for node in r.nodes_manager.nodes_cache.values()] + nodes = [node for node in r._nodes_manager.nodes_cache.values()] for i, node in enumerate(r.get_nodes()): assert node in nodes @@ -734,7 +734,7 @@ def test_all_nodes_masters(self, r): """ nodes = [ node - for node in r.nodes_manager.nodes_cache.values() + for node in r._nodes_manager.nodes_cache.values() if node.server_type == PRIMARY ] @@ -805,7 +805,7 @@ def test_get_node_from_key(self, r): """ key = "bar" slot = r.keyslot(key) - slot_nodes = r.nodes_manager.slots_cache.get(slot) + slot_nodes = r._nodes_manager.slots_cache.get(slot) primary = slot_nodes[0] assert r.get_node_from_key(key, replica=False) == primary replica = r.get_node_from_key(key, replica=True) @@ -1003,8 +1003,8 @@ class TestClusterRedisCommands: def test_case_insensitive_command_names(self, r): assert ( - r.cluster_response_callbacks["cluster slots"] - == r.cluster_response_callbacks["CLUSTER SLOTS"] + r._cluster_response_callbacks["cluster slots"] + == r._cluster_response_callbacks["CLUSTER SLOTS"] ) def test_get_and_set(self, r): @@ -1275,7 +1275,7 @@ def test_cluster_addslotsrange(self, r): @skip_if_redis_enterprise() def test_cluster_countkeysinslot(self, r): - node = r.nodes_manager.get_node_from_slot(1) + node = r._nodes_manager.get_node_from_slot(1) mock_node_resp(node, 2) assert r.cluster_countkeysinslot(1) == 2 @@ -1450,7 +1450,7 @@ def test_cluster_save_config(self, r): @skip_if_redis_enterprise() def test_cluster_get_keys_in_slot(self, r): response = ["{foo}1", "{foo}2"] - node = r.nodes_manager.get_node_from_slot(12182) + node = r._nodes_manager.get_node_from_slot(12182) mock_node_resp(node, response) keys = r.cluster_get_keys_in_slot(12182, 4) assert keys == response @@ -1476,7 +1476,7 @@ def test_cluster_setslot(self, r): r.cluster_failover(node, "STATE") def test_cluster_setslot_stable(self, r): - node = r.nodes_manager.get_node_from_slot(12182) + node = r._nodes_manager.get_node_from_slot(12182) mock_node_resp(node, "OK") assert r.cluster_setslot_stable(12182) is True assert node.redis_connection.connection.read_response.called @@ -1562,7 +1562,7 @@ def test_info(self, r): r.set("z{1}", 3) # Get node that handles the slot slot = r.keyslot("x{1}") - node = r.nodes_manager.get_node_from_slot(slot) + node = r._nodes_manager.get_node_from_slot(slot) # Run info on that node info = r.info(target_nodes=node) assert isinstance(info, dict) @@ -1618,7 +1618,7 @@ def test_slowlog_get_limit(self, r, slowlog): def test_slowlog_length(self, r, slowlog): r.get("foo") - node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) + node = r._nodes_manager.get_node_from_slot(key_slot(b"foo")) slowlog_len = r.slowlog_len(target_nodes=node) assert isinstance(slowlog_len, int) @@ -1644,7 +1644,7 @@ def test_memory_stats(self, r): # put a key into the current db to make sure that "db." # has data r.set("foo", "bar") - node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) + node = r._nodes_manager.get_node_from_slot(key_slot(b"foo")) stats = r.memory_stats(target_nodes=node) assert isinstance(stats, dict) for key, value in stats.items(): @@ -2530,7 +2530,7 @@ class TestNodesManager: """ def test_load_balancer(self, r): - n_manager = r.nodes_manager + n_manager = r._nodes_manager lb = n_manager.read_load_balancer slot_1 = 1257 slot_2 = 8975 @@ -3342,7 +3342,7 @@ def raise_error(target_node, *args, **kwargs): # 4 = 2 get_connections per execution * 2 executions assert get_connection.call_count == 4 - for cluster_node in r.nodes_manager.nodes_cache.values(): + for cluster_node in r._nodes_manager.nodes_cache.values(): connection_pool = cluster_node.redis_connection.connection_pool num_of_conns = len(connection_pool._available_connections) assert num_of_conns == connection_pool._created_connections @@ -3428,7 +3428,7 @@ def test_readonly_pipeline_from_readonly_client(self, request): mock_all_nodes_resp(ro, "MOCK_OK") assert readonly_pipe.read_from_replicas is True assert readonly_pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] - slot_nodes = ro.nodes_manager.slots_cache[ro.keyslot(key)] + slot_nodes = ro._nodes_manager.slots_cache[ro.keyslot(key)] if len(slot_nodes) > 1: executed_on_replica = False for node in slot_nodes: @@ -3468,7 +3468,7 @@ def test_readonly_pipeline_with_reading_from_replicas_strategies( mock_all_nodes_resp(ro, "MOCK_OK") assert readonly_pipe.load_balancing_strategy == load_balancing_strategy assert readonly_pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] - slot_nodes = ro.nodes_manager.slots_cache[ro.keyslot(key)] + slot_nodes = ro._nodes_manager.slots_cache[ro.keyslot(key)] executed_on_replicas_only = True for node in slot_nodes: if node.server_type == PRIMARY: diff --git a/tests/test_cluster_transaction.py b/tests/test_cluster_transaction.py index c70d2ac9bb..61134cf9ec 100644 --- a/tests/test_cluster_transaction.py +++ b/tests/test_cluster_transaction.py @@ -185,7 +185,7 @@ def test_watch_succeed(self, r): with r.pipeline() as pipe: pipe.watch(f"{hashkey}:a", f"{hashkey}:b") - assert pipe.watching + assert pipe._watching a_value = pipe.get(f"{hashkey}:a") b_value = pipe.get(f"{hashkey}:b") assert a_value == b"1" @@ -194,7 +194,7 @@ def test_watch_succeed(self, r): pipe.set(f"{hashkey}:c", 3) assert pipe.execute() == [b"OK"] - assert not pipe.watching + assert not pipe._watching @pytest.mark.onlycluster def test_watch_failure(self, r): @@ -210,7 +210,7 @@ def test_watch_failure(self, r): with pytest.raises(redis.WatchError): pipe.execute() - assert not pipe.watching + assert not pipe._watching @pytest.mark.onlycluster def test_cross_slot_watch_single_call_failure(self, r): @@ -222,7 +222,7 @@ def test_cross_slot_watch_single_call_failure(self, r): "WATCH - all keys must map to the same key slot" ) - assert not pipe.watching + assert not pipe._watching @pytest.mark.onlycluster def test_cross_slot_watch_multiple_calls_failure(self, r): @@ -235,7 +235,7 @@ def test_cross_slot_watch_multiple_calls_failure(self, r): "Cannot watch or send commands on different slots" ) - assert pipe.watching + assert pipe._watching @pytest.mark.onlycluster def test_watch_failure_in_empty_transaction(self, r): @@ -250,7 +250,7 @@ def test_watch_failure_in_empty_transaction(self, r): with pytest.raises(redis.WatchError): pipe.execute() - assert not pipe.watching + assert not pipe._watching @pytest.mark.onlycluster def test_unwatch(self, r): @@ -262,7 +262,7 @@ def test_unwatch(self, r): pipe.watch(f"{hashkey}:a", f"{hashkey}:b") r[f"{hashkey}:b"] = 3 pipe.unwatch() - assert not pipe.watching + assert not pipe._watching pipe.get(f"{hashkey}:a") assert pipe.execute() == [b"1"] @@ -273,11 +273,11 @@ def test_watch_exec_auto_unwatch(self, r): r[f"{hashkey}:b"] = 2 target_slot = r.determine_slot("GET", f"{hashkey}:a") - target_node = r.nodes_manager.get_node_from_slot(target_slot) + target_node = r._nodes_manager.get_node_from_slot(target_slot) with r.monitor(target_node=target_node) as m: with r.pipeline() as pipe: pipe.watch(f"{hashkey}:a", f"{hashkey}:b") - assert pipe.watching + assert pipe._watching a_value = pipe.get(f"{hashkey}:a") b_value = pipe.get(f"{hashkey}:b") assert a_value == b"1" @@ -285,7 +285,7 @@ def test_watch_exec_auto_unwatch(self, r): pipe.multi() pipe.set(f"{hashkey}:c", 3) assert pipe.execute() == [b"OK"] - assert not pipe.watching + assert not pipe._watching unwatch_command = wait_for_command( r, m, "UNWATCH", key=f"{hashkey}:test_watch_exec_auto_unwatch" @@ -300,13 +300,13 @@ def test_watch_reset_unwatch(self, r): r[f"{hashkey}:a"] = 1 target_slot = r.determine_slot("GET", f"{hashkey}:a") - target_node = r.nodes_manager.get_node_from_slot(target_slot) + target_node = r._nodes_manager.get_node_from_slot(target_slot) with r.monitor(target_node=target_node) as m: with r.pipeline() as pipe: pipe.watch(f"{hashkey}:a") - assert pipe.watching + assert pipe._watching pipe.reset() - assert not pipe.watching + assert not pipe._watching unwatch_command = wait_for_command( r, m, "UNWATCH", key=f"{hashkey}:test_watch_reset_unwatch" @@ -396,7 +396,7 @@ def test_transaction_discard(self, r): pipe.set(f"{hashkey}:key", "someval") pipe.discard() - assert not pipe.watching + assert not pipe._watching assert not pipe.command_stack # pipelines with multi can be discarded @@ -406,7 +406,7 @@ def test_transaction_discard(self, r): pipe.set(f"{hashkey}:key", "someval") pipe.discard() - assert not pipe.watching + assert not pipe._watching assert not pipe.command_stack @pytest.mark.onlycluster @@ -485,11 +485,11 @@ def ask_redirect_effect(conn, *args, **options): assert False, f"unexpected node {conn.host}:{conn.port} was called" def update_moved_slot(): # simulate slot table update - ask_error = r.nodes_manager._moved_exception + ask_error = r._nodes_manager._moved_exception assert ask_error is not None, "No AskError was previously triggered" assert f"{ask_error.host}:{ask_error.port}" == node_importing.name - r.nodes_manager._moved_exception = None - r.nodes_manager.slots_cache[slot] = [node_importing] + r._nodes_manager._moved_exception = None + r._nodes_manager.slots_cache[slot] = [node_importing] parse_response.side_effect = ask_redirect_effect manager_update_moved_slots.side_effect = update_moved_slot @@ -559,7 +559,7 @@ def test_retry_transaction_with_watch_after_slot_migration(self, r): # force a MovedError on the first call to pipe.watch() # by switching the node that owns the slot to another one _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) - r.nodes_manager.slots_cache[slot] = [node_importing] + r._nodes_manager.slots_cache[slot] = [node_importing] with r.pipeline(transaction=True) as pipe: pipe.watch(key) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 2e9b21ada4..2e4b5b2cd5 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -198,7 +198,7 @@ def test_watch_succeed(self, r): with r.pipeline() as pipe: pipe.watch("a", "b") - assert pipe.watching + assert pipe._watching a_value = pipe.get("a") b_value = pipe.get("b") assert a_value == b"1" @@ -207,7 +207,7 @@ def test_watch_succeed(self, r): pipe.set("c", 3) assert pipe.execute() == [True] - assert not pipe.watching + assert not pipe._watching @pytest.mark.onlynoncluster def test_watch_failure(self, r): @@ -222,7 +222,7 @@ def test_watch_failure(self, r): with pytest.raises(redis.WatchError): pipe.execute() - assert not pipe.watching + assert not pipe._watching @pytest.mark.onlynoncluster def test_watch_failure_in_empty_transaction(self, r): @@ -236,7 +236,7 @@ def test_watch_failure_in_empty_transaction(self, r): with pytest.raises(redis.WatchError): pipe.execute() - assert not pipe.watching + assert not pipe._watching @pytest.mark.onlynoncluster def test_unwatch(self, r): @@ -247,7 +247,7 @@ def test_unwatch(self, r): pipe.watch("a", "b") r["b"] = 3 pipe.unwatch() - assert not pipe.watching + assert not pipe._watching pipe.get("a") assert pipe.execute() == [b"1"] @@ -259,7 +259,7 @@ def test_watch_exec_no_unwatch(self, r): with r.monitor() as m: with r.pipeline() as pipe: pipe.watch("a", "b") - assert pipe.watching + assert pipe._watching a_value = pipe.get("a") b_value = pipe.get("b") assert a_value == b"1" @@ -267,7 +267,7 @@ def test_watch_exec_no_unwatch(self, r): pipe.multi() pipe.set("c", 3) assert pipe.execute() == [True] - assert not pipe.watching + assert not pipe._watching unwatch_command = wait_for_command(r, m, "UNWATCH") assert unwatch_command is None, "should not send UNWATCH" @@ -279,9 +279,9 @@ def test_watch_reset_unwatch(self, r): with r.monitor() as m: with r.pipeline() as pipe: pipe.watch("a") - assert pipe.watching + assert pipe._watching pipe.reset() - assert not pipe.watching + assert not pipe._watching unwatch_command = wait_for_command(r, m, "UNWATCH") assert unwatch_command is not None From 17e21189091752469083f88088724597a2ef12fd Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 2 May 2025 11:42:57 +0300 Subject: [PATCH 5/7] Refactored strategy to use composition --- redis/cluster.py | 195 +++++++++++++++++++++--------------------- tests/test_cluster.py | 4 +- 2 files changed, 101 insertions(+), 98 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index bf49c18b6c..c9b3b0718d 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2088,7 +2088,6 @@ def __init__( **kwargs, ): """ """ - self.command_stack: List[PipelineCommand] = [] self.nodes_manager = nodes_manager self.commands_parser = commands_parser self.refresh_table_asap = False @@ -2111,23 +2110,12 @@ def __init__( if lock is None: lock = threading.Lock() self._lock = lock - self.transaction = transaction + self.parent_execute_command = super().execute_command self._execution_strategy: ExecutionStrategy = PipelineStrategy( - super(), - cluster_response_callbacks=self.cluster_response_callbacks, - cluster_error_retry_attempts=self.cluster_error_retry_attempts, - ) if not self.transaction else TransactionStrategy( - super(), - cluster_response_callbacks=self.cluster_response_callbacks, - cluster_error_retry_attempts=self.cluster_error_retry_attempts, + self + ) if not transaction else TransactionStrategy( + self ) - self.explicit_transaction = False - self.watching = False - self.transaction_connection: Optional[Connection] = None - self.pipeline_slots: Set[int] = set() - self.slot_migrating = False - self.cluster_error = False - self.executing = False def __repr__(self): """ """ @@ -2149,7 +2137,7 @@ def __del__(self): def __len__(self): """ """ - return len(self.command_stack) + return len(self._execution_strategy.command_queue) def __bool__(self): "Pipeline instances should always evaluate to True on Python 3+" @@ -2159,7 +2147,7 @@ def execute_command(self, *args, **kwargs): """ Wrapper function for pipeline_execute_command """ - return self._execution_strategy.execute_command(*args, **kwargs) + return self.pipeline_execute_command(*args, **kwargs) def pipeline_execute_command(self, *args, **options): """ @@ -2423,6 +2411,11 @@ def read(self): class ExecutionStrategy(ABC): + @property + @abstractmethod + def command_queue(self): + pass + @abstractmethod def execute_command(self, *args, **kwargs): """ @@ -2504,6 +2497,11 @@ def watch(self, *names): @abstractmethod def unwatch(self): + """ + Unwatches all previously specified keys + + See: ClusterPipeline.unwatch() + """ pass @abstractmethod @@ -2512,10 +2510,20 @@ def script_load_for_pipeline(self, *args, **kwargs): @abstractmethod def delete(self, *names): + """ + "Delete a key specified by ``names``" + + See: ClusterPipeline.delete() + """ pass @abstractmethod def unlink(self, *names): + """ + "Unlink a key specified by ``names``" + + See: ClusterPipeline.unlink() + """ pass @abstractmethod @@ -2527,15 +2535,19 @@ class AbstractStrategy(ExecutionStrategy): def __init__( self, - cluster: RedisCluster, - cluster_response_callbacks: Optional[Dict[str, Callable]] = None, - cluster_error_retry_attempts: int = 3, + pipe: ClusterPipeline, ): self._command_queue: List[PipelineCommand] = [] - self._cluster = cluster - self._nodes_manager = self._cluster.nodes_manager - self._cluster_response_callbacks = cluster_response_callbacks - self._cluster_error_retry_attempts = cluster_error_retry_attempts + self._pipe = pipe + self._nodes_manager = self._pipe.nodes_manager + + @property + def command_queue(self): + return self._command_queue + + @command_queue.setter + def command_queue(self, queue: List[PipelineCommand]): + self._command_queue = queue @abstractmethod def execute_command(self, *args, **kwargs): @@ -2575,18 +2587,6 @@ def script_load_for_pipeline(self, *args, **kwargs): "method script_load_for_pipeline() is not implemented" ) - def delete(self, *names): - """ - "Delete a key specified by ``names``" - """ - return self.execute_command("DEL", *names) - - def unlink(self, *names): - """ - "Unlink a key specified by ``names``" - """ - return self.execute_command("UNLINK", *names) - def annotate_exception(self, exception, number, command): """ Provides extra context to the exception prior to it being handled @@ -2600,19 +2600,9 @@ def annotate_exception(self, exception, number, command): class PipelineStrategy(AbstractStrategy): - def __init__( - self, - cluster: RedisCluster, - cluster_response_callbacks: Optional[Dict[str, Callable]] = None, - cluster_error_retry_attempts: int = 3, - ): - super().__init__( - cluster, - cluster_response_callbacks, - cluster_error_retry_attempts - ) - self.node_flags = cluster.NODE_FLAGS.copy() - self.command_flags = cluster.COMMAND_FLAGS.copy() + def __init__(self, pipe: ClusterPipeline): + super().__init__(pipe) + self.command_flags = pipe.command_flags def execute_command(self, *args, **kwargs): self.pipeline_execute_command(*args, **kwargs) @@ -2666,7 +2656,7 @@ def send_cluster_commands( """ if not stack: return [] - retry_attempts = self._cluster_error_retry_attempts + retry_attempts = self._pipe.cluster_error_retry_attempts while True: try: return self._send_cluster_commands( @@ -2731,7 +2721,7 @@ def _send_cluster_commands( ) node = target_nodes[0] - if node == self._cluster.get_default_node(): + if node == self._pipe.get_default_node(): is_default_node = True # now that we know the name of the node @@ -2739,7 +2729,7 @@ def _send_cluster_commands( # we can build a list of commands for each node. node_name = node.name if node_name not in nodes: - redis_node = self._cluster.get_redis_connection(node) + redis_node = self._pipe.get_redis_connection(node) try: connection = get_connection(redis_node) except (ConnectionError, TimeoutError): @@ -2749,7 +2739,7 @@ def _send_cluster_commands( # Retry object. Reinitialize the node -> slot table. self._nodes_manager.initialize() if is_default_node: - self._cluster.replace_default_node() + self._pipe.replace_default_node() raise nodes[node_name] = NodeCommands( redis_node.parse_response, @@ -2830,16 +2820,16 @@ def _send_cluster_commands( # If a lot of commands have failed, we'll be setting the # flag to rebuild the slots table from scratch. # So MOVED errors should correct themselves fairly quickly. - self._cluster.reinitialize_counter += 1 - if self._cluster._should_reinitialized(): + self._pipe.reinitialize_counter += 1 + if self._pipe._should_reinitialized(): self._nodes_manager.initialize() if is_default_node: - self._cluster.replace_default_node() + self._pipe.replace_default_node() for c in attempt: try: # send each command individually like we # do in the main client. - c.result = super().execute_command(*c.args, **c.options) + c.result = self._pipe.parent_execute_command(*c.args, **c.options) except RedisError as e: c.result = e @@ -2847,10 +2837,10 @@ def _send_cluster_commands( # to the sequence of commands issued in the stack in pipeline.execute() response = [] for c in sorted(stack, key=lambda x: x.position): - if c.args[0] in self._cluster_response_callbacks: + if c.args[0] in self._pipe.cluster_response_callbacks: # Remove keys entry, it needs only for cache. c.options.pop("keys", None) - c.result = self._cluster_response_callbacks[c.args[0]]( + c.result = self._pipe.cluster_response_callbacks[c.args[0]]( c.result, **c.options ) response.append(c.result) @@ -2861,7 +2851,7 @@ def _send_cluster_commands( return response def _is_nodes_flag(self, target_nodes): - return isinstance(target_nodes, str) and target_nodes in self.node_flags + return isinstance(target_nodes, str) and target_nodes in self._pipe.node_flags def _parse_target_nodes(self, target_nodes): if isinstance(target_nodes, list): @@ -2887,7 +2877,7 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: # Determine which nodes should be executed the command on. # Returns a list of target nodes. command = args[0].upper() - if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags: + if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self._pipe.command_flags: command = f"{args[0]} {args[1]}".upper() nodes_flag = kwargs.pop("nodes_flag", None) @@ -2896,31 +2886,31 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: command_flag = nodes_flag else: # get the nodes group for this command if it was predefined - command_flag = self.command_flags.get(command) - if command_flag == self._cluster.RANDOM: + command_flag = self._pipe.command_flags.get(command) + if command_flag == self._pipe.RANDOM: # return a random node - return [self._cluster.get_random_node()] - elif command_flag == self._cluster.PRIMARIES: + return [self._pipe.get_random_node()] + elif command_flag == self._pipe.PRIMARIES: # return all primaries - return self._cluster.get_primaries() - elif command_flag == self._cluster.REPLICAS: + return self._pipe.get_primaries() + elif command_flag == self._pipe.REPLICAS: # return all replicas - return self._cluster.get_replicas() - elif command_flag == self._cluster.ALL_NODES: + return self._pipe.get_replicas() + elif command_flag == self._pipe.ALL_NODES: # return all nodes - return self._cluster.get_nodes() - elif command_flag == self._cluster.DEFAULT_NODE: + return self._pipe.get_nodes() + elif command_flag == self._pipe.DEFAULT_NODE: # return the cluster's default node return [self._nodes_manager.default_node] - elif command in self._cluster.SEARCH_COMMANDS[0]: + elif command in self._pipe.SEARCH_COMMANDS[0]: return [self._nodes_manager.default_node] else: # get the node that holds the key's slot - slot = self._cluster.determine_slot(*args) + slot = self._pipe.determine_slot(*args) node = self._nodes_manager.get_node_from_slot( slot, - self._cluster.read_from_replicas and command in READ_COMMANDS, - self._cluster.load_balancing_strategy if command in READ_COMMANDS else None, + self._pipe.read_from_replicas and command in READ_COMMANDS, + self._pipe.load_balancing_strategy if command in READ_COMMANDS else None, ) return [node] @@ -2936,6 +2926,22 @@ def watch(self, *names): def unwatch(self, *names): raise RedisClusterException("method unwatch() is not implemented") + def delete(self, *names): + if len(names) != 1: + raise RedisClusterException( + "deleting multiple keys is not implemented in pipeline command" + ) + + return self.execute_command("DEL", names[0]) + + def unlink(self, *names): + if len(names) != 1: + raise RedisClusterException( + "unlinking multiple keys is not implemented in pipeline command" + ) + + return self.execute_command("UNLINK", names[0]) + class TransactionStrategy(AbstractStrategy): @@ -2943,17 +2949,8 @@ class TransactionStrategy(AbstractStrategy): IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} - def __init__( - self, - cluster: RedisCluster, - cluster_response_callbacks: Optional[Dict[str, Callable]] = None, - cluster_error_retry_attempts: int = 3, - ): - super().__init__( - cluster, - cluster_response_callbacks, - cluster_error_retry_attempts - ) + def __init__(self, pipe: ClusterPipeline): + super().__init__(pipe) self._explicit_transaction = False self._watching = False self._pipeline_slots: Set[int] = set() @@ -2979,7 +2976,7 @@ def _get_client_and_connection_for_transaction(self) -> Tuple[Redis, Connection] node: ClusterNode = self._nodes_manager.get_node_from_slot( list(self._pipeline_slots)[0], False ) - redis_node: Redis = self._cluster.get_redis_connection(node) + redis_node: Redis = self._pipe.get_redis_connection(node) if self._transaction_connection: if not redis_node.connection_pool.owns_connection( self._transaction_connection @@ -2998,7 +2995,7 @@ def _get_client_and_connection_for_transaction(self) -> Tuple[Redis, Connection] def execute_command(self, *args, **kwargs): slot_number: Optional[int] = None if args[0] not in ClusterPipeline.NO_SLOTS_COMMANDS: - slot_number = self._cluster.determine_slot(*args) + slot_number = self._pipe.determine_slot(*args) if ( self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS @@ -3035,7 +3032,7 @@ def _validate_watch(self): def _immediate_execute_command(self, *args, **options): retry = Retry( default_backoff(), - self._cluster.cluster_error_retry_attempts, + self._pipe.cluster_error_retry_attempts, ) retry.update_supported_errors([AskError, MovedError]) return retry.call_with_retry( @@ -3080,8 +3077,8 @@ def _reinitialize_on_error(self, error): if self._transaction_connection: self._transaction_connection = None - self._cluster.reinitialize_counter += 1 - if self._cluster._should_reinitialized(): + self._pipe.reinitialize_counter += 1 + if self._pipe._should_reinitialized(): self._nodes_manager.initialize() self.reinitialize_counter = 0 else: @@ -3112,7 +3109,7 @@ def _execute_transaction_with_retries( ): retry = Retry( default_backoff(), - self._cluster.cluster_error_retry_attempts, + self._pipe.cluster_error_retry_attempts, ) retry.update_supported_errors([AskError, MovedError]) return retry.call_with_retry( @@ -3214,8 +3211,8 @@ def _execute_transaction( for r, cmd in zip(response, self._command_queue): if not isinstance(r, Exception): command_name = cmd.args[0] - if command_name in self._cluster.cluster_response_callbacks: - r = self._cluster_response_callbacks[command_name](r, **cmd.options) + if command_name in self._pipe.cluster_response_callbacks: + r = self._pipe.cluster_response_callbacks[command_name](r, **cmd.options) data.append(r) return data @@ -3281,4 +3278,10 @@ def discard(self): self.reset() return - raise RedisClusterException("DISCARD triggered without MULTI") \ No newline at end of file + raise RedisClusterException("DISCARD triggered without MULTI") + + def delete(self, *names): + return self.execute_command("DEL", *names) + + def unlink(self, *names): + return self.execute_command("UNLINK", *names) \ No newline at end of file diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 1f79ee9ef3..529b1efcf8 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -3095,7 +3095,7 @@ def test_multi_delete_supported_single_slot(self, r): """ Test that multi delete operation is unsupported """ - with r.pipeline(transaction=False) as pipe: + with r.pipeline(transaction=True) as pipe: r["{key}:a"] = 1 r["{key}:b"] = 2 pipe.delete("{key}:a", "{key}:b") @@ -3342,7 +3342,7 @@ def raise_error(target_node, *args, **kwargs): # 4 = 2 get_connections per execution * 2 executions assert get_connection.call_count == 4 - for cluster_node in r._nodes_manager.nodes_cache.values(): + for cluster_node in r.nodes_manager.nodes_cache.values(): connection_pool = cluster_node.redis_connection.connection_pool num_of_conns = len(connection_pool._available_connections) assert num_of_conns == connection_pool._created_connections From 0087b4f7d43d16cd96c3f9e5b8bb1de23016dd64 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 2 May 2025 15:31:08 +0300 Subject: [PATCH 6/7] Added test cases --- redis/cluster.py | 15 +- tests/test_cluster.py | 36 +- tests/test_cluster_transaction.py | 1010 +++++++++++++++++------------ tests/test_pipeline.py | 19 +- 4 files changed, 636 insertions(+), 444 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index c9b3b0718d..fb1317e7c8 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2116,6 +2116,7 @@ def __init__( ) if not transaction else TransactionStrategy( self ) + self.command_stack = self._execution_strategy.command_queue def __repr__(self): """ """ @@ -2137,7 +2138,7 @@ def __del__(self): def __len__(self): """ """ - return len(self._execution_strategy.command_queue) + return len(self.command_stack) def __bool__(self): "Pipeline instances should always evaluate to True on Python 3+" @@ -2147,7 +2148,7 @@ def execute_command(self, *args, **kwargs): """ Wrapper function for pipeline_execute_command """ - return self.pipeline_execute_command(*args, **kwargs) + return self._execution_strategy.execute_command(*args, **kwargs) def pipeline_execute_command(self, *args, **options): """ @@ -2161,8 +2162,7 @@ def pipeline_execute_command(self, *args, **options): At some other point, you can then run: pipe.execute(), which will execute all commands queued in the pipe. """ - self._execution_strategy.pipeline_execute_command(*args, **options) - return self + return self._execution_strategy.execute_command(*args, **options) def annotate_exception(self, exception, number, command): """ @@ -2557,6 +2557,7 @@ def pipeline_execute_command(self, *args, **options): self._command_queue.append( PipelineCommand(args, options, len(self._command_queue)) ) + return self._pipe @abstractmethod def execute(self, raise_on_error: bool = True) -> List[Any]: @@ -2605,7 +2606,7 @@ def __init__(self, pipe: ClusterPipeline): self.command_flags = pipe.command_flags def execute_command(self, *args, **kwargs): - self.pipeline_execute_command(*args, **kwargs) + return self.pipeline_execute_command(*args, **kwargs) def _raise_first_error(self, stack): """ @@ -3150,6 +3151,10 @@ def _execute_transaction( except ResponseError as e: self.annotate_exception(e, 0, "MULTI") errors.append(e) + except (ClusterDownError, ConnectionError) as cluster_error: + self._cluster_error = True + self.annotate_exception(cluster_error, 0, "MULTI") + raise # and all the other commands for i, command in enumerate(self._command_queue): diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 529b1efcf8..5788900c14 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -265,12 +265,12 @@ def moved_redirection_helper(request, failover=False): slot = 12182 redirect_node = None # Get the current primary that holds this slot - prev_primary = rc._nodes_manager.get_node_from_slot(slot) + prev_primary = rc.nodes_manager.get_node_from_slot(slot) if failover: - if len(rc._nodes_manager.slots_cache[slot]) < 2: + if len(rc.nodes_manager.slots_cache[slot]) < 2: warnings.warn("Skipping this test since it requires to have a replica") return - redirect_node = rc._nodes_manager.slots_cache[slot][1] + redirect_node = rc.nodes_manager.slots_cache[slot][1] else: # Use one of the primaries to be the redirected node redirect_node = rc.get_primaries()[0] @@ -290,7 +290,7 @@ def ok_response(connection, *args, **options): parse_response.side_effect = moved_redirect_effect assert rc.execute_command("SET", "foo", "bar") == "MOCK_OK" - slot_primary = rc._nodes_manager.slots_cache[slot][0] + slot_primary = rc.nodes_manager.slots_cache[slot][0] assert slot_primary == redirect_node if failover: assert rc.get_node(host=r_host, port=r_port).server_type == PRIMARY @@ -722,7 +722,7 @@ def test_all_nodes(self, r): """ Set a list of nodes and it should be possible to iterate over all """ - nodes = [node for node in r._nodes_manager.nodes_cache.values()] + nodes = [node for node in r.nodes_manager.nodes_cache.values()] for i, node in enumerate(r.get_nodes()): assert node in nodes @@ -734,7 +734,7 @@ def test_all_nodes_masters(self, r): """ nodes = [ node - for node in r._nodes_manager.nodes_cache.values() + for node in r.nodes_manager.nodes_cache.values() if node.server_type == PRIMARY ] @@ -805,7 +805,7 @@ def test_get_node_from_key(self, r): """ key = "bar" slot = r.keyslot(key) - slot_nodes = r._nodes_manager.slots_cache.get(slot) + slot_nodes = r.nodes_manager.slots_cache.get(slot) primary = slot_nodes[0] assert r.get_node_from_key(key, replica=False) == primary replica = r.get_node_from_key(key, replica=True) @@ -1003,8 +1003,8 @@ class TestClusterRedisCommands: def test_case_insensitive_command_names(self, r): assert ( - r._cluster_response_callbacks["cluster slots"] - == r._cluster_response_callbacks["CLUSTER SLOTS"] + r.cluster_response_callbacks["cluster slots"] + == r.cluster_response_callbacks["CLUSTER SLOTS"] ) def test_get_and_set(self, r): @@ -1275,7 +1275,7 @@ def test_cluster_addslotsrange(self, r): @skip_if_redis_enterprise() def test_cluster_countkeysinslot(self, r): - node = r._nodes_manager.get_node_from_slot(1) + node = r.nodes_manager.get_node_from_slot(1) mock_node_resp(node, 2) assert r.cluster_countkeysinslot(1) == 2 @@ -1450,7 +1450,7 @@ def test_cluster_save_config(self, r): @skip_if_redis_enterprise() def test_cluster_get_keys_in_slot(self, r): response = ["{foo}1", "{foo}2"] - node = r._nodes_manager.get_node_from_slot(12182) + node = r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, response) keys = r.cluster_get_keys_in_slot(12182, 4) assert keys == response @@ -1476,7 +1476,7 @@ def test_cluster_setslot(self, r): r.cluster_failover(node, "STATE") def test_cluster_setslot_stable(self, r): - node = r._nodes_manager.get_node_from_slot(12182) + node = r.nodes_manager.get_node_from_slot(12182) mock_node_resp(node, "OK") assert r.cluster_setslot_stable(12182) is True assert node.redis_connection.connection.read_response.called @@ -1562,7 +1562,7 @@ def test_info(self, r): r.set("z{1}", 3) # Get node that handles the slot slot = r.keyslot("x{1}") - node = r._nodes_manager.get_node_from_slot(slot) + node = r.nodes_manager.get_node_from_slot(slot) # Run info on that node info = r.info(target_nodes=node) assert isinstance(info, dict) @@ -1618,7 +1618,7 @@ def test_slowlog_get_limit(self, r, slowlog): def test_slowlog_length(self, r, slowlog): r.get("foo") - node = r._nodes_manager.get_node_from_slot(key_slot(b"foo")) + node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) slowlog_len = r.slowlog_len(target_nodes=node) assert isinstance(slowlog_len, int) @@ -1644,7 +1644,7 @@ def test_memory_stats(self, r): # put a key into the current db to make sure that "db." # has data r.set("foo", "bar") - node = r._nodes_manager.get_node_from_slot(key_slot(b"foo")) + node = r.nodes_manager.get_node_from_slot(key_slot(b"foo")) stats = r.memory_stats(target_nodes=node) assert isinstance(stats, dict) for key, value in stats.items(): @@ -2530,7 +2530,7 @@ class TestNodesManager: """ def test_load_balancer(self, r): - n_manager = r._nodes_manager + n_manager = r.nodes_manager lb = n_manager.read_load_balancer slot_1 = 1257 slot_2 = 8975 @@ -3428,7 +3428,7 @@ def test_readonly_pipeline_from_readonly_client(self, request): mock_all_nodes_resp(ro, "MOCK_OK") assert readonly_pipe.read_from_replicas is True assert readonly_pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] - slot_nodes = ro._nodes_manager.slots_cache[ro.keyslot(key)] + slot_nodes = ro.nodes_manager.slots_cache[ro.keyslot(key)] if len(slot_nodes) > 1: executed_on_replica = False for node in slot_nodes: @@ -3468,7 +3468,7 @@ def test_readonly_pipeline_with_reading_from_replicas_strategies( mock_all_nodes_resp(ro, "MOCK_OK") assert readonly_pipe.load_balancing_strategy == load_balancing_strategy assert readonly_pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] - slot_nodes = ro._nodes_manager.slots_cache[ro.keyslot(key)] + slot_nodes = ro.nodes_manager.slots_cache[ro.keyslot(key)] executed_on_replicas_only = True for node in slot_nodes: if node.server_type == PRIMARY: diff --git a/tests/test_cluster_transaction.py b/tests/test_cluster_transaction.py index 61134cf9ec..c416dbd046 100644 --- a/tests/test_cluster_transaction.py +++ b/tests/test_cluster_transaction.py @@ -1,11 +1,15 @@ +import threading from typing import Tuple -from unittest.mock import patch +from unittest.mock import patch, Mock import pytest import redis +from redis import CrossSlotTransactionError, ConnectionPool +from redis.backoff import NoBackoff from redis.client import Redis from redis.cluster import PRIMARY, ClusterNode, NodesManager, RedisCluster +from redis.retry import Retry from .conftest import skip_if_server_version_lt, wait_for_command @@ -32,382 +36,51 @@ def _find_source_and_target_node_for_slot( class TestClusterTransaction: + @pytest.mark.onlycluster - def test_pipeline_is_true(self, r): - "Ensure pipeline instances are not false-y" - with r.pipeline(transaction=True) as pipe: - assert pipe + def test_executes_transaction_against_cluster(self, r): + with r.pipeline(transaction=True) as tx: + tx.set("{foo}bar", "value1") + tx.set("{foo}baz", "value2") + tx.set("{foo}bad", "value3") + tx.get("{foo}bar") + tx.get("{foo}baz") + tx.get("{foo}bad") + assert tx.execute() == [b"OK", b"OK", b"OK", b"value1", b"value2", b"value3"] + + r.flushall() + + tx = r.pipeline(transaction=True) + tx.set("{foo}bar", "value1") + tx.set("{foo}baz", "value2") + tx.set("{foo}bad", "value3") + tx.get("{foo}bar") + tx.get("{foo}baz") + tx.get("{foo}bad") + assert tx.execute() == [b"OK", b"OK", b"OK", b"value1", b"value2", b"value3"] @pytest.mark.onlycluster - def test_pipeline_no_transaction_watch(self, r): - r["a"] = 0 + def test_throws_exception_on_different_hash_slots(self, r): + with r.pipeline(transaction=True) as tx: + tx.set("{foo}bar", "value1") + tx.set("{foobar}baz", "value2") - with r.pipeline(transaction=False) as pipe: - pipe.watch("a") - a = pipe.get("a") - pipe.multi() - pipe.set("a", int(a) + 1) - assert pipe.execute() == [b"OK"] + with pytest.raises( + CrossSlotTransactionError, + match="All keys involved in a cluster transaction must map to the same slot" + ): + tx.execute() @pytest.mark.onlycluster - def test_pipeline_no_transaction_watch_failure(self, r): + def test_transaction_with_watched_keys(self, r): r["a"] = 0 - with r.pipeline(transaction=False) as pipe: + with r.pipeline(transaction=True) as pipe: pipe.watch("a") a = pipe.get("a") - - r["a"] = "bad" - pipe.multi() pipe.set("a", int(a) + 1) - - with pytest.raises(redis.WatchError): - pipe.execute() - - assert r["a"] == b"bad" - - @pytest.mark.onlycluster - def test_pipeline_empty_transaction(self, r): - r["a"] = 0 - - with r.pipeline(transaction=True) as pipe: - assert pipe.execute() == [] - - @pytest.mark.onlycluster - def test_exec_error_in_response(self, r): - """ - an invalid pipeline command at exec time adds the exception instance - to the list of returned values - """ - hashkey = "{key}" - r[f"{hashkey}:c"] = "a" - with r.pipeline() as pipe: - pipe.set(f"{hashkey}:a", 1).set(f"{hashkey}:b", 2) - pipe.lpush(f"{hashkey}:c", 3).set(f"{hashkey}:d", 4) - result = pipe.execute(raise_on_error=False) - - assert result[0] - assert r[f"{hashkey}:a"] == b"1" - assert result[1] - assert r[f"{hashkey}:b"] == b"2" - - # we can't lpush to a key that's a string value, so this should - # be a ResponseError exception - assert isinstance(result[2], redis.ResponseError) - assert r[f"{hashkey}:c"] == b"a" - - # since this isn't a transaction, the other commands after the - # error are still executed - assert result[3] - assert r[f"{hashkey}:d"] == b"4" - - # make sure the pipe was restored to a working state - assert pipe.set(f"{hashkey}:z", "zzz").execute() == [True] - assert r[f"{hashkey}:z"] == b"zzz" - - @pytest.mark.onlycluster - def test_exec_error_raised(self, r): - hashkey = "{key}" - r[f"{hashkey}:c"] = "a" - with r.pipeline(transaction=True) as pipe: - pipe.set(f"{hashkey}:a", 1).set(f"{hashkey}:b", 2) - pipe.lpush(f"{hashkey}:c", 3).set(f"{hashkey}:d", 4) - with pytest.raises(redis.ResponseError) as ex: - pipe.execute() - assert str(ex.value).startswith( - "Command # 3 (LPUSH {key}:c 3) of pipeline caused error: " - ) - - # make sure the pipe was restored to a working state - assert pipe.set(f"{hashkey}:z", "zzz").execute() == [b"OK"] - assert r[f"{hashkey}:z"] == b"zzz" - - @pytest.mark.onlycluster - def test_parse_error_raised(self, r): - hashkey = "{key}" - with r.pipeline(transaction=True) as pipe: - # the zrem is invalid because we don't pass any keys to it - pipe.set(f"{hashkey}:a", 1).zrem(f"{hashkey}:b").set(f"{hashkey}:b", 2) - with pytest.raises(redis.ResponseError) as ex: - pipe.execute() - - assert str(ex.value).startswith( - "Command # 2 (ZREM {key}:b) of pipeline caused error: wrong number" - ) - - # make sure the pipe was restored to a working state - assert pipe.set(f"{hashkey}:z", "zzz").execute() == [b"OK"] - assert r[f"{hashkey}:z"] == b"zzz" - - @pytest.mark.onlycluster - def test_parse_error_raised_transaction(self, r): - hashkey = "{key}" - with r.pipeline() as pipe: - pipe.multi() - # the zrem is invalid because we don't pass any keys to it - pipe.set(f"{hashkey}:a", 1).zrem(f"{hashkey}:b").set(f"{hashkey}:b", 2) - with pytest.raises(redis.ResponseError) as ex: - pipe.execute() - - assert str(ex.value).startswith( - "Command # 2 (ZREM {key}:b) of pipeline caused error: " - ) - - # make sure the pipe was restored to a working state - assert pipe.set(f"{hashkey}:z", "zzz").execute() == [True] - assert r[f"{hashkey}:z"] == b"zzz" - - @pytest.mark.onlycluster - def test_parse_error_raised_invalid_response_length_transaction(self, r): - hashkey = "{key}" - with r.pipeline() as pipe: - pipe.multi() - pipe.set(f"{hashkey}:a", 1).set(f"{hashkey}:b", 1) - with patch("redis.client.Redis.parse_response") as parse_response_mock: - parse_response_mock.return_value = ["OK"] - with pytest.raises(redis.InvalidPipelineStack) as ex: - pipe.execute() - - assert str(ex.value).startswith( - "Unexpected response length for cluster pipeline EXEC" - ) - - # make sure the pipe was restored to a working state - assert pipe.set(f"{hashkey}:z", "zzz").execute() == [True] - assert r[f"{hashkey}:z"] == b"zzz" - - @pytest.mark.onlycluster - def test_watch_succeed(self, r): - hashkey = "{key}" - r[f"{hashkey}:a"] = 1 - r[f"{hashkey}:b"] = 2 - - with r.pipeline() as pipe: - pipe.watch(f"{hashkey}:a", f"{hashkey}:b") - assert pipe._watching - a_value = pipe.get(f"{hashkey}:a") - b_value = pipe.get(f"{hashkey}:b") - assert a_value == b"1" - assert b_value == b"2" - pipe.multi() - - pipe.set(f"{hashkey}:c", 3) assert pipe.execute() == [b"OK"] - assert not pipe._watching - - @pytest.mark.onlycluster - def test_watch_failure(self, r): - hashkey = "{key}" - r[f"{hashkey}:a"] = 1 - r[f"{hashkey}:b"] = 2 - - with r.pipeline() as pipe: - pipe.watch(f"{hashkey}:a", f"{hashkey}:b") - r[f"{hashkey}:b"] = 3 - pipe.multi() - pipe.get(f"{hashkey}:a") - with pytest.raises(redis.WatchError): - pipe.execute() - - assert not pipe._watching - - @pytest.mark.onlycluster - def test_cross_slot_watch_single_call_failure(self, r): - with r.pipeline() as pipe: - with pytest.raises(redis.RedisClusterException) as ex: - pipe.watch("a", "b") - - assert str(ex.value).startswith( - "WATCH - all keys must map to the same key slot" - ) - - assert not pipe._watching - - @pytest.mark.onlycluster - def test_cross_slot_watch_multiple_calls_failure(self, r): - with r.pipeline() as pipe: - with pytest.raises(redis.CrossSlotTransactionError) as ex: - pipe.watch("a") - pipe.watch("b") - - assert str(ex.value).startswith( - "Cannot watch or send commands on different slots" - ) - - assert pipe._watching - - @pytest.mark.onlycluster - def test_watch_failure_in_empty_transaction(self, r): - hashkey = "{key}" - r[f"{hashkey}:a"] = 1 - r[f"{hashkey}:b"] = 2 - - with r.pipeline() as pipe: - pipe.watch(f"{hashkey}:a", f"{hashkey}:b") - r[f"{hashkey}:b"] = 3 - pipe.multi() - with pytest.raises(redis.WatchError): - pipe.execute() - - assert not pipe._watching - - @pytest.mark.onlycluster - def test_unwatch(self, r): - hashkey = "{key}" - r[f"{hashkey}:a"] = 1 - r[f"{hashkey}:b"] = 2 - - with r.pipeline() as pipe: - pipe.watch(f"{hashkey}:a", f"{hashkey}:b") - r[f"{hashkey}:b"] = 3 - pipe.unwatch() - assert not pipe._watching - pipe.get(f"{hashkey}:a") - assert pipe.execute() == [b"1"] - - @pytest.mark.onlycluster - def test_watch_exec_auto_unwatch(self, r): - hashkey = "{key}" - r[f"{hashkey}:a"] = 1 - r[f"{hashkey}:b"] = 2 - - target_slot = r.determine_slot("GET", f"{hashkey}:a") - target_node = r._nodes_manager.get_node_from_slot(target_slot) - with r.monitor(target_node=target_node) as m: - with r.pipeline() as pipe: - pipe.watch(f"{hashkey}:a", f"{hashkey}:b") - assert pipe._watching - a_value = pipe.get(f"{hashkey}:a") - b_value = pipe.get(f"{hashkey}:b") - assert a_value == b"1" - assert b_value == b"2" - pipe.multi() - pipe.set(f"{hashkey}:c", 3) - assert pipe.execute() == [b"OK"] - assert not pipe._watching - - unwatch_command = wait_for_command( - r, m, "UNWATCH", key=f"{hashkey}:test_watch_exec_auto_unwatch" - ) - assert unwatch_command is not None, ( - "execute should reset and send UNWATCH automatically" - ) - - @pytest.mark.onlycluster - def test_watch_reset_unwatch(self, r): - hashkey = "{key}" - r[f"{hashkey}:a"] = 1 - - target_slot = r.determine_slot("GET", f"{hashkey}:a") - target_node = r._nodes_manager.get_node_from_slot(target_slot) - with r.monitor(target_node=target_node) as m: - with r.pipeline() as pipe: - pipe.watch(f"{hashkey}:a") - assert pipe._watching - pipe.reset() - assert not pipe._watching - - unwatch_command = wait_for_command( - r, m, "UNWATCH", key=f"{hashkey}:test_watch_reset_unwatch" - ) - assert unwatch_command is not None - assert unwatch_command["command"] == "UNWATCH" - - @pytest.mark.onlycluster - def test_transaction_callable(self, r): - hashkey = "{key}" - r[f"{hashkey}:a"] = 1 - r[f"{hashkey}:b"] = 2 - has_run = [] - - def my_transaction(pipe): - a_value = pipe.get(f"{hashkey}:a") - assert a_value in (b"1", b"2") - b_value = pipe.get(f"{hashkey}:b") - assert b_value == b"2" - - # silly run-once code... incr's "a" so WatchError should be raised - # forcing this all to run again. this should incr "a" once to "2" - if not has_run: - r.incr(f"{hashkey}:a") - has_run.append("it has") - - pipe.multi() - pipe.set(f"{hashkey}:c", int(a_value) + int(b_value)) - - result = r.transaction(my_transaction, f"{hashkey}:a", f"{hashkey}:b") - assert result == [b"OK"] - assert r[f"{hashkey}:c"] == b"4" - - def test_exec_error_in_no_transaction_pipeline(self, r): - r["a"] = 1 - with r.pipeline(transaction=False) as pipe: - pipe.llen("a") - pipe.expire("a", 100) - - with pytest.raises(redis.ResponseError) as ex: - pipe.execute() - - assert str(ex.value).startswith( - "Command # 1 (LLEN a) of pipeline caused error: " - ) - - assert r["a"] == b"1" - - @pytest.mark.onlycluster - @skip_if_server_version_lt("2.0.0") - def test_pipeline_discard(self, r): - hashkey = "{key}" - - # empty pipeline should raise an error - with r.pipeline() as pipe: - pipe.set(f"{hashkey}:key", "someval") - with pytest.raises(redis.exceptions.RedisClusterException) as ex: - pipe.discard() - - assert str(ex.value).startswith("DISCARD triggered without MULTI") - - # setting a pipeline and discarding should do the same - with r.pipeline() as pipe: - pipe.set(f"{hashkey}:key", "someval") - pipe.set(f"{hashkey}:someotherkey", "val") - response = pipe.execute() - pipe.set(f"{hashkey}:key", "another value!") - with pytest.raises(redis.exceptions.RedisClusterException) as ex: - pipe.discard() - - assert str(ex.value).startswith("DISCARD triggered without MULTI") - - pipe.set(f"{hashkey}:foo", "bar") - response = pipe.execute() - - assert response[0] - assert r.get(f"{hashkey}:foo") == b"bar" - - @pytest.mark.onlycluster - @skip_if_server_version_lt("2.0.0") - def test_transaction_discard(self, r): - hashkey = "{key}" - - # pipelines enabled as transactions can be discarded at any point - with r.pipeline(transaction=True) as pipe: - pipe.watch(f"{hashkey}:key") - pipe.set(f"{hashkey}:key", "someval") - pipe.discard() - - assert not pipe._watching - assert not pipe.command_stack - - # pipelines with multi can be discarded - with r.pipeline() as pipe: - pipe.watch(f"{hashkey}:key") - pipe.multi() - pipe.set(f"{hashkey}:key", "someval") - pipe.discard() - - assert not pipe._watching - assert not pipe.command_stack @pytest.mark.onlycluster def test_retry_transaction_during_unfinished_slot_migration(self, r): @@ -438,7 +111,6 @@ def ask_redirect_effect(connection, *args, **options): parse_response.side_effect = ask_redirect_effect with r.pipeline(transaction=True) as pipe: - pipe.multi() pipe.set(key, "val") with pytest.raises(redis.exceptions.AskError) as ex: pipe.execute() @@ -485,11 +157,11 @@ def ask_redirect_effect(conn, *args, **options): assert False, f"unexpected node {conn.host}:{conn.port} was called" def update_moved_slot(): # simulate slot table update - ask_error = r._nodes_manager._moved_exception + ask_error = r.nodes_manager._moved_exception assert ask_error is not None, "No AskError was previously triggered" assert f"{ask_error.host}:{ask_error.port}" == node_importing.name - r._nodes_manager._moved_exception = None - r._nodes_manager.slots_cache[slot] = [node_importing] + r.nodes_manager._moved_exception = None + r.nodes_manager.slots_cache[slot] = [node_importing] parse_response.side_effect = ask_redirect_effect manager_update_moved_slots.side_effect = update_moved_slot @@ -502,49 +174,6 @@ def update_moved_slot(): # simulate slot table update assert result and "MOCK_OK" in result, "Target node was not called" - @pytest.mark.onlycluster - def test_retry_transaction_with_watch_during_slot_migration(self, r): - """ - If a MovedError or AskError appears when calling EXEC and keys were - being watched before the migration started, a WatchError should appear. - These errors imply resetting the connection and connecting to a new node, - so watches are lost anyway and the client code must be notified. - """ - key = "book" - slot = r.keyslot(key) - node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) - - with patch.object(Redis, "parse_response") as parse_response: - - def ask_redirect_effect(conn, *args, **options): - if f"{conn.host}:{conn.port}" == node_migrating.name: - # we simulate the watch was sent before the migration started - if "WATCH" in args: - return b"OK" - # but the pipeline was triggered after the migration started - elif "MULTI" in args: - return - elif "EXEC" in args: - raise redis.exceptions.ExecAbortError() - - raise redis.exceptions.AskError(f"{slot} {node_importing.name}") - # we should not try to connect to any other node - else: - assert False, f"unexpected node {conn.host}:{conn.port} was called" - - parse_response.side_effect = ask_redirect_effect - - with r.pipeline(transaction=True) as pipe: - pipe.watch(key) - pipe.multi() - pipe.set(key, "val") - with pytest.raises(redis.exceptions.WatchError) as ex: - pipe.execute() - - assert str(ex.value).startswith( - "Slot rebalancing ocurred while watching keys" - ) - @pytest.mark.onlycluster def test_retry_transaction_with_watch_after_slot_migration(self, r): """ @@ -559,10 +188,567 @@ def test_retry_transaction_with_watch_after_slot_migration(self, r): # force a MovedError on the first call to pipe.watch() # by switching the node that owns the slot to another one _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) - r._nodes_manager.slots_cache[slot] = [node_importing] + r.nodes_manager.slots_cache[slot] = [node_importing] with r.pipeline(transaction=True) as pipe: pipe.watch(key) pipe.multi() pipe.set(key, "val") pipe.execute() + + @pytest.mark.onlycluster + def test_retry_transaction_on_connection_error(self, r, mock_connection): + key = "book" + slot = r.keyslot(key) + + mock_connection.read_response.side_effect = redis.exceptions.ConnectionError("Conn error") + mock_connection.retry = Retry(NoBackoff(), 0) + mock_pool = Mock(spec=ConnectionPool) + mock_pool.get_connection.return_value = mock_connection + mock_pool._available_connections = [mock_connection] + mock_pool._lock = threading.Lock() + + _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + node_importing.redis_connection.connection_pool = mock_pool + r.nodes_manager.slots_cache[slot] = [node_importing] + r.reinitialize_steps = 1 + + with r.pipeline(transaction=True) as pipe: + pipe.set(key, "val") + pipe.execute() + + + # @pytest.mark.onlycluster + # def test_pipeline_is_true(self, r): + # "Ensure pipeline instances are not false-y" + # with r.pipeline(transaction=True) as pipe: + # assert pipe + # + # @pytest.mark.onlycluster + # def test_pipeline_no_transaction_watch(self, r): + # r["a"] = 0 + # + # with r.pipeline(transaction=False) as pipe: + # pipe.watch("a") + # a = pipe.get("a") + # pipe.multi() + # pipe.set("a", int(a) + 1) + # assert pipe.execute() == [b"OK"] + # + # @pytest.mark.onlycluster + # def test_pipeline_no_transaction_watch_failure(self, r): + # r["a"] = 0 + # + # with r.pipeline(transaction=False) as pipe: + # pipe.watch("a") + # a = pipe.get("a") + # + # r["a"] = "bad" + # + # pipe.multi() + # pipe.set("a", int(a) + 1) + # + # with pytest.raises(redis.WatchError): + # pipe.execute() + # + # assert r["a"] == b"bad" + # + # @pytest.mark.onlycluster + # def test_pipeline_empty_transaction(self, r): + # r["a"] = 0 + # + # with r.pipeline(transaction=True) as pipe: + # assert pipe.execute() == [] + # + # @pytest.mark.onlycluster + # def test_exec_error_in_response(self, r): + # """ + # an invalid pipeline command at exec time adds the exception instance + # to the list of returned values + # """ + # hashkey = "{key}" + # r[f"{hashkey}:c"] = "a" + # with r.pipeline() as pipe: + # pipe.set(f"{hashkey}:a", 1).set(f"{hashkey}:b", 2) + # pipe.lpush(f"{hashkey}:c", 3).set(f"{hashkey}:d", 4) + # result = pipe.execute(raise_on_error=False) + # + # assert result[0] + # assert r[f"{hashkey}:a"] == b"1" + # assert result[1] + # assert r[f"{hashkey}:b"] == b"2" + # + # # we can't lpush to a key that's a string value, so this should + # # be a ResponseError exception + # assert isinstance(result[2], redis.ResponseError) + # assert r[f"{hashkey}:c"] == b"a" + # + # # since this isn't a transaction, the other commands after the + # # error are still executed + # assert result[3] + # assert r[f"{hashkey}:d"] == b"4" + # + # # make sure the pipe was restored to a working state + # assert pipe.set(f"{hashkey}:z", "zzz").execute() == [True] + # assert r[f"{hashkey}:z"] == b"zzz" + # + # @pytest.mark.onlycluster + # def test_exec_error_raised(self, r): + # hashkey = "{key}" + # r[f"{hashkey}:c"] = "a" + # with r.pipeline(transaction=True) as pipe: + # pipe.set(f"{hashkey}:a", 1).set(f"{hashkey}:b", 2) + # pipe.lpush(f"{hashkey}:c", 3).set(f"{hashkey}:d", 4) + # with pytest.raises(redis.ResponseError) as ex: + # pipe.execute() + # assert str(ex.value).startswith( + # "Command # 3 (LPUSH {key}:c 3) of pipeline caused error: " + # ) + # + # # make sure the pipe was restored to a working state + # assert pipe.set(f"{hashkey}:z", "zzz").execute() == [b"OK"] + # assert r[f"{hashkey}:z"] == b"zzz" + # + # @pytest.mark.onlycluster + # def test_parse_error_raised(self, r): + # hashkey = "{key}" + # with r.pipeline(transaction=True) as pipe: + # # the zrem is invalid because we don't pass any keys to it + # pipe.set(f"{hashkey}:a", 1).zrem(f"{hashkey}:b").set(f"{hashkey}:b", 2) + # with pytest.raises(redis.ResponseError) as ex: + # pipe.execute() + # + # assert str(ex.value).startswith( + # "Command # 2 (ZREM {key}:b) of pipeline caused error: wrong number" + # ) + # + # # make sure the pipe was restored to a working state + # assert pipe.set(f"{hashkey}:z", "zzz").execute() == [b"OK"] + # assert r[f"{hashkey}:z"] == b"zzz" + # + # @pytest.mark.onlycluster + # def test_parse_error_raised_transaction(self, r): + # hashkey = "{key}" + # with r.pipeline() as pipe: + # pipe.multi() + # # the zrem is invalid because we don't pass any keys to it + # pipe.set(f"{hashkey}:a", 1).zrem(f"{hashkey}:b").set(f"{hashkey}:b", 2) + # with pytest.raises(redis.ResponseError) as ex: + # pipe.execute() + # + # assert str(ex.value).startswith( + # "Command # 2 (ZREM {key}:b) of pipeline caused error: " + # ) + # + # # make sure the pipe was restored to a working state + # assert pipe.set(f"{hashkey}:z", "zzz").execute() == [True] + # assert r[f"{hashkey}:z"] == b"zzz" + # + # @pytest.mark.onlycluster + # def test_parse_error_raised_invalid_response_length_transaction(self, r): + # hashkey = "{key}" + # with r.pipeline() as pipe: + # pipe.multi() + # pipe.set(f"{hashkey}:a", 1).set(f"{hashkey}:b", 1) + # with patch("redis.client.Redis.parse_response") as parse_response_mock: + # parse_response_mock.return_value = ["OK"] + # with pytest.raises(redis.InvalidPipelineStack) as ex: + # pipe.execute() + # + # assert str(ex.value).startswith( + # "Unexpected response length for cluster pipeline EXEC" + # ) + # + # # make sure the pipe was restored to a working state + # assert pipe.set(f"{hashkey}:z", "zzz").execute() == [True] + # assert r[f"{hashkey}:z"] == b"zzz" + # + # @pytest.mark.onlycluster + # def test_watch_succeed(self, r): + # hashkey = "{key}" + # r[f"{hashkey}:a"] = 1 + # r[f"{hashkey}:b"] = 2 + # + # with r.pipeline() as pipe: + # pipe.watch(f"{hashkey}:a", f"{hashkey}:b") + # assert pipe._watching + # a_value = pipe.get(f"{hashkey}:a") + # b_value = pipe.get(f"{hashkey}:b") + # assert a_value == b"1" + # assert b_value == b"2" + # pipe.multi() + # + # pipe.set(f"{hashkey}:c", 3) + # assert pipe.execute() == [b"OK"] + # assert not pipe._watching + # + # @pytest.mark.onlycluster + # def test_watch_failure(self, r): + # hashkey = "{key}" + # r[f"{hashkey}:a"] = 1 + # r[f"{hashkey}:b"] = 2 + # + # with r.pipeline() as pipe: + # pipe.watch(f"{hashkey}:a", f"{hashkey}:b") + # r[f"{hashkey}:b"] = 3 + # pipe.multi() + # pipe.get(f"{hashkey}:a") + # with pytest.raises(redis.WatchError): + # pipe.execute() + # + # assert not pipe._watching + # + # @pytest.mark.onlycluster + # def test_cross_slot_watch_single_call_failure(self, r): + # with r.pipeline() as pipe: + # with pytest.raises(redis.RedisClusterException) as ex: + # pipe.watch("a", "b") + # + # assert str(ex.value).startswith( + # "WATCH - all keys must map to the same key slot" + # ) + # + # assert not pipe._watching + # + # @pytest.mark.onlycluster + # def test_cross_slot_watch_multiple_calls_failure(self, r): + # with r.pipeline() as pipe: + # with pytest.raises(redis.CrossSlotTransactionError) as ex: + # pipe.watch("a") + # pipe.watch("b") + # + # assert str(ex.value).startswith( + # "Cannot watch or send commands on different slots" + # ) + # + # assert pipe._watching + # + # @pytest.mark.onlycluster + # def test_watch_failure_in_empty_transaction(self, r): + # hashkey = "{key}" + # r[f"{hashkey}:a"] = 1 + # r[f"{hashkey}:b"] = 2 + # + # with r.pipeline() as pipe: + # pipe.watch(f"{hashkey}:a", f"{hashkey}:b") + # r[f"{hashkey}:b"] = 3 + # pipe.multi() + # with pytest.raises(redis.WatchError): + # pipe.execute() + # + # assert not pipe._watching + # + # @pytest.mark.onlycluster + # def test_unwatch(self, r): + # hashkey = "{key}" + # r[f"{hashkey}:a"] = 1 + # r[f"{hashkey}:b"] = 2 + # + # with r.pipeline() as pipe: + # pipe.watch(f"{hashkey}:a", f"{hashkey}:b") + # r[f"{hashkey}:b"] = 3 + # pipe.unwatch() + # assert not pipe._watching + # pipe.get(f"{hashkey}:a") + # assert pipe.execute() == [b"1"] + # + # @pytest.mark.onlycluster + # def test_watch_exec_auto_unwatch(self, r): + # hashkey = "{key}" + # r[f"{hashkey}:a"] = 1 + # r[f"{hashkey}:b"] = 2 + # + # target_slot = r.determine_slot("GET", f"{hashkey}:a") + # target_node = r._nodes_manager.get_node_from_slot(target_slot) + # with r.monitor(target_node=target_node) as m: + # with r.pipeline() as pipe: + # pipe.watch(f"{hashkey}:a", f"{hashkey}:b") + # assert pipe._watching + # a_value = pipe.get(f"{hashkey}:a") + # b_value = pipe.get(f"{hashkey}:b") + # assert a_value == b"1" + # assert b_value == b"2" + # pipe.multi() + # pipe.set(f"{hashkey}:c", 3) + # assert pipe.execute() == [b"OK"] + # assert not pipe._watching + # + # unwatch_command = wait_for_command( + # r, m, "UNWATCH", key=f"{hashkey}:test_watch_exec_auto_unwatch" + # ) + # assert unwatch_command is not None, ( + # "execute should reset and send UNWATCH automatically" + # ) + # + # @pytest.mark.onlycluster + # def test_watch_reset_unwatch(self, r): + # hashkey = "{key}" + # r[f"{hashkey}:a"] = 1 + # + # target_slot = r.determine_slot("GET", f"{hashkey}:a") + # target_node = r._nodes_manager.get_node_from_slot(target_slot) + # with r.monitor(target_node=target_node) as m: + # with r.pipeline() as pipe: + # pipe.watch(f"{hashkey}:a") + # assert pipe._watching + # pipe.reset() + # assert not pipe._watching + # + # unwatch_command = wait_for_command( + # r, m, "UNWATCH", key=f"{hashkey}:test_watch_reset_unwatch" + # ) + # assert unwatch_command is not None + # assert unwatch_command["command"] == "UNWATCH" + # + # @pytest.mark.onlycluster + # def test_transaction_callable(self, r): + # hashkey = "{key}" + # r[f"{hashkey}:a"] = 1 + # r[f"{hashkey}:b"] = 2 + # has_run = [] + # + # def my_transaction(pipe): + # a_value = pipe.get(f"{hashkey}:a") + # assert a_value in (b"1", b"2") + # b_value = pipe.get(f"{hashkey}:b") + # assert b_value == b"2" + # + # # silly run-once code... incr's "a" so WatchError should be raised + # # forcing this all to run again. this should incr "a" once to "2" + # if not has_run: + # r.incr(f"{hashkey}:a") + # has_run.append("it has") + # + # pipe.multi() + # pipe.set(f"{hashkey}:c", int(a_value) + int(b_value)) + # + # result = r.transaction(my_transaction, f"{hashkey}:a", f"{hashkey}:b") + # assert result == [b"OK"] + # assert r[f"{hashkey}:c"] == b"4" + # + # def test_exec_error_in_no_transaction_pipeline(self, r): + # r["a"] = 1 + # with r.pipeline(transaction=False) as pipe: + # pipe.llen("a") + # pipe.expire("a", 100) + # + # with pytest.raises(redis.ResponseError) as ex: + # pipe.execute() + # + # assert str(ex.value).startswith( + # "Command # 1 (LLEN a) of pipeline caused error: " + # ) + # + # assert r["a"] == b"1" + # + # @pytest.mark.onlycluster + # @skip_if_server_version_lt("2.0.0") + # def test_pipeline_discard(self, r): + # hashkey = "{key}" + # + # # empty pipeline should raise an error + # with r.pipeline() as pipe: + # pipe.set(f"{hashkey}:key", "someval") + # with pytest.raises(redis.exceptions.RedisClusterException) as ex: + # pipe.discard() + # + # assert str(ex.value).startswith("DISCARD triggered without MULTI") + # + # # setting a pipeline and discarding should do the same + # with r.pipeline() as pipe: + # pipe.set(f"{hashkey}:key", "someval") + # pipe.set(f"{hashkey}:someotherkey", "val") + # response = pipe.execute() + # pipe.set(f"{hashkey}:key", "another value!") + # with pytest.raises(redis.exceptions.RedisClusterException) as ex: + # pipe.discard() + # + # assert str(ex.value).startswith("DISCARD triggered without MULTI") + # + # pipe.set(f"{hashkey}:foo", "bar") + # response = pipe.execute() + # + # assert response[0] + # assert r.get(f"{hashkey}:foo") == b"bar" + # + # @pytest.mark.onlycluster + # @skip_if_server_version_lt("2.0.0") + # def test_transaction_discard(self, r): + # hashkey = "{key}" + # + # # pipelines enabled as transactions can be discarded at any point + # with r.pipeline(transaction=True) as pipe: + # pipe.watch(f"{hashkey}:key") + # pipe.set(f"{hashkey}:key", "someval") + # pipe.discard() + # + # assert not pipe._watching + # assert not pipe.command_stack + # + # # pipelines with multi can be discarded + # with r.pipeline() as pipe: + # pipe.watch(f"{hashkey}:key") + # pipe.multi() + # pipe.set(f"{hashkey}:key", "someval") + # pipe.discard() + # + # assert not pipe._watching + # assert not pipe.command_stack + # + # @pytest.mark.onlycluster + # def test_retry_transaction_during_unfinished_slot_migration(self, r): + # """ + # When a transaction is triggered during a migration, MovedError + # or AskError may appear (depends on the key being already migrated + # or the key not existing already). The patch on parse_response + # simulates such an error, but the slot cache is not updated + # (meaning the migration is still ongogin) so the pipeline eventually + # fails as if it was retried but the migration is not yet complete. + # """ + # key = "book" + # slot = r.keyslot(key) + # node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + # + # with patch.object(Redis, "parse_response") as parse_response, patch.object( + # NodesManager, "_update_moved_slots" + # ) as manager_update_moved_slots: + # + # def ask_redirect_effect(connection, *args, **options): + # if "MULTI" in args: + # return + # elif "EXEC" in args: + # raise redis.exceptions.ExecAbortError() + # + # raise redis.exceptions.AskError(f"{slot} {node_importing.name}") + # + # parse_response.side_effect = ask_redirect_effect + # + # with r.pipeline(transaction=True) as pipe: + # pipe.multi() + # pipe.set(key, "val") + # with pytest.raises(redis.exceptions.AskError) as ex: + # pipe.execute() + # + # assert str(ex.value).startswith( + # "Command # 1 (SET book val) of pipeline caused error:" + # f" {slot} {node_importing.name}" + # ) + # + # manager_update_moved_slots.assert_called() + # + # @pytest.mark.onlycluster + # def test_retry_transaction_during_slot_migration_successful(self, r): + # """ + # If a MovedError or AskError appears when calling EXEC and no key is watched, + # the pipeline is retried after updating the node manager slot table. If the + # migration was completed, the transaction may then complete successfully. + # """ + # key = "book" + # slot = r.keyslot(key) + # node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + # + # with patch.object(Redis, "parse_response") as parse_response, patch.object( + # NodesManager, "_update_moved_slots" + # ) as manager_update_moved_slots: + # + # def ask_redirect_effect(conn, *args, **options): + # # first call should go here, we trigger an AskError + # if f"{conn.host}:{conn.port}" == node_migrating.name: + # if "MULTI" in args: + # return + # elif "EXEC" in args: + # raise redis.exceptions.ExecAbortError() + # + # raise redis.exceptions.AskError(f"{slot} {node_importing.name}") + # # if the slot table is updated, the next call will go here + # elif f"{conn.host}:{conn.port}" == node_importing.name: + # if "EXEC" in args: + # return [ + # "MOCK_OK" + # ] # mock value to validate this section was called + # return + # else: + # assert False, f"unexpected node {conn.host}:{conn.port} was called" + # + # def update_moved_slot(): # simulate slot table update + # ask_error = r._nodes_manager._moved_exception + # assert ask_error is not None, "No AskError was previously triggered" + # assert f"{ask_error.host}:{ask_error.port}" == node_importing.name + # r._nodes_manager._moved_exception = None + # r._nodes_manager.slots_cache[slot] = [node_importing] + # + # parse_response.side_effect = ask_redirect_effect + # manager_update_moved_slots.side_effect = update_moved_slot + # + # result = None + # with r.pipeline(transaction=True) as pipe: + # pipe.multi() + # pipe.set(key, "val") + # result = pipe.execute() + # + # assert result and "MOCK_OK" in result, "Target node was not called" + # + # @pytest.mark.onlycluster + # def test_retry_transaction_with_watch_during_slot_migration(self, r): + # """ + # If a MovedError or AskError appears when calling EXEC and keys were + # being watched before the migration started, a WatchError should appear. + # These errors imply resetting the connection and connecting to a new node, + # so watches are lost anyway and the client code must be notified. + # """ + # key = "book" + # slot = r.keyslot(key) + # node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + # + # with patch.object(Redis, "parse_response") as parse_response: + # + # def ask_redirect_effect(conn, *args, **options): + # if f"{conn.host}:{conn.port}" == node_migrating.name: + # # we simulate the watch was sent before the migration started + # if "WATCH" in args: + # return b"OK" + # # but the pipeline was triggered after the migration started + # elif "MULTI" in args: + # return + # elif "EXEC" in args: + # raise redis.exceptions.ExecAbortError() + # + # raise redis.exceptions.AskError(f"{slot} {node_importing.name}") + # # we should not try to connect to any other node + # else: + # assert False, f"unexpected node {conn.host}:{conn.port} was called" + # + # parse_response.side_effect = ask_redirect_effect + # + # with r.pipeline(transaction=True) as pipe: + # pipe.watch(key) + # pipe.multi() + # pipe.set(key, "val") + # with pytest.raises(redis.exceptions.WatchError) as ex: + # pipe.execute() + # + # assert str(ex.value).startswith( + # "Slot rebalancing ocurred while watching keys" + # ) + # + # @pytest.mark.onlycluster + # def test_retry_transaction_with_watch_after_slot_migration(self, r): + # """ + # If a MovedError or AskError appears when calling WATCH, the client + # must attempt to recover itself before proceeding and no WatchError + # should appear. + # """ + # key = "book" + # slot = r.keyslot(key) + # r.reinitialize_steps = 1 + # + # # force a MovedError on the first call to pipe.watch() + # # by switching the node that owns the slot to another one + # _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) + # r.nodes_manager.slots_cache[slot] = [node_importing] + # + # with r.pipeline(transaction=True) as pipe: + # pipe.watch(key) + # pipe.multi() + # pipe.set(key, "val") + # pipe.execute() diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 2e4b5b2cd5..bbf1ec9eb5 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -198,7 +198,7 @@ def test_watch_succeed(self, r): with r.pipeline() as pipe: pipe.watch("a", "b") - assert pipe._watching + assert pipe.watching a_value = pipe.get("a") b_value = pipe.get("b") assert a_value == b"1" @@ -207,7 +207,7 @@ def test_watch_succeed(self, r): pipe.set("c", 3) assert pipe.execute() == [True] - assert not pipe._watching + assert not pipe.watching @pytest.mark.onlynoncluster def test_watch_failure(self, r): @@ -222,7 +222,7 @@ def test_watch_failure(self, r): with pytest.raises(redis.WatchError): pipe.execute() - assert not pipe._watching + assert not pipe.watching @pytest.mark.onlynoncluster def test_watch_failure_in_empty_transaction(self, r): @@ -236,7 +236,7 @@ def test_watch_failure_in_empty_transaction(self, r): with pytest.raises(redis.WatchError): pipe.execute() - assert not pipe._watching + assert not pipe.watching @pytest.mark.onlynoncluster def test_unwatch(self, r): @@ -247,7 +247,7 @@ def test_unwatch(self, r): pipe.watch("a", "b") r["b"] = 3 pipe.unwatch() - assert not pipe._watching + assert not pipe.watching pipe.get("a") assert pipe.execute() == [b"1"] @@ -259,7 +259,7 @@ def test_watch_exec_no_unwatch(self, r): with r.monitor() as m: with r.pipeline() as pipe: pipe.watch("a", "b") - assert pipe._watching + assert pipe.watching a_value = pipe.get("a") b_value = pipe.get("b") assert a_value == b"1" @@ -267,7 +267,7 @@ def test_watch_exec_no_unwatch(self, r): pipe.multi() pipe.set("c", 3) assert pipe.execute() == [True] - assert not pipe._watching + assert not pipe.watching unwatch_command = wait_for_command(r, m, "UNWATCH") assert unwatch_command is None, "should not send UNWATCH" @@ -279,9 +279,9 @@ def test_watch_reset_unwatch(self, r): with r.monitor() as m: with r.pipeline() as pipe: pipe.watch("a") - assert pipe._watching + assert pipe.watching pipe.reset() - assert not pipe._watching + assert not pipe.watching unwatch_command = wait_for_command(r, m, "UNWATCH") assert unwatch_command is not None @@ -330,6 +330,7 @@ def my_transaction(pipe): assert result == [True] assert r["c"] == b"4" + @pytest.mark.onlynoncluster def test_transaction_callable_returns_value_from_callable(self, r): def callback(pipe): # No need to do anything here since we only want the return value From 137c93123abfefe9a288f3aa394bae0331eff7a5 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 2 May 2025 15:52:07 +0300 Subject: [PATCH 7/7] Sync with master --- redis/cluster.py | 73 +++++++++++++++---------------- tests/conftest.py | 2 +- tests/test_cluster_transaction.py | 19 +++++--- 3 files changed, 50 insertions(+), 44 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index fecab71fd1..c435d08745 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -5,6 +5,7 @@ import time from abc import ABC, abstractmethod from collections import OrderedDict +from copy import copy from enum import Enum from itertools import chain from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -2166,7 +2167,7 @@ def __init__( else: self.retry = Retry( backoff=ExponentialWithJitterBackoff(base=1, cap=10), - retries=self.cluster_error_retry_attempts, + retries=cluster_error_retry_attempts, ) self.encoder = Encoder( @@ -2178,10 +2179,8 @@ def __init__( lock = threading.Lock() self._lock = lock self.parent_execute_command = super().execute_command - self._execution_strategy: ExecutionStrategy = PipelineStrategy( - self - ) if not transaction else TransactionStrategy( - self + self._execution_strategy: ExecutionStrategy = ( + PipelineStrategy(self) if not transaction else TransactionStrategy(self) ) self.command_stack = self._execution_strategy.command_queue @@ -2477,7 +2476,6 @@ def read(self): class ExecutionStrategy(ABC): - @property @abstractmethod def command_queue(self): @@ -2520,7 +2518,9 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: pass @abstractmethod - def send_cluster_commands(self, stack, raise_on_error=True, allow_redirections=True): + def send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): """ Sends commands according to current execution strategy. @@ -2599,10 +2599,9 @@ def discard(self): class AbstractStrategy(ExecutionStrategy): - def __init__( - self, - pipe: ClusterPipeline, + self, + pipe: ClusterPipeline, ): self._command_queue: List[PipelineCommand] = [] self._pipe = pipe @@ -2631,7 +2630,9 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: pass @abstractmethod - def send_cluster_commands(self, stack, raise_on_error=True, allow_redirections=True): + def send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): pass @abstractmethod @@ -2666,8 +2667,8 @@ def annotate_exception(self, exception, number, command): ) exception.args = (msg,) + exception.args[1:] -class PipelineStrategy(AbstractStrategy): +class PipelineStrategy(AbstractStrategy): def __init__(self, pipe: ClusterPipeline): super().__init__(pipe) self.command_flags = pipe.command_flags @@ -2702,10 +2703,7 @@ def reset(self): self._command_queue = [] def send_cluster_commands( - self, - stack, - raise_on_error=True, - allow_redirections=True + self, stack, raise_on_error=True, allow_redirections=True ): """ Wrapper for CLUSTERDOWN error handling. @@ -2724,7 +2722,7 @@ def send_cluster_commands( """ if not stack: return [] - retry_attempts = self._pipe.cluster_error_retry_attempts + retry_attempts = self._pipe.retry.get_retries() while True: try: return self._send_cluster_commands( @@ -2742,10 +2740,7 @@ def send_cluster_commands( raise e def _send_cluster_commands( - self, - stack, - raise_on_error=True, - allow_redirections=True + self, stack, raise_on_error=True, allow_redirections=True ): """ Send a bunch of cluster commands to the redis cluster. @@ -2945,7 +2940,10 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: # Determine which nodes should be executed the command on. # Returns a list of target nodes. command = args[0].upper() - if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self._pipe.command_flags: + if ( + len(args) >= 2 + and f"{args[0]} {args[1]}".upper() in self._pipe.command_flags + ): command = f"{args[0]} {args[1]}".upper() nodes_flag = kwargs.pop("nodes_flag", None) @@ -2978,7 +2976,9 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: node = self._nodes_manager.get_node_from_slot( slot, self._pipe.read_from_replicas and command in READ_COMMANDS, - self._pipe.load_balancing_strategy if command in READ_COMMANDS else None, + self._pipe.load_balancing_strategy + if command in READ_COMMANDS + else None, ) return [node] @@ -3012,7 +3012,6 @@ def unlink(self, *names): class TransactionStrategy(AbstractStrategy): - NO_SLOTS_COMMANDS = {"UNWATCH"} IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} @@ -3066,7 +3065,7 @@ def execute_command(self, *args, **kwargs): slot_number = self._pipe.determine_slot(*args) if ( - self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS + self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS ) and not self._explicit_transaction: if args[0] == "WATCH": self._validate_watch() @@ -3098,10 +3097,7 @@ def _validate_watch(self): self._watching = True def _immediate_execute_command(self, *args, **options): - retry = Retry( - default_backoff(), - self._pipe.cluster_error_retry_attempts, - ) + retry = copy(self._pipe.retry) retry.update_supported_errors([AskError, MovedError]) return retry.call_with_retry( lambda: self._get_connection_and_send_command(*args, **options), @@ -3175,10 +3171,7 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: def _execute_transaction_with_retries( self, stack: List["PipelineCommand"], raise_on_error: bool ): - retry = Retry( - default_backoff(), - self._pipe.cluster_error_retry_attempts, - ) + retry = copy(self._pipe.retry) retry.update_supported_errors([AskError, MovedError]) return retry.call_with_retry( lambda: self._execute_transaction(stack, raise_on_error), @@ -3284,7 +3277,9 @@ def _execute_transaction( if not isinstance(r, Exception): command_name = cmd.args[0] if command_name in self._pipe.cluster_response_callbacks: - r = self._pipe.cluster_response_callbacks[command_name](r, **cmd.options) + r = self._pipe.cluster_response_callbacks[command_name]( + r, **cmd.options + ) data.append(r) return data @@ -3321,8 +3316,12 @@ def reset(self): self._cluster_error = False self._executing = False - def send_cluster_commands(self, stack, raise_on_error=True, allow_redirections=True): - raise NotImplementedError("send_cluster_commands cannot be executed in transactional context.") + def send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): + raise NotImplementedError( + "send_cluster_commands cannot be executed in transactional context." + ) def multi(self): if self._explicit_transaction: @@ -3356,4 +3355,4 @@ def delete(self, *names): return self.execute_command("DEL", *names) def unlink(self, *names): - return self.execute_command("UNLINK", *names) \ No newline at end of file + return self.execute_command("UNLINK", *names) diff --git a/tests/conftest.py b/tests/conftest.py index 7eaccb1acb..9c174974ef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ from tests.ssl_utils import get_tls_certificates REDIS_INFO = {} -default_redis_url = "redis://localhost:6379/0" +default_redis_url = "redis://localhost:16379/0" default_protocol = "2" default_redismod_url = "redis://localhost:6479" diff --git a/tests/test_cluster_transaction.py b/tests/test_cluster_transaction.py index c416dbd046..9de9f20482 100644 --- a/tests/test_cluster_transaction.py +++ b/tests/test_cluster_transaction.py @@ -36,7 +36,6 @@ def _find_source_and_target_node_for_slot( class TestClusterTransaction: - @pytest.mark.onlycluster def test_executes_transaction_against_cluster(self, r): with r.pipeline(transaction=True) as tx: @@ -46,7 +45,14 @@ def test_executes_transaction_against_cluster(self, r): tx.get("{foo}bar") tx.get("{foo}baz") tx.get("{foo}bad") - assert tx.execute() == [b"OK", b"OK", b"OK", b"value1", b"value2", b"value3"] + assert tx.execute() == [ + b"OK", + b"OK", + b"OK", + b"value1", + b"value2", + b"value3", + ] r.flushall() @@ -66,8 +72,8 @@ def test_throws_exception_on_different_hash_slots(self, r): tx.set("{foobar}baz", "value2") with pytest.raises( - CrossSlotTransactionError, - match="All keys involved in a cluster transaction must map to the same slot" + CrossSlotTransactionError, + match="All keys involved in a cluster transaction must map to the same slot", ): tx.execute() @@ -201,7 +207,9 @@ def test_retry_transaction_on_connection_error(self, r, mock_connection): key = "book" slot = r.keyslot(key) - mock_connection.read_response.side_effect = redis.exceptions.ConnectionError("Conn error") + mock_connection.read_response.side_effect = redis.exceptions.ConnectionError( + "Conn error" + ) mock_connection.retry = Retry(NoBackoff(), 0) mock_pool = Mock(spec=ConnectionPool) mock_pool.get_connection.return_value = mock_connection @@ -217,7 +225,6 @@ def test_retry_transaction_on_connection_error(self, r, mock_connection): pipe.set(key, "val") pipe.execute() - # @pytest.mark.onlycluster # def test_pipeline_is_true(self, r): # "Ensure pipeline instances are not false-y"