diff --git a/src/ferro/metaclass.py b/src/ferro/metaclass.py index 01aeb75..493ce64 100644 --- a/src/ferro/metaclass.py +++ b/src/ferro/metaclass.py @@ -9,7 +9,8 @@ get_origin, ) -from pydantic import BaseModel, Field as PydanticField +from pydantic import BaseModel +from pydantic import Field as PydanticField from pydantic.fields import FieldInfo from ._core import register_model_schema @@ -188,7 +189,14 @@ def __new__(mcs, name, bases, namespace, **kwargs): if isinstance(metadata.to, ForwardRef): target_name = metadata.to.__forward_arg__ - setattr(cls, field_name, ForwardDescriptor(field_name, target_name)) + setattr( + cls, + field_name, + ForwardDescriptor( + target_model_name=target_name, + field_name=field_name, + ), + ) else: setattr(cls, field_name, None) @@ -205,9 +213,9 @@ def __new__(mcs, name, bases, namespace, **kwargs): if "properties" in schema: for f_name, metadata in ferro_fields.items(): if f_name in schema["properties"]: - schema["properties"][f_name][ - "primary_key" - ] = metadata.primary_key + schema["properties"][f_name]["primary_key"] = ( + metadata.primary_key + ) prop = schema["properties"][f_name] is_int = prop.get("type") == "integer" or any( item.get("type") == "integer" diff --git a/src/ferro/relations/__init__.py b/src/ferro/relations/__init__.py index a0a968a..86baf4a 100644 --- a/src/ferro/relations/__init__.py +++ b/src/ferro/relations/__init__.py @@ -46,7 +46,9 @@ def resolve_relationships(): target_model, rel.related_name, RelationshipDescriptor( - model_name, field_name, is_one_to_one=getattr(rel, "unique", False) + target_model_name=model_name, + field_name=field_name, + is_one_to_one=getattr(rel, "unique", False), ), ) elif isinstance(rel, ManyToManyField): @@ -68,8 +70,8 @@ def resolve_relationships(): _MODEL_REGISTRY_PY[model_name], field_name, RelationshipDescriptor( - target_model.__name__, - field_name, + target_model_name=target_model.__name__, + field_name=field_name, is_m2m=True, join_table=join_table, source_col=source_col, @@ -81,8 +83,8 @@ def resolve_relationships(): target_model, rel.related_name, RelationshipDescriptor( - model_name, - rel.related_name, + target_model_name=model_name, + field_name=rel.related_name, is_m2m=True, join_table=join_table, source_col=target_col, # Reversed for the back side @@ -119,12 +121,12 @@ def resolve_relationships(): if "properties" in schema: for f_name, metadata in model_cls.ferro_fields.items(): if f_name in schema["properties"]: - schema["properties"][f_name][ - "primary_key" - ] = metadata.primary_key - schema["properties"][f_name][ - "autoincrement" - ] = metadata.autoincrement + schema["properties"][f_name]["primary_key"] = ( + metadata.primary_key + ) + schema["properties"][f_name]["autoincrement"] = ( + metadata.autoincrement + ) schema["properties"][f_name]["unique"] = metadata.unique schema["properties"][f_name]["index"] = metadata.index diff --git a/src/ferro/relations/descriptors.py b/src/ferro/relations/descriptors.py index 1e16df8..1708493 100644 --- a/src/ferro/relations/descriptors.py +++ b/src/ferro/relations/descriptors.py @@ -1,27 +1,24 @@ +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +if TYPE_CHECKING: + from ferro.models import Model + from ..state import _MODEL_REGISTRY_PY -class RelationshipDescriptor: +class RelationshipDescriptor(BaseModel): """Descriptor that returns either a Query object or a single object (for 1:1).""" - def __init__( - self, - target_model_name: str, - field_name: str, - is_one_to_one: bool = False, - is_m2m: bool = False, - join_table: str | None = None, - source_col: str | None = None, - target_col: str | None = None, - ): - self.target_model_name = target_model_name - self.field_name = field_name - self.is_one_to_one = is_one_to_one - self.is_m2m = is_m2m - self.join_table = join_table - self.source_col = source_col - self.target_col = target_col - self._target_model = None + target_model_name: str + field_name: str + is_one_to_one: bool = False + is_m2m: bool = False + join_table: str | None = None + source_col: str | None = None + target_col: str | None = None + _target_model: Model | None = None def __get__(self, instance, owner): if instance is None: @@ -70,13 +67,12 @@ def __get__(self, instance, owner): return query -class ForwardDescriptor: +class ForwardDescriptor(BaseModel): """Descriptor that handles lazy loading of a related object.""" - def __init__(self, field_name: str, target_model_name: str): - self.field_name = field_name - self.target_model_name = target_model_name - self._target_model = None + target_model_name: str + field_name: str + _target_model: Model | None = None def __get__(self, instance, owner): if instance is None: diff --git a/tests/test_auto_migrate.py b/tests/test_auto_migrate.py index 8d12b69..6f41676 100644 --- a/tests/test_auto_migrate.py +++ b/tests/test_auto_migrate.py @@ -1,7 +1,12 @@ +from typing import Annotated + import pytest +from pydantic import Field + import ferro from ferro import Model -from pydantic import Field +from ferro.base import FerroField, ManyToManyField +from ferro.query import BackRef class AutoMigratedUser(Model): @@ -35,3 +40,39 @@ async def test_connect_without_auto_migrate(): # Manual call still works await ferro.create_tables() assert True + + +@pytest.mark.asyncio +async def test_m2m_join_table_created_during_auto_migrate(): + """Verify that the many-to-many join table is created when auto_migrate=True. + We clear registries, migrate a fresh in-memory DB, then use the M2M API; if the + join table were not created, .add() would fail. No second connection needed.""" + from ferro import clear_registry, connect, reset_engine + from ferro.state import _JOIN_TABLE_REGISTRY, _MODEL_REGISTRY_PY, _PENDING_RELATIONS + + reset_engine() + clear_registry() + _MODEL_REGISTRY_PY.clear() + _PENDING_RELATIONS.clear() + _JOIN_TABLE_REGISTRY.clear() + + class Actor(Model): + id: Annotated[int | None, FerroField(primary_key=True)] = None + name: str + movies: Annotated[list["Movie"], ManyToManyField(related_name="actors")] = None + + class Movie(Model): + id: Annotated[int | None, FerroField(primary_key=True)] = None + title: str + actors: BackRef[Actor] = None + + await connect("sqlite::memory:", auto_migrate=True) + + actor = await Actor.create(name="Alice") + movie = await Movie.create(title="Matrix") + await actor.movies.add(movie) + + linked = await actor.movies.all() + assert len(linked) == 1 + assert linked[0].id == movie.id + assert linked[0].title == "Matrix" diff --git a/tests/test_relationship_engine.py b/tests/test_relationship_engine.py index 5381f86..a7b04ee 100644 --- a/tests/test_relationship_engine.py +++ b/tests/test_relationship_engine.py @@ -1,13 +1,15 @@ -import pytest from typing import Annotated, ForwardRef + +import pytest + from ferro import ( - Model, + BackRef, FerroField, Field, - reset_engine, - clear_registry, ForeignKey, - BackRef, + Model, + clear_registry, + reset_engine, ) @@ -150,9 +152,7 @@ def test_back_ref_via_annotated_field(): class UserAnnotated(Model): id: Annotated[int | None, FerroField(primary_key=True)] = None username: str - posts: Annotated[ - list["PostAnnotated"] | None, Field(back_ref=True) - ] = None + posts: Annotated[list["PostAnnotated"] | None, Field(back_ref=True)] = None class PostAnnotated(Model): id: Annotated[int | None, FerroField(primary_key=True)] = None @@ -179,9 +179,7 @@ def test_back_ref_and_field_back_ref_raises(): class UserDouble(Model): id: Annotated[int | None, FerroField(primary_key=True)] = None username: str - posts: BackRef[list["PostDouble"]] = Field( - default=None, back_ref=True - ) + posts: BackRef[list["PostDouble"]] = Field(default=None, back_ref=True) class PostDouble(Model): id: Annotated[int | None, FerroField(primary_key=True)] = None diff --git a/tests/test_schema_constraints.py b/tests/test_schema_constraints.py index 7a5182d..01e54f2 100644 --- a/tests/test_schema_constraints.py +++ b/tests/test_schema_constraints.py @@ -1,14 +1,16 @@ -import pytest import sqlite3 from typing import Annotated + +import pytest + from ferro import ( - Model, - connect, + BackRef, FerroField, ForeignKey, - BackRef, - reset_engine, + Model, clear_registry, + connect, + reset_engine, )