From f22ecd0202020db0b11d985628d43b106a7df7ed Mon Sep 17 00:00:00 2001 From: d3vyce Date: Sun, 1 Feb 2026 11:55:54 +0100 Subject: [PATCH] feat: add join to crud functions --- src/fastapi_toolsets/crud/__init__.py | 2 - src/fastapi_toolsets/crud/factory.py | 112 ++++++++++- tests/test_crud.py | 271 ++++++++++++++++++++++++++ 3 files changed, 377 insertions(+), 8 deletions(-) diff --git a/src/fastapi_toolsets/crud/__init__.py b/src/fastapi_toolsets/crud/__init__.py index d95aaf2..093763b 100644 --- a/src/fastapi_toolsets/crud/__init__.py +++ b/src/fastapi_toolsets/crud/__init__.py @@ -4,7 +4,6 @@ from .factory import CrudFactory from .search import ( SearchConfig, - SearchFieldType, get_searchable_fields, ) @@ -13,5 +12,4 @@ "get_searchable_fields", "NoSearchableFieldsError", "SearchConfig", - "SearchFieldType", ] diff --git a/src/fastapi_toolsets/crud/factory.py b/src/fastapi_toolsets/crud/factory.py index a7f5fdc..565178c 100644 --- a/src/fastapi_toolsets/crud/factory.py +++ b/src/fastapi_toolsets/crud/factory.py @@ -17,6 +17,7 @@ from .search import SearchConfig, SearchFieldType, build_search_filters ModelType = TypeVar("ModelType", bound=DeclarativeBase) +JoinType = list[tuple[type[DeclarativeBase], Any]] class AsyncCrud(Generic[ModelType]): @@ -55,6 +56,8 @@ async def get( session: AsyncSession, filters: list[Any], *, + joins: JoinType | None = None, + outer_join: bool = False, with_for_update: bool = False, load_options: list[Any] | None = None, ) -> ModelType: @@ -63,6 +66,8 @@ async def get( Args: session: DB async session filters: List of SQLAlchemy filter conditions + joins: List of (model, condition) tuples for joining related tables + outer_join: Use LEFT OUTER JOIN instead of INNER JOIN with_for_update: Lock the row for update load_options: SQLAlchemy loader options (e.g., selectinload) @@ -73,7 +78,15 @@ async def get( NotFoundError: If no record found MultipleResultsFound: If more than one record found """ - q = select(cls.model).where(and_(*filters)) + q = select(cls.model) + if joins: + for model, condition in joins: + q = ( + q.outerjoin(model, condition) + if outer_join + else q.join(model, condition) + ) + q = q.where(and_(*filters)) if load_options: q = q.options(*load_options) if with_for_update: @@ -90,6 +103,8 @@ async def first( session: AsyncSession, filters: list[Any] | None = None, *, + joins: JoinType | None = None, + outer_join: bool = False, load_options: list[Any] | None = None, ) -> ModelType | None: """Get the first matching record, or None. @@ -97,12 +112,21 @@ async def first( Args: session: DB async session filters: List of SQLAlchemy filter conditions + joins: List of (model, condition) tuples for joining related tables + outer_join: Use LEFT OUTER JOIN instead of INNER JOIN load_options: SQLAlchemy loader options Returns: Model instance or None """ q = select(cls.model) + if joins: + for model, condition in joins: + q = ( + q.outerjoin(model, condition) + if outer_join + else q.join(model, condition) + ) if filters: q = q.where(and_(*filters)) if load_options: @@ -116,6 +140,8 @@ async def get_multi( session: AsyncSession, *, filters: list[Any] | None = None, + joins: JoinType | None = None, + outer_join: bool = False, load_options: list[Any] | None = None, order_by: Any | None = None, limit: int | None = None, @@ -126,6 +152,8 @@ async def get_multi( Args: session: DB async session filters: List of SQLAlchemy filter conditions + joins: List of (model, condition) tuples for joining related tables + outer_join: Use LEFT OUTER JOIN instead of INNER JOIN load_options: SQLAlchemy loader options order_by: Column or list of columns to order by limit: Max number of rows to return @@ -135,6 +163,13 @@ async def get_multi( List of model instances """ q = select(cls.model) + if joins: + for model, condition in joins: + q = ( + q.outerjoin(model, condition) + if outer_join + else q.join(model, condition) + ) if filters: q = q.where(and_(*filters)) if load_options: @@ -254,17 +289,29 @@ async def count( cls: type[Self], session: AsyncSession, filters: list[Any] | None = None, + *, + joins: JoinType | None = None, + outer_join: bool = False, ) -> int: """Count records matching the filters. Args: session: DB async session filters: List of SQLAlchemy filter conditions + joins: List of (model, condition) tuples for joining related tables + outer_join: Use LEFT OUTER JOIN instead of INNER JOIN Returns: Number of matching records """ q = select(func.count()).select_from(cls.model) + if joins: + for model, condition in joins: + q = ( + q.outerjoin(model, condition) + if outer_join + else q.join(model, condition) + ) if filters: q = q.where(and_(*filters)) result = await session.execute(q) @@ -275,17 +322,30 @@ async def exists( cls: type[Self], session: AsyncSession, filters: list[Any], + *, + joins: JoinType | None = None, + outer_join: bool = False, ) -> bool: """Check if a record exists. Args: session: DB async session filters: List of SQLAlchemy filter conditions + joins: List of (model, condition) tuples for joining related tables + outer_join: Use LEFT OUTER JOIN instead of INNER JOIN Returns: True if at least one record matches """ - q = select(cls.model).where(and_(*filters)).exists().select() + q = select(cls.model) + if joins: + for model, condition in joins: + q = ( + q.outerjoin(model, condition) + if outer_join + else q.join(model, condition) + ) + q = q.where(and_(*filters)).exists().select() result = await session.execute(q) return bool(result.scalar()) @@ -295,6 +355,8 @@ async def paginate( session: AsyncSession, *, filters: list[Any] | None = None, + joins: JoinType | None = None, + outer_join: bool = False, load_options: list[Any] | None = None, order_by: Any | None = None, page: int = 1, @@ -307,6 +369,8 @@ async def paginate( Args: session: DB async session filters: List of SQLAlchemy filter conditions + joins: List of (model, condition) tuples for joining related tables + outer_join: Use LEFT OUTER JOIN instead of INNER JOIN load_options: SQLAlchemy loader options order_by: Column or list of columns to order by page: Page number (1-indexed) @@ -319,7 +383,7 @@ async def paginate( """ filters = list(filters) if filters else [] offset = (page - 1) * items_per_page - joins: list[Any] = [] + search_joins: list[Any] = [] # Build search filters if search: @@ -330,11 +394,21 @@ async def paginate( default_fields=cls.searchable_fields, ) filters.extend(search_filters) - joins.extend(search_joins) # Build query with joins q = select(cls.model) - for join_rel in joins: + + # Apply explicit joins + if joins: + for model, condition in joins: + q = ( + q.outerjoin(model, condition) + if outer_join + else q.join(model, condition) + ) + + # Apply search joins (always outer joins for search) + for join_rel in search_joins: q = q.outerjoin(join_rel) if filters: @@ -352,8 +426,20 @@ async def paginate( pk_col = cls.model.__mapper__.primary_key[0] count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name)))) count_q = count_q.select_from(cls.model) - for join_rel in joins: + + # Apply explicit joins to count query + if joins: + for model, condition in joins: + count_q = ( + count_q.outerjoin(model, condition) + if outer_join + else count_q.join(model, condition) + ) + + # Apply search joins to count query + for join_rel in search_joins: count_q = count_q.outerjoin(join_rel) + if filters: count_q = count_q.where(and_(*filters)) @@ -404,6 +490,20 @@ def CrudFactory( # With search result = await UserCrud.paginate(session, search="john") + + # With joins (inner join by default): + users = await UserCrud.get_multi( + session, + joins=[(Post, Post.user_id == User.id)], + filters=[Post.published == True], + ) + + # With outer join: + users = await UserCrud.get_multi( + session, + joins=[(Post, Post.user_id == User.id)], + outer_join=True, + ) """ cls = type( f"Async{model.__name__}Crud", diff --git a/tests/test_crud.py b/tests/test_crud.py index 8c41b9c..043f01b 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -10,6 +10,9 @@ from fastapi_toolsets.exceptions import NotFoundError from .conftest import ( + Post, + PostCreate, + PostCrud, Role, RoleCreate, RoleCrud, @@ -481,3 +484,271 @@ async def test_paginate_with_ordering(self, db_session: AsyncSession): names = [r.name for r in result["data"]] assert names == ["alpha", "bravo", "charlie"] + + +class TestCrudJoins: + """Tests for CRUD operations with joins.""" + + @pytest.mark.anyio + async def test_get_with_join(self, db_session: AsyncSession): + """Get with inner join filters correctly.""" + # Create user with posts + user = await UserCrud.create( + db_session, + UserCreate(username="author", email="author@test.com"), + ) + await PostCrud.create( + db_session, + PostCreate(title="Post 1", author_id=user.id, is_published=True), + ) + + # Get user with join on published posts + fetched = await UserCrud.get( + db_session, + filters=[User.id == user.id, Post.is_published == True], # noqa: E712 + joins=[(Post, Post.author_id == User.id)], + ) + assert fetched.id == user.id + + @pytest.mark.anyio + async def test_first_with_join(self, db_session: AsyncSession): + """First with join returns matching record.""" + user = await UserCrud.create( + db_session, + UserCreate(username="writer", email="writer@test.com"), + ) + await PostCrud.create( + db_session, + PostCreate(title="Draft", author_id=user.id, is_published=False), + ) + + # Find user with unpublished posts + result = await UserCrud.first( + db_session, + filters=[Post.is_published == False], # noqa: E712 + joins=[(Post, Post.author_id == User.id)], + ) + assert result is not None + assert result.id == user.id + + @pytest.mark.anyio + async def test_first_with_outer_join(self, db_session: AsyncSession): + """First with outer join includes records without related data.""" + # User without posts + user = await UserCrud.create( + db_session, + UserCreate(username="no_posts", email="no_posts@test.com"), + ) + + # With outer join, user should be found even without posts + result = await UserCrud.first( + db_session, + filters=[User.id == user.id], + joins=[(Post, Post.author_id == User.id)], + outer_join=True, + ) + assert result is not None + assert result.id == user.id + + @pytest.mark.anyio + async def test_get_multi_with_inner_join(self, db_session: AsyncSession): + """Get multiple with inner join only returns matching records.""" + # User with published post + user1 = await UserCrud.create( + db_session, + UserCreate(username="publisher", email="pub@test.com"), + ) + await PostCrud.create( + db_session, + PostCreate(title="Published", author_id=user1.id, is_published=True), + ) + + # User without posts + await UserCrud.create( + db_session, + UserCreate(username="lurker", email="lurk@test.com"), + ) + + # Inner join should only return user with published post + users = await UserCrud.get_multi( + db_session, + joins=[(Post, Post.author_id == User.id)], + filters=[Post.is_published == True], # noqa: E712 + ) + assert len(users) == 1 + assert users[0].username == "publisher" + + @pytest.mark.anyio + async def test_get_multi_with_outer_join(self, db_session: AsyncSession): + """Get multiple with outer join includes all records.""" + # User with post + user1 = await UserCrud.create( + db_session, + UserCreate(username="has_post", email="has@test.com"), + ) + await PostCrud.create( + db_session, + PostCreate(title="My Post", author_id=user1.id), + ) + + # User without posts + await UserCrud.create( + db_session, + UserCreate(username="no_post", email="no@test.com"), + ) + + # Outer join should return both users + users = await UserCrud.get_multi( + db_session, + joins=[(Post, Post.author_id == User.id)], + outer_join=True, + ) + assert len(users) == 2 + + @pytest.mark.anyio + async def test_count_with_join(self, db_session: AsyncSession): + """Count with join counts correctly.""" + # Create users with different post statuses + user1 = await UserCrud.create( + db_session, + UserCreate(username="active_author", email="active@test.com"), + ) + await PostCrud.create( + db_session, + PostCreate(title="Published 1", author_id=user1.id, is_published=True), + ) + + user2 = await UserCrud.create( + db_session, + UserCreate(username="draft_author", email="draft@test.com"), + ) + await PostCrud.create( + db_session, + PostCreate(title="Draft 1", author_id=user2.id, is_published=False), + ) + + # Count users with published posts + count = await UserCrud.count( + db_session, + filters=[Post.is_published == True], # noqa: E712 + joins=[(Post, Post.author_id == User.id)], + ) + assert count == 1 + + @pytest.mark.anyio + async def test_exists_with_join(self, db_session: AsyncSession): + """Exists with join checks correctly.""" + user = await UserCrud.create( + db_session, + UserCreate(username="poster", email="poster@test.com"), + ) + await PostCrud.create( + db_session, + PostCreate(title="Exists Post", author_id=user.id, is_published=True), + ) + + # Check if user with published post exists + exists = await UserCrud.exists( + db_session, + filters=[Post.is_published == True], # noqa: E712 + joins=[(Post, Post.author_id == User.id)], + ) + assert exists is True + + # Check if user with specific title exists + exists = await UserCrud.exists( + db_session, + filters=[Post.title == "Nonexistent"], + joins=[(Post, Post.author_id == User.id)], + ) + assert exists is False + + @pytest.mark.anyio + async def test_paginate_with_join(self, db_session: AsyncSession): + """Paginate with join works correctly.""" + # Create users with posts + for i in range(5): + user = await UserCrud.create( + db_session, + UserCreate(username=f"author{i}", email=f"author{i}@test.com"), + ) + await PostCrud.create( + db_session, + PostCreate( + title=f"Post {i}", + author_id=user.id, + is_published=i % 2 == 0, + ), + ) + + # Paginate users with published posts + result = await UserCrud.paginate( + db_session, + joins=[(Post, Post.author_id == User.id)], + filters=[Post.is_published == True], # noqa: E712 + page=1, + items_per_page=10, + ) + + assert result["pagination"]["total_count"] == 3 + assert len(result["data"]) == 3 + + @pytest.mark.anyio + async def test_paginate_with_outer_join(self, db_session: AsyncSession): + """Paginate with outer join includes all records.""" + # User with post + user1 = await UserCrud.create( + db_session, + UserCreate(username="with_post", email="with@test.com"), + ) + await PostCrud.create( + db_session, + PostCreate(title="A Post", author_id=user1.id), + ) + + # User without post + await UserCrud.create( + db_session, + UserCreate(username="without_post", email="without@test.com"), + ) + + # Paginate with outer join + result = await UserCrud.paginate( + db_session, + joins=[(Post, Post.author_id == User.id)], + outer_join=True, + page=1, + items_per_page=10, + ) + + assert result["pagination"]["total_count"] == 2 + assert len(result["data"]) == 2 + + @pytest.mark.anyio + async def test_multiple_joins(self, db_session: AsyncSession): + """Multiple joins can be applied.""" + role = await RoleCrud.create(db_session, RoleCreate(name="author_role")) + user = await UserCrud.create( + db_session, + UserCreate( + username="multi_join", + email="multi@test.com", + role_id=role.id, + ), + ) + await PostCrud.create( + db_session, + PostCreate(title="Multi Join Post", author_id=user.id, is_published=True), + ) + + # Join both Role and Post + users = await UserCrud.get_multi( + db_session, + joins=[ + (Role, Role.id == User.role_id), + (Post, Post.author_id == User.id), + ], + filters=[Role.name == "author_role", Post.is_published == True], # noqa: E712 + ) + assert len(users) == 1 + assert users[0].username == "multi_join"