diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 45a41997fe..d708b352af 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -643,6 +643,12 @@ def __init__( ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) +class SchemaEnum(sa_Enum): + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs["inherit_schema"] = True + super().__init__(*args, **kwargs) + + def get_sqlalchemy_type(field: Any) -> Any: if IS_PYDANTIC_V2: field_info = field @@ -657,7 +663,7 @@ def get_sqlalchemy_type(field: Any) -> Any: # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI if issubclass(type_, Enum): - return sa_Enum(type_) + return SchemaEnum(type_) if issubclass( type_, (