diff --git a/.gitignore b/.gitignore index 5fdaf59..7a05668 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ var/ *.egg-info/ .installed.cfg *.egg +.venv/ # PyInstaller # Usually these files are written by a python script from a template @@ -96,3 +97,6 @@ ENV/ .vscode/* .history *.code-workspace + +# Other +.tool-versions diff --git a/.travis.yml b/.travis.yml index 32cc9d0..3111c26 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,14 +7,12 @@ dist: xenial sudo: true python: - - "3.5" - - "3.6" - "3.7" + - "3.8" + - "3.9" env: - - ASYNCPG_VERSION=0.15.0 - - ASYNCPG_VERSION=0.16.0 - - ASYNCPG_VERSION=0.17.0 + - ASYNCPG_VERSION=0.22.0 # command to install dependencies install: diff --git a/Dockerfile b/Dockerfile index 1a4b0df..3e626c4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,10 +1,10 @@ -FROM python:3.7-alpine +FROM python:3.8-alpine RUN apk update && \ apk add \ - gcc \ - musl-dev \ - postgresql-dev + gcc \ + musl-dev \ + postgresql-dev ADD dev-requirements.txt /repo/dev-requirements.txt RUN pip install -r /repo/dev-requirements.txt diff --git a/README.md b/README.md index d2a8e92..3002601 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ for columns is no longer possible and you need to access columns using exact nam 3. 0.18.0 Removed the `insert` method. We found this method was just confusing, and useless as SqlAlchemy can do it for you by defining your table with a primary key. +4. 0.27.0 Now only compatible with version 0.22.0 and greater of asyncpg. ## sqlalchemy ORM @@ -68,3 +69,7 @@ aiopg.sa: 9.541276566000306 asyncpsa: 6.747777451004367 ``` So, seems like its still faster using asyncpg, or in otherwords, this library doesnt add any overhead that is not in aiopg.sa. + +## Versioning + +This software follows [Semantic Versioning](http://semver.org/). diff --git a/asyncpgsa/connection.py b/asyncpgsa/connection.py index baf0d35..f3ae0a3 100644 --- a/asyncpgsa/connection.py +++ b/asyncpgsa/connection.py @@ -1,4 +1,5 @@ from asyncpg import connection +from sqlalchemy import func from sqlalchemy.dialects.postgresql import pypostgresql from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.dml import Insert as InsertObject, Update as UpdateObject @@ -45,7 +46,9 @@ def _execute_default_attr(query, param, attr_name): for col in query.table.columns: attr = getattr(col, attr_name) if attr and param.get(col.name) is None: - if attr.is_scalar: + if attr.is_sequence: + param[col.name] = func.nextval(attr.name) + elif attr.is_scalar: param[col.name] = attr.arg elif attr.is_callable: param[col.name] = attr.arg({}) @@ -85,11 +88,13 @@ def __init__(self, *args, dialect=None, **kwargs): super().__init__(*args, **kwargs) self._dialect = dialect or _dialect - def _execute(self, query, args, limit, timeout, return_status=False): + def _execute(self, query, args, limit, timeout, return_status=False, record_class=None, ignore_custom_codec=False): query, compiled_args = compile_query(query, dialect=self._dialect) args = compiled_args or args return super()._execute(query, args, limit, timeout, - return_status=return_status) + return_status=return_status, + record_class=record_class, + ignore_custom_codec=ignore_custom_codec) async def execute(self, script, *args, **kwargs) -> str: script, params = compile_query(script, dialect=self._dialect) diff --git a/asyncpgsa/testing/mockconnection.py b/asyncpgsa/testing/mockconnection.py index 4e07bff..cf234e3 100644 --- a/asyncpgsa/testing/mockconnection.py +++ b/asyncpgsa/testing/mockconnection.py @@ -9,9 +9,9 @@ def __subclasscheck__(cls, subclass): - if subclass == MockConnection: - return True - return old_subclass_check(cls, subclass) + if subclass == MockConnection: + return True + return old_subclass_check(cls, subclass) old_subclass_check = ConnectionMeta.__subclasscheck__ ConnectionMeta.__subclasscheck__ = __subclasscheck__ @@ -33,6 +33,11 @@ def results(self, result): global results results = result + def set_database_results(self, *dbresults): + self.results = Queue() + for result in dbresults: + self.results.put_nowait(result) + async def general_query(self, query, *args, **kwargs): completed_queries.append((query, *args, kwargs)) return results.get_nowait() diff --git a/asyncpgsa/testing/mockpgsingleton.py b/asyncpgsa/testing/mockpgsingleton.py index 31d463d..a02dd24 100644 --- a/asyncpgsa/testing/mockpgsingleton.py +++ b/asyncpgsa/testing/mockpgsingleton.py @@ -16,7 +16,7 @@ def __init__(self): self.__pool = MockSAPool(connection=self.connection) def get_completed_queries(self): - return self.connection.get_completed_queries + return self.connection.completed_queries def set_database_results(self, *results): self.connection.results = Queue() # reset queue @@ -57,5 +57,3 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): pass - - diff --git a/asyncpgsa/testing/mockpool.py b/asyncpgsa/testing/mockpool.py index b2c2b33..5f9d1d0 100644 --- a/asyncpgsa/testing/mockpool.py +++ b/asyncpgsa/testing/mockpool.py @@ -1,4 +1,5 @@ from asyncpgsa import compile_query +from asyncpg import protocol from asyncpg.pool import Pool from .mockconnection import MockConnection @@ -14,7 +15,8 @@ def __init__(self, connection=None): setup=None, loop=None, init=None, - connection_class=MockConnection) + connection_class=MockConnection, + record_class=protocol.Record) self.connection = connection if not self.connection: @@ -57,5 +59,3 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def close(self): pass - - diff --git a/asyncpgsa/testing/mockpreparedstmt.py b/asyncpgsa/testing/mockpreparedstmt.py index 08bcdff..8eda5e8 100644 --- a/asyncpgsa/testing/mockpreparedstmt.py +++ b/asyncpgsa/testing/mockpreparedstmt.py @@ -26,6 +26,3 @@ async def __anext__(self): return next(self.iterator) except StopIteration: raise StopAsyncIteration - - - diff --git a/asyncpgsa/transactionmanager.py b/asyncpgsa/transactionmanager.py index 0975e6e..4ec3718 100644 --- a/asyncpgsa/transactionmanager.py +++ b/asyncpgsa/transactionmanager.py @@ -32,7 +32,11 @@ async def __aenter__(self): self.acquire_context = self.pool.acquire(timeout=self.timeout) con = await self.acquire_context.__aenter__() self.transaction = con.transaction(**self.trans_kwargs) - await self.transaction.__aenter__() + try: + await self.transaction.__aenter__() + except Exception: + await asyncio.shield(self.acquire_context.__aexit__()) + raise return con async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/asyncpgsa/version.py b/asyncpgsa/version.py index a636f70..07f3dd7 100644 --- a/asyncpgsa/version.py +++ b/asyncpgsa/version.py @@ -1 +1 @@ -__version__ = '0.25.2' +__version__ = '0.27.1' diff --git a/dev-requirements.txt b/dev-requirements.txt index af34394..9b683b4 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -2,4 +2,4 @@ asyncpg pytest pytest-asyncio sqlalchemy -psycopg2 +psycopg2-binary diff --git a/docker-compose.yml b/docker-compose.yml index 5fca230..eb0b02a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,7 +18,7 @@ services: - DB_PASS=password - PYTHONASYNCIODEBUG=1 postgres: - image: postgres:9 + image: postgres:13 hostname: postgres environment: - POSTGRES_DB=postgres diff --git a/setup.py b/setup.py index 41cfe4f..095fb4c 100755 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ name='asyncpgsa', version=version['__version__'], install_requires=[ - 'asyncpg', + 'asyncpg>=0.22.0', 'sqlalchemy', ], packages=['asyncpgsa', 'asyncpgsa.testing'], @@ -18,5 +18,13 @@ license='Apache 2.0', author='nhumrich', author_email='nick.humrich@canopytax.com', - description='sqlalchemy support for asyncpg' + description='sqlalchemy support for asyncpg', + classifiers=[ + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: Implementation :: CPython", + ], ) diff --git a/tests/test_defaults.py b/tests/test_defaults.py index b05c2cf..8ca4f3b 100644 --- a/tests/test_defaults.py +++ b/tests/test_defaults.py @@ -2,7 +2,7 @@ from uuid import uuid4, UUID from datetime import date, datetime, timedelta from asyncpgsa import connection -from sqlalchemy import Table, Column, MetaData, types +from sqlalchemy import Table, Column, MetaData, Sequence, types from sqlalchemy.dialects.postgresql import UUID as PG_UUID metadata = MetaData() @@ -32,6 +32,7 @@ class MyIntEnum(enum.IntEnum): users = Table( 'users', metadata, Column('id', PG_UUID, unique=True, default=uuid4), + Column('serial', types.Integer, Sequence("serial_seq")), Column('name', types.String(60), nullable=False, default=name_default), Column('t_list', types.ARRAY(types.String(60)), nullable=False, @@ -58,6 +59,9 @@ class MyIntEnum(enum.IntEnum): def test_insert_query_defaults(): query = users.insert() new_query, new_params = connection.compile_query(query) + serial_default = query.parameters.get('serial') + assert serial_default.name == 'nextval' + assert serial_default.clause_expr.element.clauses[0].value == 'serial_seq' assert query.parameters.get('name') == name_default assert query.parameters.get('t_list') == t_list_default assert query.parameters.get('t_enum') == t_enum_default @@ -74,6 +78,7 @@ def test_insert_query_defaults_override(): query = users.insert() query = query.values( name='username', + serial=4444, t_list=['l1', 'l2'], t_enum=MyEnum.ITEM_1, t_int_enum=MyIntEnum.ITEM_2, @@ -85,6 +90,7 @@ def test_insert_query_defaults_override(): ) new_query, new_params = connection.compile_query(query) assert query.parameters.get('version') + assert query.parameters.get('serial') == 4444 assert query.parameters.get('name') == 'username' assert query.parameters.get('t_list') == ['l1', 'l2'] assert query.parameters.get('t_enum') == MyEnum.ITEM_1 @@ -101,6 +107,7 @@ def test_update_query(): query = users.update().where(users.c.name == 'default') query = query.values( name='newname', + serial=5555, t_list=['l3', 'l4'], t_enum=MyEnum.ITEM_1, t_int_enum=MyIntEnum.ITEM_2, @@ -112,6 +119,7 @@ def test_update_query(): ) new_query, new_params = connection.compile_query(query) assert query.parameters.get('version') + assert query.parameters.get('serial') == 5555 assert query.parameters.get('name') == 'newname' assert query.parameters.get('t_list') == ['l3', 'l4'] assert query.parameters.get('t_enum') == MyEnum.ITEM_1 diff --git a/tests/test_querying.py b/tests/test_querying.py index 268cfc3..49deb2b 100644 --- a/tests/test_querying.py +++ b/tests/test_querying.py @@ -5,7 +5,7 @@ from uuid import uuid4 from datetime import datetime, timedelta -from sqlalchemy import Table, Column, MetaData, types +from sqlalchemy import Table, Column, MetaData, types, Sequence from sqlalchemy.dialects.postgresql import UUID as PG_UUID from sqlalchemy.engine import create_engine @@ -94,6 +94,7 @@ def test_querying_table(metadata): return Table( 'test_querying_table_' + worker_id, metadata, Column('id', types.Integer, autoincrement=True, primary_key=True), + Column('serial', types.Integer, Sequence("serial_seq")), Column('t_string', types.String(60), onupdate='updated'), Column('t_list', types.ARRAY(types.String(60))), Column('t_enum', types.Enum(MyEnum)), @@ -138,7 +139,12 @@ async def test_fetch_list(test_querying_table, connection): for item, sample_item in zip(data, SAMPLE_DATA): for key in sample_item.keys(): - assert item[key] == sample_item[key] + if key == 'serial': + # Same increment as `id` column + expected = item['id'] + else: + expected = sample_item[key] + assert item[key] == expected async def test_bound_parameters(test_querying_table, connection):