From 6d93a46fe0897ddc785a37a7c41312603cc14e45 Mon Sep 17 00:00:00 2001 From: John Lyu Date: Tue, 26 Nov 2024 10:43:13 +0800 Subject: [PATCH 01/18] support sqlalchemy polymorphic --- sqlmodel/_compat.py | 14 ++++ sqlmodel/main.py | 47 ++++++++++-- tests/test_polymorphic_model.py | 127 ++++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+), 8 deletions(-) create mode 100644 tests/test_polymorphic_model.py diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 4e80cdc374..6b7d53b165 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -21,6 +21,7 @@ from pydantic import VERSION as P_VERSION from pydantic import BaseModel from pydantic.fields import FieldInfo +from sqlalchemy import inspect from typing_extensions import Annotated, get_args, get_origin # Reassign variable to make it reexported for mypy @@ -290,6 +291,19 @@ def sqlmodel_table_construct( if value is not Undefined: setattr(self_instance, key, value) # End SQLModel override + # Override polymorphic_on default value + mapper = inspect(cls) + polymorphic_on = mapper.polymorphic_on + polymorphic_property = mapper.get_property_by_column(polymorphic_on) + field_info = cls.model_fields.get(polymorphic_property.key) + if field_info: + v = values.get(polymorphic_property.key) + # if model is inherited or polymorphic_on is not explicitly set + # set the polymorphic_on by default + if mapper.inherits or v is None: + setattr( + self_instance, polymorphic_property.key, mapper.polymorphic_identity + ) return self_instance def sqlmodel_validate( diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3532e81a8e..fcba557872 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -41,9 +41,10 @@ ) from sqlalchemy import Enum as sa_Enum from sqlalchemy.orm import ( + InstrumentedAttribute, Mapped, + MappedColumn, RelationshipProperty, - declared_attr, registry, relationship, ) @@ -544,6 +545,15 @@ def __new__( **pydantic_annotations, **new_cls.__annotations__, } + # pydantic will set class attribute value inherited from parent as field + # default value, reset it back + base_fields = {} + for base in bases[::-1]: + if issubclass(base, BaseModel): + base_fields.update(base.model_fields) + for k, v in new_cls.model_fields.items(): + if isinstance(v.default, InstrumentedAttribute): + new_cls.model_fields[k] = base_fields.get(k) def get_config(name: str) -> Any: config_class_value = get_config_value( @@ -558,9 +568,19 @@ def get_config(name: str) -> Any: config_table = get_config("table") if config_table is True: + if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"): + new_cls.__tablename__ = new_cls.__name__.lower() # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): + original_v = getattr(new_cls, k, None) + if ( + isinstance(original_v, InstrumentedAttribute) + and k not in class_dict + ): + # The attribute was already set by SQLAlchemy, don't override it + # Needed for polymorphic models, see #36 + continue col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field @@ -594,7 +614,13 @@ def __init__( # trying to create a new SQLAlchemy, for a new table, with the same name, that # triggers an error base_is_table = any(is_table_model_class(base) for base in bases) - if is_table_model_class(cls) and not base_is_table: + polymorphic_identity = dict_.get("__mapper_args__", {}).get( + "polymorphic_identity" + ) + has_polymorphic = polymorphic_identity is not None + + # allow polymorphic models inherit from table models + if is_table_model_class(cls) and (not base_is_table or has_polymorphic): for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): if rel_info.sa_relationship: # There's a SQLAlchemy relationship declared, that takes precedence @@ -641,6 +667,16 @@ def __init__( # Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77 # Tag: 1.4.36 DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw) + # # patch sqlmodel field's default value to polymorphic_identity + # if has_polymorphic: + # mapper = inspect(cls) + # polymorphic_on = mapper.polymorphic_on + # polymorphic_property = mapper.get_property_by_column(polymorphic_on) + # field = cls.model_fields.get(polymorphic_property.key) + # def get__polymorphic_identity__(kw): + # return polymorphic_identity + # if field: + # field.default_factory = get__polymorphic_identity__ else: ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) @@ -708,7 +744,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore else: field_info = field.field_info sa_column = getattr(field_info, "sa_column", Undefined) - if isinstance(sa_column, Column): + if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn): return sa_column sa_type = get_sqlalchemy_type(field) primary_key = getattr(field_info, "primary_key", Undefined) @@ -772,7 +808,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) - __tablename__: ClassVar[Union[str, Callable[..., str]]] __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]] __name__: ClassVar[str] metadata: ClassVar[MetaData] @@ -836,10 +871,6 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: if not (isinstance(k, str) and k.startswith("_sa_")) ] - @declared_attr # type: ignore - def __tablename__(cls) -> str: - return cls.__name__.lower() - @classmethod def model_validate( cls: Type[_TSQLModel], diff --git a/tests/test_polymorphic_model.py b/tests/test_polymorphic_model.py new file mode 100644 index 0000000000..c9c83301b4 --- /dev/null +++ b/tests/test_polymorphic_model.py @@ -0,0 +1,127 @@ +from typing import Optional + +from sqlalchemy import ForeignKey +from sqlalchemy.orm import mapped_column +from sqlmodel import Field, Session, SQLModel, create_engine, select + + +def test_polymorphic_joined_table(clear_sqlmodel) -> None: + class Hero(SQLModel, table=True): + __tablename__ = "hero" + id: Optional[int] = Field(default=None, primary_key=True) + hero_type: str = Field(default="hero") + + __mapper_args__ = { + "polymorphic_on": "hero_type", + "polymorphic_identity": "hero", + } + + class DarkHero(Hero): + __tablename__ = "dark_hero" + id: Optional[int] = Field( + default=None, + sa_column=mapped_column(ForeignKey("hero.id"), primary_key=True), + ) + dark_power: str = Field( + default="dark", + sa_column=mapped_column( + nullable=False, use_existing_column=True, default="dark" + ), + ) + + __mapper_args__ = { + "polymorphic_identity": "dark", + } + + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + hero = Hero() + db.add(hero) + dark_hero = DarkHero() + db.add(dark_hero) + db.commit() + statement = select(DarkHero) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(result[0].dark_power, str) + + +def test_polymorphic_joined_table_sm_field(clear_sqlmodel) -> None: + class Hero(SQLModel, table=True): + __tablename__ = "hero" + id: Optional[int] = Field(default=None, primary_key=True) + hero_type: str = Field(default="hero") + + __mapper_args__ = { + "polymorphic_on": "hero_type", + "polymorphic_identity": "hero", + } + + class DarkHero(Hero): + __tablename__ = "dark_hero" + id: Optional[int] = Field( + default=None, + primary_key=True, + foreign_key="hero.id", + ) + dark_power: str = Field( + default="dark", + sa_column=mapped_column( + nullable=False, use_existing_column=True, default="dark" + ), + ) + + __mapper_args__ = { + "polymorphic_identity": "dark", + } + + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + hero = Hero() + db.add(hero) + dark_hero = DarkHero() + db.add(dark_hero) + db.commit() + statement = select(DarkHero) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(result[0].dark_power, str) + + +def test_polymorphic_single_table(clear_sqlmodel) -> None: + class Hero(SQLModel, table=True): + __tablename__ = "hero" + id: Optional[int] = Field(default=None, primary_key=True) + hero_type: str = Field(default="hero") + + __mapper_args__ = { + "polymorphic_on": "hero_type", + "polymorphic_identity": "hero", + } + + class DarkHero(Hero): + dark_power: str = Field( + default="dark", + sa_column=mapped_column( + nullable=False, use_existing_column=True, default="dark" + ), + ) + + __mapper_args__ = { + "polymorphic_identity": "dark", + } + + engine = create_engine("sqlite:///:memory:", echo=True) + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + hero = Hero() + db.add(hero) + dark_hero = DarkHero() + db.add(dark_hero) + db.commit() + statement = select(DarkHero) + result = db.exec(statement).all() + assert len(result) == 1 + assert isinstance(result[0].dark_power, str) From 589237b80f24e825bb63683cf7bc8a8b4ed44b95 Mon Sep 17 00:00:00 2001 From: John Lyu Date: Tue, 26 Nov 2024 10:51:53 +0800 Subject: [PATCH 02/18] improve docs --- sqlmodel/main.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index fcba557872..f85dfc4b95 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -568,6 +568,9 @@ def get_config(name: str) -> Any: config_table = get_config("table") if config_table is True: + # sqlalchemy mark a class as table by check if it has __tablename__ attribute + # or if __tablename__ is in __annotations__. Only set __tablename__ if it's + # a table model if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"): new_cls.__tablename__ = new_cls.__name__.lower() # If it was passed by kwargs, ensure it's also set in config @@ -667,16 +670,6 @@ def __init__( # Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77 # Tag: 1.4.36 DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw) - # # patch sqlmodel field's default value to polymorphic_identity - # if has_polymorphic: - # mapper = inspect(cls) - # polymorphic_on = mapper.polymorphic_on - # polymorphic_property = mapper.get_property_by_column(polymorphic_on) - # field = cls.model_fields.get(polymorphic_property.key) - # def get__polymorphic_identity__(kw): - # return polymorphic_identity - # if field: - # field.default_factory = get__polymorphic_identity__ else: ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) From 4071b0fc6485c9f8f037140af7e4bd65c0d99774 Mon Sep 17 00:00:00 2001 From: John Lyu Date: Tue, 26 Nov 2024 10:54:26 +0800 Subject: [PATCH 03/18] fix polymorphic_on check --- sqlmodel/_compat.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 6b7d53b165..52a98e515d 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -294,16 +294,19 @@ def sqlmodel_table_construct( # Override polymorphic_on default value mapper = inspect(cls) polymorphic_on = mapper.polymorphic_on - polymorphic_property = mapper.get_property_by_column(polymorphic_on) - field_info = cls.model_fields.get(polymorphic_property.key) - if field_info: - v = values.get(polymorphic_property.key) - # if model is inherited or polymorphic_on is not explicitly set - # set the polymorphic_on by default - if mapper.inherits or v is None: - setattr( - self_instance, polymorphic_property.key, mapper.polymorphic_identity - ) + if polymorphic_on: + polymorphic_property = mapper.get_property_by_column(polymorphic_on) + field_info = cls.model_fields.get(polymorphic_property.key) + if field_info: + v = values.get(polymorphic_property.key) + # if model is inherited or polymorphic_on is not explicitly set + # set the polymorphic_on by default + if mapper.inherits or v is None: + setattr( + self_instance, + polymorphic_property.key, + mapper.polymorphic_identity, + ) return self_instance def sqlmodel_validate( From 48f2a88752888c11f060155643b4f4becabff9ff Mon Sep 17 00:00:00 2001 From: John Lyu Date: Tue, 26 Nov 2024 10:57:07 +0800 Subject: [PATCH 04/18] fix polymorphic_on check --- sqlmodel/_compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 52a98e515d..740e27a37d 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -294,7 +294,7 @@ def sqlmodel_table_construct( # Override polymorphic_on default value mapper = inspect(cls) polymorphic_on = mapper.polymorphic_on - if polymorphic_on: + if polymorphic_on is not None: polymorphic_property = mapper.get_property_by_column(polymorphic_on) field_info = cls.model_fields.get(polymorphic_property.key) if field_info: From e6ad74d50a943f006cb4876193d55edba043a727 Mon Sep 17 00:00:00 2001 From: John Lyu Date: Tue, 26 Nov 2024 11:17:55 +0800 Subject: [PATCH 05/18] fix lint --- sqlmodel/_compat.py | 30 ++++++++++++++++-------------- sqlmodel/main.py | 9 +++++---- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 740e27a37d..10742d80d5 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -22,6 +22,7 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo from sqlalchemy import inspect +from sqlalchemy.orm import Mapper from typing_extensions import Annotated, get_args, get_origin # Reassign variable to make it reexported for mypy @@ -293,20 +294,21 @@ def sqlmodel_table_construct( # End SQLModel override # Override polymorphic_on default value mapper = inspect(cls) - polymorphic_on = mapper.polymorphic_on - if polymorphic_on is not None: - polymorphic_property = mapper.get_property_by_column(polymorphic_on) - field_info = cls.model_fields.get(polymorphic_property.key) - if field_info: - v = values.get(polymorphic_property.key) - # if model is inherited or polymorphic_on is not explicitly set - # set the polymorphic_on by default - if mapper.inherits or v is None: - setattr( - self_instance, - polymorphic_property.key, - mapper.polymorphic_identity, - ) + if isinstance(mapper, Mapper): + polymorphic_on = mapper.polymorphic_on + if polymorphic_on is not None: + polymorphic_property = mapper.get_property_by_column(polymorphic_on) + field_info = cls.model_fields.get(polymorphic_property.key) + if field_info: + v = values.get(polymorphic_property.key) + # if model is inherited or polymorphic_on is not explicitly set + # set the polymorphic_on by default + if mapper.inherits or v is None: + setattr( + self_instance, + polymorphic_property.key, + mapper.polymorphic_identity, + ) return self_instance def sqlmodel_validate( diff --git a/sqlmodel/main.py b/sqlmodel/main.py index f85dfc4b95..923079e9b6 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -551,9 +551,10 @@ def __new__( for base in bases[::-1]: if issubclass(base, BaseModel): base_fields.update(base.model_fields) - for k, v in new_cls.model_fields.items(): + fields = get_model_fields(new_cls) + for k, v in fields.items(): if isinstance(v.default, InstrumentedAttribute): - new_cls.model_fields[k] = base_fields.get(k) + fields[k] = base_fields.get(k, FieldInfo()) def get_config(name: str) -> Any: config_class_value = get_config_value( @@ -572,7 +573,7 @@ def get_config(name: str) -> Any: # or if __tablename__ is in __annotations__. Only set __tablename__ if it's # a table model if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"): - new_cls.__tablename__ = new_cls.__name__.lower() + setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010 # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): @@ -731,7 +732,7 @@ def get_sqlalchemy_type(field: Any) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") -def get_column_from_field(field: Any) -> Column: # type: ignore +def get_column_from_field(field: Any) -> Column | MappedColumn: # type: ignore if IS_PYDANTIC_V2: field_info = field else: From 277953a614ee95ed6b09ec28d8046e0f0c35c90d Mon Sep 17 00:00:00 2001 From: John Lyu Date: Tue, 26 Nov 2024 11:19:28 +0800 Subject: [PATCH 06/18] fix pydantic v1 support --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 923079e9b6..1ce0b51a5f 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -550,7 +550,7 @@ def __new__( base_fields = {} for base in bases[::-1]: if issubclass(base, BaseModel): - base_fields.update(base.model_fields) + base_fields.update(get_model_fields(base)) fields = get_model_fields(new_cls) for k, v in fields.items(): if isinstance(v.default, InstrumentedAttribute): From 4aade030813383d6f81a4173a3357dd4dd0e6349 Mon Sep 17 00:00:00 2001 From: John Lyu Date: Tue, 26 Nov 2024 11:22:05 +0800 Subject: [PATCH 07/18] fix type hint for <3.10 --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 1ce0b51a5f..d7a5ba52b1 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -732,7 +732,7 @@ def get_sqlalchemy_type(field: Any) -> Any: raise ValueError(f"{type_} has no matching SQLAlchemy type") -def get_column_from_field(field: Any) -> Column | MappedColumn: # type: ignore +def get_column_from_field(field: Any) -> Union[Column, MappedColumn]: # type: ignore if IS_PYDANTIC_V2: field_info = field else: From a3044bbf68b8bdb91124d32ae141e79cbc8836c8 Mon Sep 17 00:00:00 2001 From: John Lyu Date: Tue, 26 Nov 2024 11:24:00 +0800 Subject: [PATCH 08/18] add needs_pydanticv2 mark to test --- tests/test_polymorphic_model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_polymorphic_model.py b/tests/test_polymorphic_model.py index c9c83301b4..8cded1bac5 100644 --- a/tests/test_polymorphic_model.py +++ b/tests/test_polymorphic_model.py @@ -4,7 +4,10 @@ from sqlalchemy.orm import mapped_column from sqlmodel import Field, Session, SQLModel, create_engine, select +from tests.conftest import needs_pydanticv2 + +@needs_pydanticv2 def test_polymorphic_joined_table(clear_sqlmodel) -> None: class Hero(SQLModel, table=True): __tablename__ = "hero" @@ -47,6 +50,7 @@ class DarkHero(Hero): assert isinstance(result[0].dark_power, str) +@needs_pydanticv2 def test_polymorphic_joined_table_sm_field(clear_sqlmodel) -> None: class Hero(SQLModel, table=True): __tablename__ = "hero" @@ -90,6 +94,7 @@ class DarkHero(Hero): assert isinstance(result[0].dark_power, str) +@needs_pydanticv2 def test_polymorphic_single_table(clear_sqlmodel) -> None: class Hero(SQLModel, table=True): __tablename__ = "hero" From 015601cd5bd4d805ab8474429d8de7310c883060 Mon Sep 17 00:00:00 2001 From: John Lyu Date: Tue, 3 Dec 2024 17:59:21 +0800 Subject: [PATCH 09/18] improve code structure --- sqlmodel/_compat.py | 42 ++++++++++++++++++++------------- tests/test_polymorphic_model.py | 4 ++-- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 10742d80d5..8a7e6fd75d 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -66,6 +66,29 @@ def _is_union_type(t: Any) -> bool: finish_init: ContextVar[bool] = ContextVar("finish_init", default=True) +def set_polymorphic_default_value(self_instance, values): + """By defalut, when init a model, pydantic will set the polymorphic_on + value to field default value. But when inherit a model, the polymorphic_on + should be set to polymorphic_identity value by default.""" + cls = type(self_instance) + mapper = inspect(cls) + if isinstance(mapper, Mapper): + polymorphic_on = mapper.polymorphic_on + if polymorphic_on is not None: + polymorphic_property = mapper.get_property_by_column(polymorphic_on) + field_info = get_model_fields(cls).get(polymorphic_property.key) + if field_info: + v = values.get(polymorphic_property.key) + # if model is inherited or polymorphic_on is not explicitly set + # set the polymorphic_on by default + if mapper.inherits or v is None: + setattr( + self_instance, + polymorphic_property.key, + mapper.polymorphic_identity, + ) + + @contextmanager def partial_init() -> Generator[None, None, None]: token = finish_init.set(False) @@ -293,22 +316,7 @@ def sqlmodel_table_construct( setattr(self_instance, key, value) # End SQLModel override # Override polymorphic_on default value - mapper = inspect(cls) - if isinstance(mapper, Mapper): - polymorphic_on = mapper.polymorphic_on - if polymorphic_on is not None: - polymorphic_property = mapper.get_property_by_column(polymorphic_on) - field_info = cls.model_fields.get(polymorphic_property.key) - if field_info: - v = values.get(polymorphic_property.key) - # if model is inherited or polymorphic_on is not explicitly set - # set the polymorphic_on by default - if mapper.inherits or v is None: - setattr( - self_instance, - polymorphic_property.key, - mapper.polymorphic_identity, - ) + set_polymorphic_default_value(self_instance, values) return self_instance def sqlmodel_validate( @@ -592,3 +600,5 @@ def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None: for key in non_pydantic_keys: if key in self.__sqlmodel_relationships__: setattr(self, key, data[key]) + # Override polymorphic_on default value + set_polymorphic_default_value(self, values) diff --git a/tests/test_polymorphic_model.py b/tests/test_polymorphic_model.py index 8cded1bac5..f17e030a86 100644 --- a/tests/test_polymorphic_model.py +++ b/tests/test_polymorphic_model.py @@ -51,7 +51,7 @@ class DarkHero(Hero): @needs_pydanticv2 -def test_polymorphic_joined_table_sm_field(clear_sqlmodel) -> None: +def test_polymorphic_joined_table_with_sqlmodel_field(clear_sqlmodel) -> None: class Hero(SQLModel, table=True): __tablename__ = "hero" id: Optional[int] = Field(default=None, primary_key=True) @@ -123,7 +123,7 @@ class DarkHero(Hero): with Session(engine) as db: hero = Hero() db.add(hero) - dark_hero = DarkHero() + dark_hero = DarkHero(dark_power="pokey") db.add(dark_hero) db.commit() statement = select(DarkHero) From 66c1d93cfef06f595b2675ad802f4841678d9faa Mon Sep 17 00:00:00 2001 From: John Lyu Date: Tue, 3 Dec 2024 18:04:46 +0800 Subject: [PATCH 10/18] lint --- sqlmodel/_compat.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 8a7e6fd75d..9a0e570e8e 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -66,12 +66,16 @@ def _is_union_type(t: Any) -> bool: finish_init: ContextVar[bool] = ContextVar("finish_init", default=True) -def set_polymorphic_default_value(self_instance, values): +def set_polymorphic_default_value( + self_instance: _TSQLModel, + values: Dict[str, Any], +) -> bool: """By defalut, when init a model, pydantic will set the polymorphic_on value to field default value. But when inherit a model, the polymorphic_on should be set to polymorphic_identity value by default.""" cls = type(self_instance) mapper = inspect(cls) + ret = False if isinstance(mapper, Mapper): polymorphic_on = mapper.polymorphic_on if polymorphic_on is not None: @@ -87,6 +91,8 @@ def set_polymorphic_default_value(self_instance, values): polymorphic_property.key, mapper.polymorphic_identity, ) + ret = True + return ret @contextmanager From 0efd1bfaa0cc6d0b1f590072e996d99f348111b5 Mon Sep 17 00:00:00 2001 From: John Lyu Date: Tue, 3 Dec 2024 18:08:01 +0800 Subject: [PATCH 11/18] remove effort of pydantic v1 --- sqlmodel/_compat.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 9a0e570e8e..0b91702e26 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -606,5 +606,3 @@ def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None: for key in non_pydantic_keys: if key in self.__sqlmodel_relationships__: setattr(self, key, data[key]) - # Override polymorphic_on default value - set_polymorphic_default_value(self, values) From 84d739e0e4baa659b480bce3db8808feca811bac Mon Sep 17 00:00:00 2001 From: John Lyu Date: Thu, 12 Dec 2024 10:41:35 +0800 Subject: [PATCH 12/18] Update sqlmodel/_compat.py Co-authored-by: John Pocock --- sqlmodel/_compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 0b91702e26..3853d528d9 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -70,7 +70,7 @@ def set_polymorphic_default_value( self_instance: _TSQLModel, values: Dict[str, Any], ) -> bool: - """By defalut, when init a model, pydantic will set the polymorphic_on + """By default, when init a model, pydantic will set the polymorphic_on value to field default value. But when inherit a model, the polymorphic_on should be set to polymorphic_identity value by default.""" cls = type(self_instance) From dbd0101c40a67f859aa301b5dbaba1b966f05a25 Mon Sep 17 00:00:00 2001 From: John Lyu Date: Wed, 5 Feb 2025 10:21:13 +0800 Subject: [PATCH 13/18] fix default value is InstrumentedAttribute in inherit --- sqlmodel/main.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index d7a5ba52b1..b660a2fd96 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -539,22 +539,24 @@ def __new__( config_kwargs = { key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } - new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + base_fields = {} + base_annotations = {} + for base in bases[::-1]: + if issubclass(base, BaseModel): + base_fields.update(get_model_fields(base)) + base_annotations.update(base.__annotations__) + # use base_fields overwriting the ones from the class for inherit + # if base is a sqlalchemy model, it's attributes will be an InstrumentedAttribute + # thus pydantic will use the value of the attribute as the default value + dict_used["__annotations__"].update(base_annotations) + new_cls = super().__new__( + cls, name, bases, dict_used | base_fields, **config_kwargs + ) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, **new_cls.__annotations__, } - # pydantic will set class attribute value inherited from parent as field - # default value, reset it back - base_fields = {} - for base in bases[::-1]: - if issubclass(base, BaseModel): - base_fields.update(get_model_fields(base)) - fields = get_model_fields(new_cls) - for k, v in fields.items(): - if isinstance(v.default, InstrumentedAttribute): - fields[k] = base_fields.get(k, FieldInfo()) def get_config(name: str) -> Any: config_class_value = get_config_value( From b1ed8c36caf001438d12dbcf6e65c93a5bee6417 Mon Sep 17 00:00:00 2001 From: John Lyu Date: Wed, 5 Feb 2025 14:20:42 +0800 Subject: [PATCH 14/18] fix inherit order --- sqlmodel/main.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index b660a2fd96..a8d4936be0 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -548,9 +548,10 @@ def __new__( # use base_fields overwriting the ones from the class for inherit # if base is a sqlalchemy model, it's attributes will be an InstrumentedAttribute # thus pydantic will use the value of the attribute as the default value - dict_used["__annotations__"].update(base_annotations) + base_annotations.update(dict_used["__annotations__"]) + dict_used["__annotations__"] = base_annotations new_cls = super().__new__( - cls, name, bases, dict_used | base_fields, **config_kwargs + cls, name, bases, base_fields | dict_used, **config_kwargs ) new_cls.__annotations__ = { **relationship_annotations, From 5d1bf5c2f3e8a63e817313f58ad97fb9ffddf34a Mon Sep 17 00:00:00 2001 From: John Lyu Date: Wed, 5 Feb 2025 14:23:02 +0800 Subject: [PATCH 15/18] support python < 3.9 --- sqlmodel/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index a8d4936be0..e0e1cc7440 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -550,8 +550,9 @@ def __new__( # thus pydantic will use the value of the attribute as the default value base_annotations.update(dict_used["__annotations__"]) dict_used["__annotations__"] = base_annotations + base_fields.update(dict_used) new_cls = super().__new__( - cls, name, bases, base_fields | dict_used, **config_kwargs + cls, name, bases, base_fields, **config_kwargs ) new_cls.__annotations__ = { **relationship_annotations, From ccbb92aeafe601cd6d3366afa690af568ccee064 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Feb 2025 06:24:38 +0000 Subject: [PATCH 16/18] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index e0e1cc7440..69a39f42ed 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -551,9 +551,7 @@ def __new__( base_annotations.update(dict_used["__annotations__"]) dict_used["__annotations__"] = base_annotations base_fields.update(dict_used) - new_cls = super().__new__( - cls, name, bases, base_fields, **config_kwargs - ) + new_cls = super().__new__(cls, name, bases, base_fields, **config_kwargs) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, From d0d02887978632bd749ff47433c679661f0df006 Mon Sep 17 00:00:00 2001 From: John Lyu Date: Wed, 5 Feb 2025 15:34:03 +0800 Subject: [PATCH 17/18] skip polymorphic in pydantic v1 --- sqlmodel/main.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index e0e1cc7440..3ee2490e5d 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -539,21 +539,21 @@ def __new__( config_kwargs = { key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } - base_fields = {} - base_annotations = {} - for base in bases[::-1]: - if issubclass(base, BaseModel): - base_fields.update(get_model_fields(base)) - base_annotations.update(base.__annotations__) - # use base_fields overwriting the ones from the class for inherit - # if base is a sqlalchemy model, it's attributes will be an InstrumentedAttribute - # thus pydantic will use the value of the attribute as the default value - base_annotations.update(dict_used["__annotations__"]) - dict_used["__annotations__"] = base_annotations - base_fields.update(dict_used) - new_cls = super().__new__( - cls, name, bases, base_fields, **config_kwargs - ) + if IS_PYDANTIC_V2: + base_fields = {} + base_annotations = {} + for base in bases[::-1]: + if issubclass(base, BaseModel): + base_fields.update(get_model_fields(base)) + base_annotations.update(base.__annotations__) + # use base_fields overwriting the ones from the class for inherit + # if base is a sqlalchemy model, it's attributes will be an InstrumentedAttribute + # thus pydantic will use the value of the attribute as the default value + base_annotations.update(dict_used["__annotations__"]) + dict_used["__annotations__"] = base_annotations + base_fields.update(dict_used) + dict_used = base_fields + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, From c1dff79a8852eea08b5bbbb4a8a96e5a79820dcb Mon Sep 17 00:00:00 2001 From: John Lyu Date: Wed, 12 Mar 2025 16:10:46 +0800 Subject: [PATCH 18/18] disable pydantic warning during polymorphic --- sqlmodel/main.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3ee2490e5d..0980bb3e53 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -1,5 +1,6 @@ import ipaddress import uuid +import warnings import weakref from datetime import date, datetime, time, timedelta from decimal import Decimal @@ -539,6 +540,7 @@ def __new__( config_kwargs = { key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } + is_polymorphic = False if IS_PYDANTIC_V2: base_fields = {} base_annotations = {} @@ -546,6 +548,8 @@ def __new__( if issubclass(base, BaseModel): base_fields.update(get_model_fields(base)) base_annotations.update(base.__annotations__) + if hasattr(base, "__tablename__"): + is_polymorphic = True # use base_fields overwriting the ones from the class for inherit # if base is a sqlalchemy model, it's attributes will be an InstrumentedAttribute # thus pydantic will use the value of the attribute as the default value @@ -553,7 +557,16 @@ def __new__( dict_used["__annotations__"] = base_annotations base_fields.update(dict_used) dict_used = base_fields - new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + # if is_polymorphic, disable pydantic `shadows an attribute` warning + if is_polymorphic: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Field name .+ shadows an attribute in parent.+", + ) + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + else: + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations,