Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 69 additions & 10 deletions fquery/sqlmodel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import fields, is_dataclass
from dataclasses import _FIELD, dataclass, field, fields, is_dataclass
from datetime import date, datetime, time
from typing import (
ClassVar,
Expand Down Expand Up @@ -42,6 +42,33 @@
SQL_PK = {"metadata": {"SQL": {"primary_key": True}}}


def unique():
pass


def foreignkey(name):
return field(
default=None, metadata={"SQL": {"relationship": True, "back_populates": False}}
)


def one_to_many():
return field(default=None, metadata={"SQL": {"relationship": True}})


def many_to_one(back_populates=None):
ret = field(
default=None, metadata={"SQL": {"relationship": True, "many_to_one": True}}
)
if back_populates is not None:
ret.metadata["SQL"][back_populates] = back_populates
return ret


def sqlmodel(cls):
return model()(dataclass(kw_only=True)(cls))


def model(table: bool = True, table_name: str = None, global_id: bool = False):
"""
A decorator that generates a SQLModel from a dataclass.
Expand All @@ -58,6 +85,8 @@ def sqlmodel(self) -> SQLModel:
return self.__sqlmodel__(**attrs)

def get_field_def(cls, field) -> Union[Field, Relationship]:
if field.default == unique:
return Field(unique=True)
sql_meta = field.metadata.get("SQL", {})
has_foreign_key = bool(sql_meta.get("foreign_key", None))
has_relationship = bool(sql_meta.get("relationship", None))
Expand All @@ -68,7 +97,7 @@ def get_field_def(cls, field) -> Union[Field, Relationship]:
# TODO: revisit the idea of using string for unknown types
sa_column=Column(
SA_TYPEMAP.get(field.type, String),
GLOBAL_ID_SEQ if global_id else None,
GLOBAL_ID_SEQ if global_id else cls.id_seq,
primary_key=(
field.name == "id"
or field.metadata.get("SQL", {}).get("primary_key", False)
Expand Down Expand Up @@ -135,15 +164,44 @@ def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls):
):
sqlmodel_cls.__annotations__[field.name] = Optional[sqlmodel_cls]

def default_table_name(clsname: str) -> str:
return inflection.underscore(inflection.pluralize(clsname))

def decorator(cls):
# Check if the class is a dataclass
if not is_dataclass(cls):
raise ValueError("The class must be a dataclass")

nonlocal table_name
table_name = table_name or inflection.underscore(
inflection.pluralize(cls.__name__)
)
table_name = table_name or default_table_name(cls.__name__)

if not global_id:
cls.id_seq = Sequence(f"{table_name}_seq")

# Insert any foreign keys as necessary
for cfield in fields(cls):
sql_meta = cfield.metadata.get("SQL", {})
has_relationship = bool(sql_meta.get("relationship", None))
if has_relationship:
many_to_one = sql_meta.get("many_to_one", False)
foreign_key_name = cfield.name + "_id"
key_table_name = table_name
if many_to_one:
type_class = cfield.type
other_class = type_class.__args__[0]
other_class = getattr(other_class, "__name__", None)
key_table_name = default_table_name(other_class)
back_populates = sql_meta.get("back_populates", None)
if back_populates is False or many_to_one:
new_field = field(
default=None,
metadata={"SQL": {"foreign_key": f"{key_table_name}.id"}},
)
new_field._field_type = _FIELD
new_field.name = foreign_key_name
new_field.type = Optional[int]
cls.__dataclass_fields__[foreign_key_name] = new_field
setattr(cls, new_field.name, new_field.default)

# Generate the SQLModel class
sqlmodel_cls = type(
Expand All @@ -167,16 +225,17 @@ def decorator(cls):
# For SQLModel's SQLModelMetaClass
table=table,
)

cls.__sqlmodel__ = sqlmodel_cls
# Update type annotations in any class with a relationship with this class to point
# to the SQLModel, not the dataclass
for field in fields(cls):
if not field.name in sqlmodel_cls.__sqlmodel_relationships__:
for cfield in fields(cls):
if not cfield.name in sqlmodel_cls.__sqlmodel_relationships__:
continue
rel = sqlmodel_cls.__sqlmodel_relationships__.get(field.name, None)
rel = sqlmodel_cls.__sqlmodel_relationships__.get(cfield.name, None)
if rel and hasattr(rel, "back_populates"):
patch_back_populates_types(field, rel.back_populates, cls, sqlmodel_cls)
patch_back_populates_types(
cfield, rel.back_populates, cls, sqlmodel_cls
)
cls.sqlmodel = sqlmodel
return cls

Expand Down
53 changes: 26 additions & 27 deletions tests/test_sqlmodel.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,41 @@
from dataclasses import dataclass, field
from dataclasses import field
from datetime import datetime
from typing import List, Optional

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel

from fquery.sqlmodel import SQL_PK, model
from fquery.sqlmodel import (
SQL_PK,
foreignkey,
many_to_one,
one_to_many,
sqlmodel,
unique,
)


@model(global_id=True)
@dataclass(kw_only=True)
@sqlmodel
class User:
id: int | None = None
name: str
email: str
email: str = unique()
created_at: datetime = None
updated_at: datetime = None
friend_id: Optional[int] = field(
default=None, metadata={"SQL": {"foreign_key": "users.id"}}
)
friend: Optional["User"] = field(
default=None, metadata={"SQL": {"relationship": True, "back_populates": False}}
)
reviews: List["Review"] = field(
default=None, metadata={"SQL": {"relationship": True}}
)


@model(global_id=True)
@dataclass(kw_only=True)

friend: Optional["User"] = foreignkey("users.id")
reviews: List["Review"] = one_to_many()


@sqlmodel
class Review:
id: int | None = None
score: int
user_id: Optional[int] = field(
default=None, metadata={"SQL": {"foreign_key": "users.id"}}
)
user: Optional[User] = field(
default=None, metadata={"SQL": {"relationship": True, "many_to_one": True}}
)
user: Optional[User] = many_to_one("users.id")


@model(global_id=True)
@dataclass
@sqlmodel
class Relation:
src: int | None = field(**SQL_PK)
type: int = field(**SQL_PK)
Expand Down Expand Up @@ -93,7 +86,13 @@ def test_sqlmodel():
session.add(user1.sqlmodel())
session.commit()

relation = Relation(user.id, 1, user1.id, datetime.now(), datetime.now())
relation = Relation(
src=user.id,
type=1,
dst=user1.id,
created_at=datetime.now(),
updated_at=datetime.now(),
)
session.add(relation.sqlmodel())
session.commit()
# Read all users from the database
Expand Down
Loading