Skip to content

Commit b1a24a6

Browse files
VincentRPSfruch
authored andcommitted
refactor: slightly improve typing
1 parent d0f472f commit b1a24a6

File tree

2 files changed

+40
-26
lines changed

2 files changed

+40
-26
lines changed

cassandra/cqlengine/models.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import logging
1616
import re
17+
from typing import Any, Type, TypeVar
1718
import six
1819
from warnings import warn
1920

@@ -86,11 +87,12 @@ class QuerySetDescriptor(object):
8687
it's declared on everytime it's accessed
8788
"""
8889

89-
def __get__(self, obj, model):
90+
def __get__(self, obj: Any, model: "BaseModel"):
9091
""" :rtype: ModelQuerySet """
9192
if model.__abstract__:
9293
raise CQLEngineException('cannot execute queries against abstract models')
9394
queryset = model.__queryset__(model)
95+
queryset
9496

9597
# if this is a concrete polymorphic model, and the discriminator
9698
# key is an indexed column, add a filter clause to only return
@@ -329,6 +331,7 @@ def __delete__(self, instance):
329331
else:
330332
raise AttributeError('cannot delete {0} columns'.format(self.column.column_name))
331333

334+
M = TypeVar('M', bound='BaseModel')
332335

333336
class BaseModel(object):
334337
"""
@@ -341,7 +344,7 @@ class DoesNotExist(_DoesNotExist):
341344
class MultipleObjectsReturned(_MultipleObjectsReturned):
342345
pass
343346

344-
objects = QuerySetDescriptor()
347+
objects: query.ModelQuerySet = QuerySetDescriptor()
345348
ttl = TTLDescriptor()
346349
consistency = ConsistencyDescriptor()
347350
iff = ConditionalDescriptor()
@@ -422,6 +425,7 @@ def __str__(self):
422425
return '{0} <{1}>'.format(self.__class__.__name__,
423426
', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys()))
424427

428+
425429
@classmethod
426430
def _routing_key_from_values(cls, pk_values, protocol_version):
427431
return cls._key_serializer(pk_values, protocol_version)
@@ -658,7 +662,7 @@ def _as_dict(self):
658662
return values
659663

660664
@classmethod
661-
def create(cls, **kwargs):
665+
def create(cls: Type[M], **kwargs) -> M:
662666
"""
663667
Create an instance of this model in the database.
664668
@@ -673,7 +677,7 @@ def create(cls, **kwargs):
673677
return cls.objects.create(**kwargs)
674678

675679
@classmethod
676-
def all(cls):
680+
def all(cls: Type[M]) -> list[M]:
677681
"""
678682
Returns a queryset representing all stored objects
679683
@@ -682,7 +686,7 @@ def all(cls):
682686
return cls.objects.all()
683687

684688
@classmethod
685-
def filter(cls, *args, **kwargs):
689+
def filter(cls: Type[M], *args, **kwargs):
686690
"""
687691
Returns a queryset based on filter parameters.
688692
@@ -691,15 +695,15 @@ def filter(cls, *args, **kwargs):
691695
return cls.objects.filter(*args, **kwargs)
692696

693697
@classmethod
694-
def get(cls, *args, **kwargs):
698+
def get(cls: Type[M], *args, **kwargs) -> M:
695699
"""
696700
Returns a single object based on the passed filter constraints.
697701
698702
This is a pass-through to the model objects().:method:`~cqlengine.queries.get`.
699703
"""
700704
return cls.objects.get(*args, **kwargs)
701705

702-
def timeout(self, timeout):
706+
def timeout(self: M, timeout: float | None) -> M:
703707
"""
704708
Sets a timeout for use in :meth:`~.save`, :meth:`~.update`, and :meth:`~.delete`
705709
operations
@@ -708,7 +712,7 @@ def timeout(self, timeout):
708712
self._timeout = timeout
709713
return self
710714

711-
def save(self):
715+
def save(self: M) -> M:
712716
"""
713717
Saves an object to the database.
714718
@@ -744,7 +748,7 @@ def save(self):
744748

745749
return self
746750

747-
def update(self, **values):
751+
def update(self: M, **values) -> M:
748752
"""
749753
Performs an update on the model instance. You can pass in values to set on the model
750754
for updating, or you can call without values to execute an update against any modified
@@ -835,9 +839,13 @@ def _class_get_connection(cls):
835839
def _inst_get_connection(self):
836840
return self._connection or self.__connection__
837841

842+
def __getitem__(self, s: slice | int) -> M | list[M]:
843+
return self.objects.__getitem__(s)
844+
838845
_get_connection = hybrid_classmethod(_class_get_connection, _inst_get_connection)
839846

840847

848+
841849
class ModelMetaClass(type):
842850

843851
def __new__(cls, name, bases, attrs):

cassandra/cqlengine/query.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import copy
1618
from datetime import datetime, timedelta
1719
from functools import partial
1820
import time
21+
from typing import TYPE_CHECKING, ClassVar, Type, TypeVar
1922
import six
2023
from warnings import warn
2124

@@ -336,10 +339,12 @@ def __enter__(self):
336339
def __exit__(self, exc_type, exc_val, exc_tb):
337340
return
338341

342+
if TYPE_CHECKING:
343+
from .models import M
339344

340345
class AbstractQuerySet(object):
341346

342-
def __init__(self, model):
347+
def __init__(self, model: Type[M]):
343348
super(AbstractQuerySet, self).__init__()
344349
self.model = model
345350

@@ -529,7 +534,7 @@ def __iter__(self):
529534

530535
idx += 1
531536

532-
def __getitem__(self, s):
537+
def __getitem__(self, s: slice | int) -> M | list[M]:
533538
self._execute_query()
534539

535540
if isinstance(s, slice):
@@ -602,7 +607,7 @@ def batch(self, batch_obj):
602607
clone._batch = batch_obj
603608
return clone
604609

605-
def first(self):
610+
def first(self) -> M | None:
606611
try:
607612
return six.next(iter(self))
608613
except StopIteration:
@@ -619,7 +624,7 @@ def all(self):
619624
"""
620625
return copy.deepcopy(self)
621626

622-
def consistency(self, consistency):
627+
def consistency(self, consistency: int):
623628
"""
624629
Sets the consistency level for the operation. See :class:`.ConsistencyLevel`.
625630
@@ -743,7 +748,7 @@ def filter(self, *args, **kwargs):
743748

744749
return clone
745750

746-
def get(self, *args, **kwargs):
751+
def get(self, *args, **kwargs) -> M:
747752
"""
748753
Returns a single instance matching this query, optionally with additional filter kwargs.
749754
@@ -784,7 +789,7 @@ def _get_ordering_condition(self, colname):
784789

785790
return colname, order_type
786791

787-
def order_by(self, *colnames):
792+
def order_by(self, *colnames: str):
788793
"""
789794
Sets the column(s) to be used for ordering
790795
@@ -828,7 +833,7 @@ class Comment(Model):
828833
clone._order.extend(conditions)
829834
return clone
830835

831-
def count(self):
836+
def count(self) -> int:
832837
"""
833838
Returns the number of rows matched by this query.
834839
@@ -881,7 +886,7 @@ class Automobile(Model):
881886

882887
return clone
883888

884-
def limit(self, v):
889+
def limit(self, v: int):
885890
"""
886891
Limits the number of results returned by Cassandra. Use *0* or *None* to disable.
887892
@@ -913,7 +918,7 @@ def limit(self, v):
913918
clone._limit = v
914919
return clone
915920

916-
def fetch_size(self, v):
921+
def fetch_size(self, v: int):
917922
"""
918923
Sets the number of rows that are fetched at a time.
919924
@@ -969,15 +974,15 @@ def _only_or_defer(self, action, fields):
969974

970975
return clone
971976

972-
def only(self, fields):
977+
def only(self, fields: list[str]):
973978
""" Load only these fields for the returned query """
974979
return self._only_or_defer('only', fields)
975980

976-
def defer(self, fields):
981+
def defer(self, fields: list[str]):
977982
""" Don't load these fields for the returned query """
978983
return self._only_or_defer('defer', fields)
979984

980-
def create(self, **kwargs):
985+
def create(self, **kwargs) -> M:
981986
return self.model(**kwargs) \
982987
.batch(self._batch) \
983988
.ttl(self._ttl) \
@@ -1014,7 +1019,7 @@ def __eq__(self, q):
10141019
def __ne__(self, q):
10151020
return not (self != q)
10161021

1017-
def timeout(self, timeout):
1022+
def timeout(self, timeout: float | None):
10181023
"""
10191024
:param timeout: Timeout for the query (in seconds)
10201025
:type timeout: float or None
@@ -1065,6 +1070,7 @@ def _get_result_constructor(self):
10651070
"""
10661071
return ResultObject
10671072

1073+
T = TypeVar('T', 'ModelQuerySet')
10681074

10691075
class ModelQuerySet(AbstractQuerySet):
10701076
"""
@@ -1157,7 +1163,7 @@ def values_list(self, *fields, **kwargs):
11571163
clone._flat_values_list = flat
11581164
return clone
11591165

1160-
def ttl(self, ttl):
1166+
def ttl(self: T, ttl: int) -> T:
11611167
"""
11621168
Sets the ttl (in seconds) for modified data.
11631169
@@ -1167,15 +1173,15 @@ def ttl(self, ttl):
11671173
clone._ttl = ttl
11681174
return clone
11691175

1170-
def timestamp(self, timestamp):
1176+
def timestamp(self: T, timestamp: datetime) -> T:
11711177
"""
11721178
Allows for custom timestamps to be saved with the record.
11731179
"""
11741180
clone = copy.deepcopy(self)
11751181
clone._timestamp = timestamp
11761182
return clone
11771183

1178-
def if_not_exists(self):
1184+
def if_not_exists(self: T) -> T:
11791185
"""
11801186
Check the existence of an object before insertion.
11811187
@@ -1187,7 +1193,7 @@ def if_not_exists(self):
11871193
clone._if_not_exists = True
11881194
return clone
11891195

1190-
def if_exists(self):
1196+
def if_exists(self: T) -> T:
11911197
"""
11921198
Check the existence of an object before an update or delete.
11931199

0 commit comments

Comments
 (0)