diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index 28bb305237..20473611d9 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -141,3 +141,4 @@ from .sql.expression import type_coerce as type_coerce from .sql.expression import within_group as within_group from .sql.sqltypes import AutoString as AutoString +from .sql.sqltypes import IntEnum as IntEnum diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 512daacbab..2e14aeb5d3 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -1,8 +1,11 @@ -from typing import Any, cast +from enum import IntEnum as _IntEnum +from typing import Any, Optional, Type, TypeVar, cast from sqlalchemy import types from sqlalchemy.engine.interfaces import Dialect +_TIntEnum = TypeVar("_TIntEnum", bound="_IntEnum") + class AutoString(types.TypeDecorator): # type: ignore impl = types.String @@ -14,3 +17,39 @@ def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": if impl.length is None and dialect.name == "mysql": return dialect.type_descriptor(types.String(self.mysql_default_length)) return super().load_dialect_impl(dialect) + + +class IntEnum(types.TypeDecorator): # type: ignore + impl = types.Integer + cache_ok = True + + def __init__(self, enum_type: Type[_TIntEnum], *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + # validate the input enum type + if not issubclass(enum_type, _IntEnum): + raise TypeError("Input must be enum.IntEnum") + + self.enum_type = enum_type + + def process_result_value( # type: ignore[override] + self, + value: Optional[int], + dialect: Dialect, + ) -> Optional[_TIntEnum]: + if value is None: + return None + + result = self.enum_type(value) + return result + + def process_bind_param( + self, + value: Optional[_TIntEnum], + dialect: Dialect, + ) -> Optional[int]: + if value is None: + return None + + result = value.value + return result diff --git a/tests/test_enums.py b/tests/test_enums.py index 2808f3f9a9..aeff988a43 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -63,15 +63,22 @@ def test_json_schema_flat_model_pydantic_v1(): "properties": { "id": {"title": "Id", "type": "string", "format": "uuid"}, "enum_field": {"$ref": "#/definitions/MyEnum1"}, + "int_enum_field": {"$ref": "#/definitions/MyEnum3"}, }, - "required": ["id", "enum_field"], + "required": ["id", "enum_field", "int_enum_field"], "definitions": { "MyEnum1": { "title": "MyEnum1", "description": "An enumeration.", "enum": ["A", "B"], "type": "string", - } + }, + "MyEnum3": { + "title": "MyEnum3", + "description": "An enumeration.", + "enum": [1, 2], + "type": "integer", + }, }, } @@ -84,15 +91,22 @@ def test_json_schema_inherit_model_pydantic_v1(): "properties": { "id": {"title": "Id", "type": "string", "format": "uuid"}, "enum_field": {"$ref": "#/definitions/MyEnum2"}, + "int_enum_field": {"$ref": "#/definitions/MyEnum3"}, }, - "required": ["id", "enum_field"], + "required": ["id", "enum_field", "int_enum_field"], "definitions": { "MyEnum2": { "title": "MyEnum2", "description": "An enumeration.", "enum": ["C", "D"], "type": "string", - } + }, + "MyEnum3": { + "title": "MyEnum3", + "description": "An enumeration.", + "enum": [1, 2], + "type": "integer", + }, }, } @@ -105,10 +119,12 @@ def test_json_schema_flat_model_pydantic_v2(): "properties": { "id": {"title": "Id", "type": "string", "format": "uuid"}, "enum_field": {"$ref": "#/$defs/MyEnum1"}, + "int_enum_field": {"$ref": "#/$defs/MyEnum3"}, }, - "required": ["id", "enum_field"], + "required": ["id", "enum_field", "int_enum_field"], "$defs": { - "MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"} + "MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"}, + "MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"}, }, } @@ -121,9 +137,11 @@ def test_json_schema_inherit_model_pydantic_v2(): "properties": { "id": {"title": "Id", "type": "string", "format": "uuid"}, "enum_field": {"$ref": "#/$defs/MyEnum2"}, + "int_enum_field": {"$ref": "#/$defs/MyEnum3"}, }, - "required": ["id", "enum_field"], + "required": ["id", "enum_field", "int_enum_field"], "$defs": { - "MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"} + "MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"}, + "MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"}, }, } diff --git a/tests/test_enums_models.py b/tests/test_enums_models.py index b46ccb7d2b..c6dac3723a 100644 --- a/tests/test_enums_models.py +++ b/tests/test_enums_models.py @@ -1,7 +1,7 @@ import enum import uuid -from sqlmodel import Field, SQLModel +from sqlmodel import Field, IntEnum, SQLModel class MyEnum1(str, enum.Enum): @@ -14,14 +14,21 @@ class MyEnum2(str, enum.Enum): D = "D" +class MyEnum3(enum.IntEnum): + E = 1 + F = 2 + + class BaseModel(SQLModel): id: uuid.UUID = Field(primary_key=True) enum_field: MyEnum2 + int_enum_field: MyEnum3 class FlatModel(SQLModel, table=True): id: uuid.UUID = Field(primary_key=True) enum_field: MyEnum1 + int_enum_field: MyEnum3 = Field(sa_type=IntEnum(MyEnum3)) class InheritModel(BaseModel, table=True):