diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 4e80cdc374..3853d528d9 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -21,6 +21,8 @@ from pydantic import VERSION as P_VERSION from pydantic import BaseModel from pydantic.fields import FieldInfo +from sqlalchemy import inspect +from sqlalchemy.orm import Mapper from typing_extensions import Annotated, get_args, get_origin # Reassign variable to make it reexported for mypy @@ -64,6 +66,35 @@ def _is_union_type(t: Any) -> bool: finish_init: ContextVar[bool] = ContextVar("finish_init", default=True) +def set_polymorphic_default_value( + self_instance: _TSQLModel, + values: Dict[str, Any], +) -> bool: + """By default, when init a model, pydantic will set the polymorphic_on + value to field default value. But when inherit a model, the polymorphic_on + should be set to polymorphic_identity value by default.""" + cls = type(self_instance) + mapper = inspect(cls) + ret = False + if isinstance(mapper, Mapper): + polymorphic_on = mapper.polymorphic_on + if polymorphic_on is not None: + polymorphic_property = mapper.get_property_by_column(polymorphic_on) + field_info = get_model_fields(cls).get(polymorphic_property.key) + if field_info: + v = values.get(polymorphic_property.key) + # if model is inherited or polymorphic_on is not explicitly set + # set the polymorphic_on by default + if mapper.inherits or v is None: + setattr( + self_instance, + polymorphic_property.key, + mapper.polymorphic_identity, + ) + ret = True + return ret + + @contextmanager def partial_init() -> Generator[None, None, None]: token = finish_init.set(False) @@ -290,6 +321,8 @@ def sqlmodel_table_construct( if value is not Undefined: setattr(self_instance, key, value) # End SQLModel override + # Override polymorphic_on default value + set_polymorphic_default_value(self_instance, values) return self_instance def sqlmodel_validate( diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3532e81a8e..0980bb3e53 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -1,5 +1,6 @@ import ipaddress import uuid +import warnings import weakref from datetime import date, datetime, time, timedelta from decimal import Decimal @@ -41,9 +42,10 @@ ) from sqlalchemy import Enum as sa_Enum from sqlalchemy.orm import ( + InstrumentedAttribute, Mapped, + MappedColumn, RelationshipProperty, - declared_attr, registry, relationship, ) @@ -538,7 +540,33 @@ def __new__( config_kwargs = { key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } - new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + is_polymorphic = False + if IS_PYDANTIC_V2: + base_fields = {} + base_annotations = {} + for base in bases[::-1]: + if issubclass(base, BaseModel): + base_fields.update(get_model_fields(base)) + base_annotations.update(base.__annotations__) + if hasattr(base, "__tablename__"): + is_polymorphic = True + # use base_fields overwriting the ones from the class for inherit + # if base is a sqlalchemy model, it's attributes will be an InstrumentedAttribute + # thus pydantic will use the value of the attribute as the default value + base_annotations.update(dict_used["__annotations__"]) + dict_used["__annotations__"] = base_annotations + base_fields.update(dict_used) + dict_used = base_fields + # if is_polymorphic, disable pydantic `shadows an attribute` warning + if is_polymorphic: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Field name .+ shadows an attribute in parent.+", + ) + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + else: + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, @@ -558,9 +586,22 @@ def get_config(name: str) -> Any: config_table = get_config("table") if config_table is True: + # sqlalchemy mark a class as table by check if it has __tablename__ attribute + # or if __tablename__ is in __annotations__. Only set __tablename__ if it's + # a table model + if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"): + setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010 # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): + original_v = getattr(new_cls, k, None) + if ( + isinstance(original_v, InstrumentedAttribute) + and k not in class_dict + ): + # The attribute was already set by SQLAlchemy, don't override it + # Needed for polymorphic models, see #36 + continue col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field @@ -594,7 +635,13 @@ def __init__( # trying to create a new SQLAlchemy, for a new table, with the same name, that # triggers an error base_is_table = any(is_table_model_class(base) for base in bases) - if is_table_model_class(cls) and not base_is_table: + polymorphic_identity = dict_.get("__mapper_args__", {}).get( + "polymorphic_identity" + ) + has_polymorphic = polymorphic_identity is not None + + # allow polymorphic models inherit from table models + if is_table_model_class(cls) and (not base_is_table or has_polymorphic): for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): if rel_info.sa_relationship: # There's a SQLAlchemy relationship declared, that takes precedence @@ -702,13 +749,13 @@ def get_sqlalchemy_type(field: Any) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") -def get_column_from_field(field: Any) -> Column: # type: ignore +def get_column_from_field(field: Any) -> Union[Column, MappedColumn]: # type: ignore if IS_PYDANTIC_V2: field_info = field else: field_info = field.field_info sa_column = getattr(field_info, "sa_column", Undefined) - if isinstance(sa_column, Column): + if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn): return sa_column sa_type = get_sqlalchemy_type(field) primary_key = getattr(field_info, "primary_key", Undefined) @@ -772,7 +819,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) - __tablename__: ClassVar[Union[str, Callable[..., str]]] __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]] __name__: ClassVar[str] metadata: ClassVar[MetaData] @@ -836,10 +882,6 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: if not (isinstance(k, str) and k.startswith("_sa_")) ] - @declared_attr # type: ignore - def __tablename__(cls) -> str: - return cls.__name__.lower() - @classmethod def model_validate( cls: Type[_TSQLModel], diff --git a/tests/test_polymorphic_model.py b/tests/test_polymorphic_model.py new file mode 100644 index 0000000000..f17e030a86 --- /dev/null +++ b/tests/test_polymorphic_model.py @@ -0,0 +1,132 @@ +from typing import Optional + +from sqlalchemy import ForeignKey +from sqlalchemy.orm import mapped_column +from sqlmodel import Field, Session, SQLModel, create_engine, select + +from tests.conftest import needs_pydanticv2 + + +@needs_pydanticv2 +def test_polymorphic_joined_table(clear_sqlmodel) -> None: + class Hero(SQLModel, table=True): + __tablename__ = "hero" + id: Optional[int] = Field(default=None, primary_key=True) + hero_type: str = Field(default="hero") + + __mapper_args__ = { + "polymorphic_on": "hero_type", + "polymorphic_identity": "hero", + } + + class DarkHero(Hero): + __tablename__ = "dark_hero" + id: Optional[int] = Field( + default=None, + sa_column=mapped_column(ForeignKey("hero.id"), primary_key=True), + ) + dark_power: str = Field( + default="dark", + sa_column=mapped_column( + nullable=False, use_existing_column=True, default="dark" + ), + ) + + __mapper_args__ = { + "polymorphic_identity": "dark", + } + + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + hero = Hero() + db.add(hero) + dark_hero = DarkHero() + db.add(dark_hero) + db.commit() + statement = select(DarkHero) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(result[0].dark_power, str) + + +@needs_pydanticv2 +def test_polymorphic_joined_table_with_sqlmodel_field(clear_sqlmodel) -> None: + class Hero(SQLModel, table=True): + __tablename__ = "hero" + id: Optional[int] = Field(default=None, primary_key=True) + hero_type: str = Field(default="hero") + + __mapper_args__ = { + "polymorphic_on": "hero_type", + "polymorphic_identity": "hero", + } + + class DarkHero(Hero): + __tablename__ = "dark_hero" + id: Optional[int] = Field( + default=None, + primary_key=True, + foreign_key="hero.id", + ) + dark_power: str = Field( + default="dark", + sa_column=mapped_column( + nullable=False, use_existing_column=True, default="dark" + ), + ) + + __mapper_args__ = { + "polymorphic_identity": "dark", + } + + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + hero = Hero() + db.add(hero) + dark_hero = DarkHero() + db.add(dark_hero) + db.commit() + statement = select(DarkHero) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(result[0].dark_power, str) + + +@needs_pydanticv2 +def test_polymorphic_single_table(clear_sqlmodel) -> None: + class Hero(SQLModel, table=True): + __tablename__ = "hero" + id: Optional[int] = Field(default=None, primary_key=True) + hero_type: str = Field(default="hero") + + __mapper_args__ = { + "polymorphic_on": "hero_type", + "polymorphic_identity": "hero", + } + + class DarkHero(Hero): + dark_power: str = Field( + default="dark", + sa_column=mapped_column( + nullable=False, use_existing_column=True, default="dark" + ), + ) + + __mapper_args__ = { + "polymorphic_identity": "dark", + } + + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + hero = Hero() + db.add(hero) + dark_hero = DarkHero(dark_power="pokey") + db.add(dark_hero) + db.commit() + statement = select(DarkHero) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(result[0].dark_power, str)