Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,20 @@ async def _aio_atomic(self, use_savepoint: bool = False) -> AsyncIterator[None]:
async with self.aio_connection() as connection:
_connection_context = connection_context.get()
assert _connection_context is not None
if _connection_context.transaction_is_opened and not use_savepoint:

_is_root = not _connection_context.transaction_is_opened
_is_nested = _connection_context.transaction_is_opened

if _is_nested and not use_savepoint:
raise peewee.OperationalError("Transaction already opened")
try:
async with Transaction(connection, is_savepoint=_connection_context.transaction_is_opened):
_connection_context.transaction_is_opened = True
async with Transaction(connection, is_savepoint=_is_nested):
if _is_root:
_connection_context.transaction_is_opened = True
yield
finally:
_connection_context.transaction_is_opened = False
if _is_root:
_connection_context.transaction_is_opened = False

def set_allow_sync(self, value: bool) -> None:
"""Allow or forbid sync queries for the database. See also
Expand Down
10 changes: 8 additions & 2 deletions tests/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,20 @@ async def test_nested_transaction__error(method1: str, method2: str, db: AioData


@dbs_all
async def test_savepoint_success(db: AioDatabase) -> None:
async def test_savepoints_success(db: AioDatabase) -> None:
async with db.aio_atomic():
await TestModel.aio_create(text='FOO')

async with db.aio_atomic():
await TestModel.update(text="BAR").aio_execute()

assert await TestModel.aio_get_or_none(text="BAR") is not None
async with db.aio_atomic():
await TestModel.update(text="BAZ").aio_execute()

async with db.aio_atomic():
await TestModel.update(text="QUX").aio_execute()

assert await TestModel.aio_get_or_none(text="QUX") is not None
assert db.pool_backend.has_acquired_connections() is False


Expand Down