From 479c411e93d2eb9a8e13d9686fe1334bc262750f Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Fri, 17 Jan 2025 14:51:54 -0800 Subject: [PATCH] sqlmodel: improve syntax --- fquery/sqlmodel.py | 79 ++++++++++++++++++++++++++++++++++++------ tests/test_sqlmodel.py | 53 ++++++++++++++-------------- 2 files changed, 95 insertions(+), 37 deletions(-) diff --git a/fquery/sqlmodel.py b/fquery/sqlmodel.py index b212725..dacf4e6 100644 --- a/fquery/sqlmodel.py +++ b/fquery/sqlmodel.py @@ -1,4 +1,4 @@ -from dataclasses import fields, is_dataclass +from dataclasses import _FIELD, dataclass, field, fields, is_dataclass from datetime import date, datetime, time from typing import ( ClassVar, @@ -42,6 +42,33 @@ SQL_PK = {"metadata": {"SQL": {"primary_key": True}}} +def unique(): + pass + + +def foreignkey(name): + return field( + default=None, metadata={"SQL": {"relationship": True, "back_populates": False}} + ) + + +def one_to_many(): + return field(default=None, metadata={"SQL": {"relationship": True}}) + + +def many_to_one(back_populates=None): + ret = field( + default=None, metadata={"SQL": {"relationship": True, "many_to_one": True}} + ) + if back_populates is not None: + ret.metadata["SQL"][back_populates] = back_populates + return ret + + +def sqlmodel(cls): + return model()(dataclass(kw_only=True)(cls)) + + def model(table: bool = True, table_name: str = None, global_id: bool = False): """ A decorator that generates a SQLModel from a dataclass. @@ -58,6 +85,8 @@ def sqlmodel(self) -> SQLModel: return self.__sqlmodel__(**attrs) def get_field_def(cls, field) -> Union[Field, Relationship]: + if field.default == unique: + return Field(unique=True) sql_meta = field.metadata.get("SQL", {}) has_foreign_key = bool(sql_meta.get("foreign_key", None)) has_relationship = bool(sql_meta.get("relationship", None)) @@ -68,7 +97,7 @@ def get_field_def(cls, field) -> Union[Field, Relationship]: # TODO: revisit the idea of using string for unknown types sa_column=Column( SA_TYPEMAP.get(field.type, String), - GLOBAL_ID_SEQ if global_id else None, + GLOBAL_ID_SEQ if global_id else cls.id_seq, primary_key=( field.name == "id" or field.metadata.get("SQL", {}).get("primary_key", False) @@ -135,15 +164,44 @@ def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls): ): sqlmodel_cls.__annotations__[field.name] = Optional[sqlmodel_cls] + def default_table_name(clsname: str) -> str: + return inflection.underscore(inflection.pluralize(clsname)) + def decorator(cls): # Check if the class is a dataclass if not is_dataclass(cls): raise ValueError("The class must be a dataclass") nonlocal table_name - table_name = table_name or inflection.underscore( - inflection.pluralize(cls.__name__) - ) + table_name = table_name or default_table_name(cls.__name__) + + if not global_id: + cls.id_seq = Sequence(f"{table_name}_seq") + + # Insert any foreign keys as necessary + for cfield in fields(cls): + sql_meta = cfield.metadata.get("SQL", {}) + has_relationship = bool(sql_meta.get("relationship", None)) + if has_relationship: + many_to_one = sql_meta.get("many_to_one", False) + foreign_key_name = cfield.name + "_id" + key_table_name = table_name + if many_to_one: + type_class = cfield.type + other_class = type_class.__args__[0] + other_class = getattr(other_class, "__name__", None) + key_table_name = default_table_name(other_class) + back_populates = sql_meta.get("back_populates", None) + if back_populates is False or many_to_one: + new_field = field( + default=None, + metadata={"SQL": {"foreign_key": f"{key_table_name}.id"}}, + ) + new_field._field_type = _FIELD + new_field.name = foreign_key_name + new_field.type = Optional[int] + cls.__dataclass_fields__[foreign_key_name] = new_field + setattr(cls, new_field.name, new_field.default) # Generate the SQLModel class sqlmodel_cls = type( @@ -167,16 +225,17 @@ def decorator(cls): # For SQLModel's SQLModelMetaClass table=table, ) - cls.__sqlmodel__ = sqlmodel_cls # Update type annotations in any class with a relationship with this class to point # to the SQLModel, not the dataclass - for field in fields(cls): - if not field.name in sqlmodel_cls.__sqlmodel_relationships__: + for cfield in fields(cls): + if not cfield.name in sqlmodel_cls.__sqlmodel_relationships__: continue - rel = sqlmodel_cls.__sqlmodel_relationships__.get(field.name, None) + rel = sqlmodel_cls.__sqlmodel_relationships__.get(cfield.name, None) if rel and hasattr(rel, "back_populates"): - patch_back_populates_types(field, rel.back_populates, cls, sqlmodel_cls) + patch_back_populates_types( + cfield, rel.back_populates, cls, sqlmodel_cls + ) cls.sqlmodel = sqlmodel return cls diff --git a/tests/test_sqlmodel.py b/tests/test_sqlmodel.py index c4c996b..34039eb 100644 --- a/tests/test_sqlmodel.py +++ b/tests/test_sqlmodel.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import field from datetime import datetime from typing import List, Optional @@ -6,43 +6,36 @@ from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel -from fquery.sqlmodel import SQL_PK, model +from fquery.sqlmodel import ( + SQL_PK, + foreignkey, + many_to_one, + one_to_many, + sqlmodel, + unique, +) -@model(global_id=True) -@dataclass(kw_only=True) +@sqlmodel class User: id: int | None = None name: str - email: str + email: str = unique() created_at: datetime = None updated_at: datetime = None - friend_id: Optional[int] = field( - default=None, metadata={"SQL": {"foreign_key": "users.id"}} - ) - friend: Optional["User"] = field( - default=None, metadata={"SQL": {"relationship": True, "back_populates": False}} - ) - reviews: List["Review"] = field( - default=None, metadata={"SQL": {"relationship": True}} - ) - - -@model(global_id=True) -@dataclass(kw_only=True) + + friend: Optional["User"] = foreignkey("users.id") + reviews: List["Review"] = one_to_many() + + +@sqlmodel class Review: id: int | None = None score: int - user_id: Optional[int] = field( - default=None, metadata={"SQL": {"foreign_key": "users.id"}} - ) - user: Optional[User] = field( - default=None, metadata={"SQL": {"relationship": True, "many_to_one": True}} - ) + user: Optional[User] = many_to_one("users.id") -@model(global_id=True) -@dataclass +@sqlmodel class Relation: src: int | None = field(**SQL_PK) type: int = field(**SQL_PK) @@ -93,7 +86,13 @@ def test_sqlmodel(): session.add(user1.sqlmodel()) session.commit() - relation = Relation(user.id, 1, user1.id, datetime.now(), datetime.now()) + relation = Relation( + src=user.id, + type=1, + dst=user1.id, + created_at=datetime.now(), + updated_at=datetime.now(), + ) session.add(relation.sqlmodel()) session.commit() # Read all users from the database