From a912bd25f4a99f3929ba86b5523bcf388a0a6619 Mon Sep 17 00:00:00 2001 From: skvortsov_k Date: Fri, 22 Aug 2025 15:39:48 +0300 Subject: [PATCH 1/2] fix: capture transaction begin/end --- peewee_async/databases.py | 4 +++- tests/test_transaction.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/peewee_async/databases.py b/peewee_async/databases.py index d15d41b..83443f8 100644 --- a/peewee_async/databases.py +++ b/peewee_async/databases.py @@ -116,6 +116,7 @@ async def _aio_atomic(self, use_savepoint: bool = False) -> AsyncIterator[None]: async with self.aio_connection() as connection: _connection_context = connection_context.get() assert _connection_context is not None + begin_transaction = _connection_context.transaction_is_opened is False if _connection_context.transaction_is_opened and not use_savepoint: raise peewee.OperationalError("Transaction already opened") try: @@ -123,7 +124,8 @@ async def _aio_atomic(self, use_savepoint: bool = False) -> AsyncIterator[None]: _connection_context.transaction_is_opened = True yield finally: - _connection_context.transaction_is_opened = False + if begin_transaction: + _connection_context.transaction_is_opened = False def set_allow_sync(self, value: bool) -> None: """Allow or forbid sync queries for the database. See also diff --git a/tests/test_transaction.py b/tests/test_transaction.py index b295362..b964952 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -199,6 +199,24 @@ async def test_savepoint_manual_work(db: AioDatabase) -> None: assert db.pool_backend.has_acquired_connections() is False +@dbs_all +async def test_nested_savepoints_success(db: AioDatabase) -> None: + async with db.aio_atomic(): + await TestModel.aio_create(text='FOO') + + async with db.aio_atomic(): + await TestModel.update(text="BAR").aio_execute() + + async with db.aio_atomic(): + await TestModel.update(text="BAZ").aio_execute() + + async with db.aio_atomic(): + await TestModel.update(text="QUX").aio_execute() + + assert await TestModel.aio_get_or_none(text="QUX") is not None + assert db.pool_backend.has_acquired_connections() is False + + @transaction_methods @dbs_all async def test_acid_when_connetion_has_been_broken(transaction_method:str, db: AioDatabase) -> None: From e7068a3830a248532aa3be98c7be1e19069294c7 Mon Sep 17 00:00:00 2001 From: kalombo Date: Sat, 23 Aug 2025 15:33:14 +0500 Subject: [PATCH 2/2] small refactoring --- peewee_async/databases.py | 14 +++++++++----- tests/test_transaction.py | 28 ++++++++-------------------- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/peewee_async/databases.py b/peewee_async/databases.py index 83443f8..4302cd2 100644 --- a/peewee_async/databases.py +++ b/peewee_async/databases.py @@ -116,15 +116,19 @@ async def _aio_atomic(self, use_savepoint: bool = False) -> AsyncIterator[None]: async with self.aio_connection() as connection: _connection_context = connection_context.get() assert _connection_context is not None - begin_transaction = _connection_context.transaction_is_opened is False - if _connection_context.transaction_is_opened and not use_savepoint: + + _is_root = not _connection_context.transaction_is_opened + _is_nested = _connection_context.transaction_is_opened + + if _is_nested and not use_savepoint: raise peewee.OperationalError("Transaction already opened") try: - async with Transaction(connection, is_savepoint=_connection_context.transaction_is_opened): - _connection_context.transaction_is_opened = True + async with Transaction(connection, is_savepoint=_is_nested): + if _is_root: + _connection_context.transaction_is_opened = True yield finally: - if begin_transaction: + if _is_root: _connection_context.transaction_is_opened = False def set_allow_sync(self, value: bool) -> None: diff --git a/tests/test_transaction.py b/tests/test_transaction.py index b964952..2aaa437 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -151,14 +151,20 @@ async def test_nested_transaction__error(method1: str, method2: str, db: AioData @dbs_all -async def test_savepoint_success(db: AioDatabase) -> None: +async def test_savepoints_success(db: AioDatabase) -> None: async with db.aio_atomic(): await TestModel.aio_create(text='FOO') async with db.aio_atomic(): await TestModel.update(text="BAR").aio_execute() - assert await TestModel.aio_get_or_none(text="BAR") is not None + async with db.aio_atomic(): + await TestModel.update(text="BAZ").aio_execute() + + async with db.aio_atomic(): + await TestModel.update(text="QUX").aio_execute() + + assert await TestModel.aio_get_or_none(text="QUX") is not None assert db.pool_backend.has_acquired_connections() is False @@ -199,24 +205,6 @@ async def test_savepoint_manual_work(db: AioDatabase) -> None: assert db.pool_backend.has_acquired_connections() is False -@dbs_all -async def test_nested_savepoints_success(db: AioDatabase) -> None: - async with db.aio_atomic(): - await TestModel.aio_create(text='FOO') - - async with db.aio_atomic(): - await TestModel.update(text="BAR").aio_execute() - - async with db.aio_atomic(): - await TestModel.update(text="BAZ").aio_execute() - - async with db.aio_atomic(): - await TestModel.update(text="QUX").aio_execute() - - assert await TestModel.aio_get_or_none(text="QUX") is not None - assert db.pool_backend.has_acquired_connections() is False - - @transaction_methods @dbs_all async def test_acid_when_connetion_has_been_broken(transaction_method:str, db: AioDatabase) -> None: