Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var/
*.egg-info/
.installed.cfg
*.egg
.venv/

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down Expand Up @@ -96,3 +97,6 @@ ENV/
.vscode/*
.history
*.code-workspace

# Other
.tool-versions
8 changes: 3 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ dist: xenial
sudo: true

python:
- "3.5"
- "3.6"
- "3.7"
- "3.8"
- "3.9"

env:
- ASYNCPG_VERSION=0.15.0
- ASYNCPG_VERSION=0.16.0
- ASYNCPG_VERSION=0.17.0
- ASYNCPG_VERSION=0.22.0

# command to install dependencies
install:
Expand Down
8 changes: 4 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
FROM python:3.7-alpine
FROM python:3.8-alpine

RUN apk update && \
apk add \
gcc \
musl-dev \
postgresql-dev
gcc \
musl-dev \
postgresql-dev

ADD dev-requirements.txt /repo/dev-requirements.txt
RUN pip install -r /repo/dev-requirements.txt
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ for columns is no longer possible and you need to access columns using exact nam

3. 0.18.0 Removed the `insert` method. We found this method was just confusing, and useless as SqlAlchemy can do it for you by defining your table with a primary key.

4. 0.27.0 Now only compatible with version 0.22.0 and greater of asyncpg.

## sqlalchemy ORM

Expand Down Expand Up @@ -68,3 +69,7 @@ aiopg.sa: 9.541276566000306
asyncpsa: 6.747777451004367
```
So, seems like its still faster using asyncpg, or in otherwords, this library doesnt add any overhead that is not in aiopg.sa.

## Versioning

This software follows [Semantic Versioning](http://semver.org/).
11 changes: 8 additions & 3 deletions asyncpgsa/connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from asyncpg import connection
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import pypostgresql
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.dml import Insert as InsertObject, Update as UpdateObject
Expand Down Expand Up @@ -45,7 +46,9 @@ def _execute_default_attr(query, param, attr_name):
for col in query.table.columns:
attr = getattr(col, attr_name)
if attr and param.get(col.name) is None:
if attr.is_scalar:
if attr.is_sequence:
param[col.name] = func.nextval(attr.name)
elif attr.is_scalar:
param[col.name] = attr.arg
elif attr.is_callable:
param[col.name] = attr.arg({})
Expand Down Expand Up @@ -85,11 +88,13 @@ def __init__(self, *args, dialect=None, **kwargs):
super().__init__(*args, **kwargs)
self._dialect = dialect or _dialect

def _execute(self, query, args, limit, timeout, return_status=False):
def _execute(self, query, args, limit, timeout, return_status=False, record_class=None, ignore_custom_codec=False):
query, compiled_args = compile_query(query, dialect=self._dialect)
args = compiled_args or args
return super()._execute(query, args, limit, timeout,
return_status=return_status)
return_status=return_status,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec)

async def execute(self, script, *args, **kwargs) -> str:
script, params = compile_query(script, dialect=self._dialect)
Expand Down
11 changes: 8 additions & 3 deletions asyncpgsa/testing/mockconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@


def __subclasscheck__(cls, subclass):
if subclass == MockConnection:
return True
return old_subclass_check(cls, subclass)
if subclass == MockConnection:
return True
return old_subclass_check(cls, subclass)

old_subclass_check = ConnectionMeta.__subclasscheck__
ConnectionMeta.__subclasscheck__ = __subclasscheck__
Expand All @@ -33,6 +33,11 @@ def results(self, result):
global results
results = result

def set_database_results(self, *dbresults):
self.results = Queue()
for result in dbresults:
self.results.put_nowait(result)

async def general_query(self, query, *args, **kwargs):
completed_queries.append((query, *args, kwargs))
return results.get_nowait()
Expand Down
4 changes: 1 addition & 3 deletions asyncpgsa/testing/mockpgsingleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self):
self.__pool = MockSAPool(connection=self.connection)

def get_completed_queries(self):
return self.connection.get_completed_queries
return self.connection.completed_queries

def set_database_results(self, *results):
self.connection.results = Queue() # reset queue
Expand Down Expand Up @@ -57,5 +57,3 @@ async def __aenter__(self):

async def __aexit__(self, exc_type, exc_val, exc_tb):
pass


6 changes: 3 additions & 3 deletions asyncpgsa/testing/mockpool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from asyncpgsa import compile_query
from asyncpg import protocol
from asyncpg.pool import Pool

from .mockconnection import MockConnection
Expand All @@ -14,7 +15,8 @@ def __init__(self, connection=None):
setup=None,
loop=None,
init=None,
connection_class=MockConnection)
connection_class=MockConnection,
record_class=protocol.Record)

self.connection = connection
if not self.connection:
Expand Down Expand Up @@ -57,5 +59,3 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):

async def close(self):
pass


3 changes: 0 additions & 3 deletions asyncpgsa/testing/mockpreparedstmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,3 @@ async def __anext__(self):
return next(self.iterator)
except StopIteration:
raise StopAsyncIteration



6 changes: 5 additions & 1 deletion asyncpgsa/transactionmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ async def __aenter__(self):
self.acquire_context = self.pool.acquire(timeout=self.timeout)
con = await self.acquire_context.__aenter__()
self.transaction = con.transaction(**self.trans_kwargs)
await self.transaction.__aenter__()
try:
await self.transaction.__aenter__()
except Exception:
await asyncio.shield(self.acquire_context.__aexit__())
raise
return con

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand Down
2 changes: 1 addition & 1 deletion asyncpgsa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.25.2'
__version__ = '0.27.1'
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ asyncpg
pytest
pytest-asyncio
sqlalchemy
psycopg2
psycopg2-binary
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ services:
- DB_PASS=password
- PYTHONASYNCIODEBUG=1
postgres:
image: postgres:9
image: postgres:13
hostname: postgres
environment:
- POSTGRES_DB=postgres
Expand Down
12 changes: 10 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,21 @@
name='asyncpgsa',
version=version['__version__'],
install_requires=[
'asyncpg',
'asyncpg>=0.22.0',
'sqlalchemy',
],
packages=['asyncpgsa', 'asyncpgsa.testing'],
url='https://github.com/canopytax/asyncpgsa',
license='Apache 2.0',
author='nhumrich',
author_email='nick.humrich@canopytax.com',
description='sqlalchemy support for asyncpg'
description='sqlalchemy support for asyncpg',
classifiers=[
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: Implementation :: CPython",
],
)
10 changes: 9 additions & 1 deletion tests/test_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import uuid4, UUID
from datetime import date, datetime, timedelta
from asyncpgsa import connection
from sqlalchemy import Table, Column, MetaData, types
from sqlalchemy import Table, Column, MetaData, Sequence, types
from sqlalchemy.dialects.postgresql import UUID as PG_UUID

metadata = MetaData()
Expand Down Expand Up @@ -32,6 +32,7 @@ class MyIntEnum(enum.IntEnum):
users = Table(
'users', metadata,
Column('id', PG_UUID, unique=True, default=uuid4),
Column('serial', types.Integer, Sequence("serial_seq")),
Column('name', types.String(60), nullable=False,
default=name_default),
Column('t_list', types.ARRAY(types.String(60)), nullable=False,
Expand All @@ -58,6 +59,9 @@ class MyIntEnum(enum.IntEnum):
def test_insert_query_defaults():
query = users.insert()
new_query, new_params = connection.compile_query(query)
serial_default = query.parameters.get('serial')
assert serial_default.name == 'nextval'
assert serial_default.clause_expr.element.clauses[0].value == 'serial_seq'
assert query.parameters.get('name') == name_default
assert query.parameters.get('t_list') == t_list_default
assert query.parameters.get('t_enum') == t_enum_default
Expand All @@ -74,6 +78,7 @@ def test_insert_query_defaults_override():
query = users.insert()
query = query.values(
name='username',
serial=4444,
t_list=['l1', 'l2'],
t_enum=MyEnum.ITEM_1,
t_int_enum=MyIntEnum.ITEM_2,
Expand All @@ -85,6 +90,7 @@ def test_insert_query_defaults_override():
)
new_query, new_params = connection.compile_query(query)
assert query.parameters.get('version')
assert query.parameters.get('serial') == 4444
assert query.parameters.get('name') == 'username'
assert query.parameters.get('t_list') == ['l1', 'l2']
assert query.parameters.get('t_enum') == MyEnum.ITEM_1
Expand All @@ -101,6 +107,7 @@ def test_update_query():
query = users.update().where(users.c.name == 'default')
query = query.values(
name='newname',
serial=5555,
t_list=['l3', 'l4'],
t_enum=MyEnum.ITEM_1,
t_int_enum=MyIntEnum.ITEM_2,
Expand All @@ -112,6 +119,7 @@ def test_update_query():
)
new_query, new_params = connection.compile_query(query)
assert query.parameters.get('version')
assert query.parameters.get('serial') == 5555
assert query.parameters.get('name') == 'newname'
assert query.parameters.get('t_list') == ['l3', 'l4']
assert query.parameters.get('t_enum') == MyEnum.ITEM_1
Expand Down
10 changes: 8 additions & 2 deletions tests/test_querying.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from uuid import uuid4
from datetime import datetime, timedelta

from sqlalchemy import Table, Column, MetaData, types
from sqlalchemy import Table, Column, MetaData, types, Sequence
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
from sqlalchemy.engine import create_engine

Expand Down Expand Up @@ -94,6 +94,7 @@ def test_querying_table(metadata):
return Table(
'test_querying_table_' + worker_id, metadata,
Column('id', types.Integer, autoincrement=True, primary_key=True),
Column('serial', types.Integer, Sequence("serial_seq")),
Column('t_string', types.String(60), onupdate='updated'),
Column('t_list', types.ARRAY(types.String(60))),
Column('t_enum', types.Enum(MyEnum)),
Expand Down Expand Up @@ -138,7 +139,12 @@ async def test_fetch_list(test_querying_table, connection):

for item, sample_item in zip(data, SAMPLE_DATA):
for key in sample_item.keys():
assert item[key] == sample_item[key]
if key == 'serial':
# Same increment as `id` column
expected = item['id']
else:
expected = sample_item[key]
assert item[key] == expected


async def test_bound_parameters(test_querying_table, connection):
Expand Down