diff --git a/redis/_parsers/helpers.py b/redis/_parsers/helpers.py index 5468addf62..e1ef122355 100644 --- a/redis/_parsers/helpers.py +++ b/redis/_parsers/helpers.py @@ -701,7 +701,7 @@ def string_keys_to_dict(key_string, callback): _RedisCallbacks = { **string_keys_to_dict( "AUTH COPY EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST PSETEX " - "PEXPIRE PEXPIREAT RENAMENX SETEX SETNX SMOVE", + "PEXPIRE PEXPIREAT RENAMENX SETEX SETNX SMOVE HSETNX SISMEMBER", bool, ), **string_keys_to_dict("HINCRBYFLOAT INCRBYFLOAT", float), @@ -777,6 +777,7 @@ def string_keys_to_dict(key_string, callback): "SENTINEL SET": bool_ok, "SLOWLOG GET": parse_slowlog_get, "SLOWLOG RESET": bool_ok, + "SMISMEMBER": lambda r: list(map(bool, r)), "SORT": sort_return_tuples, "SSCAN": parse_scan, "TIME": lambda x: (int(x[0]), int(x[1])), @@ -830,6 +831,7 @@ def string_keys_to_dict(key_string, callback): "SENTINEL MASTERS": parse_sentinel_masters, "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, + "SMISMEMBER": lambda r: list(map(bool, r)), "STRALGO": parse_stralgo, "XINFO CONSUMERS": parse_list_of_dicts, "XINFO GROUPS": parse_list_of_dicts, @@ -868,6 +870,7 @@ def string_keys_to_dict(key_string, callback): "SENTINEL MASTERS": parse_sentinel_masters_resp3, "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3, "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3, + "SMISMEMBER": lambda r: list(map(bool, r)), "STRALGO": lambda r, **options: ( {str_if_bytes(key): str_if_bytes(value) for key, value in r.items()} if isinstance(r, dict) diff --git a/redis/commands/core.py b/redis/commands/core.py index a8c327f08f..2add022268 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -11,6 +11,7 @@ Awaitable, Callable, Dict, + Generic, Iterable, Iterator, List, @@ -28,6 +29,7 @@ AbsExpiryT, AnyKeyT, BitfieldOffsetT, + BooleanType, ChannelT, CommandsProtocol, ConsumerT, @@ -35,13 +37,22 @@ ExpiryT, FieldT, GroupT, + IntegerType, KeysT, KeyT, Number, + OptionalStringListType, + OptionalStringType, PatternT, ResponseT, + ResponseTypeBoolean, + ResponseTypeInteger, + ResponseTypeOptionalString, + ResponseTypeOptionalStringList, + ResponseTypeStringList, ScriptTextT, StreamIdT, + StringListType, TimeoutSecT, ZScoreBoundT, ) @@ -1442,7 +1453,7 @@ async def shutdown( raise RedisError("SHUTDOWN seems to have failed.") -class BitFieldOperation: +class BitFieldOperation(Generic[ResponseTypeInteger]): """ Command builder for BITFIELD commands. """ @@ -1542,7 +1553,7 @@ def command(self): cmd.extend(ops) return cmd - def execute(self) -> ResponseT: + def execute(self) -> ResponseTypeInteger: """ Execute the operation(s) in a single BITFIELD command. The return value is a list of values corresponding to each operation. If the client @@ -1554,12 +1565,19 @@ def execute(self) -> ResponseT: return self.client.execute_command(*command) -class BasicKeyCommands(CommandsProtocol): +class BasicKeyCommands( + CommandsProtocol, + Generic[ + ResponseTypeBoolean, + ResponseTypeInteger, + ResponseTypeOptionalString, + ], +): """ Redis basic key-based commands """ - def append(self, key: KeyT, value: EncodableT) -> ResponseT: + def append(self, key: KeyT, value: EncodableT) -> ResponseTypeInteger: """ Appends the string ``value`` to the value at ``key``. If ``key`` doesn't already exist, create it with a value of ``value``. @@ -1575,7 +1593,7 @@ def bitcount( start: Union[int, None] = None, end: Union[int, None] = None, mode: Optional[str] = None, - ) -> ResponseT: + ) -> ResponseTypeInteger: """ Returns the count of set bits in the value of ``key``. Optional ``start`` and ``end`` parameters indicate which bytes to consider @@ -1596,7 +1614,7 @@ def bitfield( self: Union["redis.client.Redis", "redis.asyncio.client.Redis"], key: KeyT, default_overflow: Union[str, None] = None, - ) -> BitFieldOperation: + ) -> BitFieldOperation[ResponseTypeInteger]: """ Return a BitFieldOperation instance to conveniently construct one or more bitfield operations on ``key``. @@ -1611,7 +1629,7 @@ def bitfield_ro( encoding: str, offset: BitfieldOffsetT, items: Optional[list] = None, - ) -> ResponseT: + ) -> ResponseTypeInteger: """ Return an array of the specified bitfield values where the first value is found using ``encoding`` and ``offset`` @@ -1628,7 +1646,7 @@ def bitfield_ro( params.extend(["GET", encoding, offset]) return self.execute_command("BITFIELD_RO", *params, keys=[key]) - def bitop(self, operation: str, dest: KeyT, *keys: KeyT) -> ResponseT: + def bitop(self, operation: str, dest: KeyT, *keys: KeyT) -> ResponseTypeInteger: """ Perform a bitwise operation using ``operation`` between ``keys`` and store the result in ``dest``. @@ -1644,7 +1662,7 @@ def bitpos( start: Union[int, None] = None, end: Union[int, None] = None, mode: Optional[str] = None, - ) -> ResponseT: + ) -> ResponseTypeInteger: """ Return the position of the first bit set to 1 or 0 in a string. ``start`` and ``end`` defines search range. The range is interpreted @@ -1674,7 +1692,7 @@ def copy( destination: str, destination_db: Union[str, None] = None, replace: bool = False, - ) -> ResponseT: + ) -> ResponseTypeBoolean: """ Copy the value stored in the ``source`` key to the ``destination`` key. @@ -1820,7 +1838,7 @@ def expiretime(self, key: str) -> int: """ return self.execute_command("EXPIRETIME", key) - def get(self, name: KeyT) -> ResponseT: + def get(self, name: KeyT) -> ResponseTypeOptionalString: """ Return the value at key ``name``, or None if the key doesn't exist @@ -2510,7 +2528,13 @@ def lcs( return self.execute_command("LCS", *pieces, keys=[key1, key2]) -class AsyncBasicKeyCommands(BasicKeyCommands): +class AsyncBasicKeyCommands( + BasicKeyCommands[ + ResponseTypeBoolean, + ResponseTypeInteger, + ResponseTypeOptionalString, + ], +): def __delitem__(self, name: KeyT): raise TypeError("Async Redis client does not support class deletion") @@ -2530,7 +2554,16 @@ async def unwatch(self) -> None: return super().unwatch() -class ListCommands(CommandsProtocol): +class ListCommands( + CommandsProtocol, + Generic[ + ResponseTypeBoolean, + ResponseTypeInteger, + ResponseTypeOptionalString, + ResponseTypeOptionalStringList, + ResponseTypeStringList, + ], +): """ Redis commands for List data type. see: https://redis.io/topics/data-types#lists @@ -2538,7 +2571,7 @@ class ListCommands(CommandsProtocol): def blpop( self, keys: List, timeout: Optional[Number] = 0 - ) -> Union[Awaitable[list], list]: + ) -> ResponseTypeOptionalStringList: """ LPOP a value off of the first non-empty list named in the ``keys`` list. @@ -2558,8 +2591,8 @@ def blpop( return self.execute_command("BLPOP", *keys) def brpop( - self, keys: List, timeout: Optional[Number] = 0 - ) -> Union[Awaitable[list], list]: + self, keys: KeysT, timeout: Optional[Number] = 0 + ) -> ResponseTypeOptionalStringList: """ RPOP a value off of the first non-empty list named in the ``keys`` list. @@ -2580,7 +2613,7 @@ def brpop( def brpoplpush( self, src: str, dst: str, timeout: Optional[Number] = 0 - ) -> Union[Awaitable[Optional[str]], Optional[str]]: + ) -> ResponseTypeOptionalString: """ Pop a value off the tail of ``src``, push it on the head of ``dst`` and then return it. @@ -2635,9 +2668,7 @@ def lmpop( return self.execute_command("LMPOP", *args) - def lindex( - self, name: str, index: int - ) -> Union[Awaitable[Optional[str]], Optional[str]]: + def lindex(self, name: str, index: int) -> ResponseTypeOptionalString: """ Return the item from list ``name`` at position ``index`` @@ -2650,7 +2681,7 @@ def lindex( def linsert( self, name: str, where: str, refvalue: str, value: str - ) -> Union[Awaitable[int], int]: + ) -> ResponseTypeInteger: """ Insert ``value`` in list ``name`` either immediately before or after [``where``] ``refvalue`` @@ -2662,7 +2693,7 @@ def linsert( """ return self.execute_command("LINSERT", name, where, refvalue, value) - def llen(self, name: str) -> Union[Awaitable[int], int]: + def llen(self, name: str) -> IntegerType: """ Return the length of the list ``name`` @@ -2689,7 +2720,7 @@ def lpop( else: return self.execute_command("LPOP", name) - def lpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def lpush(self, name: str, *values: FieldT) -> ResponseTypeInteger: """ Push ``values`` onto the head of the list ``name`` @@ -2697,7 +2728,7 @@ def lpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("LPUSH", name, *values) - def lpushx(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def lpushx(self, name: str, *values: FieldT) -> ResponseTypeInteger: """ Push ``value`` onto the head of the list ``name`` if ``name`` exists @@ -2705,7 +2736,7 @@ def lpushx(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("LPUSHX", name, *values) - def lrange(self, name: str, start: int, end: int) -> Union[Awaitable[list], list]: + def lrange(self, name: str, start: int, end: int) -> ResponseTypeStringList: """ Return a slice of the list ``name`` between position ``start`` and ``end`` @@ -2717,7 +2748,7 @@ def lrange(self, name: str, start: int, end: int) -> Union[Awaitable[list], list """ return self.execute_command("LRANGE", name, start, end, keys=[name]) - def lrem(self, name: str, count: int, value: str) -> Union[Awaitable[int], int]: + def lrem(self, name: str, count: int, value: str) -> ResponseTypeInteger: """ Remove the first ``count`` occurrences of elements equal to ``value`` from the list stored at ``name``. @@ -2731,7 +2762,7 @@ def lrem(self, name: str, count: int, value: str) -> Union[Awaitable[int], int]: """ return self.execute_command("LREM", name, count, value) - def lset(self, name: str, index: int, value: str) -> Union[Awaitable[str], str]: + def lset(self, name: str, index: int, value: str) -> ResponseTypeBoolean: """ Set element at ``index`` of list ``name`` to ``value`` @@ -2779,7 +2810,7 @@ def rpoplpush(self, src: str, dst: str) -> Union[Awaitable[str], str]: """ return self.execute_command("RPOPLPUSH", src, dst) - def rpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def rpush(self, name: str, *values: FieldT) -> ResponseTypeInteger: """ Push ``values`` onto the tail of the list ``name`` @@ -2787,7 +2818,7 @@ def rpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("RPUSH", name, *values) - def rpushx(self, name: str, *values: str) -> Union[Awaitable[int], int]: + def rpushx(self, name: str, *values: str) -> ResponseTypeInteger: """ Push ``value`` onto the tail of the list ``name`` if ``name`` exists @@ -3284,13 +3315,20 @@ async def zscan_iter( yield d -class SetCommands(CommandsProtocol): +class SetCommands( + CommandsProtocol, + Generic[ + ResponseTypeBoolean, + ResponseTypeInteger, + ResponseTypeStringList, + ], +): """ Redis commands for Set data type. see: https://redis.io/topics/data-types#sets """ - def sadd(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def sadd(self, name: str, *values: FieldT) -> ResponseTypeInteger: """ Add ``value(s)`` to set ``name`` @@ -3298,7 +3336,7 @@ def sadd(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("SADD", name, *values) - def scard(self, name: str) -> Union[Awaitable[int], int]: + def scard(self, name: str) -> ResponseTypeInteger: """ Return the number of elements in set ``name`` @@ -3306,7 +3344,7 @@ def scard(self, name: str) -> Union[Awaitable[int], int]: """ return self.execute_command("SCARD", name, keys=[name]) - def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: + def sdiff(self, keys: List, *args: List) -> ResponseTypeStringList: """ Return the difference of sets specified by ``keys`` @@ -3315,9 +3353,7 @@ def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: args = list_or_args(keys, args) return self.execute_command("SDIFF", *args, keys=args) - def sdiffstore( - self, dest: str, keys: List, *args: List - ) -> Union[Awaitable[int], int]: + def sdiffstore(self, dest: str, keys: List, *args: List) -> ResponseTypeInteger: """ Store the difference of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. @@ -3327,7 +3363,7 @@ def sdiffstore( args = list_or_args(keys, args) return self.execute_command("SDIFFSTORE", dest, *args) - def sinter(self, keys: List, *args: List) -> Union[Awaitable[list], list]: + def sinter(self, keys: List, *args: List) -> ResponseTypeStringList: """ Return the intersection of sets specified by ``keys`` @@ -3338,7 +3374,7 @@ def sinter(self, keys: List, *args: List) -> Union[Awaitable[list], list]: def sintercard( self, numkeys: int, keys: List[str], limit: int = 0 - ) -> Union[Awaitable[int], int]: + ) -> ResponseTypeInteger: """ Return the cardinality of the intersect of multiple sets specified by ``keys``. @@ -3351,9 +3387,7 @@ def sintercard( args = [numkeys, *keys, "LIMIT", limit] return self.execute_command("SINTERCARD", *args, keys=keys) - def sinterstore( - self, dest: str, keys: List, *args: List - ) -> Union[Awaitable[int], int]: + def sinterstore(self, dest: str, keys: List, *args: List) -> ResponseTypeInteger: """ Store the intersection of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. @@ -3363,9 +3397,7 @@ def sinterstore( args = list_or_args(keys, args) return self.execute_command("SINTERSTORE", dest, *args) - def sismember( - self, name: str, value: str - ) -> Union[Awaitable[Union[Literal[0], Literal[1]]], Union[Literal[0], Literal[1]]]: + def sismember(self, name: str, value: str) -> ResponseTypeBoolean: """ Return whether ``value`` is a member of set ``name``: - 1 if the value is a member of the set. @@ -3383,12 +3415,7 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]: """ return self.execute_command("SMEMBERS", name, keys=[name]) - def smismember( - self, name: str, values: List, *args: List - ) -> Union[ - Awaitable[List[Union[Literal[0], Literal[1]]]], - List[Union[Literal[0], Literal[1]]], - ]: + def smismember(self, name: str, values: List, *args: List) -> ResponseTypeBoolean: """ Return whether each value in ``values`` is a member of the set ``name`` as a list of ``int`` in the order of ``values``: @@ -3400,7 +3427,7 @@ def smismember( args = list_or_args(values, args) return self.execute_command("SMISMEMBER", name, *args, keys=[name]) - def smove(self, src: str, dst: str, value: str) -> Union[Awaitable[bool], bool]: + def smove(self, src: str, dst: str, value: str) -> ResponseTypeBoolean: """ Move ``value`` from set ``src`` to set ``dst`` atomically @@ -3408,7 +3435,9 @@ def smove(self, src: str, dst: str, value: str) -> Union[Awaitable[bool], bool]: """ return self.execute_command("SMOVE", src, dst, value) - def spop(self, name: str, count: Optional[int] = None) -> Union[str, List, None]: + def spop( + self, name: str, count: Optional[int] = None + ) -> Union[str, List[str], None]: """ Remove and return a random member of set ``name`` @@ -3419,7 +3448,7 @@ def spop(self, name: str, count: Optional[int] = None) -> Union[str, List, None] def srandmember( self, name: str, number: Optional[int] = None - ) -> Union[str, List, None]: + ) -> Union[str, List[str], None]: """ If ``number`` is None, returns a random member of set ``name``. @@ -3432,7 +3461,7 @@ def srandmember( args = (number is not None) and [number] or [] return self.execute_command("SRANDMEMBER", name, *args) - def srem(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + def srem(self, name: str, *values: FieldT) -> ResponseTypeInteger: """ Remove ``values`` from set ``name`` @@ -3440,7 +3469,7 @@ def srem(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ return self.execute_command("SREM", name, *values) - def sunion(self, keys: List, *args: List) -> Union[Awaitable[List], List]: + def sunion(self, keys: List, *args: List) -> ResponseTypeStringList: """ Return the union of sets specified by ``keys`` @@ -3449,9 +3478,7 @@ def sunion(self, keys: List, *args: List) -> Union[Awaitable[List], List]: args = list_or_args(keys, args) return self.execute_command("SUNION", *args, keys=args) - def sunionstore( - self, dest: str, keys: List, *args: List - ) -> Union[Awaitable[int], int]: + def sunionstore(self, dest: str, keys: List, *args: List) -> ResponseTypeInteger: """ Store the union of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. @@ -4918,13 +4945,21 @@ class HashDataPersistOptions(Enum): FXX = "FXX" -class HashCommands(CommandsProtocol): +class HashCommands( + CommandsProtocol, + Generic[ + ResponseTypeBoolean, + ResponseTypeInteger, + ResponseTypeOptionalString, + ResponseTypeStringList, + ], +): """ Redis commands for Hash data type. see: https://redis.io/topics/data-types-intro#redis-hashes """ - def hdel(self, name: str, *keys: str) -> Union[Awaitable[int], int]: + def hdel(self, name: str, *keys: str) -> ResponseTypeInteger: """ Delete ``keys`` from hash ``name`` @@ -4932,7 +4967,7 @@ def hdel(self, name: str, *keys: str) -> Union[Awaitable[int], int]: """ return self.execute_command("HDEL", name, *keys) - def hexists(self, name: str, key: str) -> Union[Awaitable[bool], bool]: + def hexists(self, name: str, key: str) -> ResponseTypeBoolean: """ Returns a boolean indicating if ``key`` exists within hash ``name`` @@ -4940,9 +4975,7 @@ def hexists(self, name: str, key: str) -> Union[Awaitable[bool], bool]: """ return self.execute_command("HEXISTS", name, key, keys=[name]) - def hget( - self, name: str, key: str - ) -> Union[Awaitable[Optional[str]], Optional[str]]: + def hget(self, name: str, key: str) -> ResponseTypeOptionalString: """ Return the value of ``key`` within the hash ``name`` @@ -5032,9 +5065,7 @@ def hgetex( *keys, ) - def hincrby( - self, name: str, key: str, amount: int = 1 - ) -> Union[Awaitable[int], int]: + def hincrby(self, name: str, key: str, amount: int = 1) -> ResponseTypeInteger: """ Increment the value of ``key`` in hash ``name`` by ``amount`` @@ -5060,7 +5091,7 @@ def hkeys(self, name: str) -> Union[Awaitable[List], List]: """ return self.execute_command("HKEYS", name, keys=[name]) - def hlen(self, name: str) -> Union[Awaitable[int], int]: + def hlen(self, name: str) -> ResponseTypeInteger: """ Return the number of elements in hash ``name`` @@ -5075,7 +5106,7 @@ def hset( value: Optional[str] = None, mapping: Optional[dict] = None, items: Optional[list] = None, - ) -> Union[Awaitable[int], int]: + ) -> ResponseTypeInteger: """ Set ``key`` to ``value`` within hash ``name``, ``mapping`` accepts a dict of key/value pairs that will be @@ -5114,7 +5145,7 @@ def hsetex( pxat: Optional[AbsExpiryT] = None, data_persist_option: Optional[HashDataPersistOptions] = None, keepttl: bool = False, - ) -> Union[Awaitable[int], int]: + ) -> ResponseTypeInteger: """ Set ``key`` to ``value`` within hash ``name`` @@ -5184,7 +5215,7 @@ def hsetex( "HSETEX", name, *exp_options, "FIELDS", int(len(pieces) / 2), *pieces ) - def hsetnx(self, name: str, key: str, value: str) -> Union[Awaitable[bool], bool]: + def hsetnx(self, name: str, key: str, value: str) -> ResponseTypeBoolean: """ Set ``key`` to ``value`` within hash ``name`` if ``key`` does not exist. Returns 1 if HSETNX created a field, otherwise 0. @@ -5221,7 +5252,7 @@ def hmget(self, name: str, keys: List, *args: List) -> Union[Awaitable[List], Li args = list_or_args(keys, args) return self.execute_command("HMGET", name, *args, keys=[name]) - def hvals(self, name: str) -> Union[Awaitable[List], List]: + def hvals(self, name: str) -> ResponseTypeStringList: """ Return the list of values within hash ``name`` @@ -5229,7 +5260,7 @@ def hvals(self, name: str) -> Union[Awaitable[List], List]: """ return self.execute_command("HVALS", name, keys=[name]) - def hstrlen(self, name: str, key: str) -> Union[Awaitable[int], int]: + def hstrlen(self, name: str, key: str) -> ResponseTypeInteger: """ Return the number of bytes stored in the value of ``key`` within hash ``name`` @@ -6607,13 +6638,32 @@ def function_stats(self) -> Union[Awaitable[List], List]: class DataAccessCommands( - BasicKeyCommands, + BasicKeyCommands[ + BooleanType, + IntegerType, + OptionalStringType, + ], HyperlogCommands, - HashCommands, + HashCommands[ + BooleanType, + IntegerType, + OptionalStringType, + StringListType, + ], GeoCommands, - ListCommands, + ListCommands[ + BooleanType, + IntegerType, + OptionalStringType, + OptionalStringListType, + StringListType, + ], ScanCommands, - SetCommands, + SetCommands[ + BooleanType, + IntegerType, + StringListType, + ], StreamCommands, SortedSetCommands, ): @@ -6624,13 +6674,32 @@ class DataAccessCommands( class AsyncDataAccessCommands( - AsyncBasicKeyCommands, + AsyncBasicKeyCommands[ + Awaitable[BooleanType], + Awaitable[IntegerType], + Awaitable[OptionalStringType], + ], AsyncHyperlogCommands, - AsyncHashCommands, + AsyncHashCommands[ + Awaitable[BooleanType], + Awaitable[IntegerType], + Awaitable[OptionalStringType], + Awaitable[StringListType], + ], AsyncGeoCommands, - AsyncListCommands, + AsyncListCommands[ + Awaitable[BooleanType], + Awaitable[IntegerType], + Awaitable[OptionalStringType], + Awaitable[OptionalStringListType], + Awaitable[StringListType], + ], AsyncScanCommands, - AsyncSetCommands, + AsyncSetCommands[ + Awaitable[BooleanType], + Awaitable[IntegerType], + Awaitable[StringListType], + ], AsyncStreamCommands, AsyncSortedSetCommands, ): diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py index 859a43aea9..4438bf4ef7 100644 --- a/redis/commands/helpers.py +++ b/redis/commands/helpers.py @@ -1,13 +1,13 @@ import copy import random import string -from typing import List, Tuple +from typing import List, Optional, Tuple import redis from redis.typing import KeysT, KeyT -def list_or_args(keys: KeysT, args: Tuple[KeyT, ...]) -> List[KeyT]: +def list_or_args(keys: KeysT, args: Optional[Tuple[KeyT, ...]]) -> List[KeyT]: # returns a single new list combining keys and args try: iter(keys) diff --git a/redis/typing.py b/redis/typing.py index 24ad607480..e79402db46 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -6,7 +6,9 @@ Any, Awaitable, Iterable, + List, Mapping, + Optional, Protocol, Type, TypeVar, @@ -51,6 +53,36 @@ ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] +# New typing work in progress + +BooleanType = bool +IntegerType = int +OptionalStringType = Optional[str] +StringListType = List[str] +OptionalStringListType = Optional[List[str]] + +ResponseTypeBoolean = TypeVar( + "ResponseTypeBoolean", + bound=Union[Awaitable[BooleanType], BooleanType], +) +ResponseTypeInteger = TypeVar( + "ResponseTypeInteger", + bound=Union[Awaitable[IntegerType], IntegerType], +) +ResponseTypeOptionalString = TypeVar( + "ResponseTypeOptionalString", + bound=Union[Awaitable[OptionalStringType], OptionalStringType], +) +ResponseTypeStringList = TypeVar( + "ResponseTypeStringList", + bound=Union[Awaitable[StringListType], StringListType], +) +ResponseTypeOptionalStringList = TypeVar( + "ResponseTypeOptionalStringList", + bound=Union[Awaitable[OptionalStringListType], OptionalStringListType], +) + + class CommandsProtocol(Protocol): connection_pool: Union["AsyncConnectionPool", "ConnectionPool"]