Skip to content

✨ Add IntEnum for sqltypes #1337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions sqlmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,4 @@
from .sql.expression import type_coerce as type_coerce
from .sql.expression import within_group as within_group
from .sql.sqltypes import AutoString as AutoString
from .sql.sqltypes import IntEnum as IntEnum
41 changes: 40 additions & 1 deletion sqlmodel/sql/sqltypes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Any, cast
from enum import IntEnum as _IntEnum
from typing import Any, Optional, Type, TypeVar, cast

from sqlalchemy import types
from sqlalchemy.engine.interfaces import Dialect

_TIntEnum = TypeVar("_TIntEnum", bound="_IntEnum")


class AutoString(types.TypeDecorator): # type: ignore
impl = types.String
Expand All @@ -14,3 +17,39 @@ def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
if impl.length is None and dialect.name == "mysql":
return dialect.type_descriptor(types.String(self.mysql_default_length))
return super().load_dialect_impl(dialect)


class IntEnum(types.TypeDecorator): # type: ignore
impl = types.Integer
cache_ok = True

def __init__(self, enum_type: Type[_TIntEnum], *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)

# validate the input enum type
if not issubclass(enum_type, _IntEnum):
raise TypeError("Input must be enum.IntEnum")

self.enum_type = enum_type

def process_result_value( # type: ignore[override]
self,
value: Optional[int],
dialect: Dialect,
) -> Optional[_TIntEnum]:
if value is None:
return None

result = self.enum_type(value)
return result

def process_bind_param(
self,
value: Optional[_TIntEnum],
dialect: Dialect,
) -> Optional[int]:
if value is None:
return None

result = value.value
return result
34 changes: 26 additions & 8 deletions tests/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,22 @@ def test_json_schema_flat_model_pydantic_v1():
"properties": {
"id": {"title": "Id", "type": "string", "format": "uuid"},
"enum_field": {"$ref": "#/definitions/MyEnum1"},
"int_enum_field": {"$ref": "#/definitions/MyEnum3"},
},
"required": ["id", "enum_field"],
"required": ["id", "enum_field", "int_enum_field"],
"definitions": {
"MyEnum1": {
"title": "MyEnum1",
"description": "An enumeration.",
"enum": ["A", "B"],
"type": "string",
}
},
"MyEnum3": {
"title": "MyEnum3",
"description": "An enumeration.",
"enum": [1, 2],
"type": "integer",
},
},
}

Expand All @@ -84,15 +91,22 @@ def test_json_schema_inherit_model_pydantic_v1():
"properties": {
"id": {"title": "Id", "type": "string", "format": "uuid"},
"enum_field": {"$ref": "#/definitions/MyEnum2"},
"int_enum_field": {"$ref": "#/definitions/MyEnum3"},
},
"required": ["id", "enum_field"],
"required": ["id", "enum_field", "int_enum_field"],
"definitions": {
"MyEnum2": {
"title": "MyEnum2",
"description": "An enumeration.",
"enum": ["C", "D"],
"type": "string",
}
},
"MyEnum3": {
"title": "MyEnum3",
"description": "An enumeration.",
"enum": [1, 2],
"type": "integer",
},
},
}

Expand All @@ -105,10 +119,12 @@ def test_json_schema_flat_model_pydantic_v2():
"properties": {
"id": {"title": "Id", "type": "string", "format": "uuid"},
"enum_field": {"$ref": "#/$defs/MyEnum1"},
"int_enum_field": {"$ref": "#/$defs/MyEnum3"},
},
"required": ["id", "enum_field"],
"required": ["id", "enum_field", "int_enum_field"],
"$defs": {
"MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"}
"MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"},
"MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"},
},
}

Expand All @@ -121,9 +137,11 @@ def test_json_schema_inherit_model_pydantic_v2():
"properties": {
"id": {"title": "Id", "type": "string", "format": "uuid"},
"enum_field": {"$ref": "#/$defs/MyEnum2"},
"int_enum_field": {"$ref": "#/$defs/MyEnum3"},
},
"required": ["id", "enum_field"],
"required": ["id", "enum_field", "int_enum_field"],
"$defs": {
"MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"}
"MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"},
"MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"},
},
}
9 changes: 8 additions & 1 deletion tests/test_enums_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum
import uuid

from sqlmodel import Field, SQLModel
from sqlmodel import Field, IntEnum, SQLModel


class MyEnum1(str, enum.Enum):
Expand All @@ -14,14 +14,21 @@ class MyEnum2(str, enum.Enum):
D = "D"


class MyEnum3(enum.IntEnum):
E = 1
F = 2


class BaseModel(SQLModel):
id: uuid.UUID = Field(primary_key=True)
enum_field: MyEnum2
int_enum_field: MyEnum3


class FlatModel(SQLModel, table=True):
id: uuid.UUID = Field(primary_key=True)
enum_field: MyEnum1
int_enum_field: MyEnum3 = Field(sa_type=IntEnum(MyEnum3))


class InheritModel(BaseModel, table=True):
Expand Down