From 7faf2601805dde3493d4c2abd720bc1e7af8af5f Mon Sep 17 00:00:00 2001 From: Martin Ahindura Date: Sat, 1 Mar 2025 18:47:17 +0100 Subject: [PATCH 01/10] WIP: Fix many-to-many relationships updates --- examples/blog/conftest.py | 2 +- examples/blog/main.py | 4 +- examples/blog/models.py | 1 - examples/blog/schemas.py | 19 ++-- examples/blog/test_main.py | 44 ++++++---- examples/blog/utils.py | 48 ++++++++++ nqlstore/_compat.py | 4 + nqlstore/_sql.py | 173 +++++++++++++++++++++++++++++++------ 8 files changed, 234 insertions(+), 61 deletions(-) diff --git a/examples/blog/conftest.py b/examples/blog/conftest.py index 9f1774e..8aae68e 100644 --- a/examples/blog/conftest.py +++ b/examples/blog/conftest.py @@ -7,7 +7,7 @@ import pytest_asyncio import pytest_mock from fastapi.testclient import TestClient -from models import ( # SqlAuthor, +from models import ( MongoAuthor, MongoComment, MongoInternalAuthor, diff --git a/examples/blog/main.py b/examples/blog/main.py index 18e42ee..e72644a 100644 --- a/examples/blog/main.py +++ b/examples/blog/main.py @@ -13,7 +13,7 @@ from fastapi.security import OAuth2PasswordRequestForm from models import MongoPost, RedisPost, SqlInternalAuthor, SqlPost from pydantic import BaseModel -from schemas import InternalAuthor, Post, TokenResponse +from schemas import InternalAuthor, Post, TokenResponse, PartialPost from stores import MongoStoreDep, RedisStoreDep, SqlStoreDep, clear_stores _ACCESS_TOKEN_EXPIRE_MINUTES = 30 @@ -194,7 +194,7 @@ async def update_one( mongo: MongoStoreDep, current_user: CurrentUserDep, id_: int | str, - payload: Post, + payload: PartialPost, ): """Update a post""" results = [] diff --git a/examples/blog/models.py b/examples/blog/models.py index 367af41..dc601b2 100644 --- a/examples/blog/models.py +++ b/examples/blog/models.py @@ -47,7 +47,6 @@ # sqlite models SqlInternalAuthor = SQLModel("SqlInternalAuthor", InternalAuthor) -# SqlAuthor = SQLModel("SqlAuthor", Author, table=False) SqlComment = SQLModel( "SqlComment", Comment, relationships={"author": SqlInternalAuthor | None} ) diff --git a/examples/blog/schemas.py b/examples/blog/schemas.py index d66f12a..c6cf844 100644 --- a/examples/blog/schemas.py +++ b/examples/blog/schemas.py @@ -3,7 +3,7 @@ from datetime import datetime from pydantic import BaseModel -from utils import current_timestamp +from utils import current_timestamp, Partial from nqlstore import Field, Relationship @@ -33,15 +33,8 @@ class Post(BaseModel): disable_on_redis=True, ) author: Author | None = Relationship(default=None) - comments: list["Comment"] = Relationship( - default=[], - disable_on_redis=True, - ) - tags: list["Tag"] = Relationship( - default=[], - link_model="TagLink", - disable_on_redis=True, - ) + comments: list["Comment"] = Relationship(default=[]) + tags: list["Tag"] = Relationship(default=[], link_model="TagLink") created_at: str = Field(index=True, default_factory=current_timestamp) updated_at: str = Field(index=True, default_factory=current_timestamp) @@ -89,7 +82,7 @@ class TagLink(BaseModel): class Tag(BaseModel): """The tags to help searching for posts""" - title: str = Field(index=True, unique=True, full_text_search=True) + title: str = Field(index=True, unique=True) class TokenResponse(BaseModel): @@ -97,3 +90,7 @@ class TokenResponse(BaseModel): access_token: str token_type: str + + +# Partial models +PartialPost = Partial("PartialPost", Post) diff --git a/examples/blog/test_main.py b/examples/blog/test_main.py index 95e7cbb..87206d6 100644 --- a/examples/blog/test_main.py +++ b/examples/blog/test_main.py @@ -11,6 +11,7 @@ _TITLE_SEARCH_TERMS = ["ho", "oo", "work"] _TAG_SEARCH_TERMS = ["art", "om"] +_HEADERS = {"Authorization": f"Bearer {ACCESS_TOKEN}"} @pytest.mark.asyncio @@ -21,9 +22,7 @@ async def test_create_sql_post( """POST to /posts creates a post in sql and returns it""" timestamp = datetime.now().isoformat() with client_with_sql as client: - response = client.post( - "/posts", json=post, headers={"Authorization": f"Bearer {ACCESS_TOKEN}"} - ) + response = client.post("/posts", json=post, headers=_HEADERS) got = response.json() post_id = got["id"] @@ -61,12 +60,12 @@ async def test_create_redis_post( client_with_redis: TestClient, redis_store: RedisStore, post: dict, + freezer, ): """POST to /posts creates a post in redis and returns it""" + timestamp = datetime.now().isoformat() with client_with_redis as client: - response = client.post( - "/posts", json=post, headers={"Authorization": f"Bearer {ACCESS_TOKEN}"} - ) + response = client.post("/posts", json=post, headers=_HEADERS) got = response.json() post_id = got["id"] @@ -75,7 +74,8 @@ async def test_create_redis_post( expected = { "id": post_id, "title": post["title"], - "content": post.get("content"), + "content": post.get("content", ""), + "author": {**got["author"], **AUTHOR}, "pk": post_id, "tags": [ { @@ -86,6 +86,8 @@ async def test_create_redis_post( for raw, resp in zip(raw_tags, resp_tags) ], "comments": [], + "created_at": timestamp, + "updated_at": timestamp, } db_query = {"id": {"$eq": post_id}} @@ -102,10 +104,12 @@ async def test_create_mongo_post( client_with_mongo: TestClient, mongo_store: MongoStore, post: dict, + freezer, ): """POST to /posts creates a post in redis and returns it""" + timestamp = datetime.now().isoformat() with client_with_mongo as client: - response = client.post("/posts", json=post) + response = client.post("/posts", json=post, headers=_HEADERS) got = response.json() post_id = got["id"] @@ -113,14 +117,12 @@ async def test_create_mongo_post( expected = { "id": post_id, "title": post["title"], - "content": post.get("content"), - "tags": [ - { - **raw, - } - for raw in raw_tags - ], + "content": post.get("content", ""), + "author": {"name": AUTHOR["name"]}, + "tags": raw_tags, "comments": [], + "created_at": timestamp, + "updated_at": timestamp, } db_query = {"_id": {"$eq": ObjectId(post_id)}} @@ -138,22 +140,26 @@ async def test_update_sql_post( sql_store: SQLStore, sql_posts: list[SqlPost], index: int, + freezer, ): """PUT to /posts/{id} updates the sql post of given id and returns updated version""" + timestamp = datetime.now().isoformat() with client_with_sql as client: post = sql_posts[index] + post_dict = post.model_dump(mode="json", exclude_none=True, exclude_unset=True) id_ = post.id update = { + **post_dict, "name": "some other name", - "todos": [ - *post.tags, + "tags": [ + *post_dict["tags"], {"title": "another one"}, {"title": "another one again"}, ], - "comments": [*post.comments, *COMMENT_LIST[index:]], + "comments": [*post_dict["comments"], *COMMENT_LIST[index:]], } - response = client.put(f"/posts/{id_}", json=update) + response = client.put(f"/posts/{id_}", json=update, headers=_HEADERS) got = response.json() expected = { diff --git a/examples/blog/utils.py b/examples/blog/utils.py index 3f7d2a1..60f7c46 100644 --- a/examples/blog/utils.py +++ b/examples/blog/utils.py @@ -1,6 +1,15 @@ """Some random utilities for the app""" +import copy +import sys from datetime import datetime +from typing import TypeVar, Optional, get_args, Literal, Any + +from pydantic import BaseModel, create_model +from nqlstore._field import FieldInfo +from pydantic.main import IncEx + +_T = TypeVar("_T", bound=BaseModel) def current_timestamp() -> str: @@ -10,3 +19,42 @@ def current_timestamp() -> str: string of the current datetime """ return datetime.now().isoformat() + + +def Partial(name: str, model: type[_T]) -> type[_T]: + """Creates a partial schema from another schema, with all fields optional + + Args: + name: the name of the model + model: the original model + + Returns: + A new model with all its fields optional + """ + fields = { + k: (_make_optional(v.annotation), None) + for k, v in model.model_fields.items() # type: str, FieldInfo + } + + return create_model( + name, + # module of the calling function + __module__=sys._getframe(1).f_globals["__name__"], + __doc__=model.__doc__, + __base__=(model,), + **fields, + ) + + +def _make_optional(type_: type) -> type: + """Makes a type optional if not optional + + Args: + type_: the type to make optional + + Returns: + the optional type + """ + if type(None) in get_args(type_): + return type_ + return type_ | None diff --git a/nqlstore/_compat.py b/nqlstore/_compat.py index bfc7d83..7a5abe4 100644 --- a/nqlstore/_compat.py +++ b/nqlstore/_compat.py @@ -56,6 +56,7 @@ _ColumnExpressionArgument, _ColumnExpressionOrStrLabelArgument, ) + from sqlalchemy import func from sqlmodel import SQLModel as _SQLModel from sqlmodel import delete, insert, select, update from sqlmodel._compat import post_init_field_info @@ -65,6 +66,7 @@ from sqlmodel.main import IncEx, NoArgAnyCallable, OnDeleteType from sqlmodel.main import RelationshipInfo as _RelationshipInfo except ImportError: + import types from typing import Mapping, Optional, Sequence from typing import Set from typing import Set as _ColumnExpressionArgument @@ -90,6 +92,8 @@ subqueryload = lambda *a, **kwargs: dict(**kwargs) DetachedInstanceError = RuntimeError IncEx = Set[Any] | dict + func = types.ModuleType("func") + func.max = lambda *a, **kwargs: dict(**kwargs) class _SqlFieldInfo(_FieldInfo): ... diff --git a/nqlstore/_sql.py b/nqlstore/_sql.py index 942bd52..a9e71e1 100644 --- a/nqlstore/_sql.py +++ b/nqlstore/_sql.py @@ -1,5 +1,6 @@ """SQL implementation""" +import copy import sys from collections.abc import Mapping, MutableMapping from typing import Any, Dict, Iterable, Literal, TypeVar, Union @@ -28,6 +29,7 @@ sqlite_insert, subqueryload, update, + func, ) from ._field import Field, get_field_definitions from .query.parsers import QueryParser @@ -173,6 +175,7 @@ async def update( updates: dict | None = None, **kwargs, ) -> list[_T]: + updates = copy.deepcopy(updates) async with AsyncSession(self._engine) as session: if query: filters = (*filters, *self._parser.to_sql(model, query=query)) @@ -198,27 +201,40 @@ async def update( except KeyError: pass - stmt = ( - update(model) - .where(*non_rel_filters, *rel_filters) - .values(**updates) - .returning(model.__table__) - ) + if len(updates) > 0: + stmt = ( + update(model) + .where(*non_rel_filters, *rel_filters) + .values(**updates) + .returning(model) + ) + + cursor = await session.stream_scalars(stmt) + results = await cursor.fetchall() + result_ids = [v.id for v in results] + else: + results = await self.find(model, *filters) + result_ids = [v.id for v in results] - cursor = await session.stream(stmt) - raw_results = await cursor.fetchall() - results = [model.model_validate(row._mapping) for row in raw_results] - result_ids = [v.id for v in results] - # insert_func = await _get_insert(session) + insert_func = await _get_insert(session) for k, v in embedded_updates.items(): field = relations_mapper[k] field_props = field.property field_model = field_props.mapper.class_ + link_model = model.__sqlmodel_relationships__[k].link_model + # fk = foreign key - fk_field_name = field_props.primaryjoin.right.name - fk_field = getattr(field_model, fk_field_name) - parent_id_field = field_props.primaryjoin.left.name + if link_model is not None: + child_id_field_name = field_props.secondaryjoin.left.name + parent_id_field_name = field_props.primaryjoin.left.name + child_fk_field_name = field_props.secondaryjoin.right.name + parent_fk_field_name = field_props.primaryjoin.right.name + + else: + parent_id_field_name = field_props.primaryjoin.left.name + child_fk_field_name = field_props.primaryjoin.right.name + fk_field = getattr(field_model, child_fk_field_name) # get the foreign keys to use in resetting all affected # relationships; @@ -226,29 +242,127 @@ async def update( # the old relations. # Note: this operation is strictly replace, not patch embedded_values = [] + through_table_values = [] fk_values = [] for parent in results: parent_partial, embedded_value = _parse_embedded(v, field, parent) - if isinstance(embedded_value, Iterable): - embedded_values += embedded_value - fk_values.append(getattr(parent, parent_id_field)) - elif isinstance(embedded_value, _SQLModel): + initial_embedded_values_len = len(embedded_values) + if isinstance(embedded_value, _SQLModel): embedded_values.append(embedded_value) - fk_values.append(getattr(parent, parent_id_field)) + elif isinstance(embedded_value, Iterable): + embedded_values += embedded_value + + if link_model is not None: + index_range = ( + initial_embedded_values_len, + len(embedded_values), + ) + through_table_values.append( + { + parent_fk_field_name: getattr( + parent, parent_id_field_name + ), + "index_range": index_range, + } + ) + else: + fk_values.append(getattr(parent, parent_id_field_name)) + + # for many-to-one relationships, the parent + # also needs to be updated. + # add the partial update of the parent + for key, val in parent_partial.items(): + setattr(parent, key, val) - # insert the related items if len(embedded_values) > 0: # Reset the relationship; delete all other related items # Currently, this operation replaces all past relations - reset_stmt = delete(field_model).where(fk_field.in_(fk_values)) - await session.stream(reset_stmt) + if fk_field is not None and len(fk_values) > 0: + reset_stmt = delete(field_model).where(fk_field.in_(fk_values)) + await session.stream(reset_stmt) - # insert the latest changes - embed_stmt = insert(field_model).returning(field_model) - await session.stream_scalars(embed_stmt, embedded_values) + # insert the embedded items + try: + # PostgreSQL and SQLite support on_conflict_do_nothing + embed_stmt = ( + insert_func(field_model) + .on_conflict_do_nothing() + .returning(field_model) + ) + except AttributeError: + # MySQL supports prefix("IGNORE") + # Other databases might fail at this point + embed_stmt = ( + insert_func(field_model) + .prefix_with("IGNORE", dialect="mysql") + .returning(field_model) + ) + + embedded_cursor = await session.stream_scalars( + embed_stmt, embedded_values + ) + embedded_results = await embedded_cursor.all() + + if len(through_table_values) > 0: + parent_fk_values = [ + v[parent_fk_field_name] for v in through_table_values + ] + if len(parent_fk_values) > 0: + # Reset the relationship; delete all other related items + # Currently, this operation replaces all past relations + parent_fk_field = getattr(link_model, parent_id_field_name) + reset_stmt = delete(link_model).where( + parent_fk_field.in_(parent_fk_values) + ) + await session.stream(reset_stmt) + + # insert the through table records + try: + # PostgreSQL and SQLite support on_conflict_do_nothing + through_table_stmt = ( + insert_func(link_model) + .on_conflict_do_nothing() + .returning(link_model) + ) + except AttributeError: + # MySQL supports prefix("IGNORE") + # Other databases might fail at this point + through_table_stmt = ( + insert_func(link_model) + .prefix_with("IGNORE", dialect="mysql") + .returning(link_model) + ) + + # compute the next id auto-incremented + next_id = await session.scalar(func.max(link_model.id)) + next_id = (next_id or 0) + 1 + + through_table_values = [ + { + parent_fk_field_name: v[parent_fk_field_name], + child_fk_field_name: getattr( + child, child_id_field_name + ), + } + for v in through_table_values + for child in embedded_results[ + v["index_range"][0] : v["index_range"][1] + ] + ] + through_table_values = [ + link_model(id=next_id + idx, **v) + for idx, v in enumerate(through_table_values) + ] + await session.stream_scalars( + through_table_stmt, through_table_values + ) + + # update the updated parents + session.add_all(results) await session.commit() - return await self.find(model, model.id.in_(result_ids)) + refreshed_results = await self.find(model, model.id.in_(result_ids)) + return list(refreshed_results) async def delete( self, @@ -586,7 +700,7 @@ def _parse_embedded( # # add a foreign key value to link back to parent return {fk_field: fk_value}, field_model.model_validate(value) - if direction in (RelationshipDirection.ONETOMANY, RelationshipDirection.MANYTOMANY): + if direction == RelationshipDirection.ONETOMANY: # add a foreign key values to link back to parent if issubclass(wrapper_type, (list, tuple, set)): return {}, wrapper_type( @@ -596,6 +710,11 @@ def _parse_embedded( ] ) + if direction == RelationshipDirection.MANYTOMANY: + # add a foreign key values to link back to parent + if issubclass(wrapper_type, (list, tuple, set)): + return {}, wrapper_type([field_model.model_validate(v) for v in value]) + raise NotImplementedError( f"relationship {direction} of type annotation {wrapper_type} not supported yet" ) From dd1d06cccaa740c00b4edb06e8ba7a7a7ba145c7 Mon Sep 17 00:00:00 2001 From: Martin Ahindura Date: Mon, 3 Mar 2025 23:21:14 +0100 Subject: [PATCH 02/10] WIP: Clean up the SQL CRUD methods --- nqlstore/_sql.py | 461 +++++++++++++++++++++++++++-------------------- 1 file changed, 261 insertions(+), 200 deletions(-) diff --git a/nqlstore/_sql.py b/nqlstore/_sql.py index a9e71e1..f31960e 100644 --- a/nqlstore/_sql.py +++ b/nqlstore/_sql.py @@ -3,7 +3,7 @@ import copy import sys from collections.abc import Mapping, MutableMapping -from typing import Any, Dict, Iterable, Literal, TypeVar, Union +from typing import Any, Dict, Iterable, Literal, Union from pydantic import create_model from pydantic.main import ModelT @@ -35,9 +35,87 @@ from .query.parsers import QueryParser from .query.selectors import QuerySelector -_T = TypeVar("_T", bound=_SQLModel) _Filter = _ColumnExpressionArgument[bool] | bool +class _SQLModelMeta(_SQLModel): + """The base class for all SQL models""" + + id: int | None = Field(default=None, primary_key=True) + __rel_field_cache__: dict = {} + """dict of (name, Field) that have associated relationships""" + + @classmethod + @property + def __relational_fields__(cls) -> dict[str, Any]: + """dict of (name, Field) that have associated relationships""" + + cls_fullname = f"{cls.__module__}.{cls.__qualname__}" + try: + return cls.__rel_field_cache__[cls_fullname] + except KeyError: + value = { + k: v + for k, v in cls.__mapper__.all_orm_descriptors.items() + if isinstance(v.property, RelationshipProperty) + } + cls.__rel_field_cache__[cls_fullname] = value + return value + + + def model_dump( + self, + *, + mode: Union[Literal["json", "python"], str] = "python", + include: IncEx = None, + exclude: IncEx = None, + context: Union[Dict[str, Any], None] = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: Union[bool, Literal["none", "warn", "error"]] = True, + serialize_as_any: bool = False, + ) -> Dict[str, Any]: + data = super().model_dump( + mode=mode, + include=include, + exclude=exclude, + context=context, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + round_trip=round_trip, + warnings=warnings, + serialize_as_any=serialize_as_any, + ) + relations_mappers = self.__class__.__relational_fields__ + for k, field in relations_mappers.items(): + if exclude is None or k not in exclude: + try: + value = getattr(self, k, None) + except DetachedInstanceError: + # ignore lazy loaded values + continue + + if value is not None or not exclude_none: + data[k] = _serialize_embedded( + value, + field=field, + mode=mode, + context=context, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + round_trip=round_trip, + warnings=warnings, + serialize_as_any=serialize_as_any, + ) + + return data + class SQLStore(BaseStore): """The store based on SQL relational database""" @@ -46,7 +124,9 @@ def __init__(self, uri: str, parser: QueryParser | None = None, **kwargs): super().__init__(uri, parser=parser, **kwargs) self._engine = create_async_engine(uri, **kwargs) - async def register(self, models: list[type[_T]], checkfirst: bool = True): + async def register( + self, models: list[type[_SQLModelMeta]], checkfirst: bool = True + ): tables = [v.__table__ for v in models if hasattr(v, "__table__")] async with self._engine.begin() as conn: await conn.run_sync( @@ -54,12 +134,15 @@ async def register(self, models: list[type[_T]], checkfirst: bool = True): ) async def insert( - self, model: type[_T], items: Iterable[_T | dict], **kwargs - ) -> list[_T]: + self, + model: type[_SQLModelMeta], + items: Iterable[_SQLModelMeta | dict], + **kwargs, + ) -> list[_SQLModelMeta]: parsed_items = [ v if isinstance(v, model) else model.model_validate(v) for v in items ] - relations_mapper = _get_relations_mapper(model) + relations_mapper = model.__relational_fields__ async with AsyncSession(self._engine) as session: insert_func = await _get_insert(session) @@ -123,101 +206,51 @@ async def insert( async def find( self, - model: type[_T], + model: type[_SQLModelMeta], *filters: _Filter, query: QuerySelector | None = None, skip: int = 0, limit: int | None = None, sort: tuple[_ColumnExpressionOrStrLabelArgument[Any]] = (), **kwargs, - ) -> list[_T]: + ) -> list[_SQLModelMeta]: async with AsyncSession(self._engine) as session: if query: filters = (*filters, *self._parser.to_sql(model, query=query)) - - relations = _get_relations(model) - - # eagerly load all relationships so that no validation errors occur due - # to missing session if there is an attempt to load them lazily later - eager_load_opts = [subqueryload(v) for v in relations] - - filtered_relations = _get_filtered_relations( - filters=filters, - relations=relations, + return await _find( + session, model, *filters, skip=skip, limit=limit, sort=sort ) - # Note that we need to treat relations that are referenced in the filters - # differently from those that are not. This is because filtering basing on a relationship - # requires the use of an inner join. Yet an inner join automatically excludes rows - # that are have null for a given relationship. - # - # An outer join on the other hand would just return all the rows in the left table. - # We thus need to do an inner join on tables that are being filtered. - stmt = select(model) - for rel in filtered_relations: - stmt = stmt.join_from(model, rel) - - cursor = await session.stream_scalars( - stmt.options(*eager_load_opts) - .where(*filters) - .limit(limit) - .offset(skip) - .order_by(*sort) - ) - results = await cursor.all() - return list(results) - async def update( self, - model: type[_T], + model: type[_SQLModelMeta], *filters: _Filter, query: QuerySelector | None = None, updates: dict | None = None, **kwargs, - ) -> list[_T]: + ) -> list[_SQLModelMeta]: updates = copy.deepcopy(updates) async with AsyncSession(self._engine) as session: if query: filters = (*filters, *self._parser.to_sql(model, query=query)) - # Construct filters that have sub queries - relations = _get_relations(model) - rel_filters, non_rel_filters = _sieve_rel_from_non_rel_filters( - filters=filters, - relations=relations, - ) - rel_filters = _to_subquery_based_filters( - model=model, - rel_filters=rel_filters, - relations=relations, + relational_filters = _get_relational_filters(model, filters) + non_relational_filters = _get_non_relational_filters(model, filters) + + # Let's update the fields that are not embedded model field + # and return the affected results + results = await _update_non_embedded_fields( + session, + model, + *non_relational_filters, + *relational_filters, + updates=updates, ) - # dealing with nested models in the update - relations_mapper = _get_relations_mapper(model) - embedded_updates = {} - for k in relations_mapper: - try: - embedded_updates[k] = updates.pop(k) - except KeyError: - pass - - if len(updates) > 0: - stmt = ( - update(model) - .where(*non_rel_filters, *rel_filters) - .values(**updates) - .returning(model) - ) - - cursor = await session.stream_scalars(stmt) - results = await cursor.fetchall() - result_ids = [v.id for v in results] - else: - results = await self.find(model, *filters) - result_ids = [v.id for v in results] - + embedded_updates = _get_relational_updates(model, updates) + result_ids = [v.id for v in results] insert_func = await _get_insert(session) - + relations_mapper = model.__relational_fields__ for k, v in embedded_updates.items(): field = relations_mapper[k] field_props = field.property @@ -366,101 +399,33 @@ async def update( async def delete( self, - model: type[_T], + model: type[_SQLModelMeta], *filters: _Filter, query: QuerySelector | None = None, **kwargs, - ) -> list[_T]: + ) -> list[_SQLModelMeta]: async with AsyncSession(self._engine) as session: if query: filters = (*filters, *self._parser.to_sql(model, query=query)) deleted_items = await self.find(model, *filters) - # Construct filters that have sub queries - relations = _get_relations(model) - rel_filters, non_rel_filters = _sieve_rel_from_non_rel_filters( - filters=filters, - relations=relations, - ) - rel_filters = _to_subquery_based_filters( - model=model, - rel_filters=rel_filters, - relations=relations, - ) + relational_filters = _get_relational_filters(model, filters) + non_relational_filters = _get_non_relational_filters(model, filters) + exec_options = {} - if len(rel_filters) > 0: + if len(relational_filters) > 0: exec_options = {"is_delete_using": True} await session.stream( delete(model) - .where(*non_rel_filters, *rel_filters) + .where(*non_relational_filters, *relational_filters) .execution_options(**exec_options), ) await session.commit() return deleted_items -class _SQLModelMeta(_SQLModel): - """The base class for all SQL models""" - - id: int | None = Field(default=None, primary_key=True) - - def model_dump( - self, - *, - mode: Union[Literal["json", "python"], str] = "python", - include: IncEx = None, - exclude: IncEx = None, - context: Union[Dict[str, Any], None] = None, - by_alias: bool = False, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - round_trip: bool = False, - warnings: Union[bool, Literal["none", "warn", "error"]] = True, - serialize_as_any: bool = False, - ) -> Dict[str, Any]: - data = super().model_dump( - mode=mode, - include=include, - exclude=exclude, - context=context, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - round_trip=round_trip, - warnings=warnings, - serialize_as_any=serialize_as_any, - ) - relations_mappers = _get_relations_mapper(self.__class__) - for k, field in relations_mappers.items(): - if exclude is None or k not in exclude: - try: - value = getattr(self, k, None) - except DetachedInstanceError: - # ignore lazy loaded values - continue - - if value is not None or not exclude_none: - data[k] = _serialize_embedded( - value, - field=field, - mode=mode, - context=context, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - round_trip=round_trip, - warnings=warnings, - serialize_as_any=serialize_as_any, - ) - - return data - - def SQLModel( name: str, schema: type[ModelT], @@ -504,38 +469,6 @@ def SQLModel( ) -def _get_relations(model: type[_SQLModel]): - """Gets all the relational fields of the given model - - Args: - model: the SQL model to inspect - - Returns: - list of Fields that have associated relationships - """ - return [ - v - for v in model.__mapper__.all_orm_descriptors.values() - if isinstance(v.property, RelationshipProperty) - ] - - -def _get_relations_mapper(model: type[_SQLModel]) -> dict[str, Any]: - """Gets all the relational fields with their names of the given model - - Args: - model: the SQL model to inspect - - Returns: - dict of (name, Field) that have associated relationships - """ - return { - k: v - for k, v in model.__mapper__.all_orm_descriptors.items() - if isinstance(v.property, RelationshipProperty) - } - - def _get_filtered_tables(filters: Iterable[_Filter]) -> list[Table]: """Retrieves the tables that have been referenced in the filters @@ -569,31 +502,51 @@ def _get_filtered_relations( return [rel for rel in relations if rel.property.target in filtered_tables] -def _sieve_rel_from_non_rel_filters( - filters: Iterable[_Filter], relations: Iterable[InstrumentedAttribute[Any]] -) -> tuple[list[_Filter], list[_Filter]]: - """Separates relational filters from non-relational ones +def _get_relational_filters( + model: type[_SQLModelMeta], + filters: Iterable[_Filter], +) -> list[_Filter]: + """Gets the filters that are concerned with relationships on this model + + The filters returned are in subquery form since 'update' and 'delete' + in sqlalchemy do not have join and the only way to attach these filters + to the model is through sub queries Args: + model: the model under consideration filters: the tuple of filters to inspect - relations: all relations present on the model Returns: - tuple(rel, non_rel) where rel = list of relational filters, - and non_rel = non-relational filters + list of filters that are concerned with relationships on this model """ - rel_targets = [v.property.target for v in relations] - rel = [] - non_rel = [] + relationships = list(model.__relational_fields__.values()) + targets = [v.property.target for v in relationships] + plain_filters = [ + item + for item in filters + if any([getattr(v, "table", None) in targets for v in item.get_children()]) + ] + return _to_subquery_based_filters(model, plain_filters, relationships) - for filter_ in filters: - operands = filter_.get_children() - if any([getattr(v, "table", None) in rel_targets for v in operands]): - rel.append(filter_) - else: - non_rel.append(filter_) - return rel, non_rel +def _get_non_relational_filters( + model: type[_SQLModelMeta], filters: Iterable[_Filter] +) -> list[_Filter]: + """Gets the filters that are NOT concerned with relationships on this model + + Args: + model: the model under consideration + filters: the tuple of filters to inspect + + Returns: + list of filters that are NOT concerned with relationships on this model + """ + targets = [v.property.target for v in model.__relational_fields__.values()] + return [ + item + for item in filters + if not any([getattr(v, "table", None) in targets for v in item.get_children()]) + ] def _to_subquery_based_filters( @@ -770,3 +723,111 @@ async def _get_insert(session: AsyncSession): return pg_insert return insert + + +async def _update_non_embedded_fields( + session: AsyncSession, model: type[_SQLModelMeta], *filters: _Filter, updates: dict +): + """Updates only the non-embedded fields of the model + + It ignores any relationships and returns the updated results + + Args: + session: the sqlalchemy session + model: the model to be updated + filters: the filters against which to match the records that are to be updated + updates: the updates to add to each matched record + + Returns: + the updated records + """ + non_embedded_updates = _get_non_relational_updates(model, updates) + if len(non_embedded_updates) == 0: + # if we supplied an empty update dict to update, + # there would be an error + return await _find(session, model, *filters) + + stmt = update(model).where(*filters).values(**non_embedded_updates).returning(model) + cursor = await session.stream_scalars(stmt) + return await cursor.fetchall() + + +def _get_relational_updates(model: type[_SQLModelMeta], updates: dict) -> dict: + """Gets the updates that are affect only the relationships on this model + + Args: + model: the model to be updated + updates: the dict of new values to updated on the matched records + + Returns: + a dict with only updates concerning the relationships of the given model + """ + return {k: v for k, v in updates.items() if k in model.__relational_fields__} + + +def _get_non_relational_updates(model: type[_SQLModelMeta], updates: dict) -> dict: + """Gets the updates that do not affect relationships on this model + + Args: + model: the model to be updated + updates: the dict of new values to updated on the matched records + + Returns: + a dict with only updates that do not affect relationships on this model + """ + return {k: v for k, v in updates.items() if k not in model.__relational_fields__} + + +async def _find( + session: AsyncSession, + model: type[_SQLModelMeta], + /, + *filters: _Filter, + skip: int = 0, + limit: int | None = None, + sort: tuple[_ColumnExpressionOrStrLabelArgument[Any]] = (), +) -> list[_SQLModelMeta]: + """Finds the records that match the given filters + + Args: + session: the sqlalchemy session + model: the model that is to be searched + filters: the filters to match + skip: number of records to ignore at the top of the returned results; default is 0 + limit: maximum number of records to return; default is None. + sort: fields to sort by; default = None + + Returns: + the records tha match the given filters + """ + relations = list(model.__relational_fields__.values()) + + # eagerly load all relationships so that no validation errors occur due + # to missing session if there is an attempt to load them lazily later + eager_load_opts = [subqueryload(v) for v in relations] + + filtered_relations = _get_filtered_relations( + filters=filters, + relations=relations, + ) + + # Note that we need to treat relations that are referenced in the filters + # differently from those that are not. This is because filtering basing on a relationship + # requires the use of an inner join. Yet an inner join automatically excludes rows + # that are have null for a given relationship. + # + # An outer join on the other hand would just return all the rows in the left table. + # We thus need to do an inner join on tables that are being filtered. + stmt = select(model) + for rel in filtered_relations: + stmt = stmt.join_from(model, rel) + + cursor = await session.stream_scalars( + stmt.options(*eager_load_opts) + .where(*filters) + .limit(limit) + .offset(skip) + .order_by(*sort) + ) + results = await cursor.all() + return list(results) From 5f360414cd09db5d5a60b37b78472f69cbc14f03 Mon Sep 17 00:00:00 2001 From: Martin Ahindura Date: Mon, 21 Apr 2025 14:11:53 +0200 Subject: [PATCH 03/10] WIP: Enable nested updates in SQL --- LIMITATIONS.md | 15 ++++ README.md | 5 ++ examples/blog/main.py | 2 +- examples/blog/schemas.py | 2 +- examples/blog/test_main.py | 7 +- examples/blog/utils.py | 5 +- nqlstore/_compat.py | 3 +- nqlstore/_sql.py | 162 +++++++++++++++++++++++++++++++------ 8 files changed, 169 insertions(+), 32 deletions(-) create mode 100644 LIMITATIONS.md diff --git a/LIMITATIONS.md b/LIMITATIONS.md new file mode 100644 index 0000000..f522e1b --- /dev/null +++ b/LIMITATIONS.md @@ -0,0 +1,15 @@ +# Limitations + +## Filtering + +### Redis + +- Mongo-style regular expression filtering is not supported. + This is because native redis regular expression filtering is limited to the most basic text based search. + +## Update Operation + +### SQL + +- Even though one can update a model to theoretically infinite number of levels deep, + the returned results can only contain 1-level-deep nested models and no more. diff --git a/README.md b/README.md index 7fc459b..617fd6c 100644 --- a/README.md +++ b/README.md @@ -336,6 +336,11 @@ libraries = await redis_store.delete( - [ ] Add documentation site +## Limitations + +This library is limited in some specific cases. +Read through the [`LIMITATIONS.md`](./LIMITATIONS.md) file for more. + ## Contributions Contributions are welcome. The docs have to maintained, the code has to be made cleaner, more idiomatic and faster, diff --git a/examples/blog/main.py b/examples/blog/main.py index e72644a..7c6bb7b 100644 --- a/examples/blog/main.py +++ b/examples/blog/main.py @@ -13,7 +13,7 @@ from fastapi.security import OAuth2PasswordRequestForm from models import MongoPost, RedisPost, SqlInternalAuthor, SqlPost from pydantic import BaseModel -from schemas import InternalAuthor, Post, TokenResponse, PartialPost +from schemas import InternalAuthor, PartialPost, Post, TokenResponse from stores import MongoStoreDep, RedisStoreDep, SqlStoreDep, clear_stores _ACCESS_TOKEN_EXPIRE_MINUTES = 30 diff --git a/examples/blog/schemas.py b/examples/blog/schemas.py index c6cf844..8fac64f 100644 --- a/examples/blog/schemas.py +++ b/examples/blog/schemas.py @@ -3,7 +3,7 @@ from datetime import datetime from pydantic import BaseModel -from utils import current_timestamp, Partial +from utils import Partial, current_timestamp from nqlstore import Field, Relationship diff --git a/examples/blog/test_main.py b/examples/blog/test_main.py index 87206d6..5286396 100644 --- a/examples/blog/test_main.py +++ b/examples/blog/test_main.py @@ -150,7 +150,7 @@ async def test_update_sql_post( id_ = post.id update = { **post_dict, - "name": "some other name", + "title": "some other title", "tags": [ *post_dict["tags"], {"title": "another one"}, @@ -170,8 +170,9 @@ async def test_update_sql_post( **raw, "id": final["id"], "post_id": final["post_id"], - "author": final["author"], - "author_id": final["author_id"], + "author_id": 1, + "created_at": timestamp, + "updated_at": timestamp, } for raw, final in zip(update["comments"], got["comments"]) ], diff --git a/examples/blog/utils.py b/examples/blog/utils.py index 60f7c46..4a917c0 100644 --- a/examples/blog/utils.py +++ b/examples/blog/utils.py @@ -3,12 +3,13 @@ import copy import sys from datetime import datetime -from typing import TypeVar, Optional, get_args, Literal, Any +from typing import Any, Literal, Optional, TypeVar, get_args from pydantic import BaseModel, create_model -from nqlstore._field import FieldInfo from pydantic.main import IncEx +from nqlstore._field import FieldInfo + _T = TypeVar("_T", bound=BaseModel) diff --git a/nqlstore/_compat.py b/nqlstore/_compat.py index 7a5abe4..f22a9ba 100644 --- a/nqlstore/_compat.py +++ b/nqlstore/_compat.py @@ -41,7 +41,7 @@ sql imports; and their default if sqlmodel is missing """ try: - from sqlalchemy import Column, Table + from sqlalchemy import Column, Table, func from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.ext.asyncio import create_async_engine @@ -56,7 +56,6 @@ _ColumnExpressionArgument, _ColumnExpressionOrStrLabelArgument, ) - from sqlalchemy import func from sqlmodel import SQLModel as _SQLModel from sqlmodel import delete, insert, select, update from sqlmodel._compat import post_init_field_info diff --git a/nqlstore/_sql.py b/nqlstore/_sql.py index f31960e..73841ea 100644 --- a/nqlstore/_sql.py +++ b/nqlstore/_sql.py @@ -23,13 +23,13 @@ _SQLModel, create_async_engine, delete, + func, insert, pg_insert, select, sqlite_insert, subqueryload, update, - func, ) from ._field import Field, get_field_definitions from .query.parsers import QueryParser @@ -37,6 +37,7 @@ _Filter = _ColumnExpressionArgument[bool] | bool + class _SQLModelMeta(_SQLModel): """The base class for all SQL models""" @@ -61,7 +62,6 @@ def __relational_fields__(cls) -> dict[str, Any]: cls.__rel_field_cache__[cls_fullname] = value return value - def model_dump( self, *, @@ -161,20 +161,13 @@ async def insert( for idx, record in enumerate(items): parent = results[idx] raw_value = _get_key_or_prop(record, k) - parent_partial, embedded_value = _parse_embedded( - raw_value, field, parent - ) + embedded_value = _embed_related_value(parent, field, raw_value) + if isinstance(embedded_value, _SQLModel): embedded_values.append(embedded_value) elif isinstance(embedded_value, Iterable): embedded_values += embedded_value - # for many-to-one relationships, the parent - # also needs to be updated. - # add the partial update of the parent - for key, val in parent_partial.items(): - setattr(parent, key, val) - # insert the related items if len(embedded_values) > 0: field_model = field.property.mapper.class_ @@ -237,7 +230,7 @@ async def update( relational_filters = _get_relational_filters(model, filters) non_relational_filters = _get_non_relational_filters(model, filters) - # Let's update the fields that are not embedded model field + # Let's update the fields that are not embedded model fields # and return the affected results results = await _update_non_embedded_fields( session, @@ -257,7 +250,7 @@ async def update( field_model = field_props.mapper.class_ link_model = model.__sqlmodel_relationships__[k].link_model - # fk = foreign key + # fk means foreign key if link_model is not None: child_id_field_name = field_props.secondaryjoin.left.name parent_id_field_name = field_props.primaryjoin.left.name @@ -271,6 +264,7 @@ async def update( # get the foreign keys to use in resetting all affected # relationships; + # FIXME: comment above is unclear # get parsed embedded values so that they can replace # the old relations. # Note: this operation is strictly replace, not patch @@ -278,7 +272,7 @@ async def update( through_table_values = [] fk_values = [] for parent in results: - parent_partial, embedded_value = _parse_embedded(v, field, parent) + embedded_value = _embed_related_value(parent, field, v) initial_embedded_values_len = len(embedded_values) if isinstance(embedded_value, _SQLModel): embedded_values.append(embedded_value) @@ -286,6 +280,7 @@ async def update( embedded_values += embedded_value if link_model is not None: + # FIXME: unclear name 'index_range' index_range = ( initial_embedded_values_len, len(embedded_values), @@ -301,12 +296,6 @@ async def update( else: fk_values.append(getattr(parent, parent_id_field_name)) - # for many-to-one relationships, the parent - # also needs to be updated. - # add the partial update of the parent - for key, val in parent_partial.items(): - setattr(parent, key, val) - if len(embedded_values) > 0: # Reset the relationship; delete all other related items # Currently, this operation replaces all past relations @@ -623,6 +612,111 @@ def _with_value(obj: dict | Any, field: str, value: Any) -> Any: return obj +def _embed_related_value( + parent: _SQLModel, + related_field: Any, + related_value: Iterable[dict | Any] | dict | Any, +) -> Iterable[_SQLModel] | _SQLModel | None: + """Embeds a given relationship into the parent in place and returns the related records + + Args: + parent: the model that contains the given relationships + related_field: the field that contains the given relationship + related_value: the values correspond to the related field + + Returns: + the related record(s) + """ + if related_value is None: + return None + + props = related_field.property # type: RelationshipProperty[Any] + wrapper_type = props.collection_class + field_model = props.mapper.class_ + parent_foreign_key_field = props.primaryjoin.right.name + direction = props.direction + + if direction == RelationshipDirection.MANYTOONE: + related_value_id_key = props.primaryjoin.left.name + parent_foreign_key_value = related_value.get(related_value_id_key) + # update the foreign key value in the parent + setattr(parent, parent_foreign_key_field, parent_foreign_key_value) + # create child + child = field_model.model_validate(related_value) + # update nested relationships + for field_name, field_type in field_model.__relational_fields__.items(): + if isinstance(related_value, dict): + nested_related_value = related_value.get(field_name) + else: + nested_related_value = getattr(related_value, field_name) + + nested_related_records = _embed_related_value( + parent=child, + related_field=field_type, + related_value=nested_related_value, + ) + setattr(child, field_name, nested_related_records) + + return child + + elif direction == RelationshipDirection.ONETOMANY: + related_value_id_key = props.primaryjoin.left.name + parent_foreign_key_value = getattr(parent, related_value_id_key) + # add a foreign key values to link back to parent + if issubclass(wrapper_type, (list, tuple, set)): + embedded_records = [] + for v in related_value: + child = field_model.model_validate( + _with_value(v, parent_foreign_key_field, parent_foreign_key_value) + ) + + # update nested relationships + for field_name, field_type in field_model.__relational_fields__.items(): + if isinstance(v, dict): + nested_related_value = v.get(field_name) + else: + nested_related_value = getattr(v, field_name) + + nested_related_records = _embed_related_value( + parent=child, + related_field=field_type, + related_value=nested_related_value, + ) + setattr(child, field_name, nested_related_records) + + embedded_records.append(child) + + return wrapper_type(embedded_records) + + elif direction == RelationshipDirection.MANYTOMANY: + if issubclass(wrapper_type, (list, tuple, set)): + embedded_records = [] + for v in related_value: + child = field_model.model_validate(v) + + # update nested relationships + for field_name, field_type in field_model.__relational_fields__.items(): + if isinstance(v, dict): + nested_related_value = v.get(field_name) + else: + nested_related_value = getattr(v, field_name) + nested_related_records = _embed_related_value( + parent=child, + related_field=field_type, + related_value=nested_related_value, + ) + setattr(child, field_name, nested_related_records) + + embedded_records.append(child) + + return wrapper_type(embedded_records) + + raise NotImplementedError( + f"relationship {direction} of type annotation {wrapper_type} not supported yet" + ) + + +# FIXME: Allow multiple levels of nesting def _parse_embedded( value: Iterable[dict | Any] | dict | Any, field: Any, parent: _SQLModel ) -> tuple[dict, Iterable[_SQLModel] | _SQLModel | None]: @@ -649,7 +743,13 @@ def _parse_embedded( fk_value = getattr(parent, parent_id_field) direction = props.direction + # FIXME: Maybe check if any relationship value is passed by checking the keys of value + # And then do a recursive embedded parse and return the field_model + if direction == RelationshipDirection.MANYTOONE: + if any([k in field_model.__relational_fields__ for k in value]): + # FIXME: nested relationships exist + pass # # add a foreign key value to link back to parent return {fk_field: fk_value}, field_model.model_validate(value) @@ -658,15 +758,31 @@ def _parse_embedded( if issubclass(wrapper_type, (list, tuple, set)): return {}, wrapper_type( [ - field_model.model_validate(_with_value(v, fk_field, fk_value)) + ( + field_model.model_validate(_with_value(v, fk_field, fk_value)) + # FIXME: Add a proper call to nested recursion for all relational_fields + if any([k in field_model.__relational_fields__ for k in v]) + else field_model.model_validate( + _with_value(v, fk_field, fk_value) + ) + ) for v in value ] ) if direction == RelationshipDirection.MANYTOMANY: - # add a foreign key values to link back to parent if issubclass(wrapper_type, (list, tuple, set)): - return {}, wrapper_type([field_model.model_validate(v) for v in value]) + return {}, wrapper_type( + [ + ( + field_model.model_validate(v) + # FIXME: Add a proper call to nested recursion for all relational_fields + if any([k in field_model.__relational_fields__ for k in v]) + else field_model.model_validate(v) + ) + for v in value + ] + ) raise NotImplementedError( f"relationship {direction} of type annotation {wrapper_type} not supported yet" From bdb6f1d2c9a3ca71ffd8a704e6535c966bf3bb36 Mon Sep 17 00:00:00 2001 From: Martin Ahindura Date: Thu, 1 May 2025 16:17:04 +0200 Subject: [PATCH 04/10] WIP: Clean up the SQL update method --- nqlstore/_sql.py | 558 ++++++++++++++++++++++++----------------------- 1 file changed, 284 insertions(+), 274 deletions(-) diff --git a/nqlstore/_sql.py b/nqlstore/_sql.py index 73841ea..44793d3 100644 --- a/nqlstore/_sql.py +++ b/nqlstore/_sql.py @@ -3,7 +3,7 @@ import copy import sys from collections.abc import Mapping, MutableMapping -from typing import Any, Dict, Iterable, Literal, Union +from typing import Any, Dict, Iterable, Literal, Sequence, TypeVar, Union from pydantic import create_model from pydantic.main import ModelT @@ -36,6 +36,7 @@ from .query.selectors import QuerySelector _Filter = _ColumnExpressionArgument[bool] | bool +_T = TypeVar("_T") class _SQLModelMeta(_SQLModel): @@ -145,9 +146,8 @@ async def insert( relations_mapper = model.__relational_fields__ async with AsyncSession(self._engine) as session: - insert_func = await _get_insert(session) - stmt = insert_func(model).returning(model) - cursor = await session.stream_scalars(stmt, parsed_items) + insert_stmt = await _get_insert_func(session, model=model) + cursor = await session.stream_scalars(insert_stmt, parsed_items) results = await cursor.all() result_ids = [v.id for v in results] @@ -161,7 +161,7 @@ async def insert( for idx, record in enumerate(items): parent = results[idx] raw_value = _get_key_or_prop(record, k) - embedded_value = _embed_related_value(parent, field, raw_value) + embedded_value = _embed_value(parent, field, raw_value) if isinstance(embedded_value, _SQLModel): embedded_values.append(embedded_value) @@ -171,23 +171,7 @@ async def insert( # insert the related items if len(embedded_values) > 0: field_model = field.property.mapper.class_ - - try: - # PostgreSQL and SQLite support on_conflict_do_nothing - embed_stmt = ( - insert_func(field_model) - .on_conflict_do_nothing() - .returning(field_model) - ) - except AttributeError: - # MySQL supports prefix("IGNORE") - # Other databases might fail at this point - embed_stmt = ( - insert_func(field_model) - .prefix_with("IGNORE", dialect="mysql") - .returning(field_model) - ) - + embed_stmt = await _get_insert_func(session, model=field_model) await session.stream_scalars(embed_stmt, embedded_values) # update the updated parents @@ -239,152 +223,16 @@ async def update( *relational_filters, updates=updates, ) - - embedded_updates = _get_relational_updates(model, updates) result_ids = [v.id for v in results] - insert_func = await _get_insert(session) - relations_mapper = model.__relational_fields__ - for k, v in embedded_updates.items(): - field = relations_mapper[k] - field_props = field.property - field_model = field_props.mapper.class_ - link_model = model.__sqlmodel_relationships__[k].link_model - - # fk means foreign key - if link_model is not None: - child_id_field_name = field_props.secondaryjoin.left.name - parent_id_field_name = field_props.primaryjoin.left.name - child_fk_field_name = field_props.secondaryjoin.right.name - parent_fk_field_name = field_props.primaryjoin.right.name - - else: - parent_id_field_name = field_props.primaryjoin.left.name - child_fk_field_name = field_props.primaryjoin.right.name - fk_field = getattr(field_model, child_fk_field_name) - - # get the foreign keys to use in resetting all affected - # relationships; - # FIXME: comment above is unclear - # get parsed embedded values so that they can replace - # the old relations. - # Note: this operation is strictly replace, not patch - embedded_values = [] - through_table_values = [] - fk_values = [] - for parent in results: - embedded_value = _embed_related_value(parent, field, v) - initial_embedded_values_len = len(embedded_values) - if isinstance(embedded_value, _SQLModel): - embedded_values.append(embedded_value) - elif isinstance(embedded_value, Iterable): - embedded_values += embedded_value - - if link_model is not None: - # FIXME: unclear name 'index_range' - index_range = ( - initial_embedded_values_len, - len(embedded_values), - ) - through_table_values.append( - { - parent_fk_field_name: getattr( - parent, parent_id_field_name - ), - "index_range": index_range, - } - ) - else: - fk_values.append(getattr(parent, parent_id_field_name)) - - if len(embedded_values) > 0: - # Reset the relationship; delete all other related items - # Currently, this operation replaces all past relations - if fk_field is not None and len(fk_values) > 0: - reset_stmt = delete(field_model).where(fk_field.in_(fk_values)) - await session.stream(reset_stmt) - - # insert the embedded items - try: - # PostgreSQL and SQLite support on_conflict_do_nothing - embed_stmt = ( - insert_func(field_model) - .on_conflict_do_nothing() - .returning(field_model) - ) - except AttributeError: - # MySQL supports prefix("IGNORE") - # Other databases might fail at this point - embed_stmt = ( - insert_func(field_model) - .prefix_with("IGNORE", dialect="mysql") - .returning(field_model) - ) - - embedded_cursor = await session.stream_scalars( - embed_stmt, embedded_values - ) - embedded_results = await embedded_cursor.all() - - if len(through_table_values) > 0: - parent_fk_values = [ - v[parent_fk_field_name] for v in through_table_values - ] - if len(parent_fk_values) > 0: - # Reset the relationship; delete all other related items - # Currently, this operation replaces all past relations - parent_fk_field = getattr(link_model, parent_id_field_name) - reset_stmt = delete(link_model).where( - parent_fk_field.in_(parent_fk_values) - ) - await session.stream(reset_stmt) - - # insert the through table records - try: - # PostgreSQL and SQLite support on_conflict_do_nothing - through_table_stmt = ( - insert_func(link_model) - .on_conflict_do_nothing() - .returning(link_model) - ) - except AttributeError: - # MySQL supports prefix("IGNORE") - # Other databases might fail at this point - through_table_stmt = ( - insert_func(link_model) - .prefix_with("IGNORE", dialect="mysql") - .returning(link_model) - ) - - # compute the next id auto-incremented - next_id = await session.scalar(func.max(link_model.id)) - next_id = (next_id or 0) + 1 - - through_table_values = [ - { - parent_fk_field_name: v[parent_fk_field_name], - child_fk_field_name: getattr( - child, child_id_field_name - ), - } - for v in through_table_values - for child in embedded_results[ - v["index_range"][0] : v["index_range"][1] - ] - ] - through_table_values = [ - link_model(id=next_id + idx, **v) - for idx, v in enumerate(through_table_values) - ] - await session.stream_scalars( - through_table_stmt, through_table_values - ) - - # update the updated parents - session.add_all(results) + # Let's update the embedded fields also + await _update_embedded_fields( + session, model=model, records=results, updates=updates + ) await session.commit() + refreshed_results = await self.find(model, model.id.in_(result_ids)) - return list(refreshed_results) + return refreshed_results async def delete( self, @@ -612,48 +460,48 @@ def _with_value(obj: dict | Any, field: str, value: Any) -> Any: return obj -def _embed_related_value( +def _embed_value( parent: _SQLModel, - related_field: Any, - related_value: Iterable[dict | Any] | dict | Any, + relationship: Any, + value: Iterable[dict | Any] | dict | Any, ) -> Iterable[_SQLModel] | _SQLModel | None: - """Embeds a given relationship into the parent in place and returns the related records + """Embeds in place a given value into the parent basing on the given relationship + + Note that the parent itself is changed to include the value Args: parent: the model that contains the given relationships - related_field: the field that contains the given relationship - related_value: the values correspond to the related field + relationship: the given relationship + value: the values correspond to the related field Returns: - the related record(s) + the embedded record(s) """ - if related_value is None: + if value is None: return None - props = related_field.property # type: RelationshipProperty[Any] + props = relationship.property # type: RelationshipProperty[Any] wrapper_type = props.collection_class - field_model = props.mapper.class_ + relationship_model = props.mapper.class_ parent_foreign_key_field = props.primaryjoin.right.name direction = props.direction if direction == RelationshipDirection.MANYTOONE: related_value_id_key = props.primaryjoin.left.name - parent_foreign_key_value = related_value.get(related_value_id_key) + parent_foreign_key_value = value.get(related_value_id_key) # update the foreign key value in the parent setattr(parent, parent_foreign_key_field, parent_foreign_key_value) # create child - child = field_model.model_validate(related_value) + child = relationship_model.model_validate(value) # update nested relationships - for field_name, field_type in field_model.__relational_fields__.items(): - if isinstance(related_value, dict): - nested_related_value = related_value.get(field_name) + for field_name, field_type in relationship_model.__relational_fields__.items(): + if isinstance(value, dict): + nested_related_value = value.get(field_name) else: - nested_related_value = getattr(related_value, field_name) + nested_related_value = getattr(value, field_name) - nested_related_records = _embed_related_value( - parent=child, - related_field=field_type, - related_value=nested_related_value, + nested_related_records = _embed_value( + parent=child, relationship=field_type, value=nested_related_value ) setattr(child, field_name, nested_related_records) @@ -665,22 +513,25 @@ def _embed_related_value( # add a foreign key values to link back to parent if issubclass(wrapper_type, (list, tuple, set)): embedded_records = [] - for v in related_value: - child = field_model.model_validate( + for v in value: + child = relationship_model.model_validate( _with_value(v, parent_foreign_key_field, parent_foreign_key_value) ) # update nested relationships - for field_name, field_type in field_model.__relational_fields__.items(): + for ( + field_name, + field_type, + ) in relationship_model.__relational_fields__.items(): if isinstance(v, dict): nested_related_value = v.get(field_name) else: nested_related_value = getattr(v, field_name) - nested_related_records = _embed_related_value( + nested_related_records = _embed_value( parent=child, - related_field=field_type, - related_value=nested_related_value, + relationship=field_type, + value=nested_related_value, ) setattr(child, field_name, nested_related_records) @@ -691,19 +542,22 @@ def _embed_related_value( elif direction == RelationshipDirection.MANYTOMANY: if issubclass(wrapper_type, (list, tuple, set)): embedded_records = [] - for v in related_value: - child = field_model.model_validate(v) + for v in value: + child = relationship_model.model_validate(v) # update nested relationships - for field_name, field_type in field_model.__relational_fields__.items(): + for ( + field_name, + field_type, + ) in relationship_model.__relational_fields__.items(): if isinstance(v, dict): nested_related_value = v.get(field_name) else: nested_related_value = getattr(v, field_name) - nested_related_records = _embed_related_value( + nested_related_records = _embed_value( parent=child, - related_field=field_type, - related_value=nested_related_value, + relationship=field_type, + value=nested_related_value, ) setattr(child, field_name, nested_related_records) @@ -716,79 +570,6 @@ def _embed_related_value( ) -# FIXME: Allow multiple levels of nesting -def _parse_embedded( - value: Iterable[dict | Any] | dict | Any, field: Any, parent: _SQLModel -) -> tuple[dict, Iterable[_SQLModel] | _SQLModel | None]: - """Parses embedded items that can be a single item or many into SQLModels - - Args: - value: the value to parse - field: the field on which these embedded items are - parent: the parent SQLModel to which this value is attached - - Returns: - tuple (parent_partial, embedded_models): where parent_partial is the partial update of the parent - and embedded_models is an iterable of SQLModel instances or a single SQLModel instance - or None if value is None - """ - if value is None: - return {}, None - - props = field.property # type: RelationshipProperty[Any] - wrapper_type = props.collection_class - field_model = props.mapper.class_ - fk_field = props.primaryjoin.right.name - parent_id_field = props.primaryjoin.left.name - fk_value = getattr(parent, parent_id_field) - direction = props.direction - - # FIXME: Maybe check if any relationship value is passed by checking the keys of value - # And then do a recursive embedded parse and return the field_model - - if direction == RelationshipDirection.MANYTOONE: - if any([k in field_model.__relational_fields__ for k in value]): - # FIXME: nested relationships exist - pass - # # add a foreign key value to link back to parent - return {fk_field: fk_value}, field_model.model_validate(value) - - if direction == RelationshipDirection.ONETOMANY: - # add a foreign key values to link back to parent - if issubclass(wrapper_type, (list, tuple, set)): - return {}, wrapper_type( - [ - ( - field_model.model_validate(_with_value(v, fk_field, fk_value)) - # FIXME: Add a proper call to nested recursion for all relational_fields - if any([k in field_model.__relational_fields__ for k in v]) - else field_model.model_validate( - _with_value(v, fk_field, fk_value) - ) - ) - for v in value - ] - ) - - if direction == RelationshipDirection.MANYTOMANY: - if issubclass(wrapper_type, (list, tuple, set)): - return {}, wrapper_type( - [ - ( - field_model.model_validate(v) - # FIXME: Add a proper call to nested recursion for all relational_fields - if any([k in field_model.__relational_fields__ for k in v]) - else field_model.model_validate(v) - ) - for v in value - ] - ) - - raise NotImplementedError( - f"relationship {direction} of type annotation {wrapper_type} not supported yet" - ) - - def _serialize_embedded( value: Iterable[_SQLModel] | _SQLModel, field: Any, **kwargs ) -> Iterable[dict] | dict | None: @@ -820,11 +601,12 @@ def _serialize_embedded( ) -async def _get_insert(session: AsyncSession): +async def _get_insert_func(session: AsyncSession, model: type[_SQLModelMeta]): """Gets the insert statement for the given session Args: session: the async session connecting to the database + model: the model for which the insert statement is to be obtained Returns: the insert function @@ -833,12 +615,25 @@ async def _get_insert(session: AsyncSession): dialect = conn.dialect dialect_name = dialect.name + native_insert_func = insert + if dialect_name == "sqlite": - return sqlite_insert + native_insert_func = sqlite_insert if dialect_name == "postgresql": - return pg_insert - - return insert + native_insert_func = pg_insert + + # insert the embedded items + try: + # PostgreSQL and SQLite support on_conflict_do_nothing + return native_insert_func(model).on_conflict_do_nothing().returning(model) + except AttributeError: + # MySQL supports prefix("IGNORE") + # Other databases might fail at this point + return ( + native_insert_func(model) + .prefix_with("IGNORE", dialect="mysql") + .returning(model) + ) async def _update_non_embedded_fields( @@ -868,6 +663,221 @@ async def _update_non_embedded_fields( return await cursor.fetchall() +async def _update_embedded_fields( + session: AsyncSession, + model: type[_SQLModelMeta], + records: list[_SQLModelMeta], + updates: dict, +): + """Updates only the embedded fields of the model for the given records + + It ignores any fields in the `updates` dict that are not for embedded models + Note: this operation is replaces the values of the embedded fields with the new values + passed in the `updates` dictionary as opposed to patching the pre-existing values. + + Args: + session: the sqlalchemy session + model: the model to be updated + records: the db records to update + updates: the updates to add to each record + """ + embedded_updates = _get_relational_updates(model, updates) + relations_mapper = model.__relational_fields__ + for k, v in embedded_updates.items(): + relationship = relations_mapper[k] + link_model = model.__sqlmodel_relationships__[k].link_model + + # this does a replace operation; i.e. removes old values and replaces them with the updates + await _bulk_embedded_delete( + session, relationship=relationship, data=records, link_model=link_model + ) + await _bulk_embedded_insert( + session, + relationship=relationship, + data=records, + link_model=link_model, + payload=v, + ) + + # update the updated parents + session.add_all(records) + + +async def _bulk_embedded_insert( + session: AsyncSession, + relationship: Any, + data: list[_SQLModelMeta], + link_model: type[_SQLModelMeta] | None, + payload: Iterable[dict] | dict, +) -> Sequence[_SQLModelMeta] | None: + """Inserts the payload into the data following the given relationship + + It updates the database also + + Args: + session: the database session + relationship: the relationship the payload has with the data's schema + link_model: the model for the through table + payload: the payload to merge into each record in the data + + Returns: + the updated data including the embedded data in each record + """ + relationship_props = relationship.property # type: RelationshipProperty + relationship_model = relationship_props.mapper.class_ + + parsed_embedded_records = [_embed_value(v, relationship, payload) for v in data] + + insert_stmt = await _get_insert_func(session, model=relationship_model) + embedded_cursor = await session.stream_scalars( + insert_stmt, _flatten_list(parsed_embedded_records) + ) + embedded_db_records = await embedded_cursor.all() + + parent_embedded_map = [ + (parent, embedded_db_records[idx : idx + len(_as_list(raw_embedded))]) + for idx, (parent, raw_embedded) in enumerate(zip(data, parsed_embedded_records)) + ] + + # insert through table values + await _bulk_insert_through_table_data( + session, + relationship=relationship, + link_model=link_model, + parent_embedded_map=parent_embedded_map, + ) + + return data + + +async def _bulk_insert_through_table_data( + session: AsyncSession, + relationship: Any, + link_model: type[_SQLModelMeta] | None, + parent_embedded_map: list[tuple[_SQLModelMeta, list[_SQLModelMeta]]], +): + """Inserts the link records into the through-table represented by the link_model + + Args: + session: the database session + relationship: the relationship the embedded records are based on + link_model: the model for the through table + parent_embedded_map: the list of tuples of parent and its associated embedded db records + """ + if link_model is not None: + relationship_props = relationship.property # type: RelationshipProperty + child_id_field_name = relationship_props.secondaryjoin.left.name + parent_id_field_name = relationship_props.primaryjoin.left.name + child_fk_field_name = relationship_props.secondaryjoin.right.name + parent_fk_field_name = relationship_props.primaryjoin.right.name + + link_raw_values = [ + { + parent_fk_field_name: getattr(parent, parent_id_field_name), + child_fk_field_name: getattr(child, child_id_field_name), + } + for parent, children in enumerate(parent_embedded_map) + for child in children + ] + + next_id = await _get_nextid(session, link_model) + link_model_values = [ + link_model(id=next_id + idx, **v) for idx, v in enumerate(link_raw_values) + ] + + insert_stmt = await _get_insert_func(session, model=link_model) + await session.stream_scalars(insert_stmt, link_model_values) + + +async def _bulk_embedded_delete( + session: AsyncSession, + relationship: Any, + data: list[SQLModel], + link_model: type[_SQLModelMeta] | None, +): + """Deletes the embedded records of the given parent records for the given relationship + + Args: + session: the database session + relationship: the relationship whose embedded records are to be deleted for the given records + link_model: the model for the through table + """ + relationship_props = relationship.property # type: RelationshipProperty + relationship_model = relationship_props.mapper.class_ + + parent_id_field_name = relationship_props.primaryjoin.left.name + parent_foreign_keys = [getattr(item, parent_id_field_name) for item in data] + + if link_model is None: + reverse_foreign_key_field_name = relationship_props.primaryjoin.right.name + reverse_foreign_key_field = getattr( + relationship_model, reverse_foreign_key_field_name + ) + await session.stream( + delete(relationship_model).where( + reverse_foreign_key_field.in_(parent_foreign_keys) + ) + ) + else: + reverse_foreign_key_field = getattr(link_model, parent_id_field_name) + await session.stream( + delete(link_model).where(reverse_foreign_key_field.in_(parent_foreign_keys)) + ) + + +async def _get_nextid(session: AsyncSession, model: type[_SQLModelMeta]): + """Gets the next id generator for the given model + + It returns a generator for the auto-incremented integer ID + + Args: + session: the database session + model: the model under consideration + + Returns: + a generator for the auto-incremented integer ID for the given model + """ + # compute the next id auto-incremented + next_id = await session.scalar(func.max(model.id)) + next_id = (next_id or 0) + 1 + return next_id + + +def _flatten_list(data: list[_T | list[_T]]) -> list[_T]: + """Flattens a list that may have lists of items at some indices + + Args: + data: the list to flatten + + Returns: + the flattened list + """ + results = [] + for item in data: + if isinstance(item, Iterable) and not isinstance(item, Mapping): + results += _flatten_list(item) + else: + results.append(item) + + return results + + +def _as_list(value: Any) -> list: + """Wraps the value in a list if it is not an iterable + + Args: + value: the value to wrap in a list if it is not one + + Returns: + the value as a list if it is not already one + """ + if isinstance(value, list): + return value + elif isinstance(value, Iterable) and not isinstance(value, Mapping): + return list(value) + return [value] + + def _get_relational_updates(model: type[_SQLModelMeta], updates: dict) -> dict: """Gets the updates that are affect only the relationships on this model From 7b93782050cc00f5f9e5313b912bd3b99489aa2e Mon Sep 17 00:00:00 2001 From: Martin Ahindura Date: Thu, 5 Jun 2025 11:07:03 +0300 Subject: [PATCH 05/10] Fix 'AttributeError: 'int' object has no attribute 'id'' with SQL update --- nqlstore/_sql.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nqlstore/_sql.py b/nqlstore/_sql.py index 44793d3..93a9f61 100644 --- a/nqlstore/_sql.py +++ b/nqlstore/_sql.py @@ -698,7 +698,7 @@ async def _update_embedded_fields( link_model=link_model, payload=v, ) - + # FIXME: Should the added records be updated with their embedded values? # update the updated parents session.add_all(records) @@ -776,7 +776,7 @@ async def _bulk_insert_through_table_data( parent_fk_field_name: getattr(parent, parent_id_field_name), child_fk_field_name: getattr(child, child_id_field_name), } - for parent, children in enumerate(parent_embedded_map) + for parent, children in parent_embedded_map for child in children ] @@ -855,7 +855,7 @@ def _flatten_list(data: list[_T | list[_T]]) -> list[_T]: results = [] for item in data: if isinstance(item, Iterable) and not isinstance(item, Mapping): - results += _flatten_list(item) + results += list(item) else: results.append(item) From fa2f5fd8c879aace5f502eeb38c35898510de758 Mon Sep 17 00:00:00 2001 From: Martin Ahindura Date: Thu, 5 Jun 2025 14:19:35 +0300 Subject: [PATCH 06/10] Fix failing tests --- examples/blog/main.py | 7 ++++++- examples/blog/models.py | 4 ++-- examples/blog/test_main.py | 34 ++++++++++++++++++++++------------ 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/examples/blog/main.py b/examples/blog/main.py index 7c6bb7b..7390f45 100644 --- a/examples/blog/main.py +++ b/examples/blog/main.py @@ -112,8 +112,13 @@ async def search( if redis: # redis's regex search is not mature so we use its full text search + # Unfortunately, redis search does not permit us to search fields that are arrays. redis_query = [ - (_get_redis_field(RedisPost, k) % f"*{v}*") + ( + (_get_redis_field(RedisPost, k) == f"{v}") + if k == "tags.title" + else (_get_redis_field(RedisPost, k) % f"*{v}*") + ) for k, v in query_dict.items() ] results += await redis.find(RedisPost, *redis_query) diff --git a/examples/blog/models.py b/examples/blog/models.py index dc601b2..e03688b 100644 --- a/examples/blog/models.py +++ b/examples/blog/models.py @@ -21,7 +21,7 @@ "MongoPost", Post, embedded_models={ - "author": MongoAuthor, + "author": MongoAuthor | None, "comments": list[MongoComment], "tags": list[MongoTag], }, @@ -39,7 +39,7 @@ "RedisPost", Post, embedded_models={ - "author": RedisAuthor, + "author": RedisAuthor | None, "comments": list[RedisComment], "tags": list[RedisTag], }, diff --git a/examples/blog/test_main.py b/examples/blog/test_main.py index 5286396..cc15e0a 100644 --- a/examples/blog/test_main.py +++ b/examples/blog/test_main.py @@ -199,22 +199,25 @@ async def test_update_redis_post( redis_store: RedisStore, redis_posts: list[RedisPost], index: int, + freezer, ): """PUT to /posts/{id} updates the redis post of given id and returns updated version""" + timestamp = datetime.now().isoformat() with client_with_redis as client: post = redis_posts[index] + post_dict = post.model_dump(mode="json", exclude_none=True, exclude_unset=True) id_ = post.id update = { - "name": "some other name", - "todos": [ - *post.tags, + "title": "some other title", + "tags": [ + *post_dict.get("tags", []), {"title": "another one"}, {"title": "another one again"}, ], - "comments": [*post.comments, *COMMENT_LIST[index:]], + "comments": [*post_dict.get("comments", []), *COMMENT_LIST[index:]], } - response = client.put(f"/posts/{id_}", json=update) + response = client.put(f"/posts/{id_}", json=update, headers=_HEADERS) got = response.json() expected = { @@ -226,6 +229,8 @@ async def test_update_redis_post( "id": final["id"], "author": final["author"], "pk": final["pk"], + "created_at": timestamp, + "updated_at": timestamp, } for raw, final in zip(update["comments"], got["comments"]) ], @@ -272,22 +277,25 @@ async def test_update_mongo_post( mongo_store: MongoStore, mongo_posts: list[MongoPost], index: int, + freezer, ): """PUT to /posts/{id} updates the mongo post of given id and returns updated version""" + timestamp = datetime.now().isoformat() with client_with_mongo as client: post = mongo_posts[index] + post_dict = post.model_dump(mode="json", exclude_none=True, exclude_unset=True) id_ = post.id update = { - "name": "some other name", - "todos": [ - *post.tags, + "title": "some other title", + "tags": [ + *post_dict.get("tags", []), {"title": "another one"}, {"title": "another one again"}, ], - "comments": [*post.comments, *COMMENT_LIST[index:]], + "comments": [*post_dict.get("comments", []), *COMMENT_LIST[index:]], } - response = client.put(f"/posts/{id_}", json=update) + response = client.put(f"/posts/{id_}", json=update, headers=_HEADERS) got = response.json() expected = { @@ -297,6 +305,8 @@ async def test_update_mongo_post( { **raw, "author": final["author"], + "created_at": timestamp, + "updated_at": timestamp, } for raw, final in zip(update["comments"], got["comments"]) ], @@ -545,14 +555,14 @@ async def test_search_sql_by_tag( @pytest.mark.asyncio -@pytest.mark.parametrize("q", _TAG_SEARCH_TERMS) +@pytest.mark.parametrize("q", ["random", "another one", "another one again"]) async def test_search_redis_by_tag( client_with_redis: TestClient, redis_store: RedisStore, redis_posts: list[RedisPost], q: str, ): - """GET /posts?tag={} gets all redis posts with tag containing search item""" + """GET /posts?tag={} gets all redis posts with tag containing search item. Partial searches nit supported.""" with client_with_redis as client: response = client.get(f"/posts?tag={q}") From 394706a275f7e8d3cac44ea78fe1af9a7949525c Mon Sep 17 00:00:00 2001 From: Martin Ahindura Date: Fri, 6 Jun 2025 23:26:51 +0300 Subject: [PATCH 07/10] Fix dependency errors in python 3.13 --- .github/workflows/ci.yml | 3 +-- .gitignore | 1 + CONTRIBUTING.md | 2 +- nqlstore/_sql.py | 26 ++++++++++++++------------ pyproject.toml | 4 ++-- 5 files changed, 19 insertions(+), 17 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ddb25cd..f0c2f53 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -107,8 +107,7 @@ jobs: cd $GITHUB_WORKSPACE/examples/${{ matrix.example_app }} python -m pip install --upgrade pip python --version - pip install -r requirements.txt - pip install -U ../.."[all]" + pip install -U ../.."[all,test]" black --check . pytest . diff --git a/.gitignore b/.gitignore index 2f02596..f67c6d6 100644 --- a/.gitignore +++ b/.gitignore @@ -135,6 +135,7 @@ venv/ ENV/ env.bak/ venv.bak/ +.venv3_13 # Spyder project settings .spyderproject diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e3e541b..0ab99fc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -77,7 +77,7 @@ By contributing, you agree that your contributions will be licensed under its MI - Install the dependencies ```bash - pip install -r requirements.txt + pip install -r ."[all,test]" ``` - Run the pre-commit installation diff --git a/nqlstore/_sql.py b/nqlstore/_sql.py index 93a9f61..18346bf 100644 --- a/nqlstore/_sql.py +++ b/nqlstore/_sql.py @@ -47,7 +47,6 @@ class _SQLModelMeta(_SQLModel): """dict of (name, Field) that have associated relationships""" @classmethod - @property def __relational_fields__(cls) -> dict[str, Any]: """dict of (name, Field) that have associated relationships""" @@ -91,7 +90,7 @@ def model_dump( warnings=warnings, serialize_as_any=serialize_as_any, ) - relations_mappers = self.__class__.__relational_fields__ + relations_mappers = self.__class__.__relational_fields__() for k, field in relations_mappers.items(): if exclude is None or k not in exclude: try: @@ -143,7 +142,7 @@ async def insert( parsed_items = [ v if isinstance(v, model) else model.model_validate(v) for v in items ] - relations_mapper = model.__relational_fields__ + relations_mapper = model.__relational_fields__() async with AsyncSession(self._engine) as session: insert_stmt = await _get_insert_func(session, model=model) @@ -356,7 +355,7 @@ def _get_relational_filters( Returns: list of filters that are concerned with relationships on this model """ - relationships = list(model.__relational_fields__.values()) + relationships = list(model.__relational_fields__().values()) targets = [v.property.target for v in relationships] plain_filters = [ item @@ -378,7 +377,7 @@ def _get_non_relational_filters( Returns: list of filters that are NOT concerned with relationships on this model """ - targets = [v.property.target for v in model.__relational_fields__.values()] + targets = [v.property.target for v in model.__relational_fields__().values()] return [ item for item in filters @@ -494,7 +493,10 @@ def _embed_value( # create child child = relationship_model.model_validate(value) # update nested relationships - for field_name, field_type in relationship_model.__relational_fields__.items(): + for ( + field_name, + field_type, + ) in relationship_model.__relational_fields__().items(): if isinstance(value, dict): nested_related_value = value.get(field_name) else: @@ -522,7 +524,7 @@ def _embed_value( for ( field_name, field_type, - ) in relationship_model.__relational_fields__.items(): + ) in relationship_model.__relational_fields__().items(): if isinstance(v, dict): nested_related_value = v.get(field_name) else: @@ -549,7 +551,7 @@ def _embed_value( for ( field_name, field_type, - ) in relationship_model.__relational_fields__.items(): + ) in relationship_model.__relational_fields__().items(): if isinstance(v, dict): nested_related_value = v.get(field_name) else: @@ -682,7 +684,7 @@ async def _update_embedded_fields( updates: the updates to add to each record """ embedded_updates = _get_relational_updates(model, updates) - relations_mapper = model.__relational_fields__ + relations_mapper = model.__relational_fields__() for k, v in embedded_updates.items(): relationship = relations_mapper[k] link_model = model.__sqlmodel_relationships__[k].link_model @@ -888,7 +890,7 @@ def _get_relational_updates(model: type[_SQLModelMeta], updates: dict) -> dict: Returns: a dict with only updates concerning the relationships of the given model """ - return {k: v for k, v in updates.items() if k in model.__relational_fields__} + return {k: v for k, v in updates.items() if k in model.__relational_fields__()} def _get_non_relational_updates(model: type[_SQLModelMeta], updates: dict) -> dict: @@ -901,7 +903,7 @@ def _get_non_relational_updates(model: type[_SQLModelMeta], updates: dict) -> di Returns: a dict with only updates that do not affect relationships on this model """ - return {k: v for k, v in updates.items() if k not in model.__relational_fields__} + return {k: v for k, v in updates.items() if k not in model.__relational_fields__()} async def _find( @@ -926,7 +928,7 @@ async def _find( Returns: the records tha match the given filters """ - relations = list(model.__relational_fields__.values()) + relations = list(model.__relational_fields__().values()) # eagerly load all relationships so that no validation errors occur due # to missing session if there is an attempt to load them lazily later diff --git a/pyproject.toml b/pyproject.toml index 582237b..63fa897 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,13 +33,13 @@ sql = [ "greenlet~=3.1.1", ] mongo = ["beanie~=1.29.0"] -redis = ["redis-om~=0.3.3"] +redis = ["redis-om~=0.3.3,<0.3.4"] all = [ "sqlmodel~=0.0.22", "aiosqlite~=0.20.0", "greenlet~=3.1.1", "beanie~=1.29.0", - "redis-om~=0.3.3", + "redis-om~=0.3.3,<0.3.4", ] [project.urls] From 8aaab278048d9fe8e19c3c79bab3617862d8e5c0 Mon Sep 17 00:00:00 2001 From: Martin Ahindura Date: Fri, 6 Jun 2025 23:32:15 +0300 Subject: [PATCH 08/10] Fix failing CI tests for examples --- .github/workflows/ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f0c2f53..ddb25cd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -107,7 +107,8 @@ jobs: cd $GITHUB_WORKSPACE/examples/${{ matrix.example_app }} python -m pip install --upgrade pip python --version - pip install -U ../.."[all,test]" + pip install -r requirements.txt + pip install -U ../.."[all]" black --check . pytest . From 5f1ae305a4b14ae761966f00300d3782bf97e592 Mon Sep 17 00:00:00 2001 From: Martin Ahindura Date: Sat, 7 Jun 2025 13:22:02 +0300 Subject: [PATCH 09/10] Fix 'TypeError: field 'TodoList.description' was not initialized with a Field() or Relationship()' --- examples/todos/schemas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/todos/schemas.py b/examples/todos/schemas.py index 3fa0d3c..d765343 100644 --- a/examples/todos/schemas.py +++ b/examples/todos/schemas.py @@ -9,7 +9,7 @@ class TodoList(BaseModel): """A list of Todos""" name: str = Field(index=True, full_text_search=True) - description: str | None = None + description: str | None = Field(default=None) todos: list["Todo"] = Relationship(back_populates="parent", default=[]) From af7121a49694f8f9ca4fe989e868f05361f803f0 Mon Sep 17 00:00:00 2001 From: Martin Ahindura Date: Sat, 7 Jun 2025 13:25:40 +0300 Subject: [PATCH 10/10] Add tests for blog example in github CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ddb25cd..8b31786 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,7 +69,7 @@ jobs: strategy: matrix: python-version: [ "3.10", "3.11", "3.12", "3.13" ] - example_app: ["todos"] + example_app: ["todos", "blog"] services: mongodb: