From d34b334f965efcfc23840d9342ee9506fc56cf4d Mon Sep 17 00:00:00 2001 From: Jonathan Nye Date: Fri, 6 Mar 2020 21:41:26 +0100 Subject: [PATCH 1/8] Added the base_route and crud_base Both of these can be used to start a CRUD API adds filtering and sorting to the read_many endpoint --- fastapi_utils/base_route.py | 108 ++++++++++++++++++++++++ fastapi_utils/crud_base.py | 159 ++++++++++++++++++++++++++++++++++++ poetry.lock | 31 +++++-- pyproject.toml | 1 + 4 files changed, 292 insertions(+), 7 deletions(-) create mode 100644 fastapi_utils/base_route.py create mode 100644 fastapi_utils/crud_base.py diff --git a/fastapi_utils/base_route.py b/fastapi_utils/base_route.py new file mode 100644 index 00000000..2b19ba27 --- /dev/null +++ b/fastapi_utils/base_route.py @@ -0,0 +1,108 @@ +from typing import Dict, Generic, List, TypeVar + +from fastapi_utils.crud_base import CRUDBase, Base +from fastapi import Depends, HTTPException +from pydantic import BaseModel +from sqlalchemy.orm import Session + +ResponseModelType = TypeVar("ResponseModelType", bound=BaseModel) +ResponseModelManyType = TypeVar("ResponseModelManyType", bound=BaseModel) +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) +IDType = TypeVar("IDType") + + +def get_filter_fields(self) -> List[str]: + """This would need to get overridden for each BaseRoute where the filter fields are defined. + + Returns: + List[str] -- List of fields to filter by + """ + return [] + + +class BaseRoute(Generic[ResponseModelType, ResponseModelManyType, CreateSchemaType, UpdateSchemaType, IDType]): + """A base route that has the basic CRUD endpoints. + + For read_many + + """ + + filter_fields: List[str] = Depends(get_filter_fields) + crud_base = CRUDBase(Base) # type: ignore + db: Session = Depends(None) + object_name = "Base" + + def read_many(self, skip: int = 0, limit: int = 100, sort_by: str = None, **kwargs) -> ResponseModelManyType: + """Reads many from the database with the provided filter and sort parameters. + + Filter parameters need to be specified by overriding this read_many method and calling it like: + + @router.get("/", response_model=List[Person]) + def read_persons( + self, skip: int = 0, limit: int = 100, sort_by: str = None, name: str = None, + ) -> List[Person]: + return super().read_many(skip=skip, limit=limit, sort_by=sort_by, name=name) + + Where the filter fields are defined as parameters. In this case "name" is a filter field + + Keyword Arguments: + skip {int} -- [description] (default: {0}) + limit {int} -- [description] (default: {100}) + sort_by {str} -- Expected in the form "model__field_name:asc,field_name:desc" (default: {None}) + + **kwargs {str} -- Filter field names expected in the form field_name or model__field_name if filtering through + a join. The filter is defined as op:value. For example ==:paul or eq:paul + + The filter op is specified in the crud_base FilterOpEnum. + + Returns: + ResponseModelManyType -- [description] + """ + filter_fields: Dict[str, str] = {} + for field in self.filter_fields: + filter_fields[field] = kwargs.pop(field, None) + results = self.crud_base.get_multi(self.db, skip=skip, limit=limit, filter_by=filter_fields, sort_by=sort_by) + return results + + def create(self, *, obj_in: CreateSchemaType,) -> ResponseModelType: + """ + Create new object. + """ + result = self.crud_base.create(db_session=self.db, obj_in=obj_in) + return result + + def update(self, *, id: IDType, obj_in: UpdateSchemaType,) -> ResponseModelType: + """ + Update an object. + """ + result = self.crud_base.get(db_session=self.db, id=id) + if not result: + raise HTTPException(status_code=404, detail=f"{self.object_name} not found") + # if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): + # raise HTTPException(status_code=400, detail="Not enough permissions") + result = self.crud_base.update(db_session=self.db, db_obj=result, obj_in=obj_in) + return result + + def read(self, *, id: IDType,) -> ResponseModelType: + """ + Get object by ID. + """ + result = self.crud_base.get(db_session=self.db, id=id) + if not result: + raise HTTPException(status_code=404, detail=f"{self.object_name} not found") + # if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): + # raise HTTPException(status_code=400, detail="Not enough permissions") + return result + + def delete(self, *, id: IDType,) -> ResponseModelType: + """ + Delete an object. + """ + result = self.crud_base.get(db_session=self.db, id=id) + if not result: + raise HTTPException(status_code=404, detail=f"{self.object_name} not found") + # if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): + # raise HTTPException(status_code=400, detail="Not enough permissions") + result = self.crud_base.remove(db_session=self.db, id=id) + return result diff --git a/fastapi_utils/crud_base.py b/fastapi_utils/crud_base.py new file mode 100644 index 00000000..8bc67568 --- /dev/null +++ b/fastapi_utils/crud_base.py @@ -0,0 +1,159 @@ +from decimal import Decimal +from enum import Enum +from typing import Dict, Generic, List, Optional, Type, TypeVar, Union + +from sqlalchemy.ext.declarative import declarative_base +from fastapi.encoders import jsonable_encoder +from fastapi_utils.camelcase import snake2camel +from pydantic import BaseModel +from sqlalchemy.orm import Session +from sqlalchemy_filters import apply_filters, apply_sort + +Base = declarative_base() + +ModelType = TypeVar("ModelType", bound=Base) +MultiSchemaType = TypeVar("MultiSchemaType", bound=BaseModel) +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) +IDType = TypeVar("IDType") + + +class SortDirectionEnum(str, Enum): + ASC = "asc" + DESC = "desc" + + +class FilterOpEnum(str, Enum): + IS_NULL = "is_null" + IS_NOT_NULL = "is_not_null" + EQ_SYM = "==" + EQ = "eq" + NE_SYM = "!=" + NE = "ne" + GT_SYM = ">" + GT = "gt" + LT_SYM = "<" + LT = "lt" + GE_SYM = ">=" + GE = "ge" + LE_SYM = "<=" + LE = "le" + LIKE = "like" + ILIKE = "ilike" + IN = "in" + NOT_IN = "not_in" + + +class SortField(BaseModel): + field: str + model: Optional[str] = None + direction: SortDirectionEnum = SortDirectionEnum.DESC + + +class FilterField(BaseModel): + field: str + model: Optional[str] = None + op: FilterOpEnum + value: Union[str, int, Decimal] + + +def get_filter_field(field: str, field_name: str, split_character: str = ":") -> FilterField: + model = None + op, value = field.split(":") + if "__" in field_name: + model, field_name = field_name.split("__") + model = snake2camel(model, start_lower=False) + filter_field = FilterField(field=field_name, model=model, op=op, value=value) + return filter_field + + +def get_filter_fields(fields: Optional[Dict[str, str]], split_character: str = ":") -> List[FilterField]: + filter_fields = [] + if fields: + for field_name in fields: + if fields[field_name]: + filter_fields.append(get_filter_field(field=fields[field_name], field_name=field_name)) + return filter_fields + + +def get_sort_field(field: str) -> SortField: + model = None + field_name, direction = field.split(":") + if "__" in field_name: + model, field_name = field_name.split("__") + sort_field = SortField(model=model, field=field_name, direction=direction) + return sort_field + + +def get_sort_fields(sort_string: str, split_character: str = ",") -> List[SortField]: + sort_fields = [] + # There could be many sort fields + if sort_string: + sort_by_fields = sort_string.split(",") + for _to_sort in sort_by_fields: + sort_fields.append(get_sort_field(_to_sort)) + return sort_fields + + +class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): + def __init__(self, model: Type[ModelType]): + """ + CRUD object with default methods to Create, Read, Update, Delete (CRUD). + **Parameters** + * `model`: A SQLAlchemy model class + * `schema`: A Pydantic model (schema) class + """ + self.model = model + + def get(self, db_session: Session, id: IDType) -> Optional[ModelType]: + return db_session.query(self.model).filter(self.model.id == id).first() # type: ignore + + def get_multi( + self, + db_session: Session, + *, + skip: int = 0, + limit: int = 100, + sort_by: Optional[str] = None, + filter_by: Optional[Dict[str, str]] = None, + ) -> List[ModelType]: + + sort_spec_pydantic = get_sort_fields(sort_by) + filter_spec_pydantic = get_filter_fields(filter_by) + + sort_spec = [x.dict(exclude_none=True) for x in sort_spec_pydantic] + filter_spec = [x.dict(exclude_none=True) for x in filter_spec_pydantic] + + query = db_session.query(self.model) + query = apply_filters(query, filter_spec) + query = apply_sort(query, sort_spec) + + count = query.count() + query = query.offset(skip).limit(limit) + + return query.all() + + def create(self, db_session: Session, *, obj_in: CreateSchemaType) -> ModelType: + obj_in_data = jsonable_encoder(obj_in) + db_obj = self.model(**obj_in_data) + db_session.add(db_obj) + db_session.commit() + db_session.refresh(db_obj) + return db_obj + + def update(self, db_session: Session, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType: + obj_data = jsonable_encoder(db_obj) + update_data = obj_in.dict(skip_defaults=True) + for field in obj_data: + if field in update_data: + setattr(db_obj, field, update_data[field]) + db_session.add(db_obj) + db_session.commit() + db_session.refresh(db_obj) + return db_obj + + def remove(self, db_session: Session, *, id: IDType) -> ModelType: + obj = db_session.query(self.model).get(id) + db_session.delete(obj) + db_session.commit() + return obj diff --git a/poetry.lock b/poetry.lock index b76d9b43..2c994eb4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -603,7 +603,7 @@ security = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)"] socks = ["PySocks (>=1.5.6,<1.5.7 || >1.5.7)", "win-inet-pton"] [[package]] -category = "dev" +category = "main" description = "Python 2 and 3 compatibility utilities" name = "six" optional = false @@ -630,6 +630,24 @@ postgresql_psycopg2binary = ["psycopg2-binary"] postgresql_psycopg2cffi = ["psycopg2cffi"] pymysql = ["pymysql"] +[[package]] +category = "main" +description = "A library to filter SQLAlchemy queries." +name = "sqlalchemy-filters" +optional = false +python-versions = "*" +version = "0.10.0" + +[package.dependencies] +six = ">=1.10.0" +sqlalchemy = ">=1.0.16" + +[package.extras] +dev = ["pytest (4.3.0)", "flake8 (3.7.7)", "coverage (4.5.3)", "sqlalchemy-utils (0.33.11)", "restructuredtext-lint (1.2.2)", "Pygments (2.3.1)"] +mysql = ["mysql-connector-python-rf (2.2.2)"] +postgresql = ["psycopg2 (2.7.7)"] +python2 = ["funcsigs (>=1.0.2)"] + [[package]] category = "dev" description = "SQLAlchemy stubs and mypy plugin" @@ -720,7 +738,7 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [metadata] -content-hash = "2b727851846408766afb773a03cef7946f566c0817623186d8bcfed2e9d62557" +content-hash = "aa4f526dbf926321768fab7d310bdc0a449537a074babec63c01ea88c0cadacf" python-versions = "^3.6" [metadata.files] @@ -886,11 +904,6 @@ markupsafe = [ {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-win32.whl", hash = "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"}, {file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"}, ] mccabe = [ @@ -1044,6 +1057,10 @@ six = [ sqlalchemy = [ {file = "SQLAlchemy-1.3.13.tar.gz", hash = "sha256:64a7b71846db6423807e96820993fa12a03b89127d278290ca25c0b11ed7b4fb"}, ] +sqlalchemy-filters = [ + {file = "sqlalchemy-filters-0.10.0.tar.gz", hash = "sha256:3b0d4fc39075cd1079e6089ac3165c1930b74fb1804515f109ec80e75fec46c8"}, + {file = "sqlalchemy_filters-0.10.0-py3-none-any.whl", hash = "sha256:34265e3b4605aae6e7c7fe3082b1de148c6295409f4d34286447f8c195bac699"}, +] sqlalchemy-stubs = [ {file = "sqlalchemy-stubs-0.3.tar.gz", hash = "sha256:a3318c810697164e8c818aa2d90bac570c1a0e752ced3ec25455b309c0bee8fd"}, {file = "sqlalchemy_stubs-0.3-py3-none-any.whl", hash = "sha256:ca1250605a39648cc433f5c70cb1a6f9fe0b60bdda4c51e1f9a2ab3651daadc8"}, diff --git a/pyproject.toml b/pyproject.toml index 78d42467..b3ba045b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ python = "^3.6" fastapi = "*" pydantic = "^1.0" sqlalchemy = "^1.3.12" +sqlalchemy-filters = "^0.10.0" [tool.poetry.dev-dependencies] # Starlette features From a825b7da5e23ecc5cac15b3b7f523aa599617b96 Mon Sep 17 00:00:00 2001 From: Jonathan Nye Date: Fri, 6 Mar 2020 21:46:15 +0100 Subject: [PATCH 2/8] updated import sorting --- fastapi_utils/base_route.py | 2 +- fastapi_utils/crud_base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastapi_utils/base_route.py b/fastapi_utils/base_route.py index 2b19ba27..0cd2a453 100644 --- a/fastapi_utils/base_route.py +++ b/fastapi_utils/base_route.py @@ -1,7 +1,7 @@ from typing import Dict, Generic, List, TypeVar -from fastapi_utils.crud_base import CRUDBase, Base from fastapi import Depends, HTTPException +from fastapi_utils.crud_base import Base, CRUDBase from pydantic import BaseModel from sqlalchemy.orm import Session diff --git a/fastapi_utils/crud_base.py b/fastapi_utils/crud_base.py index 8bc67568..2fa4f9a1 100644 --- a/fastapi_utils/crud_base.py +++ b/fastapi_utils/crud_base.py @@ -2,10 +2,10 @@ from enum import Enum from typing import Dict, Generic, List, Optional, Type, TypeVar, Union -from sqlalchemy.ext.declarative import declarative_base from fastapi.encoders import jsonable_encoder from fastapi_utils.camelcase import snake2camel from pydantic import BaseModel +from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session from sqlalchemy_filters import apply_filters, apply_sort From ff18fc5bd6e21af5281a972722a37122e8e2b927 Mon Sep 17 00:00:00 2001 From: Jonathan Nye Date: Fri, 6 Mar 2020 21:53:17 +0100 Subject: [PATCH 3/8] update import sorting again --- .gitignore | 1 + fastapi_utils/base_route.py | 3 ++- fastapi_utils/crud_base.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 5f0959b1..86214288 100644 --- a/.gitignore +++ b/.gitignore @@ -87,3 +87,4 @@ venv.bak/ site .bento/ +.vscode/ diff --git a/fastapi_utils/base_route.py b/fastapi_utils/base_route.py index 0cd2a453..240a93af 100644 --- a/fastapi_utils/base_route.py +++ b/fastapi_utils/base_route.py @@ -1,10 +1,11 @@ from typing import Dict, Generic, List, TypeVar from fastapi import Depends, HTTPException -from fastapi_utils.crud_base import Base, CRUDBase from pydantic import BaseModel from sqlalchemy.orm import Session +from fastapi_utils.crud_base import Base, CRUDBase + ResponseModelType = TypeVar("ResponseModelType", bound=BaseModel) ResponseModelManyType = TypeVar("ResponseModelManyType", bound=BaseModel) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) diff --git a/fastapi_utils/crud_base.py b/fastapi_utils/crud_base.py index 2fa4f9a1..ed436bd4 100644 --- a/fastapi_utils/crud_base.py +++ b/fastapi_utils/crud_base.py @@ -3,12 +3,13 @@ from typing import Dict, Generic, List, Optional, Type, TypeVar, Union from fastapi.encoders import jsonable_encoder -from fastapi_utils.camelcase import snake2camel from pydantic import BaseModel from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session from sqlalchemy_filters import apply_filters, apply_sort +from fastapi_utils.camelcase import snake2camel + Base = declarative_base() ModelType = TypeVar("ModelType", bound=Base) From 6a252665a8138b0d324a04fed7dbe51cbea4d590 Mon Sep 17 00:00:00 2001 From: Jonathan Nye Date: Fri, 6 Mar 2020 21:41:26 +0100 Subject: [PATCH 4/8] Added the base_route and crud_base Both of these can be used to start a CRUD API adds filtering and sorting to the read_many endpoint --- fastapi_utils/base_route.py | 108 ++++++++++++++++++++++++ fastapi_utils/crud_base.py | 159 ++++++++++++++++++++++++++++++++++++ poetry.lock | 31 +++++-- pyproject.toml | 1 + 4 files changed, 292 insertions(+), 7 deletions(-) create mode 100644 fastapi_utils/base_route.py create mode 100644 fastapi_utils/crud_base.py diff --git a/fastapi_utils/base_route.py b/fastapi_utils/base_route.py new file mode 100644 index 00000000..2b19ba27 --- /dev/null +++ b/fastapi_utils/base_route.py @@ -0,0 +1,108 @@ +from typing import Dict, Generic, List, TypeVar + +from fastapi_utils.crud_base import CRUDBase, Base +from fastapi import Depends, HTTPException +from pydantic import BaseModel +from sqlalchemy.orm import Session + +ResponseModelType = TypeVar("ResponseModelType", bound=BaseModel) +ResponseModelManyType = TypeVar("ResponseModelManyType", bound=BaseModel) +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) +IDType = TypeVar("IDType") + + +def get_filter_fields(self) -> List[str]: + """This would need to get overridden for each BaseRoute where the filter fields are defined. + + Returns: + List[str] -- List of fields to filter by + """ + return [] + + +class BaseRoute(Generic[ResponseModelType, ResponseModelManyType, CreateSchemaType, UpdateSchemaType, IDType]): + """A base route that has the basic CRUD endpoints. + + For read_many + + """ + + filter_fields: List[str] = Depends(get_filter_fields) + crud_base = CRUDBase(Base) # type: ignore + db: Session = Depends(None) + object_name = "Base" + + def read_many(self, skip: int = 0, limit: int = 100, sort_by: str = None, **kwargs) -> ResponseModelManyType: + """Reads many from the database with the provided filter and sort parameters. + + Filter parameters need to be specified by overriding this read_many method and calling it like: + + @router.get("/", response_model=List[Person]) + def read_persons( + self, skip: int = 0, limit: int = 100, sort_by: str = None, name: str = None, + ) -> List[Person]: + return super().read_many(skip=skip, limit=limit, sort_by=sort_by, name=name) + + Where the filter fields are defined as parameters. In this case "name" is a filter field + + Keyword Arguments: + skip {int} -- [description] (default: {0}) + limit {int} -- [description] (default: {100}) + sort_by {str} -- Expected in the form "model__field_name:asc,field_name:desc" (default: {None}) + + **kwargs {str} -- Filter field names expected in the form field_name or model__field_name if filtering through + a join. The filter is defined as op:value. For example ==:paul or eq:paul + + The filter op is specified in the crud_base FilterOpEnum. + + Returns: + ResponseModelManyType -- [description] + """ + filter_fields: Dict[str, str] = {} + for field in self.filter_fields: + filter_fields[field] = kwargs.pop(field, None) + results = self.crud_base.get_multi(self.db, skip=skip, limit=limit, filter_by=filter_fields, sort_by=sort_by) + return results + + def create(self, *, obj_in: CreateSchemaType,) -> ResponseModelType: + """ + Create new object. + """ + result = self.crud_base.create(db_session=self.db, obj_in=obj_in) + return result + + def update(self, *, id: IDType, obj_in: UpdateSchemaType,) -> ResponseModelType: + """ + Update an object. + """ + result = self.crud_base.get(db_session=self.db, id=id) + if not result: + raise HTTPException(status_code=404, detail=f"{self.object_name} not found") + # if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): + # raise HTTPException(status_code=400, detail="Not enough permissions") + result = self.crud_base.update(db_session=self.db, db_obj=result, obj_in=obj_in) + return result + + def read(self, *, id: IDType,) -> ResponseModelType: + """ + Get object by ID. + """ + result = self.crud_base.get(db_session=self.db, id=id) + if not result: + raise HTTPException(status_code=404, detail=f"{self.object_name} not found") + # if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): + # raise HTTPException(status_code=400, detail="Not enough permissions") + return result + + def delete(self, *, id: IDType,) -> ResponseModelType: + """ + Delete an object. + """ + result = self.crud_base.get(db_session=self.db, id=id) + if not result: + raise HTTPException(status_code=404, detail=f"{self.object_name} not found") + # if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): + # raise HTTPException(status_code=400, detail="Not enough permissions") + result = self.crud_base.remove(db_session=self.db, id=id) + return result diff --git a/fastapi_utils/crud_base.py b/fastapi_utils/crud_base.py new file mode 100644 index 00000000..8bc67568 --- /dev/null +++ b/fastapi_utils/crud_base.py @@ -0,0 +1,159 @@ +from decimal import Decimal +from enum import Enum +from typing import Dict, Generic, List, Optional, Type, TypeVar, Union + +from sqlalchemy.ext.declarative import declarative_base +from fastapi.encoders import jsonable_encoder +from fastapi_utils.camelcase import snake2camel +from pydantic import BaseModel +from sqlalchemy.orm import Session +from sqlalchemy_filters import apply_filters, apply_sort + +Base = declarative_base() + +ModelType = TypeVar("ModelType", bound=Base) +MultiSchemaType = TypeVar("MultiSchemaType", bound=BaseModel) +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) +IDType = TypeVar("IDType") + + +class SortDirectionEnum(str, Enum): + ASC = "asc" + DESC = "desc" + + +class FilterOpEnum(str, Enum): + IS_NULL = "is_null" + IS_NOT_NULL = "is_not_null" + EQ_SYM = "==" + EQ = "eq" + NE_SYM = "!=" + NE = "ne" + GT_SYM = ">" + GT = "gt" + LT_SYM = "<" + LT = "lt" + GE_SYM = ">=" + GE = "ge" + LE_SYM = "<=" + LE = "le" + LIKE = "like" + ILIKE = "ilike" + IN = "in" + NOT_IN = "not_in" + + +class SortField(BaseModel): + field: str + model: Optional[str] = None + direction: SortDirectionEnum = SortDirectionEnum.DESC + + +class FilterField(BaseModel): + field: str + model: Optional[str] = None + op: FilterOpEnum + value: Union[str, int, Decimal] + + +def get_filter_field(field: str, field_name: str, split_character: str = ":") -> FilterField: + model = None + op, value = field.split(":") + if "__" in field_name: + model, field_name = field_name.split("__") + model = snake2camel(model, start_lower=False) + filter_field = FilterField(field=field_name, model=model, op=op, value=value) + return filter_field + + +def get_filter_fields(fields: Optional[Dict[str, str]], split_character: str = ":") -> List[FilterField]: + filter_fields = [] + if fields: + for field_name in fields: + if fields[field_name]: + filter_fields.append(get_filter_field(field=fields[field_name], field_name=field_name)) + return filter_fields + + +def get_sort_field(field: str) -> SortField: + model = None + field_name, direction = field.split(":") + if "__" in field_name: + model, field_name = field_name.split("__") + sort_field = SortField(model=model, field=field_name, direction=direction) + return sort_field + + +def get_sort_fields(sort_string: str, split_character: str = ",") -> List[SortField]: + sort_fields = [] + # There could be many sort fields + if sort_string: + sort_by_fields = sort_string.split(",") + for _to_sort in sort_by_fields: + sort_fields.append(get_sort_field(_to_sort)) + return sort_fields + + +class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): + def __init__(self, model: Type[ModelType]): + """ + CRUD object with default methods to Create, Read, Update, Delete (CRUD). + **Parameters** + * `model`: A SQLAlchemy model class + * `schema`: A Pydantic model (schema) class + """ + self.model = model + + def get(self, db_session: Session, id: IDType) -> Optional[ModelType]: + return db_session.query(self.model).filter(self.model.id == id).first() # type: ignore + + def get_multi( + self, + db_session: Session, + *, + skip: int = 0, + limit: int = 100, + sort_by: Optional[str] = None, + filter_by: Optional[Dict[str, str]] = None, + ) -> List[ModelType]: + + sort_spec_pydantic = get_sort_fields(sort_by) + filter_spec_pydantic = get_filter_fields(filter_by) + + sort_spec = [x.dict(exclude_none=True) for x in sort_spec_pydantic] + filter_spec = [x.dict(exclude_none=True) for x in filter_spec_pydantic] + + query = db_session.query(self.model) + query = apply_filters(query, filter_spec) + query = apply_sort(query, sort_spec) + + count = query.count() + query = query.offset(skip).limit(limit) + + return query.all() + + def create(self, db_session: Session, *, obj_in: CreateSchemaType) -> ModelType: + obj_in_data = jsonable_encoder(obj_in) + db_obj = self.model(**obj_in_data) + db_session.add(db_obj) + db_session.commit() + db_session.refresh(db_obj) + return db_obj + + def update(self, db_session: Session, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType: + obj_data = jsonable_encoder(db_obj) + update_data = obj_in.dict(skip_defaults=True) + for field in obj_data: + if field in update_data: + setattr(db_obj, field, update_data[field]) + db_session.add(db_obj) + db_session.commit() + db_session.refresh(db_obj) + return db_obj + + def remove(self, db_session: Session, *, id: IDType) -> ModelType: + obj = db_session.query(self.model).get(id) + db_session.delete(obj) + db_session.commit() + return obj diff --git a/poetry.lock b/poetry.lock index 8ef126fa..bc862aa1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -603,7 +603,7 @@ security = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)"] socks = ["PySocks (>=1.5.6,<1.5.7 || >1.5.7)", "win-inet-pton"] [[package]] -category = "dev" +category = "main" description = "Python 2 and 3 compatibility utilities" name = "six" optional = false @@ -630,6 +630,24 @@ postgresql_psycopg2binary = ["psycopg2-binary"] postgresql_psycopg2cffi = ["psycopg2cffi"] pymysql = ["pymysql"] +[[package]] +category = "main" +description = "A library to filter SQLAlchemy queries." +name = "sqlalchemy-filters" +optional = false +python-versions = "*" +version = "0.10.0" + +[package.dependencies] +six = ">=1.10.0" +sqlalchemy = ">=1.0.16" + +[package.extras] +dev = ["pytest (4.3.0)", "flake8 (3.7.7)", "coverage (4.5.3)", "sqlalchemy-utils (0.33.11)", "restructuredtext-lint (1.2.2)", "Pygments (2.3.1)"] +mysql = ["mysql-connector-python-rf (2.2.2)"] +postgresql = ["psycopg2 (2.7.7)"] +python2 = ["funcsigs (>=1.0.2)"] + [[package]] category = "dev" description = "SQLAlchemy stubs and mypy plugin" @@ -720,7 +738,7 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [metadata] -content-hash = "2b727851846408766afb773a03cef7946f566c0817623186d8bcfed2e9d62557" +content-hash = "aa4f526dbf926321768fab7d310bdc0a449537a074babec63c01ea88c0cadacf" python-versions = "^3.6" [metadata.files] @@ -886,11 +904,6 @@ markupsafe = [ {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-win32.whl", hash = "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b"}, - {file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"}, {file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"}, ] mccabe = [ @@ -1044,6 +1057,10 @@ six = [ sqlalchemy = [ {file = "SQLAlchemy-1.3.13.tar.gz", hash = "sha256:64a7b71846db6423807e96820993fa12a03b89127d278290ca25c0b11ed7b4fb"}, ] +sqlalchemy-filters = [ + {file = "sqlalchemy-filters-0.10.0.tar.gz", hash = "sha256:3b0d4fc39075cd1079e6089ac3165c1930b74fb1804515f109ec80e75fec46c8"}, + {file = "sqlalchemy_filters-0.10.0-py3-none-any.whl", hash = "sha256:34265e3b4605aae6e7c7fe3082b1de148c6295409f4d34286447f8c195bac699"}, +] sqlalchemy-stubs = [ {file = "sqlalchemy-stubs-0.3.tar.gz", hash = "sha256:a3318c810697164e8c818aa2d90bac570c1a0e752ced3ec25455b309c0bee8fd"}, {file = "sqlalchemy_stubs-0.3-py3-none-any.whl", hash = "sha256:ca1250605a39648cc433f5c70cb1a6f9fe0b60bdda4c51e1f9a2ab3651daadc8"}, diff --git a/pyproject.toml b/pyproject.toml index 750a4ceb..44c1db00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ python = "^3.6" fastapi = "*" pydantic = "^1.0" sqlalchemy = "^1.3.12" +sqlalchemy-filters = "^0.10.0" [tool.poetry.dev-dependencies] # Starlette features From 5fd201585789631d297cbbfa4050a59166d84436 Mon Sep 17 00:00:00 2001 From: Jonathan Nye Date: Fri, 6 Mar 2020 21:46:15 +0100 Subject: [PATCH 5/8] updated import sorting --- fastapi_utils/base_route.py | 2 +- fastapi_utils/crud_base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastapi_utils/base_route.py b/fastapi_utils/base_route.py index 2b19ba27..0cd2a453 100644 --- a/fastapi_utils/base_route.py +++ b/fastapi_utils/base_route.py @@ -1,7 +1,7 @@ from typing import Dict, Generic, List, TypeVar -from fastapi_utils.crud_base import CRUDBase, Base from fastapi import Depends, HTTPException +from fastapi_utils.crud_base import Base, CRUDBase from pydantic import BaseModel from sqlalchemy.orm import Session diff --git a/fastapi_utils/crud_base.py b/fastapi_utils/crud_base.py index 8bc67568..2fa4f9a1 100644 --- a/fastapi_utils/crud_base.py +++ b/fastapi_utils/crud_base.py @@ -2,10 +2,10 @@ from enum import Enum from typing import Dict, Generic, List, Optional, Type, TypeVar, Union -from sqlalchemy.ext.declarative import declarative_base from fastapi.encoders import jsonable_encoder from fastapi_utils.camelcase import snake2camel from pydantic import BaseModel +from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session from sqlalchemy_filters import apply_filters, apply_sort From 0e2481212557a63378cddc2a711e2501e8d11b6e Mon Sep 17 00:00:00 2001 From: Jonathan Nye Date: Fri, 6 Mar 2020 21:53:17 +0100 Subject: [PATCH 6/8] update import sorting again --- .gitignore | 1 + fastapi_utils/base_route.py | 3 ++- fastapi_utils/crud_base.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 5f0959b1..86214288 100644 --- a/.gitignore +++ b/.gitignore @@ -87,3 +87,4 @@ venv.bak/ site .bento/ +.vscode/ diff --git a/fastapi_utils/base_route.py b/fastapi_utils/base_route.py index 0cd2a453..240a93af 100644 --- a/fastapi_utils/base_route.py +++ b/fastapi_utils/base_route.py @@ -1,10 +1,11 @@ from typing import Dict, Generic, List, TypeVar from fastapi import Depends, HTTPException -from fastapi_utils.crud_base import Base, CRUDBase from pydantic import BaseModel from sqlalchemy.orm import Session +from fastapi_utils.crud_base import Base, CRUDBase + ResponseModelType = TypeVar("ResponseModelType", bound=BaseModel) ResponseModelManyType = TypeVar("ResponseModelManyType", bound=BaseModel) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) diff --git a/fastapi_utils/crud_base.py b/fastapi_utils/crud_base.py index 2fa4f9a1..ed436bd4 100644 --- a/fastapi_utils/crud_base.py +++ b/fastapi_utils/crud_base.py @@ -3,12 +3,13 @@ from typing import Dict, Generic, List, Optional, Type, TypeVar, Union from fastapi.encoders import jsonable_encoder -from fastapi_utils.camelcase import snake2camel from pydantic import BaseModel from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session from sqlalchemy_filters import apply_filters, apply_sort +from fastapi_utils.camelcase import snake2camel + Base = declarative_base() ModelType = TypeVar("ModelType", bound=Base) From 8aaa3d3ffdb16abd08ba70c0544756ebabeb3d1b Mon Sep 17 00:00:00 2001 From: Jonathan Nye Date: Sat, 7 Mar 2020 11:24:15 +0100 Subject: [PATCH 7/8] move routes into routes module --- fastapi_utils/{ => crud}/crud_base.py | 0 fastapi_utils/{base_route.py => crud/route.py} | 15 +++------------ 2 files changed, 3 insertions(+), 12 deletions(-) rename fastapi_utils/{ => crud}/crud_base.py (100%) rename fastapi_utils/{base_route.py => crud/route.py} (91%) diff --git a/fastapi_utils/crud_base.py b/fastapi_utils/crud/crud_base.py similarity index 100% rename from fastapi_utils/crud_base.py rename to fastapi_utils/crud/crud_base.py diff --git a/fastapi_utils/base_route.py b/fastapi_utils/crud/route.py similarity index 91% rename from fastapi_utils/base_route.py rename to fastapi_utils/crud/route.py index 240a93af..2ce50bbd 100644 --- a/fastapi_utils/base_route.py +++ b/fastapi_utils/crud/route.py @@ -1,4 +1,4 @@ -from typing import Dict, Generic, List, TypeVar +from typing import Dict, Generic, List, TypeVar, ClassVar, Tuple from fastapi import Depends, HTTPException from pydantic import BaseModel @@ -13,23 +13,14 @@ IDType = TypeVar("IDType") -def get_filter_fields(self) -> List[str]: - """This would need to get overridden for each BaseRoute where the filter fields are defined. - - Returns: - List[str] -- List of fields to filter by - """ - return [] - - -class BaseRoute(Generic[ResponseModelType, ResponseModelManyType, CreateSchemaType, UpdateSchemaType, IDType]): +class CRUDRoute(Generic[ResponseModelType, ResponseModelManyType, CreateSchemaType, UpdateSchemaType, IDType]): """A base route that has the basic CRUD endpoints. For read_many """ - filter_fields: List[str] = Depends(get_filter_fields) + filter_fields: ClassVar[Tuple[str]] = () crud_base = CRUDBase(Base) # type: ignore db: Session = Depends(None) object_name = "Base" From a08659bf08c6363c0ac142b88bbd29b495fb63c1 Mon Sep 17 00:00:00 2001 From: Jonathan Nye Date: Sat, 7 Mar 2020 13:16:47 +0100 Subject: [PATCH 8/8] Update CRUDBase name Specify crud_base as classvar Add filter_fields validation --- fastapi_utils/base_route.py | 109 ------------- fastapi_utils/crud/__init__.py | 4 + fastapi_utils/crud/{crud_base.py => base.py} | 14 +- fastapi_utils/crud/route.py | 19 ++- fastapi_utils/crud_base.py | 160 ------------------- 5 files changed, 23 insertions(+), 283 deletions(-) delete mode 100644 fastapi_utils/base_route.py create mode 100644 fastapi_utils/crud/__init__.py rename fastapi_utils/crud/{crud_base.py => base.py} (95%) delete mode 100644 fastapi_utils/crud_base.py diff --git a/fastapi_utils/base_route.py b/fastapi_utils/base_route.py deleted file mode 100644 index 240a93af..00000000 --- a/fastapi_utils/base_route.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import Dict, Generic, List, TypeVar - -from fastapi import Depends, HTTPException -from pydantic import BaseModel -from sqlalchemy.orm import Session - -from fastapi_utils.crud_base import Base, CRUDBase - -ResponseModelType = TypeVar("ResponseModelType", bound=BaseModel) -ResponseModelManyType = TypeVar("ResponseModelManyType", bound=BaseModel) -CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) -UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) -IDType = TypeVar("IDType") - - -def get_filter_fields(self) -> List[str]: - """This would need to get overridden for each BaseRoute where the filter fields are defined. - - Returns: - List[str] -- List of fields to filter by - """ - return [] - - -class BaseRoute(Generic[ResponseModelType, ResponseModelManyType, CreateSchemaType, UpdateSchemaType, IDType]): - """A base route that has the basic CRUD endpoints. - - For read_many - - """ - - filter_fields: List[str] = Depends(get_filter_fields) - crud_base = CRUDBase(Base) # type: ignore - db: Session = Depends(None) - object_name = "Base" - - def read_many(self, skip: int = 0, limit: int = 100, sort_by: str = None, **kwargs) -> ResponseModelManyType: - """Reads many from the database with the provided filter and sort parameters. - - Filter parameters need to be specified by overriding this read_many method and calling it like: - - @router.get("/", response_model=List[Person]) - def read_persons( - self, skip: int = 0, limit: int = 100, sort_by: str = None, name: str = None, - ) -> List[Person]: - return super().read_many(skip=skip, limit=limit, sort_by=sort_by, name=name) - - Where the filter fields are defined as parameters. In this case "name" is a filter field - - Keyword Arguments: - skip {int} -- [description] (default: {0}) - limit {int} -- [description] (default: {100}) - sort_by {str} -- Expected in the form "model__field_name:asc,field_name:desc" (default: {None}) - - **kwargs {str} -- Filter field names expected in the form field_name or model__field_name if filtering through - a join. The filter is defined as op:value. For example ==:paul or eq:paul - - The filter op is specified in the crud_base FilterOpEnum. - - Returns: - ResponseModelManyType -- [description] - """ - filter_fields: Dict[str, str] = {} - for field in self.filter_fields: - filter_fields[field] = kwargs.pop(field, None) - results = self.crud_base.get_multi(self.db, skip=skip, limit=limit, filter_by=filter_fields, sort_by=sort_by) - return results - - def create(self, *, obj_in: CreateSchemaType,) -> ResponseModelType: - """ - Create new object. - """ - result = self.crud_base.create(db_session=self.db, obj_in=obj_in) - return result - - def update(self, *, id: IDType, obj_in: UpdateSchemaType,) -> ResponseModelType: - """ - Update an object. - """ - result = self.crud_base.get(db_session=self.db, id=id) - if not result: - raise HTTPException(status_code=404, detail=f"{self.object_name} not found") - # if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): - # raise HTTPException(status_code=400, detail="Not enough permissions") - result = self.crud_base.update(db_session=self.db, db_obj=result, obj_in=obj_in) - return result - - def read(self, *, id: IDType,) -> ResponseModelType: - """ - Get object by ID. - """ - result = self.crud_base.get(db_session=self.db, id=id) - if not result: - raise HTTPException(status_code=404, detail=f"{self.object_name} not found") - # if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): - # raise HTTPException(status_code=400, detail="Not enough permissions") - return result - - def delete(self, *, id: IDType,) -> ResponseModelType: - """ - Delete an object. - """ - result = self.crud_base.get(db_session=self.db, id=id) - if not result: - raise HTTPException(status_code=404, detail=f"{self.object_name} not found") - # if not crud.user.is_superuser(current_user) and (object.owner_id != current_user.id): - # raise HTTPException(status_code=400, detail="Not enough permissions") - result = self.crud_base.remove(db_session=self.db, id=id) - return result diff --git a/fastapi_utils/crud/__init__.py b/fastapi_utils/crud/__init__.py new file mode 100644 index 00000000..43ef8504 --- /dev/null +++ b/fastapi_utils/crud/__init__.py @@ -0,0 +1,4 @@ +from .base import CRUDBase +from .route import CRUDRoute + +__all__ = ["CRUDBase", "CRUDRoute"] diff --git a/fastapi_utils/crud/crud_base.py b/fastapi_utils/crud/base.py similarity index 95% rename from fastapi_utils/crud/crud_base.py rename to fastapi_utils/crud/base.py index ed436bd4..2c46701b 100644 --- a/fastapi_utils/crud/crud_base.py +++ b/fastapi_utils/crud/base.py @@ -6,13 +6,11 @@ from pydantic import BaseModel from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session -from sqlalchemy_filters import apply_filters, apply_sort from fastapi_utils.camelcase import snake2camel +from sqlalchemy_filters import apply_filters, apply_sort -Base = declarative_base() - -ModelType = TypeVar("ModelType", bound=Base) +ModelType = TypeVar("ModelType") MultiSchemaType = TypeVar("MultiSchemaType", bound=BaseModel) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) @@ -86,7 +84,7 @@ def get_sort_field(field: str) -> SortField: return sort_field -def get_sort_fields(sort_string: str, split_character: str = ",") -> List[SortField]: +def get_sort_fields(sort_string: Optional[str], split_character: str = ",") -> List[SortField]: sort_fields = [] # There could be many sort fields if sort_string: @@ -109,14 +107,14 @@ def __init__(self, model: Type[ModelType]): def get(self, db_session: Session, id: IDType) -> Optional[ModelType]: return db_session.query(self.model).filter(self.model.id == id).first() # type: ignore - def get_multi( + def get_many( self, db_session: Session, *, skip: int = 0, limit: int = 100, - sort_by: Optional[str] = None, filter_by: Optional[Dict[str, str]] = None, + sort_by: Optional[str] = None, ) -> List[ModelType]: sort_spec_pydantic = get_sort_fields(sort_by) @@ -129,7 +127,7 @@ def get_multi( query = apply_filters(query, filter_spec) query = apply_sort(query, sort_spec) - count = query.count() + # count = query.count() query = query.offset(skip).limit(limit) return query.all() diff --git a/fastapi_utils/crud/route.py b/fastapi_utils/crud/route.py index 2ce50bbd..332f4056 100644 --- a/fastapi_utils/crud/route.py +++ b/fastapi_utils/crud/route.py @@ -1,15 +1,16 @@ -from typing import Dict, Generic, List, TypeVar, ClassVar, Tuple +from typing import ClassVar, Dict, Generic, Tuple, TypeVar from fastapi import Depends, HTTPException from pydantic import BaseModel from sqlalchemy.orm import Session -from fastapi_utils.crud_base import Base, CRUDBase +from fastapi_utils.crud import CRUDBase ResponseModelType = TypeVar("ResponseModelType", bound=BaseModel) ResponseModelManyType = TypeVar("ResponseModelManyType", bound=BaseModel) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) +CRUDBaseType = TypeVar("CRUDBaseType", bound=CRUDBase) IDType = TypeVar("IDType") @@ -20,10 +21,10 @@ class CRUDRoute(Generic[ResponseModelType, ResponseModelManyType, CreateSchemaTy """ + crud_base: ClassVar[CRUDBaseType] filter_fields: ClassVar[Tuple[str]] = () - crud_base = CRUDBase(Base) # type: ignore db: Session = Depends(None) - object_name = "Base" + object_name: ClassVar[str] = "CRUDBase" def read_many(self, skip: int = 0, limit: int = 100, sort_by: str = None, **kwargs) -> ResponseModelManyType: """Reads many from the database with the provided filter and sort parameters. @@ -43,7 +44,8 @@ def read_persons( limit {int} -- [description] (default: {100}) sort_by {str} -- Expected in the form "model__field_name:asc,field_name:desc" (default: {None}) - **kwargs {str} -- Filter field names expected in the form field_name or model__field_name if filtering through + **kwargs {str} -- Filter field names expected in the form field_name or model__field_name if + filtering through a join. The filter is defined as op:value. For example ==:paul or eq:paul The filter op is specified in the crud_base FilterOpEnum. @@ -52,9 +54,14 @@ def read_persons( ResponseModelManyType -- [description] """ filter_fields: Dict[str, str] = {} + for field in self.filter_fields: filter_fields[field] = kwargs.pop(field, None) - results = self.crud_base.get_multi(self.db, skip=skip, limit=limit, filter_by=filter_fields, sort_by=sort_by) + + if len(kwargs) != 0: + raise ValueError(f"Method parameters have not been added to class filter fields {kwargs.keys()}") + + results = self.crud_base.get_many(self.db, skip=skip, limit=limit, filter_by=filter_fields, sort_by=sort_by) return results def create(self, *, obj_in: CreateSchemaType,) -> ResponseModelType: diff --git a/fastapi_utils/crud_base.py b/fastapi_utils/crud_base.py deleted file mode 100644 index ed436bd4..00000000 --- a/fastapi_utils/crud_base.py +++ /dev/null @@ -1,160 +0,0 @@ -from decimal import Decimal -from enum import Enum -from typing import Dict, Generic, List, Optional, Type, TypeVar, Union - -from fastapi.encoders import jsonable_encoder -from pydantic import BaseModel -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session -from sqlalchemy_filters import apply_filters, apply_sort - -from fastapi_utils.camelcase import snake2camel - -Base = declarative_base() - -ModelType = TypeVar("ModelType", bound=Base) -MultiSchemaType = TypeVar("MultiSchemaType", bound=BaseModel) -CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) -UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) -IDType = TypeVar("IDType") - - -class SortDirectionEnum(str, Enum): - ASC = "asc" - DESC = "desc" - - -class FilterOpEnum(str, Enum): - IS_NULL = "is_null" - IS_NOT_NULL = "is_not_null" - EQ_SYM = "==" - EQ = "eq" - NE_SYM = "!=" - NE = "ne" - GT_SYM = ">" - GT = "gt" - LT_SYM = "<" - LT = "lt" - GE_SYM = ">=" - GE = "ge" - LE_SYM = "<=" - LE = "le" - LIKE = "like" - ILIKE = "ilike" - IN = "in" - NOT_IN = "not_in" - - -class SortField(BaseModel): - field: str - model: Optional[str] = None - direction: SortDirectionEnum = SortDirectionEnum.DESC - - -class FilterField(BaseModel): - field: str - model: Optional[str] = None - op: FilterOpEnum - value: Union[str, int, Decimal] - - -def get_filter_field(field: str, field_name: str, split_character: str = ":") -> FilterField: - model = None - op, value = field.split(":") - if "__" in field_name: - model, field_name = field_name.split("__") - model = snake2camel(model, start_lower=False) - filter_field = FilterField(field=field_name, model=model, op=op, value=value) - return filter_field - - -def get_filter_fields(fields: Optional[Dict[str, str]], split_character: str = ":") -> List[FilterField]: - filter_fields = [] - if fields: - for field_name in fields: - if fields[field_name]: - filter_fields.append(get_filter_field(field=fields[field_name], field_name=field_name)) - return filter_fields - - -def get_sort_field(field: str) -> SortField: - model = None - field_name, direction = field.split(":") - if "__" in field_name: - model, field_name = field_name.split("__") - sort_field = SortField(model=model, field=field_name, direction=direction) - return sort_field - - -def get_sort_fields(sort_string: str, split_character: str = ",") -> List[SortField]: - sort_fields = [] - # There could be many sort fields - if sort_string: - sort_by_fields = sort_string.split(",") - for _to_sort in sort_by_fields: - sort_fields.append(get_sort_field(_to_sort)) - return sort_fields - - -class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): - def __init__(self, model: Type[ModelType]): - """ - CRUD object with default methods to Create, Read, Update, Delete (CRUD). - **Parameters** - * `model`: A SQLAlchemy model class - * `schema`: A Pydantic model (schema) class - """ - self.model = model - - def get(self, db_session: Session, id: IDType) -> Optional[ModelType]: - return db_session.query(self.model).filter(self.model.id == id).first() # type: ignore - - def get_multi( - self, - db_session: Session, - *, - skip: int = 0, - limit: int = 100, - sort_by: Optional[str] = None, - filter_by: Optional[Dict[str, str]] = None, - ) -> List[ModelType]: - - sort_spec_pydantic = get_sort_fields(sort_by) - filter_spec_pydantic = get_filter_fields(filter_by) - - sort_spec = [x.dict(exclude_none=True) for x in sort_spec_pydantic] - filter_spec = [x.dict(exclude_none=True) for x in filter_spec_pydantic] - - query = db_session.query(self.model) - query = apply_filters(query, filter_spec) - query = apply_sort(query, sort_spec) - - count = query.count() - query = query.offset(skip).limit(limit) - - return query.all() - - def create(self, db_session: Session, *, obj_in: CreateSchemaType) -> ModelType: - obj_in_data = jsonable_encoder(obj_in) - db_obj = self.model(**obj_in_data) - db_session.add(db_obj) - db_session.commit() - db_session.refresh(db_obj) - return db_obj - - def update(self, db_session: Session, *, db_obj: ModelType, obj_in: UpdateSchemaType) -> ModelType: - obj_data = jsonable_encoder(db_obj) - update_data = obj_in.dict(skip_defaults=True) - for field in obj_data: - if field in update_data: - setattr(db_obj, field, update_data[field]) - db_session.add(db_obj) - db_session.commit() - db_session.refresh(db_obj) - return db_obj - - def remove(self, db_session: Session, *, id: IDType) -> ModelType: - obj = db_session.query(self.model).get(id) - db_session.delete(obj) - db_session.commit() - return obj