diff --git a/docs/peewee_async/api.rst b/docs/peewee_async/api.rst index 4b7843a..23e9106 100644 --- a/docs/peewee_async/api.rst +++ b/docs/peewee_async/api.rst @@ -21,6 +21,8 @@ Databases .. automethod:: peewee_async.databases.AioDatabase.aio_atomic +.. automethod:: peewee_async.databases.AioDatabase.aio_transaction + .. autoclass:: peewee_async.PsycopgDatabase :members: init diff --git a/docs/peewee_async/transaction.rst b/docs/peewee_async/transaction.rst index 22b9852..57abe15 100644 --- a/docs/peewee_async/transaction.rst +++ b/docs/peewee_async/transaction.rst @@ -1,10 +1,11 @@ Transactions ========================= -Peewee-async provides similiar to peewee interface for working with transactions. The interface is the :py:meth:`~peewee_async.databases.AioDatabase.aio_atomic` method, -which also supports nested transactions and works as context manager. **aio_atomic()** blocks will be run in a transaction or savepoint, depending on the level of nesting. +Peewee-async provides several interfaces similiar to peewee for working with transactions. +The most general interface are :py:meth:`~peewee_async.databases.AioDatabase.aio_atomic` and :py:meth:`~peewee_async.databases.AioDatabase.aio_transaction` methods which work as context managers. +The :py:meth:`~peewee_async.databases.AioDatabase.aio_atomic` method supports nested transactions and run the block of code in a transaction or savepoint, depending on the level of nesting. If an exception occurs in a wrapped block, the current transaction/savepoint will be rolled back. Otherwise the statements will be committed at the end of the wrapped block. .. code-block:: python @@ -17,6 +18,16 @@ If an exception occurs in a wrapped block, the current transaction/savepoint wil # RELEASE SAVEPOINT PWASYNC__e83bf5fc118f4e28b0fbdac90ab857ca # COMMIT +The :py:meth:`~peewee_async.databases.AioDatabase.aio_transcation` method does not allow nested transactions and run the block of code in a transaction. + +.. code-block:: python + + async with db.aio_atomic(): # BEGIN + await TestModel.aio_create(text='FOO') # INSERT INTO "testmodel" ("text", "data") VALUES ('FOO', '') RETURNING "testmodel"."id" + # COMMIT + +Using nested :py:meth:`~peewee_async.databases.AioDatabase.aio_transcation` will lead to **OperationalError**. + Manual management +++++++++++++++++ diff --git a/peewee_async/databases.py b/peewee_async/databases.py index 24c3a92..d15d41b 100644 --- a/peewee_async/databases.py +++ b/peewee_async/databases.py @@ -1,7 +1,7 @@ import contextlib import logging import warnings -from typing import Type, Optional, Any, AsyncIterator, Iterator, Dict, List +from typing import Type, Optional, Any, AsyncIterator, Iterator, Dict, List, AsyncContextManager import peewee from playhouse import postgres_ext as ext @@ -97,22 +97,33 @@ async def aio_close(self) -> None: await self.pool_backend.close() - @contextlib.asynccontextmanager - async def aio_atomic(self) -> AsyncIterator[None]: - """Similar to peewee `Database.atomic()` method, but returns - asynchronous context manager. + def aio_atomic(self) -> AsyncContextManager[None]: + """Create an async context-manager which runs any queries in the wrapped block in a transaction (or save-point if blocks are nested). + Calls to :meth:`.aio_atomic()` can be nested. + """ + return self._aio_atomic(use_savepoint=True) + + def aio_transaction(self) -> AsyncContextManager[None]: + """Create an async context-manager that runs all queries in the wrapped block in a transaction. + + Calls to :meth:`.aio_transaction()` cannot be nested. If so OperationalError will be raised. """ + return self._aio_atomic(use_savepoint=False) + + @contextlib.asynccontextmanager + async def _aio_atomic(self, use_savepoint: bool = False) -> AsyncIterator[None]: + async with self.aio_connection() as connection: _connection_context = connection_context.get() assert _connection_context is not None - begin_transaction = _connection_context.transaction_is_opened is False + if _connection_context.transaction_is_opened and not use_savepoint: + raise peewee.OperationalError("Transaction already opened") try: - async with Transaction(connection, is_savepoint=begin_transaction is False): + async with Transaction(connection, is_savepoint=_connection_context.transaction_is_opened): _connection_context.transaction_is_opened = True yield finally: - if begin_transaction is True: - _connection_context.transaction_is_opened = False + _connection_context.transaction_is_opened = False def set_allow_sync(self, value: bool) -> None: """Allow or forbid sync queries for the database. See also diff --git a/pyproject.toml b/pyproject.toml index e1a4676..06206d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "peewee-async" -version = "1.0.0" +version = "1.1.0" description = "Asynchronous interface for peewee ORM powered by asyncio." authors = ["Alexey Kinev ", "Gorshkov Nikolay(contributor) "] readme = "README.md" diff --git a/tests/test_transaction.py b/tests/test_transaction.py index cba71d8..b295362 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,8 +1,9 @@ import asyncio import pytest -from peewee import IntegrityError +from peewee import IntegrityError, OperationalError from pytest_mock import MockerFixture +from typing import AsyncContextManager, cast from peewee_async import Transaction from peewee_async.databases import AioDatabase @@ -13,31 +14,43 @@ class FakeConnectionError(Exception): pass +transaction_methods = pytest.mark.parametrize( + "transaction_method", ["aio_transaction", "aio_atomic"] +) +def _transaction_method(db: AioDatabase, transaction_method: str) -> AsyncContextManager[None]: + return cast(AsyncContextManager[None], getattr(db, transaction_method)()) + +@transaction_methods @dbs_all -async def test_transaction_error_on_begin(db: AioDatabase, mocker: MockerFixture) -> None: +async def test_transaction_error_on_begin( + db: AioDatabase, + transaction_method: str, + mocker: MockerFixture +) -> None: mocker.patch.object(Transaction, "begin", side_effect=FakeConnectionError) with pytest.raises(FakeConnectionError): - async with db.aio_atomic(): + async with _transaction_method(db, transaction_method): await TestModel.aio_create(text='FOO') assert db.pool_backend.has_acquired_connections() is False - +@transaction_methods @dbs_all -async def test_transaction_error_on_commit(db: AioDatabase, mocker: MockerFixture) -> None: +async def test_transaction_error_on_commit(transaction_method: str, db: AioDatabase, mocker: MockerFixture) -> None: mocker.patch.object(Transaction, "commit", side_effect=FakeConnectionError) with pytest.raises(FakeConnectionError): - async with db.aio_atomic(): + async with _transaction_method(db, transaction_method): await TestModel.aio_create(text='FOO') assert db.pool_backend.has_acquired_connections() is False +@transaction_methods @dbs_all -async def test_transaction_error_on_rollback(db: AioDatabase, mocker: MockerFixture) -> None: +async def test_transaction_error_on_rollback(transaction_method: str, db: AioDatabase, mocker: MockerFixture) -> None: await TestModel.aio_create(text='FOO', data="") mocker.patch.object(Transaction, "rollback", side_effect=FakeConnectionError) with pytest.raises(FakeConnectionError): - async with db.aio_atomic(): + async with _transaction_method(db, transaction_method): await TestModel.update(data="BAR").aio_execute() assert await TestModel.aio_get_or_none(data="BAR") is not None await TestModel.aio_create(text='FOO') @@ -45,21 +58,22 @@ async def test_transaction_error_on_rollback(db: AioDatabase, mocker: MockerFixt assert db.pool_backend.has_acquired_connections() is False +@transaction_methods @dbs_all -async def test_transaction_success(db: AioDatabase) -> None: - async with db.aio_atomic(): +async def test_transaction_success(transaction_method: str,db: AioDatabase) -> None: + async with _transaction_method(db, transaction_method): await TestModel.aio_create(text='FOO') assert await TestModel.aio_get_or_none(text="FOO") is not None assert db.pool_backend.has_acquired_connections() is False - +@transaction_methods @dbs_all -async def test_transaction_rollback(db: AioDatabase) -> None: +async def test_transaction_rollback(transaction_method: str, db: AioDatabase) -> None: await TestModel.aio_create(text='FOO', data="") with pytest.raises(IntegrityError): - async with db.aio_atomic(): + async with _transaction_method(db, transaction_method): await TestModel.update(data="BAR").aio_execute() assert await TestModel.aio_get_or_none(data="BAR") is not None await TestModel.aio_create(text='FOO') @@ -116,6 +130,26 @@ async def test_transaction_manual_work(db: AioDatabase) -> None: assert db.pool_backend.has_acquired_connections() is False +@pytest.mark.parametrize( + ("method1", "method2"), + [ + ("aio_atomic", "aio_transaction"), + ("aio_transaction", "aio_transaction"), + ], +) +@dbs_all +async def test_nested_transaction__error(method1: str, method2: str, db: AioDatabase) -> None: + + with pytest.raises(OperationalError): + async with _transaction_method(db, method1): + await TestModel.aio_create(text='FOO') + async with _transaction_method(db, method2): + await TestModel.update(text="BAR").aio_execute() + + assert await TestModel.aio_get_or_none(text='FOO') is None + assert db.pool_backend.has_acquired_connections() is False + + @dbs_all async def test_savepoint_success(db: AioDatabase) -> None: async with db.aio_atomic(): @@ -165,8 +199,9 @@ async def test_savepoint_manual_work(db: AioDatabase) -> None: assert db.pool_backend.has_acquired_connections() is False +@transaction_methods @dbs_all -async def test_acid_when_connetion_has_been_broken(db: AioDatabase) -> None: +async def test_acid_when_connetion_has_been_broken(transaction_method:str, db: AioDatabase) -> None: async def restart_connections(event_for_lock: asyncio.Event) -> None: event_for_lock.set() await asyncio.sleep(0.05) @@ -185,7 +220,7 @@ async def restart_connections(event_for_lock: asyncio.Event) -> None: async def insert_records(event_for_wait: asyncio.Event) -> None: await event_for_wait.wait() - async with db.aio_atomic(): + async with _transaction_method(db, transaction_method): # BEGIN # INSERT 1 await TestModel.aio_create(text="1")