diff --git a/docs/index.rst b/docs/index.rst index cb2f0df..181d2dc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -36,6 +36,7 @@ Contents peewee_async/api peewee_async/connection peewee_async/transaction + peewee_async/signals peewee_async/examples Indices and tables diff --git a/docs/peewee_async/signals.rst b/docs/peewee_async/signals.rst new file mode 100644 index 0000000..bbb405f --- /dev/null +++ b/docs/peewee_async/signals.rst @@ -0,0 +1,43 @@ +Signal support +==================== + + `Signal support`_ has been backported from the original peewee with a few differences. Models with hooks for signals are provided in + ``peewee_async.signals``. To use the signals, you will need all of your project's + models to be a subclass of ``peewee_async.signals.AioModel``, which overrides the + necessary methods to provide support for the various signals. A handler for any signal except ``pre_init`` should be a coroutine function. For obvious reasons + ``pre_init`` signal handler can be only a synchronious function. + +.. code-block:: python + + from peewee_async.signals import AioModel, aio_post_save + + + class MyModel(AioModel): + data = IntegerField() + + @aio_post_save(sender=MyModel) + async def on_save_handler(model_class, instance, created): + await save_in_history_table(instance.data) + + +The following signals are provided: + +``aio_pre_save`` + Called immediately before an object is saved to the database. Provides an + additional keyword argument ``created``, indicating whether the model is being + saved for the first time or updated. +``aio_post_save`` + Called immediately after an object is saved to the database. Provides an + additional keyword argument ``created``, indicating whether the model is being + saved for the first time or updated. +``aio_pre_delete`` + Called immediately before an object is deleted from the database when :py:meth:`Model.aio_delete_instance` + is used. +``aio_post_delete`` + Called immediately after an object is deleted from the database when :py:meth:`Model.aio_delete_instance` + is used. +``pre_init`` + Called when a model class is first instantiated. Can not be async. + + +.. _Signal support: https://docs.peewee-orm.com/en/latest/peewee/playhouse.html#signal-support diff --git a/peewee_async/aio_model.py b/peewee_async/aio_model.py index d19e5f7..f588d16 100644 --- a/peewee_async/aio_model.py +++ b/peewee_async/aio_model.py @@ -5,7 +5,7 @@ from .result_wrappers import fetch_models from .utils import CursorProtocol from typing_extensions import Self -from typing import Tuple, List, Any, cast, Optional, Dict, Union +from typing import Literal, Tuple, List, Any, cast, Optional, Dict, Union async def aio_prefetch(sq: Any, *subqueries: Any, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE) -> Any: @@ -281,7 +281,7 @@ async def aio_delete_instance(self, recursive: bool = False, delete_nullable: bo await model.delete().where(query).aio_execute() return cast(int, await type(self).delete().where(self._pk_expr()).aio_execute()) - async def aio_save(self, force_insert: bool = False, only: Any =None) -> int: + async def aio_save(self, force_insert: bool = False, only: Any =None) -> Union[int, Literal[False]]: """ Async version of **peewee.Model.save** diff --git a/peewee_async/signals.py b/peewee_async/signals.py new file mode 100644 index 0000000..f3e65ec --- /dev/null +++ b/peewee_async/signals.py @@ -0,0 +1,40 @@ +from peewee_async import AioModel as _Model +from typing import Union, Literal, Any +from playhouse.signals import Signal + +class AioSignal(Signal): + async def send(self, instance: "AioModel", *args: Any, **kwargs: Any) -> list[tuple[Any, Any]]: + sender = type(instance) + responses = [] + for n, r, s in self._receiver_list: + if s is None or isinstance(instance, s): + responses.append((r, await r(sender, instance, *args, **kwargs))) + return responses + + +aio_pre_save = AioSignal() +aio_post_save = AioSignal() +aio_pre_delete = AioSignal() +aio_post_delete = AioSignal() +pre_init = Signal() # can't be async ! + + +class AioModel(_Model): + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super(AioModel, self).__init__(*args, **kwargs) + pre_init.send(self) + + async def aio_save(self, force_insert: bool = False, only: Any = None) -> Union[int, Literal[False]]: + pk_value = self._pk if self._meta.primary_key else True + created = force_insert or not bool(pk_value) + await aio_pre_save.send(self, created=created) + ret = await super(AioModel, self).aio_save(force_insert, only) + await aio_post_save.send(self, created=created) + return ret + + async def aio_delete_instance(self, recursive: bool = False, delete_nullable: bool = False) -> int: + await aio_pre_delete.send(self) + ret = await super(AioModel, self).aio_delete_instance(recursive, delete_nullable) + await aio_post_delete.send(self) + return ret diff --git a/tests/models.py b/tests/models.py index 3cdce35..d7240b4 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,13 +1,15 @@ import uuid -import peewee +import peewee as pw import peewee_async +import peewee_async.signals +import datetime as dt class TestModel(peewee_async.AioModel): __test__ = False # disable pytest warnings - text = peewee.CharField(max_length=100, unique=True) - data = peewee.TextField(default='') + text = pw.CharField(max_length=100, unique=True) + data = pw.TextField(default='') def __str__(self) -> str: return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) @@ -15,7 +17,7 @@ def __str__(self) -> str: class TestModelAlpha(peewee_async.AioModel): __test__ = False - text = peewee.CharField() + text = pw.CharField() def __str__(self) -> str: return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) @@ -23,8 +25,8 @@ def __str__(self) -> str: class TestModelBeta(peewee_async.AioModel): __test__ = False - alpha = peewee.ForeignKeyField(TestModelAlpha, backref='betas') - text = peewee.CharField() + alpha = pw.ForeignKeyField(TestModelAlpha, backref='betas') + text = pw.CharField() def __str__(self) -> str: return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) @@ -32,16 +34,16 @@ def __str__(self) -> str: class TestModelGamma(peewee_async.AioModel): __test__ = False - text = peewee.CharField() - beta = peewee.ForeignKeyField(TestModelBeta, backref='gammas') + text = pw.CharField() + beta = pw.ForeignKeyField(TestModelBeta, backref='gammas') def __str__(self) -> str: return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) class UUIDTestModel(peewee_async.AioModel): - id = peewee.UUIDField(primary_key=True, default=uuid.uuid4) - text = peewee.CharField() + id = pw.UUIDField(primary_key=True, default=uuid.uuid4) + text = pw.CharField() def __str__(self) -> str: return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) @@ -49,19 +51,24 @@ def __str__(self) -> str: class CompositeTestModel(peewee_async.AioModel): """A simple "through" table for many-to-many relationship.""" - task_id = peewee.IntegerField() - product_type = peewee.CharField() + task_id = pw.IntegerField() + product_type = pw.CharField() class Meta: - primary_key = peewee.CompositeKey('task_id', 'product_type') + primary_key = pw.CompositeKey('task_id', 'product_type') class IntegerTestModel(peewee_async.AioModel): __test__ = False # disable pytest warnings - num = peewee.IntegerField() + num = pw.IntegerField() + + +class TestSignalModel(peewee_async.signals.AioModel): + __test__ = False # disable pytest warnings + text = pw.CharField(max_length=100) ALL_MODELS = ( TestModel, UUIDTestModel, TestModelAlpha, TestModelBeta, TestModelGamma, - CompositeTestModel, IntegerTestModel + CompositeTestModel, IntegerTestModel, TestSignalModel ) diff --git a/tests/test_signals.py b/tests/test_signals.py new file mode 100644 index 0000000..8ed3ba3 --- /dev/null +++ b/tests/test_signals.py @@ -0,0 +1,84 @@ +from contextlib import contextmanager +from typing import Any, Callable, Coroutine, Iterator + +from peewee_async.databases import AioDatabase +from tests.conftest import dbs_all +from tests.models import TestSignalModel +from peewee_async.signals import AioModel, aio_pre_save, aio_post_save, aio_post_delete, aio_pre_delete, AioSignal, pre_init + + +@contextmanager +def _connect(signal: AioSignal , receiver: Callable[..., Coroutine[Any, Any, Any]], sender: type[AioModel]) -> Iterator[None]: + signal.connect(receiver=receiver, sender=sender) + yield + signal.disconnect(receiver=receiver, sender=sender) + + +@dbs_all +async def test_aio_pre_save(db: AioDatabase) -> None: + + + async def on_save_handler(model_class: type[TestSignalModel], instance: TestSignalModel, created: bool) -> None: + assert await TestSignalModel.select().aio_exists() is False + assert model_class is TestSignalModel + assert isinstance(instance, TestSignalModel) + assert created + + with _connect(aio_pre_save, receiver=on_save_handler, sender=TestSignalModel): + await TestSignalModel.aio_create(text="aio_create") + + +@dbs_all +async def test_aio_post_save(db: AioDatabase) -> None: + + + async def on_save_handler(model_class: type[TestSignalModel], instance: TestSignalModel, created: bool) -> None: + assert await TestSignalModel.select().aio_exists() is True + assert model_class is TestSignalModel + assert isinstance(instance, TestSignalModel) + assert created + + with _connect(aio_post_save, receiver=on_save_handler, sender=TestSignalModel): + await TestSignalModel.aio_create(text="aio_create") + + +@dbs_all +async def test_aio_pre_delete(db: AioDatabase) -> None: + + t = await TestSignalModel.aio_create(text="aio_create") + + async def on_delete_handler(model_class: type[TestSignalModel], instance: TestSignalModel) -> None: + assert await TestSignalModel.select().aio_exists() is True + assert model_class is TestSignalModel + assert isinstance(instance, TestSignalModel) + + with _connect(aio_pre_delete, receiver=on_delete_handler, sender=TestSignalModel): + await t.aio_delete_instance() + + +@dbs_all +async def test_aio_post_delete(db: AioDatabase) -> None: + + t = await TestSignalModel.aio_create(text="aio_create") + + async def on_delete_handler(model_class: type[TestSignalModel], instance: TestSignalModel) -> None: + assert await TestSignalModel.select().aio_exists() is False + assert model_class is TestSignalModel + assert isinstance(instance, TestSignalModel) + + with _connect(aio_post_delete, receiver=on_delete_handler, sender=TestSignalModel): + await t.aio_delete_instance() + + +@dbs_all +def test_pre_init(db: AioDatabase) -> None: + + def on_init_handler(model_class: type[TestSignalModel], instance: TestSignalModel) -> None: + assert model_class is TestSignalModel + assert instance.text == "text" + + pre_init.connect(receiver=on_init_handler, sender=TestSignalModel) + + TestSignalModel(text="text") + + pre_init.disconnect(receiver=on_init_handler, sender=TestSignalModel) \ No newline at end of file