Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/fastapi_toolsets/crud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .factory import CrudFactory
from .search import (
SearchConfig,
SearchFieldType,
get_searchable_fields,
)

Expand All @@ -13,5 +12,4 @@
"get_searchable_fields",
"NoSearchableFieldsError",
"SearchConfig",
"SearchFieldType",
]
112 changes: 106 additions & 6 deletions src/fastapi_toolsets/crud/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .search import SearchConfig, SearchFieldType, build_search_filters

ModelType = TypeVar("ModelType", bound=DeclarativeBase)
JoinType = list[tuple[type[DeclarativeBase], Any]]


class AsyncCrud(Generic[ModelType]):
Expand Down Expand Up @@ -55,6 +56,8 @@ async def get(
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
with_for_update: bool = False,
load_options: list[Any] | None = None,
) -> ModelType:
Expand All @@ -63,6 +66,8 @@ async def get(
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
with_for_update: Lock the row for update
load_options: SQLAlchemy loader options (e.g., selectinload)

Expand All @@ -73,7 +78,15 @@ async def get(
NotFoundError: If no record found
MultipleResultsFound: If more than one record found
"""
q = select(cls.model).where(and_(*filters))
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters))
if load_options:
q = q.options(*load_options)
if with_for_update:
Expand All @@ -90,19 +103,30 @@ async def first(
session: AsyncSession,
filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None,
) -> ModelType | None:
"""Get the first matching record, or None.

Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options

Returns:
Model instance or None
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters:
q = q.where(and_(*filters))
if load_options:
Expand All @@ -116,6 +140,8 @@ async def get_multi(
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None,
order_by: Any | None = None,
limit: int | None = None,
Expand All @@ -126,6 +152,8 @@ async def get_multi(
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
limit: Max number of rows to return
Expand All @@ -135,6 +163,13 @@ async def get_multi(
List of model instances
"""
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters:
q = q.where(and_(*filters))
if load_options:
Expand Down Expand Up @@ -254,17 +289,29 @@ async def count(
cls: type[Self],
session: AsyncSession,
filters: list[Any] | None = None,
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> int:
"""Count records matching the filters.

Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN

Returns:
Number of matching records
"""
q = select(func.count()).select_from(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
if filters:
q = q.where(and_(*filters))
result = await session.execute(q)
Expand All @@ -275,17 +322,30 @@ async def exists(
cls: type[Self],
session: AsyncSession,
filters: list[Any],
*,
joins: JoinType | None = None,
outer_join: bool = False,
) -> bool:
"""Check if a record exists.

Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN

Returns:
True if at least one record matches
"""
q = select(cls.model).where(and_(*filters)).exists().select()
q = select(cls.model)
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)
q = q.where(and_(*filters)).exists().select()
result = await session.execute(q)
return bool(result.scalar())

Expand All @@ -295,6 +355,8 @@ async def paginate(
session: AsyncSession,
*,
filters: list[Any] | None = None,
joins: JoinType | None = None,
outer_join: bool = False,
load_options: list[Any] | None = None,
order_by: Any | None = None,
page: int = 1,
Expand All @@ -307,6 +369,8 @@ async def paginate(
Args:
session: DB async session
filters: List of SQLAlchemy filter conditions
joins: List of (model, condition) tuples for joining related tables
outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
load_options: SQLAlchemy loader options
order_by: Column or list of columns to order by
page: Page number (1-indexed)
Expand All @@ -319,7 +383,7 @@ async def paginate(
"""
filters = list(filters) if filters else []
offset = (page - 1) * items_per_page
joins: list[Any] = []
search_joins: list[Any] = []

# Build search filters
if search:
Expand All @@ -330,11 +394,21 @@ async def paginate(
default_fields=cls.searchable_fields,
)
filters.extend(search_filters)
joins.extend(search_joins)

# Build query with joins
q = select(cls.model)
for join_rel in joins:

# Apply explicit joins
if joins:
for model, condition in joins:
q = (
q.outerjoin(model, condition)
if outer_join
else q.join(model, condition)
)

# Apply search joins (always outer joins for search)
for join_rel in search_joins:
q = q.outerjoin(join_rel)

if filters:
Expand All @@ -352,8 +426,20 @@ async def paginate(
pk_col = cls.model.__mapper__.primary_key[0]
count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
count_q = count_q.select_from(cls.model)
for join_rel in joins:

# Apply explicit joins to count query
if joins:
for model, condition in joins:
count_q = (
count_q.outerjoin(model, condition)
if outer_join
else count_q.join(model, condition)
)

# Apply search joins to count query
for join_rel in search_joins:
count_q = count_q.outerjoin(join_rel)

if filters:
count_q = count_q.where(and_(*filters))

Expand Down Expand Up @@ -404,6 +490,20 @@ def CrudFactory(

# With search
result = await UserCrud.paginate(session, search="john")

# With joins (inner join by default):
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
filters=[Post.published == True],
)

# With outer join:
users = await UserCrud.get_multi(
session,
joins=[(Post, Post.user_id == User.id)],
outer_join=True,
)
"""
cls = type(
f"Async{model.__name__}Crud",
Expand Down
Loading