diff --git a/peewee_async/databases.py b/peewee_async/databases.py index d15d41b..4302cd2 100644 --- a/peewee_async/databases.py +++ b/peewee_async/databases.py @@ -116,14 +116,20 @@ async def _aio_atomic(self, use_savepoint: bool = False) -> AsyncIterator[None]: async with self.aio_connection() as connection: _connection_context = connection_context.get() assert _connection_context is not None - if _connection_context.transaction_is_opened and not use_savepoint: + + _is_root = not _connection_context.transaction_is_opened + _is_nested = _connection_context.transaction_is_opened + + if _is_nested and not use_savepoint: raise peewee.OperationalError("Transaction already opened") try: - async with Transaction(connection, is_savepoint=_connection_context.transaction_is_opened): - _connection_context.transaction_is_opened = True + async with Transaction(connection, is_savepoint=_is_nested): + if _is_root: + _connection_context.transaction_is_opened = True yield finally: - _connection_context.transaction_is_opened = False + if _is_root: + _connection_context.transaction_is_opened = False def set_allow_sync(self, value: bool) -> None: """Allow or forbid sync queries for the database. See also diff --git a/tests/test_transaction.py b/tests/test_transaction.py index b295362..2aaa437 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -151,14 +151,20 @@ async def test_nested_transaction__error(method1: str, method2: str, db: AioData @dbs_all -async def test_savepoint_success(db: AioDatabase) -> None: +async def test_savepoints_success(db: AioDatabase) -> None: async with db.aio_atomic(): await TestModel.aio_create(text='FOO') async with db.aio_atomic(): await TestModel.update(text="BAR").aio_execute() - assert await TestModel.aio_get_or_none(text="BAR") is not None + async with db.aio_atomic(): + await TestModel.update(text="BAZ").aio_execute() + + async with db.aio_atomic(): + await TestModel.update(text="QUX").aio_execute() + + assert await TestModel.aio_get_or_none(text="QUX") is not None assert db.pool_backend.has_acquired_connections() is False