Skip to content

Commit 6275ebb

Browse files
committed
refactor: enhance model relationship descriptors and improve field handling
- Updated RelationshipDescriptor and ForwardDescriptor to inherit from BaseModel for better integration with Pydantic. - Refactored the initialization of descriptors to use class attributes instead of constructor parameters for cleaner code. - Improved the handling of primary key and autoincrement properties in the schema generation process. - Adjusted the metaclass to ensure proper assignment of target model names and field names in relationship descriptors.
1 parent b3c2cde commit 6275ebb

6 files changed

Lines changed: 104 additions & 57 deletions

File tree

src/ferro/metaclass.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
get_origin,
1010
)
1111

12-
from pydantic import BaseModel, Field as PydanticField
12+
from pydantic import BaseModel
13+
from pydantic import Field as PydanticField
1314
from pydantic.fields import FieldInfo
1415

1516
from ._core import register_model_schema
@@ -188,7 +189,14 @@ def __new__(mcs, name, bases, namespace, **kwargs):
188189
if isinstance(metadata.to, ForwardRef):
189190
target_name = metadata.to.__forward_arg__
190191

191-
setattr(cls, field_name, ForwardDescriptor(field_name, target_name))
192+
setattr(
193+
cls,
194+
field_name,
195+
ForwardDescriptor(
196+
target_model_name=target_name,
197+
field_name=field_name,
198+
),
199+
)
192200
else:
193201
setattr(cls, field_name, None)
194202

@@ -205,9 +213,9 @@ def __new__(mcs, name, bases, namespace, **kwargs):
205213
if "properties" in schema:
206214
for f_name, metadata in ferro_fields.items():
207215
if f_name in schema["properties"]:
208-
schema["properties"][f_name][
209-
"primary_key"
210-
] = metadata.primary_key
216+
schema["properties"][f_name]["primary_key"] = (
217+
metadata.primary_key
218+
)
211219
prop = schema["properties"][f_name]
212220
is_int = prop.get("type") == "integer" or any(
213221
item.get("type") == "integer"

src/ferro/relations/__init__.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def resolve_relationships():
4646
target_model,
4747
rel.related_name,
4848
RelationshipDescriptor(
49-
model_name, field_name, is_one_to_one=getattr(rel, "unique", False)
49+
target_model_name=model_name,
50+
field_name=field_name,
51+
is_one_to_one=getattr(rel, "unique", False),
5052
),
5153
)
5254
elif isinstance(rel, ManyToManyField):
@@ -68,8 +70,8 @@ def resolve_relationships():
6870
_MODEL_REGISTRY_PY[model_name],
6971
field_name,
7072
RelationshipDescriptor(
71-
target_model.__name__,
72-
field_name,
73+
target_model_name=target_model.__name__,
74+
field_name=field_name,
7375
is_m2m=True,
7476
join_table=join_table,
7577
source_col=source_col,
@@ -81,8 +83,8 @@ def resolve_relationships():
8183
target_model,
8284
rel.related_name,
8385
RelationshipDescriptor(
84-
model_name,
85-
rel.related_name,
86+
target_model_name=model_name,
87+
field_name=rel.related_name,
8688
is_m2m=True,
8789
join_table=join_table,
8890
source_col=target_col, # Reversed for the back side
@@ -119,12 +121,12 @@ def resolve_relationships():
119121
if "properties" in schema:
120122
for f_name, metadata in model_cls.ferro_fields.items():
121123
if f_name in schema["properties"]:
122-
schema["properties"][f_name][
123-
"primary_key"
124-
] = metadata.primary_key
125-
schema["properties"][f_name][
126-
"autoincrement"
127-
] = metadata.autoincrement
124+
schema["properties"][f_name]["primary_key"] = (
125+
metadata.primary_key
126+
)
127+
schema["properties"][f_name]["autoincrement"] = (
128+
metadata.autoincrement
129+
)
128130
schema["properties"][f_name]["unique"] = metadata.unique
129131
schema["properties"][f_name]["index"] = metadata.index
130132

src/ferro/relations/descriptors.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,24 @@
1+
from typing import TYPE_CHECKING
2+
3+
from pydantic import BaseModel
4+
5+
if TYPE_CHECKING:
6+
from ferro.models import Model
7+
18
from ..state import _MODEL_REGISTRY_PY
29

310

4-
class RelationshipDescriptor:
11+
class RelationshipDescriptor(BaseModel):
512
"""Descriptor that returns either a Query object or a single object (for 1:1)."""
613

7-
def __init__(
8-
self,
9-
target_model_name: str,
10-
field_name: str,
11-
is_one_to_one: bool = False,
12-
is_m2m: bool = False,
13-
join_table: str | None = None,
14-
source_col: str | None = None,
15-
target_col: str | None = None,
16-
):
17-
self.target_model_name = target_model_name
18-
self.field_name = field_name
19-
self.is_one_to_one = is_one_to_one
20-
self.is_m2m = is_m2m
21-
self.join_table = join_table
22-
self.source_col = source_col
23-
self.target_col = target_col
24-
self._target_model = None
14+
target_model_name: str
15+
field_name: str
16+
is_one_to_one: bool = False
17+
is_m2m: bool = False
18+
join_table: str | None = None
19+
source_col: str | None = None
20+
target_col: str | None = None
21+
_target_model: Model | None = None
2522

2623
def __get__(self, instance, owner):
2724
if instance is None:
@@ -70,13 +67,12 @@ def __get__(self, instance, owner):
7067
return query
7168

7269

73-
class ForwardDescriptor:
70+
class ForwardDescriptor(BaseModel):
7471
"""Descriptor that handles lazy loading of a related object."""
7572

76-
def __init__(self, field_name: str, target_model_name: str):
77-
self.field_name = field_name
78-
self.target_model_name = target_model_name
79-
self._target_model = None
73+
target_model_name: str
74+
field_name: str
75+
_target_model: Model | None = None
8076

8177
def __get__(self, instance, owner):
8278
if instance is None:

tests/test_auto_migrate.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
from typing import Annotated
2+
13
import pytest
4+
from pydantic import Field
5+
26
import ferro
37
from ferro import Model
4-
from pydantic import Field
8+
from ferro.base import FerroField, ManyToManyField
9+
from ferro.query import BackRef
510

611

712
class AutoMigratedUser(Model):
@@ -35,3 +40,39 @@ async def test_connect_without_auto_migrate():
3540
# Manual call still works
3641
await ferro.create_tables()
3742
assert True
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_m2m_join_table_created_during_auto_migrate():
47+
"""Verify that the many-to-many join table is created when auto_migrate=True.
48+
We clear registries, migrate a fresh in-memory DB, then use the M2M API; if the
49+
join table were not created, .add() would fail. No second connection needed."""
50+
from ferro import clear_registry, connect, reset_engine
51+
from ferro.state import _JOIN_TABLE_REGISTRY, _MODEL_REGISTRY_PY, _PENDING_RELATIONS
52+
53+
reset_engine()
54+
clear_registry()
55+
_MODEL_REGISTRY_PY.clear()
56+
_PENDING_RELATIONS.clear()
57+
_JOIN_TABLE_REGISTRY.clear()
58+
59+
class Actor(Model):
60+
id: Annotated[int | None, FerroField(primary_key=True)] = None
61+
name: str
62+
movies: Annotated[list["Movie"], ManyToManyField(related_name="actors")] = None
63+
64+
class Movie(Model):
65+
id: Annotated[int | None, FerroField(primary_key=True)] = None
66+
title: str
67+
actors: BackRef[Actor] = None
68+
69+
await connect("sqlite::memory:", auto_migrate=True)
70+
71+
actor = await Actor.create(name="Alice")
72+
movie = await Movie.create(title="Matrix")
73+
await actor.movies.add(movie)
74+
75+
linked = await actor.movies.all()
76+
assert len(linked) == 1
77+
assert linked[0].id == movie.id
78+
assert linked[0].title == "Matrix"

tests/test_relationship_engine.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
import pytest
21
from typing import Annotated, ForwardRef
2+
3+
import pytest
4+
35
from ferro import (
4-
Model,
6+
BackRef,
57
FerroField,
68
Field,
7-
reset_engine,
8-
clear_registry,
99
ForeignKey,
10-
BackRef,
10+
Model,
11+
clear_registry,
12+
reset_engine,
1113
)
1214

1315

@@ -150,9 +152,7 @@ def test_back_ref_via_annotated_field():
150152
class UserAnnotated(Model):
151153
id: Annotated[int | None, FerroField(primary_key=True)] = None
152154
username: str
153-
posts: Annotated[
154-
list["PostAnnotated"] | None, Field(back_ref=True)
155-
] = None
155+
posts: Annotated[list["PostAnnotated"] | None, Field(back_ref=True)] = None
156156

157157
class PostAnnotated(Model):
158158
id: Annotated[int | None, FerroField(primary_key=True)] = None
@@ -179,9 +179,7 @@ def test_back_ref_and_field_back_ref_raises():
179179
class UserDouble(Model):
180180
id: Annotated[int | None, FerroField(primary_key=True)] = None
181181
username: str
182-
posts: BackRef[list["PostDouble"]] = Field(
183-
default=None, back_ref=True
184-
)
182+
posts: BackRef[list["PostDouble"]] = Field(default=None, back_ref=True)
185183

186184
class PostDouble(Model):
187185
id: Annotated[int | None, FerroField(primary_key=True)] = None

tests/test_schema_constraints.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
import pytest
21
import sqlite3
32
from typing import Annotated
3+
4+
import pytest
5+
46
from ferro import (
5-
Model,
6-
connect,
7+
BackRef,
78
FerroField,
89
ForeignKey,
9-
BackRef,
10-
reset_engine,
10+
Model,
1111
clear_registry,
12+
connect,
13+
reset_engine,
1214
)
1315

1416

0 commit comments

Comments
 (0)