From 07bf10cc76f0f9630c010a758fbc2a96abfb3efc Mon Sep 17 00:00:00 2001 From: KunxiSun Date: Thu, 3 Apr 2025 11:24:51 +0800 Subject: [PATCH 01/10] feat: add IntEnum for sqltypes --- sqlmodel/__init__.py | 1 + sqlmodel/sql/sqltypes.py | 57 +++++++++++++++++++++++++++++++++++++- tests/test_enums.py | 30 ++++++++++++++++---- tests/test_enums_models.py | 8 ++++-- 4 files changed, 87 insertions(+), 9 deletions(-) diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index 28bb305237..20473611d9 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -141,3 +141,4 @@ from .sql.expression import type_coerce as type_coerce from .sql.expression import within_group as within_group from .sql.sqltypes import AutoString as AutoString +from .sql.sqltypes import IntEnum as IntEnum diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 512daacbab..207a33d441 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -1,4 +1,5 @@ -from typing import Any, cast +from typing import Any, cast, Optional +from enum import IntEnum as _IntEnum from sqlalchemy import types from sqlalchemy.engine.interfaces import Dialect @@ -14,3 +15,57 @@ def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": if impl.length is None and dialect.name == "mysql": return dialect.type_descriptor(types.String(self.mysql_default_length)) return super().load_dialect_impl(dialect) + +class IntEnum(types.TypeDecorator): # type: ignore + """TypeDecorator for Integer-enum conversion. + + Automatically converts Python enum.IntEnum <-> database integers. + + Args: + enum_type (enum.IntEnum): Integer enum class (subclass of enum.IntEnum) + + Example: + >>> class HeroStatus(enum.IntEnum): + ... ACTIVE = 1 + ... DISABLE = 2 + >>>> + >>> from sqlmodel import IntEnum + >>> class Hero(SQLModel): + ... hero_status: HeroStatus = Field(sa_type=sqlmodel.IntEnum(HeroStatus)) + >>> user.hero_status == Status.ACTIVE # Loads back as enum + + Returns: + Optional[enum.IntEnum]: Converted enum instance (None if database value is NULL) + + Raises: + TypeError: For invalid enum types + """ + + impl = types.Integer + + def __init__(self, enum_type: _IntEnum, *args, **kwargs): + super().__init__(*args, **kwargs) + + # validate the input enum type + if not issubclass(enum_type, _IntEnum): + raise TypeError( + f"Input must be enum.IntEnum" + ) + + self.enum_type = enum_type + + def process_result_value(self, value: Optional[int], dialect) -> Optional[_IntEnum]: + + if value is None: + return None + + result = self.enum_type(value) + return result + + def process_bind_param(self, value: Optional[_IntEnum], dialect) -> Optional[int]: + + if value is None: + return None + + result = value.value + return result diff --git a/tests/test_enums.py b/tests/test_enums.py index 2808f3f9a9..bf3fc2c651 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -63,14 +63,21 @@ def test_json_schema_flat_model_pydantic_v1(): "properties": { "id": {"title": "Id", "type": "string", "format": "uuid"}, "enum_field": {"$ref": "#/definitions/MyEnum1"}, + "int_enum_field": {"$ref": "#/definitions/MyEnum3"}, }, - "required": ["id", "enum_field"], + "required": ["id", "enum_field", "int_enum_field"], "definitions": { "MyEnum1": { "title": "MyEnum1", "description": "An enumeration.", "enum": ["A", "B"], "type": "string", + }, + "MyEnum3": { + "title": "MyEnum3", + "description": "An enumeration.", + "enum": [1, 3], + "type": "int", } }, } @@ -84,14 +91,21 @@ def test_json_schema_inherit_model_pydantic_v1(): "properties": { "id": {"title": "Id", "type": "string", "format": "uuid"}, "enum_field": {"$ref": "#/definitions/MyEnum2"}, + "int_enum_field": {"$ref": "#/definitions/MyEnum3"}, }, - "required": ["id", "enum_field"], + "required": ["id", "enum_field", "int_enum_field"], "definitions": { "MyEnum2": { "title": "MyEnum2", "description": "An enumeration.", "enum": ["C", "D"], "type": "string", + }, + "MyEnum3": { + "title": "MyEnum3", + "description": "An int enumeration.", + "enum": [1, 3], + "type": "int", } }, } @@ -105,10 +119,12 @@ def test_json_schema_flat_model_pydantic_v2(): "properties": { "id": {"title": "Id", "type": "string", "format": "uuid"}, "enum_field": {"$ref": "#/$defs/MyEnum1"}, + "int_enum_field": {"$ref": "#/$defs/MyEnum3"}, }, - "required": ["id", "enum_field"], + "required": ["id", "enum_field", "int_enum_field"], "$defs": { - "MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"} + "MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"}, + "MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"}, }, } @@ -121,9 +137,11 @@ def test_json_schema_inherit_model_pydantic_v2(): "properties": { "id": {"title": "Id", "type": "string", "format": "uuid"}, "enum_field": {"$ref": "#/$defs/MyEnum2"}, + "int_enum_field": {"$ref": "#/$defs/MyEnum3"}, }, - "required": ["id", "enum_field"], + "required": ["id", "enum_field", "int_enum_field"], "$defs": { - "MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"} + "MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"}, + "MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"}, }, } diff --git a/tests/test_enums_models.py b/tests/test_enums_models.py index b46ccb7d2b..04635d12c4 100644 --- a/tests/test_enums_models.py +++ b/tests/test_enums_models.py @@ -1,7 +1,7 @@ import enum import uuid -from sqlmodel import Field, SQLModel +from sqlmodel import Field, SQLModel, IntEnum class MyEnum1(str, enum.Enum): @@ -13,15 +13,19 @@ class MyEnum2(str, enum.Enum): C = "C" D = "D" +class MyEnum3(enum.IntEnum): + E = 1 + F = 2 class BaseModel(SQLModel): id: uuid.UUID = Field(primary_key=True) enum_field: MyEnum2 - + int_enum_field: MyEnum3 class FlatModel(SQLModel, table=True): id: uuid.UUID = Field(primary_key=True) enum_field: MyEnum1 + int_enum_field: MyEnum3 = Field(sa_type=IntEnum(MyEnum3)) class InheritModel(BaseModel, table=True): From 0f49d7a87bcbbd9c8039f1d122ae5d2b82178c71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Apr 2025 03:31:33 +0000 Subject: [PATCH 02/10] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/sql/sqltypes.py | 11 ++++------- tests/test_enums.py | 4 ++-- tests/test_enums_models.py | 5 ++++- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 207a33d441..805635f15d 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -1,5 +1,5 @@ -from typing import Any, cast, Optional from enum import IntEnum as _IntEnum +from typing import Any, Optional, cast from sqlalchemy import types from sqlalchemy.engine.interfaces import Dialect @@ -16,6 +16,7 @@ def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": return dialect.type_descriptor(types.String(self.mysql_default_length)) return super().load_dialect_impl(dialect) + class IntEnum(types.TypeDecorator): # type: ignore """TypeDecorator for Integer-enum conversion. @@ -27,7 +28,7 @@ class IntEnum(types.TypeDecorator): # type: ignore Example: >>> class HeroStatus(enum.IntEnum): ... ACTIVE = 1 - ... DISABLE = 2 + ... DISABLE = 2 >>>> >>> from sqlmodel import IntEnum >>> class Hero(SQLModel): @@ -48,14 +49,11 @@ def __init__(self, enum_type: _IntEnum, *args, **kwargs): # validate the input enum type if not issubclass(enum_type, _IntEnum): - raise TypeError( - f"Input must be enum.IntEnum" - ) + raise TypeError("Input must be enum.IntEnum") self.enum_type = enum_type def process_result_value(self, value: Optional[int], dialect) -> Optional[_IntEnum]: - if value is None: return None @@ -63,7 +61,6 @@ def process_result_value(self, value: Optional[int], dialect) -> Optional[_IntEn return result def process_bind_param(self, value: Optional[_IntEnum], dialect) -> Optional[int]: - if value is None: return None diff --git a/tests/test_enums.py b/tests/test_enums.py index bf3fc2c651..f9ec8e27d1 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -78,7 +78,7 @@ def test_json_schema_flat_model_pydantic_v1(): "description": "An enumeration.", "enum": [1, 3], "type": "int", - } + }, }, } @@ -106,7 +106,7 @@ def test_json_schema_inherit_model_pydantic_v1(): "description": "An int enumeration.", "enum": [1, 3], "type": "int", - } + }, }, } diff --git a/tests/test_enums_models.py b/tests/test_enums_models.py index 04635d12c4..c6dac3723a 100644 --- a/tests/test_enums_models.py +++ b/tests/test_enums_models.py @@ -1,7 +1,7 @@ import enum import uuid -from sqlmodel import Field, SQLModel, IntEnum +from sqlmodel import Field, IntEnum, SQLModel class MyEnum1(str, enum.Enum): @@ -13,15 +13,18 @@ class MyEnum2(str, enum.Enum): C = "C" D = "D" + class MyEnum3(enum.IntEnum): E = 1 F = 2 + class BaseModel(SQLModel): id: uuid.UUID = Field(primary_key=True) enum_field: MyEnum2 int_enum_field: MyEnum3 + class FlatModel(SQLModel, table=True): id: uuid.UUID = Field(primary_key=True) enum_field: MyEnum1 From f462b3e83f2c4868529c52b887dc69de4061638a Mon Sep 17 00:00:00 2001 From: KunxiSun Date: Thu, 3 Apr 2025 12:02:47 +0800 Subject: [PATCH 03/10] fix: type lint check --- sqlmodel/sql/sqltypes.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 207a33d441..0cf4a2e0b5 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -1,10 +1,13 @@ -from typing import Any, cast, Optional +from typing import Any, cast, Optional, TypeVar from enum import IntEnum as _IntEnum from sqlalchemy import types from sqlalchemy.engine.interfaces import Dialect +_TIntEnum = TypeVar('_TIntEnum', bound="_IntEnum") + + class AutoString(types.TypeDecorator): # type: ignore impl = types.String cache_ok = True @@ -43,7 +46,7 @@ class IntEnum(types.TypeDecorator): # type: ignore impl = types.Integer - def __init__(self, enum_type: _IntEnum, *args, **kwargs): + def __init__(self, enum_type: _TIntEnum, *args, **kwargs): super().__init__(*args, **kwargs) # validate the input enum type @@ -54,7 +57,7 @@ def __init__(self, enum_type: _IntEnum, *args, **kwargs): self.enum_type = enum_type - def process_result_value(self, value: Optional[int], dialect) -> Optional[_IntEnum]: + def process_result_value(self, value: Optional[int], dialect: Dialect) -> Optional[_TIntEnum]: if value is None: return None @@ -62,7 +65,7 @@ def process_result_value(self, value: Optional[int], dialect) -> Optional[_IntEn result = self.enum_type(value) return result - def process_bind_param(self, value: Optional[_IntEnum], dialect) -> Optional[int]: + def process_bind_param(self, value: Optional[_TIntEnum], dialect: Dialect) -> Optional[int]: if value is None: return None From acb12ca567db23c7bfab2281bfc013df1f7ac22e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Apr 2025 04:04:43 +0000 Subject: [PATCH 04/10] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/sql/sqltypes.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 4dae276c89..54f68b11ab 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -1,11 +1,10 @@ -from typing import Any, cast, Optional, TypeVar from enum import IntEnum as _IntEnum +from typing import Any, Optional, TypeVar, cast from sqlalchemy import types from sqlalchemy.engine.interfaces import Dialect - -_TIntEnum = TypeVar('_TIntEnum', bound="_IntEnum") +_TIntEnum = TypeVar("_TIntEnum", bound="_IntEnum") class AutoString(types.TypeDecorator): # type: ignore @@ -56,16 +55,18 @@ def __init__(self, enum_type: _TIntEnum, *args, **kwargs): self.enum_type = enum_type - def process_result_value(self, value: Optional[int], dialect: Dialect) -> Optional[_TIntEnum]: - + def process_result_value( + self, value: Optional[int], dialect: Dialect + ) -> Optional[_TIntEnum]: if value is None: return None result = self.enum_type(value) return result - def process_bind_param(self, value: Optional[_TIntEnum], dialect: Dialect) -> Optional[int]: - + def process_bind_param( + self, value: Optional[_TIntEnum], dialect: Dialect + ) -> Optional[int]: if value is None: return None From 1bd5deecb2f6a17122565f2ac08c30c1363e93d7 Mon Sep 17 00:00:00 2001 From: KunxiSun Date: Thu, 3 Apr 2025 12:50:48 +0800 Subject: [PATCH 05/10] fix: sqltypes lint check --- sqlmodel/sql/sqltypes.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 54f68b11ab..8acf4f7ef3 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -1,5 +1,5 @@ from enum import IntEnum as _IntEnum -from typing import Any, Optional, TypeVar, cast +from typing import Any, Optional, TypeVar, Type, cast, Any from sqlalchemy import types from sqlalchemy.engine.interfaces import Dialect @@ -46,7 +46,7 @@ class IntEnum(types.TypeDecorator): # type: ignore impl = types.Integer - def __init__(self, enum_type: _TIntEnum, *args, **kwargs): + def __init__(self, enum_type: Type[_TIntEnum], *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) # validate the input enum type @@ -55,9 +55,12 @@ def __init__(self, enum_type: _TIntEnum, *args, **kwargs): self.enum_type = enum_type - def process_result_value( - self, value: Optional[int], dialect: Dialect + def process_result_value( # type: ignore[override] + self, + value: Optional[int], + dialect: Dialect, ) -> Optional[_TIntEnum]: + if value is None: return None @@ -65,8 +68,11 @@ def process_result_value( return result def process_bind_param( - self, value: Optional[_TIntEnum], dialect: Dialect + self, + value: Optional[_TIntEnum], + dialect: Dialect, ) -> Optional[int]: + if value is None: return None From 2d3185d1539bde1984a0cdf075342d47d0fe77ee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Apr 2025 04:53:07 +0000 Subject: [PATCH 06/10] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/sql/sqltypes.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 8acf4f7ef3..2821ef7b55 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -1,5 +1,5 @@ from enum import IntEnum as _IntEnum -from typing import Any, Optional, TypeVar, Type, cast, Any +from typing import Any, Optional, Type, TypeVar, cast from sqlalchemy import types from sqlalchemy.engine.interfaces import Dialect @@ -60,7 +60,6 @@ def process_result_value( # type: ignore[override] value: Optional[int], dialect: Dialect, ) -> Optional[_TIntEnum]: - if value is None: return None @@ -72,7 +71,6 @@ def process_bind_param( value: Optional[_TIntEnum], dialect: Dialect, ) -> Optional[int]: - if value is None: return None From f388555905795cc0ec1cf1c4323d5e4b3bc56704 Mon Sep 17 00:00:00 2001 From: KunxiSun Date: Thu, 3 Apr 2025 13:01:33 +0800 Subject: [PATCH 07/10] fix: unit test pydanticv1 --- tests/test_enums.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_enums.py b/tests/test_enums.py index f9ec8e27d1..93321c0856 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -76,8 +76,8 @@ def test_json_schema_flat_model_pydantic_v1(): "MyEnum3": { "title": "MyEnum3", "description": "An enumeration.", - "enum": [1, 3], - "type": "int", + "enum": [1, 2], + "type": "integer", }, }, } @@ -104,8 +104,8 @@ def test_json_schema_inherit_model_pydantic_v1(): "MyEnum3": { "title": "MyEnum3", "description": "An int enumeration.", - "enum": [1, 3], - "type": "int", + "enum": [1, 2], + "type": "integer", }, }, } From bfcd346950c46d3de2c38d143531f5b8ef798f3d Mon Sep 17 00:00:00 2001 From: KunxiSun Date: Thu, 3 Apr 2025 14:57:52 +0800 Subject: [PATCH 08/10] fix: pydanticv1 unitest --- tests/test_enums.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_enums.py b/tests/test_enums.py index 93321c0856..aeff988a43 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -103,7 +103,7 @@ def test_json_schema_inherit_model_pydantic_v1(): }, "MyEnum3": { "title": "MyEnum3", - "description": "An int enumeration.", + "description": "An enumeration.", "enum": [1, 2], "type": "integer", }, From 921c2950796bb5d9fca665ab967b935326c78b65 Mon Sep 17 00:00:00 2001 From: KunxiSun Date: Thu, 17 Apr 2025 22:58:24 +0800 Subject: [PATCH 09/10] doc: no class docstring like other s --- sqlmodel/sql/sqltypes.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 2821ef7b55..f3b553e4ea 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -20,30 +20,6 @@ def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": class IntEnum(types.TypeDecorator): # type: ignore - """TypeDecorator for Integer-enum conversion. - - Automatically converts Python enum.IntEnum <-> database integers. - - Args: - enum_type (enum.IntEnum): Integer enum class (subclass of enum.IntEnum) - - Example: - >>> class HeroStatus(enum.IntEnum): - ... ACTIVE = 1 - ... DISABLE = 2 - >>>> - >>> from sqlmodel import IntEnum - >>> class Hero(SQLModel): - ... hero_status: HeroStatus = Field(sa_type=sqlmodel.IntEnum(HeroStatus)) - >>> user.hero_status == Status.ACTIVE # Loads back as enum - - Returns: - Optional[enum.IntEnum]: Converted enum instance (None if database value is NULL) - - Raises: - TypeError: For invalid enum types - """ - impl = types.Integer def __init__(self, enum_type: Type[_TIntEnum], *args: Any, **kwargs: Any): From 34fc6d45ee481cbd05edfcee9a45d30ac1a389b0 Mon Sep 17 00:00:00 2001 From: KunxiSun Date: Thu, 17 Apr 2025 22:59:01 +0800 Subject: [PATCH 10/10] perfomance: add cache and remove warning --- sqlmodel/sql/sqltypes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index f3b553e4ea..2e14aeb5d3 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -21,6 +21,7 @@ def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": class IntEnum(types.TypeDecorator): # type: ignore impl = types.Integer + cache_ok = True def __init__(self, enum_type: Type[_TIntEnum], *args: Any, **kwargs: Any): super().__init__(*args, **kwargs)