diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 38dd501c4a..9797dbf257 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -164,6 +164,9 @@ def get_relationship_to( # If a list, then also get the real field elif origin is list: use_annotation = get_args(annotation)[0] + # If a dict, then use the value type + elif origin is dict: + use_annotation = get_args(annotation)[1] return get_relationship_to( name=name, rel_info=rel_info, annotation=use_annotation diff --git a/tests/test_attribute_keyed_dict.py b/tests/test_attribute_keyed_dict.py new file mode 100644 index 0000000000..9d06196396 --- /dev/null +++ b/tests/test_attribute_keyed_dict.py @@ -0,0 +1,48 @@ +from enum import Enum +from typing import Dict, Optional + +from sqlalchemy.orm.collections import attribute_keyed_dict +from sqlmodel import Field, Index, Relationship, Session, SQLModel, create_engine + + +def test_attribute_keyed_dict_works(clear_sqlmodel): + class Color(str, Enum): + Orange = "Orange" + Blue = "Blue" + + class Child(SQLModel, table=True): + __tablename__ = "children" + __table_args__ = ( + Index("ix_children_parent_id_color", "parent_id", "color", unique=True), + ) + + id: Optional[int] = Field(primary_key=True, default=None) + parent_id: int = Field(foreign_key="parents.id") + color: Color + value: int + + class Parent(SQLModel, table=True): + __tablename__ = "parents" + + id: Optional[int] = Field(primary_key=True, default=None) + children_by_color: Dict[Color, Child] = Relationship( + sa_relationship_kwargs={"collection_class": attribute_keyed_dict("color")} + ) + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + parent = Parent() + session.add(parent) + session.commit() + session.refresh(parent) + session.add(Child(parent_id=parent.id, color=Color.Orange, value=1)) + session.add(Child(parent_id=parent.id, color=Color.Blue, value=2)) + session.commit() + session.refresh(parent) + assert parent.children_by_color[Color.Orange].parent_id == parent.id + assert parent.children_by_color[Color.Orange].color == Color.Orange + assert parent.children_by_color[Color.Orange].value == 1 + assert parent.children_by_color[Color.Blue].parent_id == parent.id + assert parent.children_by_color[Color.Blue].color == Color.Blue + assert parent.children_by_color[Color.Blue].value == 2