diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ddb25cd..8b31786 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,7 +69,7 @@ jobs: strategy: matrix: python-version: [ "3.10", "3.11", "3.12", "3.13" ] - example_app: ["todos"] + example_app: ["todos", "blog"] services: mongodb: diff --git a/.gitignore b/.gitignore index 2f02596..f67c6d6 100644 --- a/.gitignore +++ b/.gitignore @@ -135,6 +135,7 @@ venv/ ENV/ env.bak/ venv.bak/ +.venv3_13 # Spyder project settings .spyderproject diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e3e541b..0ab99fc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -77,7 +77,7 @@ By contributing, you agree that your contributions will be licensed under its MI - Install the dependencies ```bash - pip install -r requirements.txt + pip install -r ."[all,test]" ``` - Run the pre-commit installation diff --git a/LIMITATIONS.md b/LIMITATIONS.md new file mode 100644 index 0000000..f522e1b --- /dev/null +++ b/LIMITATIONS.md @@ -0,0 +1,15 @@ +# Limitations + +## Filtering + +### Redis + +- Mongo-style regular expression filtering is not supported. + This is because native redis regular expression filtering is limited to the most basic text based search. + +## Update Operation + +### SQL + +- Even though one can update a model to theoretically infinite number of levels deep, + the returned results can only contain 1-level-deep nested models and no more. diff --git a/README.md b/README.md index 7fc459b..617fd6c 100644 --- a/README.md +++ b/README.md @@ -336,6 +336,11 @@ libraries = await redis_store.delete( - [ ] Add documentation site +## Limitations + +This library is limited in some specific cases. +Read through the [`LIMITATIONS.md`](./LIMITATIONS.md) file for more. + ## Contributions Contributions are welcome. The docs have to maintained, the code has to be made cleaner, more idiomatic and faster, diff --git a/examples/blog/conftest.py b/examples/blog/conftest.py index 9f1774e..8aae68e 100644 --- a/examples/blog/conftest.py +++ b/examples/blog/conftest.py @@ -7,7 +7,7 @@ import pytest_asyncio import pytest_mock from fastapi.testclient import TestClient -from models import ( # SqlAuthor, +from models import ( MongoAuthor, MongoComment, MongoInternalAuthor, diff --git a/examples/blog/main.py b/examples/blog/main.py index 18e42ee..7390f45 100644 --- a/examples/blog/main.py +++ b/examples/blog/main.py @@ -13,7 +13,7 @@ from fastapi.security import OAuth2PasswordRequestForm from models import MongoPost, RedisPost, SqlInternalAuthor, SqlPost from pydantic import BaseModel -from schemas import InternalAuthor, Post, TokenResponse +from schemas import InternalAuthor, PartialPost, Post, TokenResponse from stores import MongoStoreDep, RedisStoreDep, SqlStoreDep, clear_stores _ACCESS_TOKEN_EXPIRE_MINUTES = 30 @@ -112,8 +112,13 @@ async def search( if redis: # redis's regex search is not mature so we use its full text search + # Unfortunately, redis search does not permit us to search fields that are arrays. redis_query = [ - (_get_redis_field(RedisPost, k) % f"*{v}*") + ( + (_get_redis_field(RedisPost, k) == f"{v}") + if k == "tags.title" + else (_get_redis_field(RedisPost, k) % f"*{v}*") + ) for k, v in query_dict.items() ] results += await redis.find(RedisPost, *redis_query) @@ -194,7 +199,7 @@ async def update_one( mongo: MongoStoreDep, current_user: CurrentUserDep, id_: int | str, - payload: Post, + payload: PartialPost, ): """Update a post""" results = [] diff --git a/examples/blog/models.py b/examples/blog/models.py index 367af41..e03688b 100644 --- a/examples/blog/models.py +++ b/examples/blog/models.py @@ -21,7 +21,7 @@ "MongoPost", Post, embedded_models={ - "author": MongoAuthor, + "author": MongoAuthor | None, "comments": list[MongoComment], "tags": list[MongoTag], }, @@ -39,7 +39,7 @@ "RedisPost", Post, embedded_models={ - "author": RedisAuthor, + "author": RedisAuthor | None, "comments": list[RedisComment], "tags": list[RedisTag], }, @@ -47,7 +47,6 @@ # sqlite models SqlInternalAuthor = SQLModel("SqlInternalAuthor", InternalAuthor) -# SqlAuthor = SQLModel("SqlAuthor", Author, table=False) SqlComment = SQLModel( "SqlComment", Comment, relationships={"author": SqlInternalAuthor | None} ) diff --git a/examples/blog/schemas.py b/examples/blog/schemas.py index d66f12a..8fac64f 100644 --- a/examples/blog/schemas.py +++ b/examples/blog/schemas.py @@ -3,7 +3,7 @@ from datetime import datetime from pydantic import BaseModel -from utils import current_timestamp +from utils import Partial, current_timestamp from nqlstore import Field, Relationship @@ -33,15 +33,8 @@ class Post(BaseModel): disable_on_redis=True, ) author: Author | None = Relationship(default=None) - comments: list["Comment"] = Relationship( - default=[], - disable_on_redis=True, - ) - tags: list["Tag"] = Relationship( - default=[], - link_model="TagLink", - disable_on_redis=True, - ) + comments: list["Comment"] = Relationship(default=[]) + tags: list["Tag"] = Relationship(default=[], link_model="TagLink") created_at: str = Field(index=True, default_factory=current_timestamp) updated_at: str = Field(index=True, default_factory=current_timestamp) @@ -89,7 +82,7 @@ class TagLink(BaseModel): class Tag(BaseModel): """The tags to help searching for posts""" - title: str = Field(index=True, unique=True, full_text_search=True) + title: str = Field(index=True, unique=True) class TokenResponse(BaseModel): @@ -97,3 +90,7 @@ class TokenResponse(BaseModel): access_token: str token_type: str + + +# Partial models +PartialPost = Partial("PartialPost", Post) diff --git a/examples/blog/test_main.py b/examples/blog/test_main.py index 95e7cbb..cc15e0a 100644 --- a/examples/blog/test_main.py +++ b/examples/blog/test_main.py @@ -11,6 +11,7 @@ _TITLE_SEARCH_TERMS = ["ho", "oo", "work"] _TAG_SEARCH_TERMS = ["art", "om"] +_HEADERS = {"Authorization": f"Bearer {ACCESS_TOKEN}"} @pytest.mark.asyncio @@ -21,9 +22,7 @@ async def test_create_sql_post( """POST to /posts creates a post in sql and returns it""" timestamp = datetime.now().isoformat() with client_with_sql as client: - response = client.post( - "/posts", json=post, headers={"Authorization": f"Bearer {ACCESS_TOKEN}"} - ) + response = client.post("/posts", json=post, headers=_HEADERS) got = response.json() post_id = got["id"] @@ -61,12 +60,12 @@ async def test_create_redis_post( client_with_redis: TestClient, redis_store: RedisStore, post: dict, + freezer, ): """POST to /posts creates a post in redis and returns it""" + timestamp = datetime.now().isoformat() with client_with_redis as client: - response = client.post( - "/posts", json=post, headers={"Authorization": f"Bearer {ACCESS_TOKEN}"} - ) + response = client.post("/posts", json=post, headers=_HEADERS) got = response.json() post_id = got["id"] @@ -75,7 +74,8 @@ async def test_create_redis_post( expected = { "id": post_id, "title": post["title"], - "content": post.get("content"), + "content": post.get("content", ""), + "author": {**got["author"], **AUTHOR}, "pk": post_id, "tags": [ { @@ -86,6 +86,8 @@ async def test_create_redis_post( for raw, resp in zip(raw_tags, resp_tags) ], "comments": [], + "created_at": timestamp, + "updated_at": timestamp, } db_query = {"id": {"$eq": post_id}} @@ -102,10 +104,12 @@ async def test_create_mongo_post( client_with_mongo: TestClient, mongo_store: MongoStore, post: dict, + freezer, ): """POST to /posts creates a post in redis and returns it""" + timestamp = datetime.now().isoformat() with client_with_mongo as client: - response = client.post("/posts", json=post) + response = client.post("/posts", json=post, headers=_HEADERS) got = response.json() post_id = got["id"] @@ -113,14 +117,12 @@ async def test_create_mongo_post( expected = { "id": post_id, "title": post["title"], - "content": post.get("content"), - "tags": [ - { - **raw, - } - for raw in raw_tags - ], + "content": post.get("content", ""), + "author": {"name": AUTHOR["name"]}, + "tags": raw_tags, "comments": [], + "created_at": timestamp, + "updated_at": timestamp, } db_query = {"_id": {"$eq": ObjectId(post_id)}} @@ -138,22 +140,26 @@ async def test_update_sql_post( sql_store: SQLStore, sql_posts: list[SqlPost], index: int, + freezer, ): """PUT to /posts/{id} updates the sql post of given id and returns updated version""" + timestamp = datetime.now().isoformat() with client_with_sql as client: post = sql_posts[index] + post_dict = post.model_dump(mode="json", exclude_none=True, exclude_unset=True) id_ = post.id update = { - "name": "some other name", - "todos": [ - *post.tags, + **post_dict, + "title": "some other title", + "tags": [ + *post_dict["tags"], {"title": "another one"}, {"title": "another one again"}, ], - "comments": [*post.comments, *COMMENT_LIST[index:]], + "comments": [*post_dict["comments"], *COMMENT_LIST[index:]], } - response = client.put(f"/posts/{id_}", json=update) + response = client.put(f"/posts/{id_}", json=update, headers=_HEADERS) got = response.json() expected = { @@ -164,8 +170,9 @@ async def test_update_sql_post( **raw, "id": final["id"], "post_id": final["post_id"], - "author": final["author"], - "author_id": final["author_id"], + "author_id": 1, + "created_at": timestamp, + "updated_at": timestamp, } for raw, final in zip(update["comments"], got["comments"]) ], @@ -192,22 +199,25 @@ async def test_update_redis_post( redis_store: RedisStore, redis_posts: list[RedisPost], index: int, + freezer, ): """PUT to /posts/{id} updates the redis post of given id and returns updated version""" + timestamp = datetime.now().isoformat() with client_with_redis as client: post = redis_posts[index] + post_dict = post.model_dump(mode="json", exclude_none=True, exclude_unset=True) id_ = post.id update = { - "name": "some other name", - "todos": [ - *post.tags, + "title": "some other title", + "tags": [ + *post_dict.get("tags", []), {"title": "another one"}, {"title": "another one again"}, ], - "comments": [*post.comments, *COMMENT_LIST[index:]], + "comments": [*post_dict.get("comments", []), *COMMENT_LIST[index:]], } - response = client.put(f"/posts/{id_}", json=update) + response = client.put(f"/posts/{id_}", json=update, headers=_HEADERS) got = response.json() expected = { @@ -219,6 +229,8 @@ async def test_update_redis_post( "id": final["id"], "author": final["author"], "pk": final["pk"], + "created_at": timestamp, + "updated_at": timestamp, } for raw, final in zip(update["comments"], got["comments"]) ], @@ -265,22 +277,25 @@ async def test_update_mongo_post( mongo_store: MongoStore, mongo_posts: list[MongoPost], index: int, + freezer, ): """PUT to /posts/{id} updates the mongo post of given id and returns updated version""" + timestamp = datetime.now().isoformat() with client_with_mongo as client: post = mongo_posts[index] + post_dict = post.model_dump(mode="json", exclude_none=True, exclude_unset=True) id_ = post.id update = { - "name": "some other name", - "todos": [ - *post.tags, + "title": "some other title", + "tags": [ + *post_dict.get("tags", []), {"title": "another one"}, {"title": "another one again"}, ], - "comments": [*post.comments, *COMMENT_LIST[index:]], + "comments": [*post_dict.get("comments", []), *COMMENT_LIST[index:]], } - response = client.put(f"/posts/{id_}", json=update) + response = client.put(f"/posts/{id_}", json=update, headers=_HEADERS) got = response.json() expected = { @@ -290,6 +305,8 @@ async def test_update_mongo_post( { **raw, "author": final["author"], + "created_at": timestamp, + "updated_at": timestamp, } for raw, final in zip(update["comments"], got["comments"]) ], @@ -538,14 +555,14 @@ async def test_search_sql_by_tag( @pytest.mark.asyncio -@pytest.mark.parametrize("q", _TAG_SEARCH_TERMS) +@pytest.mark.parametrize("q", ["random", "another one", "another one again"]) async def test_search_redis_by_tag( client_with_redis: TestClient, redis_store: RedisStore, redis_posts: list[RedisPost], q: str, ): - """GET /posts?tag={} gets all redis posts with tag containing search item""" + """GET /posts?tag={} gets all redis posts with tag containing search item. Partial searches nit supported.""" with client_with_redis as client: response = client.get(f"/posts?tag={q}") diff --git a/examples/blog/utils.py b/examples/blog/utils.py index 3f7d2a1..4a917c0 100644 --- a/examples/blog/utils.py +++ b/examples/blog/utils.py @@ -1,6 +1,16 @@ """Some random utilities for the app""" +import copy +import sys from datetime import datetime +from typing import Any, Literal, Optional, TypeVar, get_args + +from pydantic import BaseModel, create_model +from pydantic.main import IncEx + +from nqlstore._field import FieldInfo + +_T = TypeVar("_T", bound=BaseModel) def current_timestamp() -> str: @@ -10,3 +20,42 @@ def current_timestamp() -> str: string of the current datetime """ return datetime.now().isoformat() + + +def Partial(name: str, model: type[_T]) -> type[_T]: + """Creates a partial schema from another schema, with all fields optional + + Args: + name: the name of the model + model: the original model + + Returns: + A new model with all its fields optional + """ + fields = { + k: (_make_optional(v.annotation), None) + for k, v in model.model_fields.items() # type: str, FieldInfo + } + + return create_model( + name, + # module of the calling function + __module__=sys._getframe(1).f_globals["__name__"], + __doc__=model.__doc__, + __base__=(model,), + **fields, + ) + + +def _make_optional(type_: type) -> type: + """Makes a type optional if not optional + + Args: + type_: the type to make optional + + Returns: + the optional type + """ + if type(None) in get_args(type_): + return type_ + return type_ | None diff --git a/examples/todos/schemas.py b/examples/todos/schemas.py index 3fa0d3c..d765343 100644 --- a/examples/todos/schemas.py +++ b/examples/todos/schemas.py @@ -9,7 +9,7 @@ class TodoList(BaseModel): """A list of Todos""" name: str = Field(index=True, full_text_search=True) - description: str | None = None + description: str | None = Field(default=None) todos: list["Todo"] = Relationship(back_populates="parent", default=[]) diff --git a/nqlstore/_compat.py b/nqlstore/_compat.py index bfc7d83..f22a9ba 100644 --- a/nqlstore/_compat.py +++ b/nqlstore/_compat.py @@ -41,7 +41,7 @@ sql imports; and their default if sqlmodel is missing """ try: - from sqlalchemy import Column, Table + from sqlalchemy import Column, Table, func from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.ext.asyncio import create_async_engine @@ -65,6 +65,7 @@ from sqlmodel.main import IncEx, NoArgAnyCallable, OnDeleteType from sqlmodel.main import RelationshipInfo as _RelationshipInfo except ImportError: + import types from typing import Mapping, Optional, Sequence from typing import Set from typing import Set as _ColumnExpressionArgument @@ -90,6 +91,8 @@ subqueryload = lambda *a, **kwargs: dict(**kwargs) DetachedInstanceError = RuntimeError IncEx = Set[Any] | dict + func = types.ModuleType("func") + func.max = lambda *a, **kwargs: dict(**kwargs) class _SqlFieldInfo(_FieldInfo): ... diff --git a/nqlstore/_sql.py b/nqlstore/_sql.py index 942bd52..18346bf 100644 --- a/nqlstore/_sql.py +++ b/nqlstore/_sql.py @@ -1,8 +1,9 @@ """SQL implementation""" +import copy import sys from collections.abc import Mapping, MutableMapping -from typing import Any, Dict, Iterable, Literal, TypeVar, Union +from typing import Any, Dict, Iterable, Literal, Sequence, TypeVar, Union from pydantic import create_model from pydantic.main import ModelT @@ -22,6 +23,7 @@ _SQLModel, create_async_engine, delete, + func, insert, pg_insert, select, @@ -33,8 +35,86 @@ from .query.parsers import QueryParser from .query.selectors import QuerySelector -_T = TypeVar("_T", bound=_SQLModel) _Filter = _ColumnExpressionArgument[bool] | bool +_T = TypeVar("_T") + + +class _SQLModelMeta(_SQLModel): + """The base class for all SQL models""" + + id: int | None = Field(default=None, primary_key=True) + __rel_field_cache__: dict = {} + """dict of (name, Field) that have associated relationships""" + + @classmethod + def __relational_fields__(cls) -> dict[str, Any]: + """dict of (name, Field) that have associated relationships""" + + cls_fullname = f"{cls.__module__}.{cls.__qualname__}" + try: + return cls.__rel_field_cache__[cls_fullname] + except KeyError: + value = { + k: v + for k, v in cls.__mapper__.all_orm_descriptors.items() + if isinstance(v.property, RelationshipProperty) + } + cls.__rel_field_cache__[cls_fullname] = value + return value + + def model_dump( + self, + *, + mode: Union[Literal["json", "python"], str] = "python", + include: IncEx = None, + exclude: IncEx = None, + context: Union[Dict[str, Any], None] = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: Union[bool, Literal["none", "warn", "error"]] = True, + serialize_as_any: bool = False, + ) -> Dict[str, Any]: + data = super().model_dump( + mode=mode, + include=include, + exclude=exclude, + context=context, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + round_trip=round_trip, + warnings=warnings, + serialize_as_any=serialize_as_any, + ) + relations_mappers = self.__class__.__relational_fields__() + for k, field in relations_mappers.items(): + if exclude is None or k not in exclude: + try: + value = getattr(self, k, None) + except DetachedInstanceError: + # ignore lazy loaded values + continue + + if value is not None or not exclude_none: + data[k] = _serialize_embedded( + value, + field=field, + mode=mode, + context=context, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + round_trip=round_trip, + warnings=warnings, + serialize_as_any=serialize_as_any, + ) + + return data class SQLStore(BaseStore): @@ -44,7 +124,9 @@ def __init__(self, uri: str, parser: QueryParser | None = None, **kwargs): super().__init__(uri, parser=parser, **kwargs) self._engine = create_async_engine(uri, **kwargs) - async def register(self, models: list[type[_T]], checkfirst: bool = True): + async def register( + self, models: list[type[_SQLModelMeta]], checkfirst: bool = True + ): tables = [v.__table__ for v in models if hasattr(v, "__table__")] async with self._engine.begin() as conn: await conn.run_sync( @@ -52,17 +134,19 @@ async def register(self, models: list[type[_T]], checkfirst: bool = True): ) async def insert( - self, model: type[_T], items: Iterable[_T | dict], **kwargs - ) -> list[_T]: + self, + model: type[_SQLModelMeta], + items: Iterable[_SQLModelMeta | dict], + **kwargs, + ) -> list[_SQLModelMeta]: parsed_items = [ v if isinstance(v, model) else model.model_validate(v) for v in items ] - relations_mapper = _get_relations_mapper(model) + relations_mapper = model.__relational_fields__() async with AsyncSession(self._engine) as session: - insert_func = await _get_insert(session) - stmt = insert_func(model).returning(model) - cursor = await session.stream_scalars(stmt, parsed_items) + insert_stmt = await _get_insert_func(session, model=model) + cursor = await session.stream_scalars(insert_stmt, parsed_items) results = await cursor.all() result_ids = [v.id for v in results] @@ -76,40 +160,17 @@ async def insert( for idx, record in enumerate(items): parent = results[idx] raw_value = _get_key_or_prop(record, k) - parent_partial, embedded_value = _parse_embedded( - raw_value, field, parent - ) + embedded_value = _embed_value(parent, field, raw_value) + if isinstance(embedded_value, _SQLModel): embedded_values.append(embedded_value) elif isinstance(embedded_value, Iterable): embedded_values += embedded_value - # for many-to-one relationships, the parent - # also needs to be updated. - # add the partial update of the parent - for key, val in parent_partial.items(): - setattr(parent, key, val) - # insert the related items if len(embedded_values) > 0: field_model = field.property.mapper.class_ - - try: - # PostgreSQL and SQLite support on_conflict_do_nothing - embed_stmt = ( - insert_func(field_model) - .on_conflict_do_nothing() - .returning(field_model) - ) - except AttributeError: - # MySQL supports prefix("IGNORE") - # Other databases might fail at this point - embed_stmt = ( - insert_func(field_model) - .prefix_with("IGNORE", dialect="mysql") - .returning(field_model) - ) - + embed_stmt = await _get_insert_func(session, model=field_model) await session.stream_scalars(embed_stmt, embedded_values) # update the updated parents @@ -121,232 +182,86 @@ async def insert( async def find( self, - model: type[_T], + model: type[_SQLModelMeta], *filters: _Filter, query: QuerySelector | None = None, skip: int = 0, limit: int | None = None, sort: tuple[_ColumnExpressionOrStrLabelArgument[Any]] = (), **kwargs, - ) -> list[_T]: + ) -> list[_SQLModelMeta]: async with AsyncSession(self._engine) as session: if query: filters = (*filters, *self._parser.to_sql(model, query=query)) - - relations = _get_relations(model) - - # eagerly load all relationships so that no validation errors occur due - # to missing session if there is an attempt to load them lazily later - eager_load_opts = [subqueryload(v) for v in relations] - - filtered_relations = _get_filtered_relations( - filters=filters, - relations=relations, + return await _find( + session, model, *filters, skip=skip, limit=limit, sort=sort ) - # Note that we need to treat relations that are referenced in the filters - # differently from those that are not. This is because filtering basing on a relationship - # requires the use of an inner join. Yet an inner join automatically excludes rows - # that are have null for a given relationship. - # - # An outer join on the other hand would just return all the rows in the left table. - # We thus need to do an inner join on tables that are being filtered. - stmt = select(model) - for rel in filtered_relations: - stmt = stmt.join_from(model, rel) - - cursor = await session.stream_scalars( - stmt.options(*eager_load_opts) - .where(*filters) - .limit(limit) - .offset(skip) - .order_by(*sort) - ) - results = await cursor.all() - return list(results) - async def update( self, - model: type[_T], + model: type[_SQLModelMeta], *filters: _Filter, query: QuerySelector | None = None, updates: dict | None = None, **kwargs, - ) -> list[_T]: + ) -> list[_SQLModelMeta]: + updates = copy.deepcopy(updates) async with AsyncSession(self._engine) as session: if query: filters = (*filters, *self._parser.to_sql(model, query=query)) - # Construct filters that have sub queries - relations = _get_relations(model) - rel_filters, non_rel_filters = _sieve_rel_from_non_rel_filters( - filters=filters, - relations=relations, - ) - rel_filters = _to_subquery_based_filters( - model=model, - rel_filters=rel_filters, - relations=relations, + relational_filters = _get_relational_filters(model, filters) + non_relational_filters = _get_non_relational_filters(model, filters) + + # Let's update the fields that are not embedded model fields + # and return the affected results + results = await _update_non_embedded_fields( + session, + model, + *non_relational_filters, + *relational_filters, + updates=updates, ) - - # dealing with nested models in the update - relations_mapper = _get_relations_mapper(model) - embedded_updates = {} - for k in relations_mapper: - try: - embedded_updates[k] = updates.pop(k) - except KeyError: - pass - - stmt = ( - update(model) - .where(*non_rel_filters, *rel_filters) - .values(**updates) - .returning(model.__table__) - ) - - cursor = await session.stream(stmt) - raw_results = await cursor.fetchall() - results = [model.model_validate(row._mapping) for row in raw_results] result_ids = [v.id for v in results] - # insert_func = await _get_insert(session) - - for k, v in embedded_updates.items(): - field = relations_mapper[k] - field_props = field.property - field_model = field_props.mapper.class_ - # fk = foreign key - fk_field_name = field_props.primaryjoin.right.name - fk_field = getattr(field_model, fk_field_name) - parent_id_field = field_props.primaryjoin.left.name - - # get the foreign keys to use in resetting all affected - # relationships; - # get parsed embedded values so that they can replace - # the old relations. - # Note: this operation is strictly replace, not patch - embedded_values = [] - fk_values = [] - for parent in results: - parent_partial, embedded_value = _parse_embedded(v, field, parent) - if isinstance(embedded_value, Iterable): - embedded_values += embedded_value - fk_values.append(getattr(parent, parent_id_field)) - elif isinstance(embedded_value, _SQLModel): - embedded_values.append(embedded_value) - fk_values.append(getattr(parent, parent_id_field)) - - # insert the related items - if len(embedded_values) > 0: - # Reset the relationship; delete all other related items - # Currently, this operation replaces all past relations - reset_stmt = delete(field_model).where(fk_field.in_(fk_values)) - await session.stream(reset_stmt) - - # insert the latest changes - embed_stmt = insert(field_model).returning(field_model) - await session.stream_scalars(embed_stmt, embedded_values) + # Let's update the embedded fields also + await _update_embedded_fields( + session, model=model, records=results, updates=updates + ) await session.commit() - return await self.find(model, model.id.in_(result_ids)) + + refreshed_results = await self.find(model, model.id.in_(result_ids)) + return refreshed_results async def delete( self, - model: type[_T], + model: type[_SQLModelMeta], *filters: _Filter, query: QuerySelector | None = None, **kwargs, - ) -> list[_T]: + ) -> list[_SQLModelMeta]: async with AsyncSession(self._engine) as session: if query: filters = (*filters, *self._parser.to_sql(model, query=query)) deleted_items = await self.find(model, *filters) - # Construct filters that have sub queries - relations = _get_relations(model) - rel_filters, non_rel_filters = _sieve_rel_from_non_rel_filters( - filters=filters, - relations=relations, - ) - rel_filters = _to_subquery_based_filters( - model=model, - rel_filters=rel_filters, - relations=relations, - ) + relational_filters = _get_relational_filters(model, filters) + non_relational_filters = _get_non_relational_filters(model, filters) + exec_options = {} - if len(rel_filters) > 0: + if len(relational_filters) > 0: exec_options = {"is_delete_using": True} await session.stream( delete(model) - .where(*non_rel_filters, *rel_filters) + .where(*non_relational_filters, *relational_filters) .execution_options(**exec_options), ) await session.commit() return deleted_items -class _SQLModelMeta(_SQLModel): - """The base class for all SQL models""" - - id: int | None = Field(default=None, primary_key=True) - - def model_dump( - self, - *, - mode: Union[Literal["json", "python"], str] = "python", - include: IncEx = None, - exclude: IncEx = None, - context: Union[Dict[str, Any], None] = None, - by_alias: bool = False, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - round_trip: bool = False, - warnings: Union[bool, Literal["none", "warn", "error"]] = True, - serialize_as_any: bool = False, - ) -> Dict[str, Any]: - data = super().model_dump( - mode=mode, - include=include, - exclude=exclude, - context=context, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - round_trip=round_trip, - warnings=warnings, - serialize_as_any=serialize_as_any, - ) - relations_mappers = _get_relations_mapper(self.__class__) - for k, field in relations_mappers.items(): - if exclude is None or k not in exclude: - try: - value = getattr(self, k, None) - except DetachedInstanceError: - # ignore lazy loaded values - continue - - if value is not None or not exclude_none: - data[k] = _serialize_embedded( - value, - field=field, - mode=mode, - context=context, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - round_trip=round_trip, - warnings=warnings, - serialize_as_any=serialize_as_any, - ) - - return data - - def SQLModel( name: str, schema: type[ModelT], @@ -390,38 +305,6 @@ def SQLModel( ) -def _get_relations(model: type[_SQLModel]): - """Gets all the relational fields of the given model - - Args: - model: the SQL model to inspect - - Returns: - list of Fields that have associated relationships - """ - return [ - v - for v in model.__mapper__.all_orm_descriptors.values() - if isinstance(v.property, RelationshipProperty) - ] - - -def _get_relations_mapper(model: type[_SQLModel]) -> dict[str, Any]: - """Gets all the relational fields with their names of the given model - - Args: - model: the SQL model to inspect - - Returns: - dict of (name, Field) that have associated relationships - """ - return { - k: v - for k, v in model.__mapper__.all_orm_descriptors.items() - if isinstance(v.property, RelationshipProperty) - } - - def _get_filtered_tables(filters: Iterable[_Filter]) -> list[Table]: """Retrieves the tables that have been referenced in the filters @@ -455,31 +338,51 @@ def _get_filtered_relations( return [rel for rel in relations if rel.property.target in filtered_tables] -def _sieve_rel_from_non_rel_filters( - filters: Iterable[_Filter], relations: Iterable[InstrumentedAttribute[Any]] -) -> tuple[list[_Filter], list[_Filter]]: - """Separates relational filters from non-relational ones +def _get_relational_filters( + model: type[_SQLModelMeta], + filters: Iterable[_Filter], +) -> list[_Filter]: + """Gets the filters that are concerned with relationships on this model + + The filters returned are in subquery form since 'update' and 'delete' + in sqlalchemy do not have join and the only way to attach these filters + to the model is through sub queries Args: + model: the model under consideration filters: the tuple of filters to inspect - relations: all relations present on the model Returns: - tuple(rel, non_rel) where rel = list of relational filters, - and non_rel = non-relational filters + list of filters that are concerned with relationships on this model """ - rel_targets = [v.property.target for v in relations] - rel = [] - non_rel = [] - - for filter_ in filters: - operands = filter_.get_children() - if any([getattr(v, "table", None) in rel_targets for v in operands]): - rel.append(filter_) - else: - non_rel.append(filter_) + relationships = list(model.__relational_fields__().values()) + targets = [v.property.target for v in relationships] + plain_filters = [ + item + for item in filters + if any([getattr(v, "table", None) in targets for v in item.get_children()]) + ] + return _to_subquery_based_filters(model, plain_filters, relationships) - return rel, non_rel + +def _get_non_relational_filters( + model: type[_SQLModelMeta], filters: Iterable[_Filter] +) -> list[_Filter]: + """Gets the filters that are NOT concerned with relationships on this model + + Args: + model: the model under consideration + filters: the tuple of filters to inspect + + Returns: + list of filters that are NOT concerned with relationships on this model + """ + targets = [v.property.target for v in model.__relational_fields__().values()] + return [ + item + for item in filters + if not any([getattr(v, "table", None) in targets for v in item.get_children()]) + ] def _to_subquery_based_filters( @@ -556,45 +459,113 @@ def _with_value(obj: dict | Any, field: str, value: Any) -> Any: return obj -def _parse_embedded( - value: Iterable[dict | Any] | dict | Any, field: Any, parent: _SQLModel -) -> tuple[dict, Iterable[_SQLModel] | _SQLModel | None]: - """Parses embedded items that can be a single item or many into SQLModels +def _embed_value( + parent: _SQLModel, + relationship: Any, + value: Iterable[dict | Any] | dict | Any, +) -> Iterable[_SQLModel] | _SQLModel | None: + """Embeds in place a given value into the parent basing on the given relationship + + Note that the parent itself is changed to include the value Args: - value: the value to parse - field: the field on which these embedded items are - parent: the parent SQLModel to which this value is attached + parent: the model that contains the given relationships + relationship: the given relationship + value: the values correspond to the related field Returns: - tuple (parent_partial, embedded_models): where parent_partial is the partial update of the parent - and embedded_models is an iterable of SQLModel instances or a single SQLModel instance - or None if value is None + the embedded record(s) """ if value is None: - return {}, None + return None - props = field.property # type: RelationshipProperty[Any] + props = relationship.property # type: RelationshipProperty[Any] wrapper_type = props.collection_class - field_model = props.mapper.class_ - fk_field = props.primaryjoin.right.name - parent_id_field = props.primaryjoin.left.name - fk_value = getattr(parent, parent_id_field) + relationship_model = props.mapper.class_ + parent_foreign_key_field = props.primaryjoin.right.name direction = props.direction if direction == RelationshipDirection.MANYTOONE: - # # add a foreign key value to link back to parent - return {fk_field: fk_value}, field_model.model_validate(value) + related_value_id_key = props.primaryjoin.left.name + parent_foreign_key_value = value.get(related_value_id_key) + # update the foreign key value in the parent + setattr(parent, parent_foreign_key_field, parent_foreign_key_value) + # create child + child = relationship_model.model_validate(value) + # update nested relationships + for ( + field_name, + field_type, + ) in relationship_model.__relational_fields__().items(): + if isinstance(value, dict): + nested_related_value = value.get(field_name) + else: + nested_related_value = getattr(value, field_name) + + nested_related_records = _embed_value( + parent=child, relationship=field_type, value=nested_related_value + ) + setattr(child, field_name, nested_related_records) + + return child - if direction in (RelationshipDirection.ONETOMANY, RelationshipDirection.MANYTOMANY): + elif direction == RelationshipDirection.ONETOMANY: + related_value_id_key = props.primaryjoin.left.name + parent_foreign_key_value = getattr(parent, related_value_id_key) # add a foreign key values to link back to parent if issubclass(wrapper_type, (list, tuple, set)): - return {}, wrapper_type( - [ - field_model.model_validate(_with_value(v, fk_field, fk_value)) - for v in value - ] - ) + embedded_records = [] + for v in value: + child = relationship_model.model_validate( + _with_value(v, parent_foreign_key_field, parent_foreign_key_value) + ) + + # update nested relationships + for ( + field_name, + field_type, + ) in relationship_model.__relational_fields__().items(): + if isinstance(v, dict): + nested_related_value = v.get(field_name) + else: + nested_related_value = getattr(v, field_name) + + nested_related_records = _embed_value( + parent=child, + relationship=field_type, + value=nested_related_value, + ) + setattr(child, field_name, nested_related_records) + + embedded_records.append(child) + + return wrapper_type(embedded_records) + + elif direction == RelationshipDirection.MANYTOMANY: + if issubclass(wrapper_type, (list, tuple, set)): + embedded_records = [] + for v in value: + child = relationship_model.model_validate(v) + + # update nested relationships + for ( + field_name, + field_type, + ) in relationship_model.__relational_fields__().items(): + if isinstance(v, dict): + nested_related_value = v.get(field_name) + else: + nested_related_value = getattr(v, field_name) + nested_related_records = _embed_value( + parent=child, + relationship=field_type, + value=nested_related_value, + ) + setattr(child, field_name, nested_related_records) + + embedded_records.append(child) + + return wrapper_type(embedded_records) raise NotImplementedError( f"relationship {direction} of type annotation {wrapper_type} not supported yet" @@ -632,11 +603,12 @@ def _serialize_embedded( ) -async def _get_insert(session: AsyncSession): +async def _get_insert_func(session: AsyncSession, model: type[_SQLModelMeta]): """Gets the insert statement for the given session Args: session: the async session connecting to the database + model: the model for which the insert statement is to be obtained Returns: the insert function @@ -645,9 +617,345 @@ async def _get_insert(session: AsyncSession): dialect = conn.dialect dialect_name = dialect.name + native_insert_func = insert + if dialect_name == "sqlite": - return sqlite_insert + native_insert_func = sqlite_insert if dialect_name == "postgresql": - return pg_insert + native_insert_func = pg_insert + + # insert the embedded items + try: + # PostgreSQL and SQLite support on_conflict_do_nothing + return native_insert_func(model).on_conflict_do_nothing().returning(model) + except AttributeError: + # MySQL supports prefix("IGNORE") + # Other databases might fail at this point + return ( + native_insert_func(model) + .prefix_with("IGNORE", dialect="mysql") + .returning(model) + ) + + +async def _update_non_embedded_fields( + session: AsyncSession, model: type[_SQLModelMeta], *filters: _Filter, updates: dict +): + """Updates only the non-embedded fields of the model + + It ignores any relationships and returns the updated results + + Args: + session: the sqlalchemy session + model: the model to be updated + filters: the filters against which to match the records that are to be updated + updates: the updates to add to each matched record + + Returns: + the updated records + """ + non_embedded_updates = _get_non_relational_updates(model, updates) + if len(non_embedded_updates) == 0: + # if we supplied an empty update dict to update, + # there would be an error + return await _find(session, model, *filters) + + stmt = update(model).where(*filters).values(**non_embedded_updates).returning(model) + cursor = await session.stream_scalars(stmt) + return await cursor.fetchall() + + +async def _update_embedded_fields( + session: AsyncSession, + model: type[_SQLModelMeta], + records: list[_SQLModelMeta], + updates: dict, +): + """Updates only the embedded fields of the model for the given records + + It ignores any fields in the `updates` dict that are not for embedded models + Note: this operation is replaces the values of the embedded fields with the new values + passed in the `updates` dictionary as opposed to patching the pre-existing values. - return insert + Args: + session: the sqlalchemy session + model: the model to be updated + records: the db records to update + updates: the updates to add to each record + """ + embedded_updates = _get_relational_updates(model, updates) + relations_mapper = model.__relational_fields__() + for k, v in embedded_updates.items(): + relationship = relations_mapper[k] + link_model = model.__sqlmodel_relationships__[k].link_model + + # this does a replace operation; i.e. removes old values and replaces them with the updates + await _bulk_embedded_delete( + session, relationship=relationship, data=records, link_model=link_model + ) + await _bulk_embedded_insert( + session, + relationship=relationship, + data=records, + link_model=link_model, + payload=v, + ) + # FIXME: Should the added records be updated with their embedded values? + # update the updated parents + session.add_all(records) + + +async def _bulk_embedded_insert( + session: AsyncSession, + relationship: Any, + data: list[_SQLModelMeta], + link_model: type[_SQLModelMeta] | None, + payload: Iterable[dict] | dict, +) -> Sequence[_SQLModelMeta] | None: + """Inserts the payload into the data following the given relationship + + It updates the database also + + Args: + session: the database session + relationship: the relationship the payload has with the data's schema + link_model: the model for the through table + payload: the payload to merge into each record in the data + + Returns: + the updated data including the embedded data in each record + """ + relationship_props = relationship.property # type: RelationshipProperty + relationship_model = relationship_props.mapper.class_ + + parsed_embedded_records = [_embed_value(v, relationship, payload) for v in data] + + insert_stmt = await _get_insert_func(session, model=relationship_model) + embedded_cursor = await session.stream_scalars( + insert_stmt, _flatten_list(parsed_embedded_records) + ) + embedded_db_records = await embedded_cursor.all() + + parent_embedded_map = [ + (parent, embedded_db_records[idx : idx + len(_as_list(raw_embedded))]) + for idx, (parent, raw_embedded) in enumerate(zip(data, parsed_embedded_records)) + ] + + # insert through table values + await _bulk_insert_through_table_data( + session, + relationship=relationship, + link_model=link_model, + parent_embedded_map=parent_embedded_map, + ) + + return data + + +async def _bulk_insert_through_table_data( + session: AsyncSession, + relationship: Any, + link_model: type[_SQLModelMeta] | None, + parent_embedded_map: list[tuple[_SQLModelMeta, list[_SQLModelMeta]]], +): + """Inserts the link records into the through-table represented by the link_model + + Args: + session: the database session + relationship: the relationship the embedded records are based on + link_model: the model for the through table + parent_embedded_map: the list of tuples of parent and its associated embedded db records + """ + if link_model is not None: + relationship_props = relationship.property # type: RelationshipProperty + child_id_field_name = relationship_props.secondaryjoin.left.name + parent_id_field_name = relationship_props.primaryjoin.left.name + child_fk_field_name = relationship_props.secondaryjoin.right.name + parent_fk_field_name = relationship_props.primaryjoin.right.name + + link_raw_values = [ + { + parent_fk_field_name: getattr(parent, parent_id_field_name), + child_fk_field_name: getattr(child, child_id_field_name), + } + for parent, children in parent_embedded_map + for child in children + ] + + next_id = await _get_nextid(session, link_model) + link_model_values = [ + link_model(id=next_id + idx, **v) for idx, v in enumerate(link_raw_values) + ] + + insert_stmt = await _get_insert_func(session, model=link_model) + await session.stream_scalars(insert_stmt, link_model_values) + + +async def _bulk_embedded_delete( + session: AsyncSession, + relationship: Any, + data: list[SQLModel], + link_model: type[_SQLModelMeta] | None, +): + """Deletes the embedded records of the given parent records for the given relationship + + Args: + session: the database session + relationship: the relationship whose embedded records are to be deleted for the given records + link_model: the model for the through table + """ + relationship_props = relationship.property # type: RelationshipProperty + relationship_model = relationship_props.mapper.class_ + + parent_id_field_name = relationship_props.primaryjoin.left.name + parent_foreign_keys = [getattr(item, parent_id_field_name) for item in data] + + if link_model is None: + reverse_foreign_key_field_name = relationship_props.primaryjoin.right.name + reverse_foreign_key_field = getattr( + relationship_model, reverse_foreign_key_field_name + ) + await session.stream( + delete(relationship_model).where( + reverse_foreign_key_field.in_(parent_foreign_keys) + ) + ) + else: + reverse_foreign_key_field = getattr(link_model, parent_id_field_name) + await session.stream( + delete(link_model).where(reverse_foreign_key_field.in_(parent_foreign_keys)) + ) + + +async def _get_nextid(session: AsyncSession, model: type[_SQLModelMeta]): + """Gets the next id generator for the given model + + It returns a generator for the auto-incremented integer ID + + Args: + session: the database session + model: the model under consideration + + Returns: + a generator for the auto-incremented integer ID for the given model + """ + # compute the next id auto-incremented + next_id = await session.scalar(func.max(model.id)) + next_id = (next_id or 0) + 1 + return next_id + + +def _flatten_list(data: list[_T | list[_T]]) -> list[_T]: + """Flattens a list that may have lists of items at some indices + + Args: + data: the list to flatten + + Returns: + the flattened list + """ + results = [] + for item in data: + if isinstance(item, Iterable) and not isinstance(item, Mapping): + results += list(item) + else: + results.append(item) + + return results + + +def _as_list(value: Any) -> list: + """Wraps the value in a list if it is not an iterable + + Args: + value: the value to wrap in a list if it is not one + + Returns: + the value as a list if it is not already one + """ + if isinstance(value, list): + return value + elif isinstance(value, Iterable) and not isinstance(value, Mapping): + return list(value) + return [value] + + +def _get_relational_updates(model: type[_SQLModelMeta], updates: dict) -> dict: + """Gets the updates that are affect only the relationships on this model + + Args: + model: the model to be updated + updates: the dict of new values to updated on the matched records + + Returns: + a dict with only updates concerning the relationships of the given model + """ + return {k: v for k, v in updates.items() if k in model.__relational_fields__()} + + +def _get_non_relational_updates(model: type[_SQLModelMeta], updates: dict) -> dict: + """Gets the updates that do not affect relationships on this model + + Args: + model: the model to be updated + updates: the dict of new values to updated on the matched records + + Returns: + a dict with only updates that do not affect relationships on this model + """ + return {k: v for k, v in updates.items() if k not in model.__relational_fields__()} + + +async def _find( + session: AsyncSession, + model: type[_SQLModelMeta], + /, + *filters: _Filter, + skip: int = 0, + limit: int | None = None, + sort: tuple[_ColumnExpressionOrStrLabelArgument[Any]] = (), +) -> list[_SQLModelMeta]: + """Finds the records that match the given filters + + Args: + session: the sqlalchemy session + model: the model that is to be searched + filters: the filters to match + skip: number of records to ignore at the top of the returned results; default is 0 + limit: maximum number of records to return; default is None. + sort: fields to sort by; default = None + + Returns: + the records tha match the given filters + """ + relations = list(model.__relational_fields__().values()) + + # eagerly load all relationships so that no validation errors occur due + # to missing session if there is an attempt to load them lazily later + eager_load_opts = [subqueryload(v) for v in relations] + + filtered_relations = _get_filtered_relations( + filters=filters, + relations=relations, + ) + + # Note that we need to treat relations that are referenced in the filters + # differently from those that are not. This is because filtering basing on a relationship + # requires the use of an inner join. Yet an inner join automatically excludes rows + # that are have null for a given relationship. + # + # An outer join on the other hand would just return all the rows in the left table. + # We thus need to do an inner join on tables that are being filtered. + stmt = select(model) + for rel in filtered_relations: + stmt = stmt.join_from(model, rel) + + cursor = await session.stream_scalars( + stmt.options(*eager_load_opts) + .where(*filters) + .limit(limit) + .offset(skip) + .order_by(*sort) + ) + results = await cursor.all() + return list(results) diff --git a/pyproject.toml b/pyproject.toml index 582237b..63fa897 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,13 +33,13 @@ sql = [ "greenlet~=3.1.1", ] mongo = ["beanie~=1.29.0"] -redis = ["redis-om~=0.3.3"] +redis = ["redis-om~=0.3.3,<0.3.4"] all = [ "sqlmodel~=0.0.22", "aiosqlite~=0.20.0", "greenlet~=3.1.1", "beanie~=1.29.0", - "redis-om~=0.3.3", + "redis-om~=0.3.3,<0.3.4", ] [project.urls]