From b4e2cf0eca829feb929275c9a4a5a7c5be14033f Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Mon, 8 Apr 2024 19:33:40 +0300 Subject: [PATCH 01/24] docs: update type annotation fixes #22 --- pgcopy/copy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgcopy/copy.py b/pgcopy/copy.py index 064d198..7b81e7e 100644 --- a/pgcopy/copy.py +++ b/pgcopy/copy.py @@ -242,7 +242,7 @@ class CopyManager(object): :type table: str :param cols: columns in the table into which to copy data - :type cols: list of str + :type cols: iterable of str :raises ValueError: if the table or columns do not exist. """ From bc0b2b5b212fec5e4996f3cbd66f40543e5d53e8 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Sun, 5 May 2024 17:20:37 +0300 Subject: [PATCH 02/24] refactor: wrap adaptor in backend interface --- pgcopy/backend.py | 21 +++++++++++++++++++++ pgcopy/copy.py | 13 +++++-------- pgcopy/inspect.py | 6 +++--- pgcopy/util.py | 4 +--- 4 files changed, 30 insertions(+), 14 deletions(-) create mode 100644 pgcopy/backend.py diff --git a/pgcopy/backend.py b/pgcopy/backend.py new file mode 100644 index 0000000..f662d33 --- /dev/null +++ b/pgcopy/backend.py @@ -0,0 +1,21 @@ +"psycopg backends" +import importlib + + +class Psycopg2Backend: + def __init__(self, conn): + self.conn = conn + self.adaptor = importlib.import_module("psycopg2") + self.adaptor.extras = importlib.import_module("psycopg2.extras") + + def get_encoding(self): + encodings = self.adaptor.extensions.encodings + return encodings[self.conn.encoding] + + def copystream(self, sql, datastream): + cursor = self.conn.cursor() + cursor.copy_expert(sql, datastream) + + def namedtuple_cursor(self): + factory = self.adaptor.extras.NamedTupleCursor + return self.conn.cursor(cursor_factory=factory) diff --git a/pgcopy/copy.py b/pgcopy/copy.py index 7b81e7e..e785203 100644 --- a/pgcopy/copy.py +++ b/pgcopy/copy.py @@ -10,9 +10,7 @@ except ImportError: pass -from psycopg2.extensions import encodings - -from . import errors, inspect, util +from . import backend, errors, inspect, util from .thread import RaisingThread __all__ = ["CopyManager"] @@ -254,7 +252,7 @@ def __init__(self, conn, table, cols): **type_formatters, **self.type_formatters, } - self.conn = conn + self.backend = backend.Psycopg2Backend(conn) if "." in table: self.schema, self.table = table.split(".", 1) else: @@ -264,8 +262,8 @@ def __init__(self, conn, table, cols): def compile(self): self.formatters = [] - type_dict = inspect.get_types(self.conn, self.schema, self.table) - encoding = encodings[self.conn.encoding] + type_dict = inspect.get_types(self.backend, self.schema, self.table) + encoding = self.backend.get_encoding() for column in self.cols: att = type_dict.get(column) if att is None: @@ -350,9 +348,8 @@ def copystream(self, datastream): columns = '", "'.join(self.cols) cmd = 'COPY "{0}"."{1}" ("{2}") FROM STDIN WITH BINARY' sql = cmd.format(self.schema, self.table, columns) - cursor = self.conn.cursor() try: - cursor.copy_expert(sql, datastream) + self.backend.copystream(sql, datastream) except Exception as e: templ = "error doing binary copy into {0}.{1}:\n{2}" e.message = templ.format(self.schema, self.table, e) diff --git a/pgcopy/inspect.py b/pgcopy/inspect.py index eca2195..c57c1fe 100644 --- a/pgcopy/inspect.py +++ b/pgcopy/inspect.py @@ -1,7 +1,7 @@ -from psycopg2.extras import NamedTupleCursor +"inspect column types" -def get_types(conn, schema, table): +def get_types(backend, schema, table): # for arrays: # typname has '_' prefix # attndims > 0 @@ -24,6 +24,6 @@ def get_types(conn, schema, table): WHERE n.nspname = %s and relname = %s and attnum > 0 ORDER BY c.relname, a.attnum; """ - cursor = conn.cursor(cursor_factory=NamedTupleCursor) + cursor = backend.namedtuple_cursor() cursor.execute(query, (schema, table)) return {r.attname: r for r in cursor} diff --git a/pgcopy/util.py b/pgcopy/util.py index ff29566..3191a6f 100644 --- a/pgcopy/util.py +++ b/pgcopy/util.py @@ -3,7 +3,6 @@ import string from datetime import datetime, time -from psycopg2 import sql from pytz import UTC @@ -34,14 +33,13 @@ def array_iter(arr): def get_schema(conn, table): cur = conn.cursor() - quoted_table = sql.Identifier(table).as_string(cur) query = """ SELECT n.nspname, c.relname FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE c.oid = %s::regclass """ - cur.execute(query, (quoted_table,)) + cur.execute(query, (f'"{table}"',)) return cur.fetchone()[0] From 5515f3620a766f2f071188d9be2e4dddeab2bcb5 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Mon, 6 May 2024 00:02:41 +0300 Subject: [PATCH 03/24] refactor: move plumbing to backend create context managers exposing interface similar to psycopg3 --- pgcopy/backend.py | 55 +++++++++++++++++++++++++++++++++++++++++++---- pgcopy/copy.py | 43 ++++++++++++++++-------------------- 2 files changed, 70 insertions(+), 28 deletions(-) diff --git a/pgcopy/backend.py b/pgcopy/backend.py index f662d33..46a5bd9 100644 --- a/pgcopy/backend.py +++ b/pgcopy/backend.py @@ -1,5 +1,8 @@ "psycopg backends" import importlib +import os + +from .thread import RaisingThread class Psycopg2Backend: @@ -12,10 +15,54 @@ def get_encoding(self): encodings = self.adaptor.extensions.encodings return encodings[self.conn.encoding] - def copystream(self, sql, datastream): - cursor = self.conn.cursor() - cursor.copy_expert(sql, datastream) - def namedtuple_cursor(self): factory = self.adaptor.extras.NamedTupleCursor return self.conn.cursor(cursor_factory=factory) + + def copy(self, sql, fobject_factory): + return Psycopg2Copy(self.conn, sql, fobject_factory) + + def threading_copy(self, sql): + return Psycopg2ThreadingCopy(self.conn, sql) + + +class Psycopg2Copy: + def __init__(self, conn, sql, fobject_factory): + self.conn = conn + self.sql = sql + self.datastream = fobject_factory() + + def __enter__(self): + return self.datastream + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self.datastream.seek(0) + self.copystream() + self.datastream.close() + + def copystream(self): + cursor = self.conn.cursor() + cursor.copy_expert(self.sql, self.datastream) + + +class Psycopg2ThreadingCopy: + def __init__(self, conn, sql): + self.conn = conn + self.sql = sql + r_fd, w_fd = os.pipe() + self.rstream = os.fdopen(r_fd, "rb") + self.wstream = os.fdopen(w_fd, "wb") + + def __enter__(self): + self.copy_thread = RaisingThread(target=self.copystream) + self.copy_thread.start() + return self.wstream + + def __exit__(self, exc_type, exc_val, exc_tb): + self.wstream.close() + self.copy_thread.join() + + def copystream(self): + cursor = self.conn.cursor() + cursor.copy_expert(self.sql, self.rstream) diff --git a/pgcopy/copy.py b/pgcopy/copy.py index e785203..41cabe3 100644 --- a/pgcopy/copy.py +++ b/pgcopy/copy.py @@ -309,11 +309,13 @@ def copy(self, data, fobject_factory=tempfile.TemporaryFile): ``ValueError`` is raised if a null value is provided for a column with non-null constraint. """ - datastream = fobject_factory() - self.writestream(data, datastream) - datastream.seek(0) - self.copystream(datastream) - datastream.close() + try: + with self.backend.copy(self.sql(), fobject_factory) as datastream: + self.writestream(data, datastream) + except Exception as e: + templ = "error doing binary copy into {0}.{1}:\n{2}" + e.message = templ.format(self.schema, self.table, e) + raise e def threading_copy(self, data): """ @@ -322,14 +324,18 @@ def threading_copy(self, data): :param data: the data to be inserted :type data: iterable of iterables """ - r_fd, w_fd = os.pipe() - rstream = os.fdopen(r_fd, "rb") - wstream = os.fdopen(w_fd, "wb") - copy_thread = RaisingThread(target=self.copystream, args=(rstream,)) - copy_thread.start() - self.writestream(data, wstream) - wstream.close() - copy_thread.join() + try: + with self.backend.threading_copy(self.sql()) as datastream: + self.writestream(data, datastream) + except Exception as e: + templ = "error doing binary copy into {0}.{1}:\n{2}" + e.message = templ.format(self.schema, self.table, e) + raise e + + def sql(self): + columns = '", "'.join(self.cols) + cmd = 'COPY "{0}"."{1}" ("{2}") FROM STDIN WITH BINARY' + return cmd.format(self.schema, self.table, columns) def writestream(self, data, datastream): datastream.write(BINCOPY_HEADER) @@ -343,14 +349,3 @@ def writestream(self, data, datastream): rdat.extend(d) datastream.write(struct.pack("".join(fmt), *rdat)) datastream.write(BINCOPY_TRAILER) - - def copystream(self, datastream): - columns = '", "'.join(self.cols) - cmd = 'COPY "{0}"."{1}" ("{2}") FROM STDIN WITH BINARY' - sql = cmd.format(self.schema, self.table, columns) - try: - self.backend.copystream(sql, datastream) - except Exception as e: - templ = "error doing binary copy into {0}.{1}:\n{2}" - e.message = templ.format(self.schema, self.table, e) - raise e From 60f60d9a7a09c08e9a3c23d2766c2da9f2cb5960 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Mon, 6 May 2024 00:08:02 +0300 Subject: [PATCH 04/24] refactor: move error handling to common function --- pgcopy/copy.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/pgcopy/copy.py b/pgcopy/copy.py index 41cabe3..2cde461 100644 --- a/pgcopy/copy.py +++ b/pgcopy/copy.py @@ -309,13 +309,7 @@ def copy(self, data, fobject_factory=tempfile.TemporaryFile): ``ValueError`` is raised if a null value is provided for a column with non-null constraint. """ - try: - with self.backend.copy(self.sql(), fobject_factory) as datastream: - self.writestream(data, datastream) - except Exception as e: - templ = "error doing binary copy into {0}.{1}:\n{2}" - e.message = templ.format(self.schema, self.table, e) - raise e + self._copy(data, self.backend.copy(self.sql(), fobject_factory)) def threading_copy(self, data): """ @@ -324,8 +318,11 @@ def threading_copy(self, data): :param data: the data to be inserted :type data: iterable of iterables """ + self._copy(data, self.backend.threading_copy(self.sql())) + + def _copy(self, data, copy): try: - with self.backend.threading_copy(self.sql()) as datastream: + with copy as datastream: self.writestream(data, datastream) except Exception as e: templ = "error doing binary copy into {0}.{1}:\n{2}" From 38c33cc0e193817f2260b7905317a073f8eb7c49 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Mon, 6 May 2024 11:43:24 +0300 Subject: [PATCH 05/24] test: use separate fixture for client encoding --- tests/conftest.py | 9 +++++++-- tests/test_datatypes.py | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 422d4cd..bfc6356 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -87,10 +87,15 @@ def drop_db(): @pytest.fixture -def conn(request, db): +def client_encoding(request): + return getattr(request, "param", "UTF8") + + +@pytest.fixture +def conn(request, db, client_encoding): conn = connect() conn.autocommit = False - conn.set_client_encoding(getattr(request, "param", "UTF8")) + conn.set_client_encoding(client_encoding) inst = request.instance if isinstance(inst, TemporaryTable): for extension in inst.extensions: diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index 9326938..f5c04e0 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -65,7 +65,9 @@ class TestEncoding(TypeMixin): datatypes = ["varchar(12)"] data = [("database",), ("מוסד נתונים",)] - @pytest.mark.parametrize("conn", ["UTF8", "ISO_8859_8", "WIN1255"], indirect=True) + @pytest.mark.parametrize( + "client_encoding", ["UTF8", "ISO_8859_8", "WIN1255"], indirect=True + ) def test_type(self, conn, cursor, schema_table, data): psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, cursor) super(TestEncoding, self).test_type(conn, cursor, schema_table, data) From 5965a07db3f6ee32f825714d3237f0e656426236 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Mon, 6 May 2024 15:59:19 +0300 Subject: [PATCH 06/24] test: remove direct psycopg2 references --- tests/adaptor.py | 22 ++++++++++++++++++++++ tests/conftest.py | 8 ++++---- tests/test_datatypes.py | 2 -- tests/test_replace.py | 5 ++--- tests/test_threading_copy.py | 3 +-- 5 files changed, 29 insertions(+), 11 deletions(-) create mode 100644 tests/adaptor.py diff --git a/tests/adaptor.py b/tests/adaptor.py new file mode 100644 index 0000000..793da80 --- /dev/null +++ b/tests/adaptor.py @@ -0,0 +1,22 @@ +import importlib +import sys + +import pytest + + +class Psycopg2: + def __init__(self, connection_params, client_encoding): + try: + psycopg2 = importlib.import_module("psycopg2") + extras = importlib.import_module("psycopg2.extras") + except: + pytest.skip("psycopg2 not available") + + self.conn = psycopg2.connect( + connection_factory=extras.LoggingConnection, + **connection_params, + ) + self.conn.initialize(sys.stderr) + self.conn.autocommit = False + self.conn.set_client_encoding(client_encoding) + self.errors = psycopg2.errors diff --git a/tests/conftest.py b/tests/conftest.py index bfc6356..6314b1e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ import pytest from psycopg2.extras import LoggingConnection +from . import adaptor from .db import TemporaryTable @@ -91,11 +92,10 @@ def client_encoding(request): return getattr(request, "param", "UTF8") -@pytest.fixture +@pytest.fixture(params=[adaptor.Psycopg2]) def conn(request, db, client_encoding): - conn = connect() - conn.autocommit = False - conn.set_client_encoding(client_encoding) + psycopg2 = request.param(connection_params, client_encoding) + conn = psycopg2.conn inst = request.instance if isinstance(inst, TemporaryTable): for extension in inst.extensions: diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index f5c04e0..9129817 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -11,7 +11,6 @@ if sys.version_info < (3,): memoryview = buffer -import psycopg2.extensions from pgcopy import CopyManager, util from . import db @@ -69,7 +68,6 @@ class TestEncoding(TypeMixin): "client_encoding", ["UTF8", "ISO_8859_8", "WIN1255"], indirect=True ) def test_type(self, conn, cursor, schema_table, data): - psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, cursor) super(TestEncoding, self).test_type(conn, cursor, schema_table, data) diff --git a/tests/test_replace.py b/tests/test_replace.py index be3eb52..9389d5f 100644 --- a/tests/test_replace.py +++ b/tests/test_replace.py @@ -1,6 +1,5 @@ import contextlib -import psycopg2 import pytest from pgcopy import Replace, util @@ -64,14 +63,14 @@ def test_replace_with_default(self, conn, cursor, schema_table): @contextlib.contextmanager -def replace_raises(conn, table, exc=psycopg2.IntegrityError): +def replace_raises(conn, table): """ Wrap Replace context manager and assert exception is thrown on context exit """ r = Replace(conn, table) yield r.__enter__() - with pytest.raises(exc): + with pytest.raises(conn.IntegrityError): r.__exit__(None, None, None) diff --git a/tests/test_threading_copy.py b/tests/test_threading_copy.py index db229df..ce0f365 100644 --- a/tests/test_threading_copy.py +++ b/tests/test_threading_copy.py @@ -1,6 +1,5 @@ import pytest from pgcopy import CopyManager -from psycopg2.errors import BadCopyFileFormat from . import test_datatypes @@ -21,7 +20,7 @@ def test_threading_copy(self, conn, cursor, schema_table, data): def test_threading_copy_error(self, conn, cursor): data = [{}] mgr = CopyManager(conn, self.table, self.cols) - with pytest.raises(BadCopyFileFormat): + with pytest.raises(conn.DataError): mgr.threading_copy(data) def test_threading_copy_generator(self, conn, cursor, schema_table, data): From ebcbf3b71474ff102c69119dd548d7bee49f4f10 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Mon, 6 May 2024 15:59:54 +0300 Subject: [PATCH 07/24] test: remove broken tox python2.7 test environment --- tox.ini | 1 - 1 file changed, 1 deletion(-) diff --git a/tox.ini b/tox.ini index 5cd5d9d..abec665 100644 --- a/tox.ini +++ b/tox.ini @@ -15,7 +15,6 @@ setenv = POSTGRES_HOST=localhost POSTGRES_USER=postgres POSTGRES_PASSWORD=postgres -[testenv:python2.7] [testenv:py310-pg12] docker = pg12 [testenv:py310-pg13] From 13498d6ac591b0334225ab8b86f8875c74aca621 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Mon, 6 May 2024 16:00:30 +0300 Subject: [PATCH 08/24] chore: remove explicit psycopg2 dependency --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 78cfb49..c05dfc5 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def get_version(package_name, default="0.1"): license="MIT", url="https://pgcopy.readthedocs.io/en/latest/", packages=["pgcopy", "pgcopy.errors", "pgcopy.contrib"], - install_requires=["psycopg2", "pytz"], + install_requires=["pytz"], classifiers=[ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", From d47e73c3c1149b1fba901bba6b6856683b29c9ed Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Mon, 6 May 2024 18:27:58 +0300 Subject: [PATCH 09/24] feat: psycopg3 backend --- pgcopy/backend.py | 44 +++++++++++++++++++++++++++++++++++---- pgcopy/copy.py | 2 +- pgcopy/errors/__init__.py | 4 ++++ tests/adaptor.py | 15 ++++++++++++- tests/conftest.py | 2 +- tests/test_datatypes.py | 9 ++------ tests/test_errors.py | 11 ++++++++++ tox.ini | 2 ++ 8 files changed, 75 insertions(+), 14 deletions(-) diff --git a/pgcopy/backend.py b/pgcopy/backend.py index 46a5bd9..54269c9 100644 --- a/pgcopy/backend.py +++ b/pgcopy/backend.py @@ -1,10 +1,21 @@ "psycopg backends" +import contextlib import importlib import os +from .errors import UnsupportedConnectionError from .thread import RaisingThread +def for_connection(conn): + if hasattr(conn, "set_client_encoding") and hasattr(conn, "encoding"): + return Psycopg2Backend(conn) + if hasattr(conn, "execute"): + return Psycopg3Backend(conn) + message = f"{conn.__class__.__name__} is not a supported connection type" + raise UnsupportedConnectionError(message) + + class Psycopg2Backend: def __init__(self, conn): self.conn = conn @@ -42,8 +53,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.datastream.close() def copystream(self): - cursor = self.conn.cursor() - cursor.copy_expert(self.sql, self.datastream) + with self.conn.cursor() as cur: + cur.copy_expert(self.sql, self.datastream) class Psycopg2ThreadingCopy: @@ -64,5 +75,30 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.copy_thread.join() def copystream(self): - cursor = self.conn.cursor() - cursor.copy_expert(self.sql, self.rstream) + with self.conn.cursor() as cur: + cur.copy_expert(self.sql, self.rstream) + + +class Psycopg3Backend: + def __init__(self, conn): + self.conn = conn + self.adaptor = importlib.import_module("psycopg") + + def get_encoding(self): + return self.conn.info.encoding + + def namedtuple_cursor(self): + factory = self.adaptor.rows.namedtuple_row + return self.conn.cursor(row_factory=factory) + + @contextlib.contextmanager + def copy(self, sql, _): + with self.conn.cursor() as cur: + with cur.copy(sql) as copy: + yield copy + + @contextlib.contextmanager + def threading_copy(self, sql): + with self.conn.cursor() as cur: + with cur.copy(sql) as copy: + yield copy diff --git a/pgcopy/copy.py b/pgcopy/copy.py index 2cde461..df7ae1e 100644 --- a/pgcopy/copy.py +++ b/pgcopy/copy.py @@ -252,7 +252,7 @@ def __init__(self, conn, table, cols): **type_formatters, **self.type_formatters, } - self.backend = backend.Psycopg2Backend(conn) + self.backend = backend.for_connection(conn) if "." in table: self.schema, self.table = table.split(".", 1) else: diff --git a/pgcopy/errors/__init__.py b/pgcopy/errors/__init__.py index fc09d33..90ead89 100644 --- a/pgcopy/errors/__init__.py +++ b/pgcopy/errors/__init__.py @@ -9,3 +9,7 @@ from .py2 import raise_from else: from .py3 import raise_from + + +class UnsupportedConnectionError(TypeError): + "connection type not supported" diff --git a/tests/adaptor.py b/tests/adaptor.py index 793da80..6b89db5 100644 --- a/tests/adaptor.py +++ b/tests/adaptor.py @@ -9,7 +9,7 @@ def __init__(self, connection_params, client_encoding): try: psycopg2 = importlib.import_module("psycopg2") extras = importlib.import_module("psycopg2.extras") - except: + except ModuleNotFoundError: pytest.skip("psycopg2 not available") self.conn = psycopg2.connect( @@ -20,3 +20,16 @@ def __init__(self, connection_params, client_encoding): self.conn.autocommit = False self.conn.set_client_encoding(client_encoding) self.errors = psycopg2.errors + + +class Psycopg3: + def __init__(self, connection_params, client_encoding): + try: + psycopg3 = importlib.import_module("psycopg") + except ModuleNotFoundError: + pytest.skip("psycopg3 not available") + + self.conn = psycopg3.connect(**connection_params) + self.conn.autocommit = False + self.conn.execute(f"SET client_encoding='{client_encoding}'") + self.errors = psycopg3.errors diff --git a/tests/conftest.py b/tests/conftest.py index 6314b1e..c502fcf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -92,7 +92,7 @@ def client_encoding(request): return getattr(request, "param", "UTF8") -@pytest.fixture(params=[adaptor.Psycopg2]) +@pytest.fixture(params=[adaptor.Psycopg2, adaptor.Psycopg3]) def conn(request, db, client_encoding): psycopg2 = request.param(connection_params, client_encoding) conn = psycopg2.conn diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index 9129817..ad912c5 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -16,10 +16,6 @@ from . import db -def test_connection_encoding(conn): - assert conn.encoding == "UTF8" - - def test_db_encoding(conn): assert conn.info.parameter_status("server_encoding") == "UTF8" @@ -235,8 +231,7 @@ class TestBytea(TypeMixin): ] def cast(self, v): - assert isinstance(v, memoryview) - return bytes(v) + return bytes(v) if isinstance(v, memoryview) else v class TestTime(TypeMixin): @@ -289,7 +284,7 @@ class TestUUID(TypeMixin): ] def cast(self, v): - return uuid.UUID(v) + return uuid.UUID(v) if isinstance(v, str) else v class TestEnum(TypeMixin): diff --git a/tests/test_errors.py b/tests/test_errors.py index bbc9661..36638c4 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,3 +1,4 @@ +import pgcopy.errors import pytest from pgcopy import CopyManager @@ -40,3 +41,13 @@ def test_dropped_col(self, conn, cursor, schema): msg = '"{}" is not a column of table "{}"."{}"' with pytest.raises(ValueError, match=msg.format(col, schema, self.table)): CopyManager(conn, self.table, self.cols) + + +def test_unsupported_connection(): + message = "FakeConnection is not a supported connection type" + with pytest.raises(pgcopy.errors.UnsupportedConnectionError, match=message): + CopyManager(FakeConnection(), "fake_table", ["id", "name"]) + + +class FakeConnection: + encoding = "utf8" diff --git a/tox.ini b/tox.ini index abec665..9bb22d4 100644 --- a/tox.ini +++ b/tox.ini @@ -8,6 +8,7 @@ envlist = deps = pytest psycopg2~=2.9 + psycopg[binary] commands = pytest tests/ docker = pg16 setenv = @@ -32,6 +33,7 @@ deps = pytest pytest-cov psycopg2~=2.9 + psycopg[binary] commands = pytest --cov=pgcopy/ tests/ docker = pgvector From c22e62a52d5ac0eb93db57ff7de1f124465c51b1 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Mon, 13 May 2024 20:46:03 +0300 Subject: [PATCH 10/24] refactor: better backend detection --- pgcopy/backend.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pgcopy/backend.py b/pgcopy/backend.py index 54269c9..003bcea 100644 --- a/pgcopy/backend.py +++ b/pgcopy/backend.py @@ -8,9 +8,10 @@ def for_connection(conn): - if hasattr(conn, "set_client_encoding") and hasattr(conn, "encoding"): + sources = [cls.__module__.split(".")[0] for cls in conn.__class__.mro()] + if "psycopg2" in sources: return Psycopg2Backend(conn) - if hasattr(conn, "execute"): + if "psycopg" in sources: return Psycopg3Backend(conn) message = f"{conn.__class__.__name__} is not a supported connection type" raise UnsupportedConnectionError(message) From a5461926aa20890092a84f0e2f38c095f2ef6438 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Mon, 13 May 2024 20:47:54 +0300 Subject: [PATCH 11/24] refactor: simplify --- pgcopy/backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pgcopy/backend.py b/pgcopy/backend.py index 003bcea..662f2be 100644 --- a/pgcopy/backend.py +++ b/pgcopy/backend.py @@ -24,8 +24,7 @@ def __init__(self, conn): self.adaptor.extras = importlib.import_module("psycopg2.extras") def get_encoding(self): - encodings = self.adaptor.extensions.encodings - return encodings[self.conn.encoding] + return self.adaptor.extensions.encodings[self.conn.encoding] def namedtuple_cursor(self): factory = self.adaptor.extras.NamedTupleCursor From e2be2a6f86e1c1ce84ff17318f4db2299b3794a4 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Mon, 13 May 2024 20:58:47 +0300 Subject: [PATCH 12/24] test: test available backends but always use psycopg2 to set up the database --- tests/adaptor.py | 56 +++++++++++++++++++++++++++++++++-------------- tests/conftest.py | 25 +++++++++++++++------ 2 files changed, 57 insertions(+), 24 deletions(-) diff --git a/tests/adaptor.py b/tests/adaptor.py index 6b89db5..8a129ac 100644 --- a/tests/adaptor.py +++ b/tests/adaptor.py @@ -1,35 +1,57 @@ +import argparse import importlib import sys -import pytest +def available_adaptors(): + adaptors = [Psycopg2, Psycopg3] + return [a for a in adaptors if a.load()] -class Psycopg2: - def __init__(self, connection_params, client_encoding): + +class Adaptor: + module_names: list[str] + pgcode_attribute = "sqlstate" + m: argparse.Namespace + + @classmethod + def load(cls): try: - psycopg2 = importlib.import_module("psycopg2") - extras = importlib.import_module("psycopg2.extras") + modules = [importlib.import_module(m) for m in cls.module_names] + moddict = dict(zip(map(modname, cls.module_names), modules)) + cls.m = argparse.Namespace(**moddict) + return True except ModuleNotFoundError: - pytest.skip("psycopg2 not available") + return False + + @classmethod + def get_pgcode(cls, err): + return getattr(err, cls.pgcode_attribute, None) + + +def modname(modpath): + return modpath.rsplit(".", 1)[-1] + - self.conn = psycopg2.connect( - connection_factory=extras.LoggingConnection, +class Psycopg2(Adaptor): + module_names = ["psycopg2", "psycopg2.extras"] + pgcode_attribute = "pgcode" + + def __init__(self, connection_params, client_encoding): + self.conn = self.m.psycopg2.connect( + connection_factory=self.m.extras.LoggingConnection, **connection_params, ) self.conn.initialize(sys.stderr) self.conn.autocommit = False self.conn.set_client_encoding(client_encoding) - self.errors = psycopg2.errors + self.unsupported_type = self.m.psycopg2.errors.UndefinedObject -class Psycopg3: - def __init__(self, connection_params, client_encoding): - try: - psycopg3 = importlib.import_module("psycopg") - except ModuleNotFoundError: - pytest.skip("psycopg3 not available") +class Psycopg3(Adaptor): + module_names = ["psycopg"] - self.conn = psycopg3.connect(**connection_params) + def __init__(self, connection_params, client_encoding): + self.conn = self.m.psycopg.connect(**connection_params) self.conn.autocommit = False self.conn.execute(f"SET client_encoding='{client_encoding}'") - self.errors = psycopg3.errors + self.unsupported_type = self.m.psycopg.errors.UndefinedObject diff --git a/tests/conftest.py b/tests/conftest.py index c502fcf..adfbb31 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ import pytest from psycopg2.extras import LoggingConnection -from . import adaptor +from .adaptor import available_adaptors from .db import TemporaryTable @@ -92,10 +92,9 @@ def client_encoding(request): return getattr(request, "param", "UTF8") -@pytest.fixture(params=[adaptor.Psycopg2, adaptor.Psycopg3]) -def conn(request, db, client_encoding): - psycopg2 = request.param(connection_params, client_encoding) - conn = psycopg2.conn +@pytest.fixture +def db_ext(request, db): + conn = connect() inst = request.instance if isinstance(inst, TemporaryTable): for extension in inst.extensions: @@ -109,12 +108,24 @@ def conn(request, db, client_encoding): psycopg2.errors.FeatureNotSupported, # postgres >= 15 ): conn.rollback() + + +@pytest.fixture(params=available_adaptors()) +def conn(request, db_ext, client_encoding): + adaptor = request.param(connection_params, client_encoding) + conn = adaptor.conn + inst = request.instance + if isinstance(inst, TemporaryTable): try: with conn.cursor() as cur: cur.execute(inst.create_sql(inst.tempschema)) - except psycopg2.errors.UndefinedObject as e: + except adaptor.unsupported_type as e: + pgcode = adaptor.get_pgcode(e) conn.rollback() - pytest.skip("Unsupported datatype") + if pgcode == "42704": + pytest.skip("Unsupported datatype") + else: + raise yield conn conn.rollback() conn.close() From e44bbdc3902e3396d83cd57f7c01e1e32cc48b2d Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Mon, 13 May 2024 21:02:10 +0300 Subject: [PATCH 13/24] feat: add PyGreSQL backend --- pgcopy/backend.py | 94 ++++++++++++++++++++++++++++++++++-- pgcopy/copy.py | 17 ++++--- tests/adaptor.py | 30 +++++++++++- tests/conftest.py | 2 + tests/test_datatypes.py | 5 +- tests/test_threading_copy.py | 8 +++ 6 files changed, 143 insertions(+), 13 deletions(-) diff --git a/pgcopy/backend.py b/pgcopy/backend.py index 662f2be..edc962e 100644 --- a/pgcopy/backend.py +++ b/pgcopy/backend.py @@ -1,4 +1,5 @@ "psycopg backends" +import codecs import contextlib import importlib import os @@ -13,10 +14,20 @@ def for_connection(conn): return Psycopg2Backend(conn) if "psycopg" in sources: return Psycopg3Backend(conn) + if "pgdb" in sources: + return PyGreSQLBackend(conn) + if "pg8000" in sources: + return Pg8000Backend(conn) message = f"{conn.__class__.__name__} is not a supported connection type" raise UnsupportedConnectionError(message) +def copy_sql(schema, table, columns): + column_list = '", "'.join(columns) + cmd = 'COPY "{0}"."{1}" ("{2}") FROM STDIN WITH BINARY' + return cmd.format(schema, table, column_list) + + class Psycopg2Backend: def __init__(self, conn): self.conn = conn @@ -30,10 +41,12 @@ def namedtuple_cursor(self): factory = self.adaptor.extras.NamedTupleCursor return self.conn.cursor(cursor_factory=factory) - def copy(self, sql, fobject_factory): + def copy(self, schema, table, columns, fobject_factory): + sql = copy_sql(schema, table, columns) return Psycopg2Copy(self.conn, sql, fobject_factory) - def threading_copy(self, sql): + def threading_copy(self, schema, table, columns): + sql = copy_sql(schema, table, columns) return Psycopg2ThreadingCopy(self.conn, sql) @@ -92,13 +105,86 @@ def namedtuple_cursor(self): return self.conn.cursor(row_factory=factory) @contextlib.contextmanager - def copy(self, sql, _): + def copy(self, schema, table, columns, _): + sql = copy_sql(schema, table, columns) with self.conn.cursor() as cur: with cur.copy(sql) as copy: yield copy @contextlib.contextmanager - def threading_copy(self, sql): + def threading_copy(self, schema, table, columns): + sql = copy_sql(schema, table, columns) with self.conn.cursor() as cur: with cur.copy(sql) as copy: yield copy + + +class PyGreSQLBackend: + def __init__(self, conn): + self.conn = conn + + def get_encoding(self): + with self.conn.cursor() as cur: + cur.execute("SHOW client_encoding") + row = cur.fetchone() + return codecs.lookup(row.client_encoding).name + + def namedtuple_cursor(self): + return self.conn.cursor() + + def copy(self, schema, table, columns, fobject_factory): + return PyGreSQLCopy(self.conn, schema, table, columns, fobject_factory) + + +class PyGreSQLCopy: + def __init__(self, conn, schema, table, columns, fobject_factory): + self.conn = conn + self.table = f"{schema}.{table}" + self.columns = columns + self.datastream = fobject_factory() + + def __enter__(self): + return self.datastream + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self.datastream.seek(0) + self.copystream() + self.datastream.close() + + def copystream(self): + with self.conn.cursor() as cur: + cur.copy_from( + self.datastream, + self.table, + format="binary", + columns=self.columns, + ) + + +class PyGreSQLThreadingCopy: + def __init__(self, conn, schema, table, columns): + self.conn = conn + self.table = f"{schema}.{table}" + self.columns = columns + r_fd, w_fd = os.pipe() + self.rstream = os.fdopen(r_fd, "rb") + self.wstream = os.fdopen(w_fd, "wb") + + def __enter__(self): + self.copy_thread = RaisingThread(target=self.copystream) + self.copy_thread.start() + return self.wstream + + def __exit__(self, exc_type, exc_val, exc_tb): + self.wstream.close() + self.copy_thread.join() + + def copystream(self): + with self.conn.cursor() as cur: + cur.copy_from( + self.rstream, + self.table, + format="binary", + columns=self.columns, + ) diff --git a/pgcopy/copy.py b/pgcopy/copy.py index df7ae1e..be015ba 100644 --- a/pgcopy/copy.py +++ b/pgcopy/copy.py @@ -253,6 +253,10 @@ def __init__(self, conn, table, cols): **self.type_formatters, } self.backend = backend.for_connection(conn) + self.implements_threading_copy = hasattr(self.backend, "threading_copy") + if not self.implements_threading_copy: + self.threading_copy = self.copy + if "." in table: self.schema, self.table = table.split(".", 1) else: @@ -309,7 +313,9 @@ def copy(self, data, fobject_factory=tempfile.TemporaryFile): ``ValueError`` is raised if a null value is provided for a column with non-null constraint. """ - self._copy(data, self.backend.copy(self.sql(), fobject_factory)) + self._copy( + data, self.backend.copy(self.schema, self.table, self.cols, fobject_factory) + ) def threading_copy(self, data): """ @@ -318,7 +324,9 @@ def threading_copy(self, data): :param data: the data to be inserted :type data: iterable of iterables """ - self._copy(data, self.backend.threading_copy(self.sql())) + self._copy( + data, self.backend.threading_copy(self.schema, self.table, self.cols) + ) def _copy(self, data, copy): try: @@ -329,11 +337,6 @@ def _copy(self, data, copy): e.message = templ.format(self.schema, self.table, e) raise e - def sql(self): - columns = '", "'.join(self.cols) - cmd = 'COPY "{0}"."{1}" ("{2}") FROM STDIN WITH BINARY' - return cmd.format(self.schema, self.table, columns) - def writestream(self, data, datastream): datastream.write(BINCOPY_HEADER) count = len(self.cols) diff --git a/tests/adaptor.py b/tests/adaptor.py index 8a129ac..4b1ad29 100644 --- a/tests/adaptor.py +++ b/tests/adaptor.py @@ -1,10 +1,11 @@ import argparse +import codecs import importlib import sys def available_adaptors(): - adaptors = [Psycopg2, Psycopg3] + adaptors = [Psycopg2, Psycopg3, PyGreSQL] return [a for a in adaptors if a.load()] @@ -46,6 +47,10 @@ def __init__(self, connection_params, client_encoding): self.conn.set_client_encoding(client_encoding) self.unsupported_type = self.m.psycopg2.errors.UndefinedObject + @staticmethod + def supports_encoding(encoding): + return True + class Psycopg3(Adaptor): module_names = ["psycopg"] @@ -55,3 +60,26 @@ def __init__(self, connection_params, client_encoding): self.conn.autocommit = False self.conn.execute(f"SET client_encoding='{client_encoding}'") self.unsupported_type = self.m.psycopg.errors.UndefinedObject + + @staticmethod + def supports_encoding(encoding): + return True + + +class PyGreSQL(Adaptor): + module_names = ["pgdb"] + + def __init__(self, connection_params, client_encoding): + self.conn = self.m.pgdb.connect(**connection_params) + self.conn.autocommit = False + self.conn.execute(f"SET client_encoding='{client_encoding}'") + self.unsupported_type = self.conn.ProgrammingError + self.integrity_error = self.conn.IntegrityError + + @staticmethod + def supports_encoding(encoding): + try: + codecs.lookup(encoding) + return True + except LookupError: + return False diff --git a/tests/conftest.py b/tests/conftest.py index adfbb31..a2b5a57 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -112,6 +112,8 @@ def db_ext(request, db): @pytest.fixture(params=available_adaptors()) def conn(request, db_ext, client_encoding): + if not request.param.supports_encoding(client_encoding): + pytest.skip("Unsupported encoding for {request.param}") adaptor = request.param(connection_params, client_encoding) conn = adaptor.conn inst = request.instance diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index ad912c5..daeaeea 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -17,7 +17,10 @@ def test_db_encoding(conn): - assert conn.info.parameter_status("server_encoding") == "UTF8" + with conn.cursor() as cur: + cur.execute("SHOW server_encoding") + res = cur.fetchone() + assert res[0] == "UTF8" class TypeMixin(db.TemporaryTable): diff --git a/tests/test_threading_copy.py b/tests/test_threading_copy.py index ce0f365..6e3b4da 100644 --- a/tests/test_threading_copy.py +++ b/tests/test_threading_copy.py @@ -12,6 +12,8 @@ class TestThreadingCopy(test_datatypes.TypeMixin): def test_threading_copy(self, conn, cursor, schema_table, data): mgr = CopyManager(conn, self.table, self.cols) + if not mgr.implements_threading_copy: + pytest.skip("threading_copy not implemented") mgr.threading_copy(data) select_list = ",".join(self.cols) cursor.execute(self.select_sql(schema_table)) @@ -20,11 +22,15 @@ def test_threading_copy(self, conn, cursor, schema_table, data): def test_threading_copy_error(self, conn, cursor): data = [{}] mgr = CopyManager(conn, self.table, self.cols) + if not mgr.implements_threading_copy: + pytest.skip("threading_copy not implemented") with pytest.raises(conn.DataError): mgr.threading_copy(data) def test_threading_copy_generator(self, conn, cursor, schema_table, data): mgr = CopyManager(conn, self.table, self.cols) + if not mgr.implements_threading_copy: + pytest.skip("threading_copy not implemented") mgr.threading_copy(iter(data)) select_list = ",".join(self.cols) cursor.execute(self.select_sql(schema_table)) @@ -33,6 +39,8 @@ def test_threading_copy_generator(self, conn, cursor, schema_table, data): def test_threading_copy_empty_generator(self, conn, cursor, schema_table): data = [] mgr = CopyManager(conn, self.table, self.cols) + if not mgr.implements_threading_copy: + pytest.skip("threading_copy not implemented") mgr.threading_copy(iter(data)) select_list = ",".join(self.cols) cursor.execute(self.select_sql(schema_table)) From bc85ec14d3888e970b5833ff6be3306b46c37bc5 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Mon, 13 May 2024 21:16:00 +0300 Subject: [PATCH 14/24] feat: add pg8000 backend --- pgcopy/backend.py | 64 +++++++++++++++++++++++++++++++++++++++++ tests/adaptor.py | 47 +++++++++++++++++++++++++++++- tests/conftest.py | 17 +++++++++-- tests/test_datatypes.py | 3 +- tests/test_replace.py | 64 +++++++++++++++++++++++------------------ 5 files changed, 162 insertions(+), 33 deletions(-) diff --git a/pgcopy/backend.py b/pgcopy/backend.py index edc962e..c90d3bd 100644 --- a/pgcopy/backend.py +++ b/pgcopy/backend.py @@ -1,5 +1,6 @@ "psycopg backends" import codecs +import collections import contextlib import importlib import os @@ -188,3 +189,66 @@ def copystream(self): format="binary", columns=self.columns, ) + + +class Pg8000Backend: + NamedTupleCursor = None + + def __init__(self, conn): + self.conn = conn + + def get_encoding(self): + with contextlib.closing(self.namedtuple_cursor()) as cur: + cur.execute("SHOW client_encoding") + row = cur.fetchone() + return codecs.lookup(row.client_encoding).name + + def namedtuple_cursor(self): + if not Pg8000Backend.NamedTupleCursor: + cur = self.conn.cursor() + Cursor = cur.__class__ + cur.close() + + class NamedTupleCursor(Cursor): + def __next__(self): + val = super().__next__() + context = self._context + if context is None: + return val # raised an error already + rowclass = getattr(context, "_pgcopy_row_class", None) + if not rowclass: + columns = context.columns + if columns is None or len(columns) == 0: + return val # probably also raised an error + column_names = [col["name"] for col in columns] + rowclass = collections.namedtuple("Row", column_names) + context._pgcopy_row_class = rowclass + return rowclass(*val) + + Pg8000Backend.NamedTupleCursor = NamedTupleCursor + return Pg8000Backend.NamedTupleCursor(self.conn) + + def copy(self, schema, table, columns, fobject_factory): + sql = copy_sql(schema, table, columns) + return Pg8000Copy(self.conn, sql, fobject_factory) + + +class Pg8000Copy: + def __init__(self, conn, sql, fobject_factory): + self.conn = conn + self.sql = sql + self.datastream = fobject_factory() + + def __enter__(self): + return self.datastream + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self.datastream.seek(0) + self.copystream() + self.datastream.close() + + def copystream(self): + cur = self.conn.cursor() + cur.execute(self.sql, stream=self.datastream) + cur.close() diff --git a/tests/adaptor.py b/tests/adaptor.py index 4b1ad29..fa49877 100644 --- a/tests/adaptor.py +++ b/tests/adaptor.py @@ -1,11 +1,13 @@ import argparse import codecs +import contextlib import importlib +import os import sys def available_adaptors(): - adaptors = [Psycopg2, Psycopg3, PyGreSQL] + adaptors = [Psycopg2, Psycopg3, PyGreSQL, Pg8000] return [a for a in adaptors if a.load()] @@ -46,6 +48,7 @@ def __init__(self, connection_params, client_encoding): self.conn.autocommit = False self.conn.set_client_encoding(client_encoding) self.unsupported_type = self.m.psycopg2.errors.UndefinedObject + self.integrity_error = self.m.psycopg2.errors.IntegrityError @staticmethod def supports_encoding(encoding): @@ -60,6 +63,7 @@ def __init__(self, connection_params, client_encoding): self.conn.autocommit = False self.conn.execute(f"SET client_encoding='{client_encoding}'") self.unsupported_type = self.m.psycopg.errors.UndefinedObject + self.integrity_error = self.m.psycopg.errors.IntegrityError @staticmethod def supports_encoding(encoding): @@ -83,3 +87,44 @@ def supports_encoding(encoding): return True except LookupError: return False + + +class Pg8000(Adaptor): + module_names = ["pg8000.dbapi", "pg8000.exceptions"] + + def __init__(self, connection_params, client_encoding): + params = self.get_connection_parameters(connection_params) + self.conn = self.m.dbapi.connect(**params) + self.conn.autocommit = False + with contextlib.closing(self.conn.cursor()) as cur: + cur.execute(f"SET client_encoding='{client_encoding}'") + self.unsupported_type = self.m.exceptions.DatabaseError + self.integrity_error = self.m.exceptions.DatabaseError + + def get_connection_parameters(self, connection_params): + psycopg2 = importlib.import_module("psycopg2") + conn = psycopg2.connect(**connection_params) + parameters = { + "user": conn.info.user, + "database": conn.info.dbname, + } + host = conn.info.host + if host.startswith("/"): + sock = f"{host}/.s.PGSQL.{conn.info.port}" + if os.path.exists(sock): + parameters["unix_sock"] = sock + return parameters + parameters["host"] = "localhost" + else: + parameters["host"] = host + parameters["port"] = conn.info.port + parameters["password"] = conn.info.password + return parameters + + @staticmethod + def supports_encoding(encoding): + return encoding.upper() == "UTF8" + + @classmethod + def get_pgcode(cls, err): + return err.args[0]["C"] diff --git a/tests/conftest.py b/tests/conftest.py index a2b5a57..5ff710e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import contextlib import os import re import sys @@ -111,7 +112,7 @@ def db_ext(request, db): @pytest.fixture(params=available_adaptors()) -def conn(request, db_ext, client_encoding): +def adaptor(request, db_ext, client_encoding): if not request.param.supports_encoding(client_encoding): pytest.skip("Unsupported encoding for {request.param}") adaptor = request.param(connection_params, client_encoding) @@ -119,7 +120,7 @@ def conn(request, db_ext, client_encoding): inst = request.instance if isinstance(inst, TemporaryTable): try: - with conn.cursor() as cur: + with contextlib.closing(conn.cursor()) as cur: cur.execute(inst.create_sql(inst.tempschema)) except adaptor.unsupported_type as e: pgcode = adaptor.get_pgcode(e) @@ -128,11 +129,21 @@ def conn(request, db_ext, client_encoding): pytest.skip("Unsupported datatype") else: raise - yield conn + yield adaptor conn.rollback() conn.close() +@pytest.fixture +def conn(adaptor): + return adaptor.conn + + +@pytest.fixture +def integrity_error(adaptor): + return adaptor.integrity_error + + @pytest.fixture def cursor(conn): cur = conn.cursor() diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index daeaeea..9c0d046 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals +import contextlib import decimal import json import sys @@ -17,7 +18,7 @@ def test_db_encoding(conn): - with conn.cursor() as cur: + with contextlib.closing(conn.cursor()) as cur: cur.execute("SHOW server_encoding") res = cur.fetchone() assert res[0] == "UTF8" diff --git a/tests/test_replace.py b/tests/test_replace.py index 9389d5f..490a397 100644 --- a/tests/test_replace.py +++ b/tests/test_replace.py @@ -6,6 +6,10 @@ from . import db +def tuplist(iterable): + return [tuple(row) for row in iterable] + + class TestRenameReplace(db.TemporaryTable): datatypes = ["integer"] mixed_case = False @@ -19,11 +23,11 @@ def test_rename_replace(self, conn, cursor, schema): with util.RenameReplace(conn, self.table, xform) as temp: cursor.executemany(sql.format(temp), [(36,), (72,)]) cursor.execute("SELECT * FROM {}".format(self.table)) - assert list(cursor) == [(36,), (72,)] + assert tuplist(cursor) == [(36,), (72,)] cursor.execute("SELECT * FROM v") - assert list(cursor) == [(37,), (73,)] + assert tuplist(cursor) == [(37,), (73,)] cursor.execute("SELECT * FROM {}".format(self.table + "_old")) - assert list(cursor) == [(1,), (2,)] + assert tuplist(cursor) == [(1,), (2,)] class TestReplaceFallbackSchema(db.TemporaryTable): @@ -39,7 +43,7 @@ def test_fallback_schema_honors_search_path( with Replace(conn, self.table) as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM {}".format(schema_table)) - assert list(cursor) == [(1,)] + assert tuplist(cursor) == [(1,)] class TestReplaceDefault(db.TemporaryTable): @@ -59,19 +63,23 @@ def test_replace_with_default(self, conn, cursor, schema_table): with Replace(conn, schema_table) as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM {}".format(schema_table)) - assert list(cursor) == [(1, 3)] + assert tuplist(cursor) == [(1, 3)] -@contextlib.contextmanager -def replace_raises(conn, table): - """ - Wrap Replace context manager and assert - exception is thrown on context exit - """ - r = Replace(conn, table) - yield r.__enter__() - with pytest.raises(conn.IntegrityError): - r.__exit__(None, None, None) +@pytest.fixture +def replace_raises(conn, schema_table, integrity_error): + @contextlib.contextmanager + def _replace_raises(): + """ + Wrap Replace context manager and assert + exception is thrown on context exit + """ + r = Replace(conn, schema_table) + yield r.__enter__() + with pytest.raises(integrity_error): + r.__exit__(None, None, None) + + return _replace_raises class TestReplaceNotNull(db.TemporaryTable): @@ -82,15 +90,15 @@ class TestReplaceNotNull(db.TemporaryTable): "integer NOT NULL", ] - def test_replace_not_null(self, conn, cursor, schema_table): + def test_replace_not_null(self, replace_raises, cursor): """ Not-null constraint is added on exit """ sql = 'INSERT INTO {} ("a") VALUES (%s)' - with replace_raises(conn, schema_table) as temp: + with replace_raises() as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM {}".format(temp)) - assert list(cursor) == [(1, None)] + assert tuplist(cursor) == [(1, None)] class TestReplaceConstraint(db.TemporaryTable): @@ -100,12 +108,12 @@ class TestReplaceConstraint(db.TemporaryTable): "integer CHECK (a > 5)", ] - def test_replace_constraint(self, conn, cursor, schema_table): + def test_replace_constraint(self, replace_raises, cursor): sql = 'INSERT INTO {} ("a") VALUES (%s)' - with replace_raises(conn, schema_table) as temp: + with replace_raises() as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM {}".format(temp)) - assert list(cursor) == [(1,)] + assert tuplist(cursor) == [(1,)] class TestReplaceNamedConstraint(db.TemporaryTable): @@ -129,16 +137,16 @@ class TestReplaceUniqueIndex(db.TemporaryTable): "integer UNIQUE", ] - def test_replace_unique_index(self, conn, cursor, schema_table): + def test_replace_unique_index(self, replace_raises, cursor): """ Not-null constraint is added on exit """ sql = 'INSERT INTO {} ("a") VALUES (%s)' - with replace_raises(conn, schema_table) as temp: + with replace_raises() as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM {}".format(temp)) - assert list(cursor) == [(1,), (1,)] + assert tuplist(cursor) == [(1,), (1,)] class TestReplaceView(db.TemporaryTable): @@ -152,7 +160,7 @@ def test_replace_with_view(self, conn, cursor, schema_table): with Replace(conn, schema_table) as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM v") - assert list(cursor) == [(2,)] + assert tuplist(cursor) == [(2,)] class TestReplaceViewMultiSchema(db.TemporaryTable): @@ -167,7 +175,7 @@ def test_replace_view_in_different_schema(self, conn, cursor, schema_table): with Replace(conn, schema_table) as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM ns.v") - assert list(cursor) == [(2,)] + assert tuplist(cursor) == [(2,)] class TestReplaceTrigger(db.TemporaryTable): @@ -200,7 +208,7 @@ def test_replace_with_trigger(self, conn, cursor, schema_table): cursor.execute(sql.format(temp), (1, 1)) cursor.execute(sql.format(schema_table), (2, 1)) cursor.execute("SELECT * FROM {}".format(schema_table)) - assert list(cursor) == [(1, 1), (2, 8)] + assert tuplist(cursor) == [(1, 1), (2, 8)] class TestReplaceSequence(db.TemporaryTable): @@ -219,4 +227,4 @@ def test_replace_with_sequence(self, conn, cursor, schema_table): cursor.execute(sql.format(schema_table), (40,)) cursor.execute(sql.format(schema_table), (40,)) cursor.execute("SELECT * FROM {}".format(schema_table)) - assert list(cursor) == [(30, 3), (40, 5)] + assert tuplist(cursor) == [(30, 3), (40, 5)] From 64ddbbebe03ec6135fa7ad3c44d5458f7640ace3 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Wed, 24 Dec 2025 21:49:47 +0200 Subject: [PATCH 15/24] test: add all backends to tox --- tox.ini | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 9bb22d4..8906ecf 100644 --- a/tox.ini +++ b/tox.ini @@ -9,7 +9,9 @@ deps = pytest psycopg2~=2.9 psycopg[binary] -commands = pytest tests/ + pg8000 + PyGreSQL +commands = python -m pytest tests/ --tb=native docker = pg16 setenv = POSTGRES_DB=pgcopy_tox_test @@ -34,6 +36,8 @@ deps = pytest-cov psycopg2~=2.9 psycopg[binary] + pg8000 + PyGreSQL commands = pytest --cov=pgcopy/ tests/ docker = pgvector From f96621a9f1d21adb04ab572a58887e255fa03c20 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Sun, 23 Nov 2025 12:07:51 +0200 Subject: [PATCH 16/24] test: add ability to run tests without temporary tables --- tests/db.py | 5 ++++- tests/test_replace.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/db.py b/tests/db.py index 1f2cd62..b38ebb5 100644 --- a/tests/db.py +++ b/tests/db.py @@ -1,9 +1,12 @@ import hashlib +import os from datetime import date, datetime, time, timedelta from random import randint from pgcopy import util +NO_TEMPORARY_TABLES = bool(os.getenv("NO_TEMPORARY_TABLES")) + genbool = lambda i: 0 == (i % 3) genint = lambda i: i genfloat = lambda i: 1.125 * i @@ -32,7 +35,7 @@ class TemporaryTable(object): - tempschema = True + tempschema = not NO_TEMPORARY_TABLES null = "NOT NULL" data = None extensions = [] diff --git a/tests/test_replace.py b/tests/test_replace.py index 490a397..3893b3f 100644 --- a/tests/test_replace.py +++ b/tests/test_replace.py @@ -37,6 +37,8 @@ class TestReplaceFallbackSchema(db.TemporaryTable): def test_fallback_schema_honors_search_path( self, conn, cursor, schema, schema_table ): + if not db.TemporaryTable.tempschema: + pytest.skip("This test currently requires temp schema") cursor.execute(self.create_sql(tempschema=False)) cursor.execute("SET search_path TO {}".format(schema)) sql = 'INSERT INTO {} ("a") VALUES (%s)' From 57f225b164ed81679ae9f12271747f3db0b32a7f Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Sun, 23 Nov 2025 12:54:20 +0200 Subject: [PATCH 17/24] test: refactor: use super() --- tests/test_replace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_replace.py b/tests/test_replace.py index 3893b3f..2e0391f 100644 --- a/tests/test_replace.py +++ b/tests/test_replace.py @@ -37,7 +37,7 @@ class TestReplaceFallbackSchema(db.TemporaryTable): def test_fallback_schema_honors_search_path( self, conn, cursor, schema, schema_table ): - if not db.TemporaryTable.tempschema: + if not super().tempschema: pytest.skip("This test currently requires temp schema") cursor.execute(self.create_sql(tempschema=False)) cursor.execute("SET search_path TO {}".format(schema)) From 102615fc5a5e71348b84d6bb025d07dc3b7a27d1 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Sun, 21 Dec 2025 21:26:53 +0200 Subject: [PATCH 18/24] test: recombine extension and table setup --- tests/conftest.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 5ff710e..341cf9d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -93,32 +93,29 @@ def client_encoding(request): return getattr(request, "param", "UTF8") -@pytest.fixture -def db_ext(request, db): - conn = connect() +@pytest.fixture(params=available_adaptors()) +def adaptor(request, db, client_encoding): + if not request.param.supports_encoding(client_encoding): + pytest.skip("Unsupported encoding for {request.param}") + adaptor = request.param(connection_params, client_encoding) + conn = adaptor.conn inst = request.instance if isinstance(inst, TemporaryTable): + # use psycopg2 connection to create extensions if necessary + psycopg2_conn = connect() for extension in inst.extensions: try: - with conn.cursor() as cur: + with psycopg2_conn.cursor() as cur: cur.execute("CREATE EXTENSION {}".format(extension)) - conn.commit() + psycopg2_conn.commit() except ( psycopg2.errors.DuplicateObject, psycopg2.errors.UndefinedFile, # postgres <= 14 psycopg2.errors.FeatureNotSupported, # postgres >= 15 ): - conn.rollback() - + psycopg2_conn.rollback() + psycopg2_conn.close() -@pytest.fixture(params=available_adaptors()) -def adaptor(request, db_ext, client_encoding): - if not request.param.supports_encoding(client_encoding): - pytest.skip("Unsupported encoding for {request.param}") - adaptor = request.param(connection_params, client_encoding) - conn = adaptor.conn - inst = request.instance - if isinstance(inst, TemporaryTable): try: with contextlib.closing(conn.cursor()) as cur: cur.execute(inst.create_sql(inst.tempschema)) From 0a547d68aa22900d07f8f9f58d8c1fc0a2a74fd9 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Sun, 21 Dec 2025 22:03:06 +0200 Subject: [PATCH 19/24] test: port dsql branch --- tests/conftest.py | 57 +++++++++++++++++++++++------------------ tests/db.py | 34 +++++++++++++++++++----- tests/db_connection.py | 45 ++++++++++++++++++++++++++++++++ tests/test_datatypes.py | 16 +++++++----- tests/test_errors.py | 4 ++- tests/test_replace.py | 29 +++++++++++---------- tests/test_sanity.py | 1 + tests/test_schema.py | 8 ++++-- 8 files changed, 140 insertions(+), 54 deletions(-) create mode 100644 tests/db_connection.py diff --git a/tests/conftest.py b/tests/conftest.py index 341cf9d..c0bf835 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,34 +1,15 @@ import contextlib -import os -import re import sys import psycopg2 import pytest from psycopg2.extras import LoggingConnection +from . import db_connection from .adaptor import available_adaptors from .db import TemporaryTable - -def get_port(): - # this would be much more straightforward if tox-docker would release - # recent updates https://github.com/tox-dev/tox-docker/pull/167 - if os.getenv("TOX_ENV_NAME"): - search_pattern = re.compile(r"PG\w+_5432_TCP_PORT") - for name, val in os.environ.items(): - if search_pattern.fullmatch(name): - return int(val) - return int(os.getenv("POSTGRES_PORT", "5432")) - - -connection_params = { - "dbname": os.getenv("POSTGRES_DB", "pgcopy_test"), - "port": get_port(), - "host": os.getenv("POSTGRES_HOST"), - "user": os.getenv("POSTGRES_USER"), - "password": os.getenv("POSTGRES_PASSWORD"), -} +connection_params = db_connection.get_connection_params() @pytest.fixture(scope="session") @@ -73,8 +54,7 @@ def create_db(): + ".\nThe error is: %s" % exc ) raise RuntimeError(message) - else: - return True + return True def drop_db(): @@ -98,9 +78,30 @@ def adaptor(request, db, client_encoding): if not request.param.supports_encoding(client_encoding): pytest.skip("Unsupported encoding for {request.param}") adaptor = request.param(connection_params, client_encoding) - conn = adaptor.conn inst = request.instance if isinstance(inst, TemporaryTable): + if db_connection.IS_DSQL: + yield from dsql_table(adaptor, inst) + else: + yield from temporary_table(adaptor, inst) + else: + yield from no_table(adaptor) + + +def dsql_table(adaptor, inst): + for adaptor in temporary_table(adaptor, inst): + conn = adaptor.conn + conn.commit() + yield conn + conn.commit() + if drop_sql := inst.drop_sql(): + with contextlib.closing(conn.cursor()) as cur: + cur.execute(drop_sql) + conn.commit() + + +def temporary_table(adaptor, inst): + for adaptor in no_table(adaptor): # use psycopg2 connection to create extensions if necessary psycopg2_conn = connect() for extension in inst.extensions: @@ -116,6 +117,7 @@ def adaptor(request, db, client_encoding): psycopg2_conn.rollback() psycopg2_conn.close() + conn = adaptor.conn try: with contextlib.closing(conn.cursor()) as cur: cur.execute(inst.create_sql(inst.tempschema)) @@ -126,6 +128,11 @@ def adaptor(request, db, client_encoding): pytest.skip("Unsupported datatype") else: raise + yield adaptor + + +def no_table(adaptor): + conn = adaptor.conn yield adaptor conn.rollback() conn.close() @@ -175,4 +182,4 @@ def schema_table(request, schema): def data(request): inst = request.instance if isinstance(inst, TemporaryTable): - return inst.data or inst.generate_data(inst.record_count) + return inst.generate_data() diff --git a/tests/db.py b/tests/db.py index b38ebb5..5f59f6b 100644 --- a/tests/db.py +++ b/tests/db.py @@ -5,7 +5,11 @@ from pgcopy import util -NO_TEMPORARY_TABLES = bool(os.getenv("NO_TEMPORARY_TABLES")) +from . import db_connection + +# pylint: disable=consider-using-f-string + +NO_TEMPORARY_TABLES = bool(os.getenv("NO_TEMPORARY_TABLES")) or db_connection.IS_DSQL genbool = lambda i: 0 == (i % 3) genint = lambda i: i @@ -36,11 +40,16 @@ class TemporaryTable(object): tempschema = not NO_TEMPORARY_TABLES + id_col = True null = "NOT NULL" data = None + datatypes: list extensions = [] record_count = 0 mixed_case = True + table: str + cols: list + select_list: str def colname(self, i): char = chr(ord("a") + i) @@ -52,17 +61,28 @@ def setup_method(self): self.table = self.__class__.__name__ if not self.mixed_case: self.table = self.__class__.__name__.lower() - self.cols = [self.colname(i) for i in range(len(self.datatypes))] - self.select_list = ','.join('"{}"'.format(c) for c in self.cols) + id_cols = ["id"] if self.id_col else [] + self.cols = id_cols + [self.colname(i) for i in range(len(self.datatypes))] + self.select_list = ",".join('"{}"'.format(c) for c in self.cols) def create_sql(self, tempschema=None): col_ids = ['"{}"'.format(c) for c in self.cols] - colsql = [(c, t, self.null) for c, t in zip(col_ids, self.datatypes)] + id_types = ["integer"] if self.id_col else [] + colsql = [(c, t, self.null) for c, t in zip(col_ids, id_types + self.datatypes)] collist = ", ".join(map(" ".join, colsql)) if tempschema: return 'CREATE TEMPORARY TABLE "{}" ({})'.format(self.table, collist) return 'CREATE TABLE "public"."{}" ({})'.format(self.table, collist) - def generate_data(self, count): - gen = [datagen[t] for t in self.datatypes] - return [tuple(g(i) for g in gen) for i in range(count)] + def generate_data(self): + if self.data: + return [(i, *row) for i, row in enumerate(self.data)] + id_types = ["integer"] if self.id_col else [] + datatypes = id_types + self.datatypes + gen = [datagen[t] for t in datatypes] + return [tuple(g(i) for g in gen) for i in range(self.record_count)] + + def drop_sql(self, tempschema=None): + if tempschema: + return + return 'DROP TABLE "public"."{}"'.format(self.table) diff --git a/tests/db_connection.py b/tests/db_connection.py new file mode 100644 index 0000000..ee40d4e --- /dev/null +++ b/tests/db_connection.py @@ -0,0 +1,45 @@ +import os +import re + +DB_HOST = os.getenv("POSTGRES_HOST") +IS_DSQL = bool(DB_HOST and re.match(r"^\w+\.dsql\.\w\w-\w+-[1-9]\.on\.aws$", DB_HOST)) + + +def get_connection_params(): + if IS_DSQL: + return dsql_connection_params() + return { + "dbname": os.getenv("POSTGRES_DB", "pgcopy_test"), + "port": get_port(), + "host": DB_HOST, + "user": os.getenv("POSTGRES_USER"), + "password": os.getenv("POSTGRES_PASSWORD"), + } + + +def dsql_connection_params(): + import boto3 + + dsql = boto3.client("dsql") + admin_token = dsql.generate_db_connect_admin_auth_token( + Hostname=DB_HOST, + ExpiresIn=300, + ) + return { + "user": "admin", + "host": DB_HOST, + "password": admin_token, + "dbname": "postgres", + "sslmode": "require", + } + + +def get_port(): + # this would be much more straightforward if tox-docker would release + # recent updates https://github.com/tox-dev/tox-docker/pull/167 + if os.getenv("TOX_ENV_NAME"): + search_pattern = re.compile(r"PG\w+_5432_TCP_PORT") + for name, val in os.environ.items(): + if search_pattern.fullmatch(name): + return int(val) + return int(os.getenv("POSTGRES_PORT", "5432")) diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index 9c0d046..8547e9f 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -33,13 +33,17 @@ class TypeMixin(db.TemporaryTable): def test_type(self, conn, cursor, schema_table, data): bincopy = self.copy_manager_class(conn, schema_table, self.cols) bincopy.copy(data) - select_list = ",".join(self.cols) - cursor.execute(self.select_sql(schema_table)) - self.checkResults(cursor, data) + select_list = ",".join('"{}"'.format(c) for c in self.cols[1:]) + cursor.execute(self.select_sql(schema_table, select_list)) + self.checkResults(cursor, [row[1:] for row in data]) - def select_sql(self, schema_table): + def select_sql(self, schema_table, select_list=None): schema, table = schema_table.split(".") - return 'SELECT %s from "%s"."%s"' % (self.select_list, schema, table) + return 'SELECT %s from "%s"."%s" ORDER BY "id"' % ( + select_list or self.select_list, + schema, + table, + ) def checkResults(self, cursor, data): for rec in data: @@ -266,7 +270,7 @@ class TestNumeric(TypeMixin): (decimal.Decimal("-1000"),), (decimal.Decimal("21034.56"),), (decimal.Decimal("-900000.0001"),), - (decimal.Decimal("-1.3E25"),), + (decimal.Decimal("-1.3E11"),), ] diff --git a/tests/test_errors.py b/tests/test_errors.py index 36638c4..b8ff750 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -2,7 +2,7 @@ import pytest from pgcopy import CopyManager -from . import db +from . import db, db_connection class TestErrors(db.TemporaryTable): @@ -22,6 +22,7 @@ def test_notnull(self, conn, schema_table): class TestFormatterDiagnostic(db.TemporaryTable): + id_col = False datatypes = ["varchar"] def test_formatting_diagnostic(self, conn): @@ -31,6 +32,7 @@ def test_formatting_diagnostic(self, conn): copymgr.copy([[23]]) +@pytest.mark.skipif(db_connection.IS_DSQL, reason="drop column not supported on dsql") class TestDroppedCol(db.TemporaryTable): datatypes = ["integer", "integer"] diff --git a/tests/test_replace.py b/tests/test_replace.py index 2e0391f..6a80e71 100644 --- a/tests/test_replace.py +++ b/tests/test_replace.py @@ -3,14 +3,18 @@ import pytest from pgcopy import Replace, util -from . import db +from . import db, db_connection + +@pytest.mark.skipif(db_connection.IS_DSQL, reason="tests not supported on dsql") +class TemporaryTable(db.TemporaryTable): + id_col = False def tuplist(iterable): return [tuple(row) for row in iterable] -class TestRenameReplace(db.TemporaryTable): +class TestRenameReplace(TemporaryTable): datatypes = ["integer"] mixed_case = False @@ -30,7 +34,7 @@ def test_rename_replace(self, conn, cursor, schema): assert tuplist(cursor) == [(1,), (2,)] -class TestReplaceFallbackSchema(db.TemporaryTable): +class TestReplaceFallbackSchema(TemporaryTable): datatypes = ["integer"] mixed_case = False @@ -48,7 +52,7 @@ def test_fallback_schema_honors_search_path( assert tuplist(cursor) == [(1,)] -class TestReplaceDefault(db.TemporaryTable): +class TestReplaceDefault(TemporaryTable): """ Defaults are set on temp table immediately. """ @@ -83,8 +87,7 @@ def _replace_raises(): return _replace_raises - -class TestReplaceNotNull(db.TemporaryTable): +class TestReplaceNotNull(TemporaryTable): mixed_case = False null = "" datatypes = [ @@ -103,7 +106,7 @@ def test_replace_not_null(self, replace_raises, cursor): assert tuplist(cursor) == [(1, None)] -class TestReplaceConstraint(db.TemporaryTable): +class TestReplaceConstraint(TemporaryTable): mixed_case = False null = "" datatypes = [ @@ -118,7 +121,7 @@ def test_replace_constraint(self, replace_raises, cursor): assert tuplist(cursor) == [(1,)] -class TestReplaceNamedConstraint(db.TemporaryTable): +class TestReplaceNamedConstraint(TemporaryTable): mixed_case = False null = "" datatypes = [ @@ -132,7 +135,7 @@ def test_replace_constraint_no_name_conflict(self, conn, schema_table): pass -class TestReplaceUniqueIndex(db.TemporaryTable): +class TestReplaceUniqueIndex(TemporaryTable): mixed_case = False null = "" datatypes = [ @@ -151,7 +154,7 @@ def test_replace_unique_index(self, replace_raises, cursor): assert tuplist(cursor) == [(1,), (1,)] -class TestReplaceView(db.TemporaryTable): +class TestReplaceView(TemporaryTable): mixed_case = False datatypes = ["integer"] @@ -165,7 +168,7 @@ def test_replace_with_view(self, conn, cursor, schema_table): assert tuplist(cursor) == [(2,)] -class TestReplaceViewMultiSchema(db.TemporaryTable): +class TestReplaceViewMultiSchema(TemporaryTable): mixed_case = False tempschema = False datatypes = ["integer"] @@ -180,7 +183,7 @@ def test_replace_view_in_different_schema(self, conn, cursor, schema_table): assert tuplist(cursor) == [(2,)] -class TestReplaceTrigger(db.TemporaryTable): +class TestReplaceTrigger(TemporaryTable): mixed_case = False null = "" datatypes = [ @@ -213,7 +216,7 @@ def test_replace_with_trigger(self, conn, cursor, schema_table): assert tuplist(cursor) == [(1, 1), (2, 8)] -class TestReplaceSequence(db.TemporaryTable): +class TestReplaceSequence(TemporaryTable): mixed_case = False null = "" datatypes = [ diff --git a/tests/test_sanity.py b/tests/test_sanity.py index 994533f..1bae7fc 100644 --- a/tests/test_sanity.py +++ b/tests/test_sanity.py @@ -7,6 +7,7 @@ class TestSanity(db.TemporaryTable): + id_col = False manager = CopyManager method = "copy" record_count = 3 diff --git a/tests/test_schema.py b/tests/test_schema.py index c3ef1df..25b3131 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -10,7 +10,9 @@ class TestPublicSchema(test_datatypes.TypeMixin): def test_default_public(self, conn, cursor, data): bincopy = CopyManager(conn, self.table, self.cols) bincopy.copy(data) - cursor.execute('SELECT %s from public."%s"' % (self.select_list, self.table)) + cursor.execute( + 'SELECT %s from public."%s" ORDER BY 1' % (self.select_list, self.table) + ) self.checkResults(cursor, data) def cast(self, v): @@ -26,7 +28,9 @@ def test_fallback_schema_honors_search_path(self, conn, cursor, data, schema): cursor.execute("SET search_path TO {}".format(schema)) bincopy = CopyManager(conn, self.table, self.cols) bincopy.copy(data) - cursor.execute('SELECT %s from "%s"' % (self.select_list, self.table)) + cursor.execute( + 'SELECT %s from "%s" ORDER BY 1' % (self.select_list, self.table) + ) self.checkResults(cursor, data) def cast(self, v): From 558ed29bac6121347f51fdd40c9cc4de22592f39 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Wed, 24 Dec 2025 22:18:01 +0200 Subject: [PATCH 20/24] test: minor refactor in db creation --- tests/conftest.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c0bf835..6c3eb82 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,21 +37,21 @@ def create_db(): connect().close() return False except psycopg2.OperationalError as exc: - nosuch_db = 'database "%s" does not exist' % connection_params["dbname"] + dbname = connection_params["dbname"] + nosuch_db = 'database "%s" does not exist' % dbname if nosuch_db in str(exc): try: master = connect(dbname="postgres") master.rollback() master.autocommit = True cursor = master.cursor() - cursor.execute("CREATE DATABASE %s" % connection_params["dbname"]) + cursor.execute("CREATE DATABASE %s" % dbname) cursor.close() master.close() except psycopg2.Error as exc: message = ( - "Unable to connect to or create test db " - + connection_params["dbname"] - + ".\nThe error is: %s" % exc + "Unable to connect to or create test db %s.\nThe error is: %s" + % (dbname, exc) ) raise RuntimeError(message) return True From 1fd9f527f4754652076d2a51d234f8190c547b77 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Wed, 24 Dec 2025 22:24:02 +0200 Subject: [PATCH 21/24] test: silence linter warnings --- tests/conftest.py | 2 ++ tests/test_datatypes.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 6c3eb82..5467c12 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,8 @@ from .adaptor import available_adaptors from .db import TemporaryTable +# pylint: disable=redefined-outer-name,consider-using-f-string + connection_params = db_connection.get_connection_params() diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index 722c66b..67e1056 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -9,6 +9,8 @@ import pytest +# pylint: disable=consider-using-f-string + if sys.version_info < (3,): memoryview = buffer From 93dc1b5112730fc0188caf72f4909f2a8deaece6 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Wed, 24 Dec 2025 22:58:11 +0200 Subject: [PATCH 22/24] test: allow testing with psycopg2 or psycopg3 require either one of the two to run the test suite --- tests/adaptor.py | 37 +++++++++---------- tests/conftest.py | 75 +++------------------------------------ tests/db_connection.py | 80 ++++++++++++++++++++++++++++++++++++++++++ tox.ini | 14 ++++++++ 4 files changed, 118 insertions(+), 88 deletions(-) diff --git a/tests/adaptor.py b/tests/adaptor.py index fa49877..96ff106 100644 --- a/tests/adaptor.py +++ b/tests/adaptor.py @@ -5,6 +5,8 @@ import os import sys +from . import db_connection + def available_adaptors(): adaptors = [Psycopg2, Psycopg3, PyGreSQL, Pg8000] @@ -102,24 +104,23 @@ def __init__(self, connection_params, client_encoding): self.integrity_error = self.m.exceptions.DatabaseError def get_connection_parameters(self, connection_params): - psycopg2 = importlib.import_module("psycopg2") - conn = psycopg2.connect(**connection_params) - parameters = { - "user": conn.info.user, - "database": conn.info.dbname, - } - host = conn.info.host - if host.startswith("/"): - sock = f"{host}/.s.PGSQL.{conn.info.port}" - if os.path.exists(sock): - parameters["unix_sock"] = sock - return parameters - parameters["host"] = "localhost" - else: - parameters["host"] = host - parameters["port"] = conn.info.port - parameters["password"] = conn.info.password - return parameters + with db_connection.conninfo(connection_params) as conninfo: + parameters = { + "user": conninfo.user, + "database": conninfo.dbname, + } + host = conninfo.host + if host.startswith("/"): + sock = f"{host}/.s.PGSQL.{conninfo.port}" + if os.path.exists(sock): + parameters["unix_sock"] = sock + return parameters + parameters["host"] = "localhost" + else: + parameters["host"] = host + parameters["port"] = conninfo.port + parameters["password"] = conninfo.password + return parameters @staticmethod def supports_encoding(encoding): diff --git a/tests/conftest.py b/tests/conftest.py index 5467c12..23599fc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,6 @@ import contextlib -import sys -import psycopg2 import pytest -from psycopg2.extras import LoggingConnection from . import db_connection from .adaptor import available_adaptors @@ -11,63 +8,13 @@ # pylint: disable=redefined-outer-name,consider-using-f-string -connection_params = db_connection.get_connection_params() - @pytest.fixture(scope="session") def db(): - drop = create_db() + drop = db_connection.create_db() yield if drop: - try: - drop_db() - except psycopg2.OperationalError: - pass - - -def connect(**kwargs): - kw = connection_params.copy() - kw.update(kwargs) - conn = psycopg2.connect(connection_factory=LoggingConnection, **kw) - conn.initialize(sys.stderr) - return conn - - -def create_db(): - "connect to test db" - try: - connect().close() - return False - except psycopg2.OperationalError as exc: - dbname = connection_params["dbname"] - nosuch_db = 'database "%s" does not exist' % dbname - if nosuch_db in str(exc): - try: - master = connect(dbname="postgres") - master.rollback() - master.autocommit = True - cursor = master.cursor() - cursor.execute("CREATE DATABASE %s" % dbname) - cursor.close() - master.close() - except psycopg2.Error as exc: - message = ( - "Unable to connect to or create test db %s.\nThe error is: %s" - % (dbname, exc) - ) - raise RuntimeError(message) - return True - - -def drop_db(): - "Drop test db" - master = connect(dbname="postgres") - master.rollback() - master.autocommit = True - cursor = master.cursor() - cursor.execute("DROP DATABASE %s" % connection_params["dbname"]) - cursor.close() - master.close() + db_connection.drop_db() @pytest.fixture @@ -79,7 +26,7 @@ def client_encoding(request): def adaptor(request, db, client_encoding): if not request.param.supports_encoding(client_encoding): pytest.skip("Unsupported encoding for {request.param}") - adaptor = request.param(connection_params, client_encoding) + adaptor = request.param(db_connection.connection_params, client_encoding) inst = request.instance if isinstance(inst, TemporaryTable): if db_connection.IS_DSQL: @@ -104,20 +51,8 @@ def dsql_table(adaptor, inst): def temporary_table(adaptor, inst): for adaptor in no_table(adaptor): - # use psycopg2 connection to create extensions if necessary - psycopg2_conn = connect() - for extension in inst.extensions: - try: - with psycopg2_conn.cursor() as cur: - cur.execute("CREATE EXTENSION {}".format(extension)) - psycopg2_conn.commit() - except ( - psycopg2.errors.DuplicateObject, - psycopg2.errors.UndefinedFile, # postgres <= 14 - psycopg2.errors.FeatureNotSupported, # postgres >= 15 - ): - psycopg2_conn.rollback() - psycopg2_conn.close() + if inst.extensions: + db_connection.create_extensions(inst.extensions) conn = adaptor.conn try: diff --git a/tests/db_connection.py b/tests/db_connection.py index ee40d4e..338e3fe 100644 --- a/tests/db_connection.py +++ b/tests/db_connection.py @@ -1,6 +1,12 @@ +import contextlib import os import re +try: + import psycopg +except ModuleNotFoundError: + import psycopg2 as psycopg + DB_HOST = os.getenv("POSTGRES_HOST") IS_DSQL = bool(DB_HOST and re.match(r"^\w+\.dsql\.\w\w-\w+-[1-9]\.on\.aws$", DB_HOST)) @@ -43,3 +49,77 @@ def get_port(): if search_pattern.fullmatch(name): return int(val) return int(os.getenv("POSTGRES_PORT", "5432")) + + +connection_params = get_connection_params() + + +def connect(**kwargs): + kw = connection_params.copy() + kw.update(kwargs) + conn = psycopg.connect(**kw) + return conn + + +def create_db(): + "connect to test db" + try: + connect().close() + return False + except psycopg.OperationalError as exc: + dbname = connection_params["dbname"] + nosuch_db = 'database "%s" does not exist' % dbname + if nosuch_db in str(exc): + try: + master = connect(dbname="postgres") + master.rollback() + master.autocommit = True + cursor = master.cursor() + cursor.execute("CREATE DATABASE %s" % dbname) + cursor.close() + master.close() + except psycopg.Error as exc: + message = ( + "Unable to connect to or create test db %s.\nThe error is: %s" + % (dbname, exc) + ) + raise RuntimeError(message) + return True + + +def drop_db(): + "Drop test db" + try: + master = connect(dbname="postgres") + master.rollback() + master.autocommit = True + cursor = master.cursor() + cursor.execute("DROP DATABASE %s" % connection_params["dbname"]) + cursor.close() + master.close() + except psycopg.OperationalError: + pass + + +def create_extensions(extensions): + # always use psycopg2 connection to create extensions if necessary + conn = connect() + for extension in extensions: + try: + with conn.cursor() as cur: + cur.execute("CREATE EXTENSION {}".format(extension)) + conn.commit() + except ( + psycopg.errors.DuplicateObject, + psycopg.errors.UndefinedFile, # postgres <= 14 + psycopg.errors.FeatureNotSupported, # postgres >= 15 + ): + conn.rollback() + conn.close() + + +@contextlib.contextmanager +def conninfo(connection_params): + conn = psycopg.connect(**connection_params) + yield conn.info + conn.close() diff --git a/tox.ini b/tox.ini index 66334c5..a53cd57 100644 --- a/tox.ini +++ b/tox.ini @@ -4,6 +4,8 @@ envlist = py312-pg{13,14,15,16,17,18} vector psycopg28 + psycopg29-only + psycopg3-only [testenv] deps = pytest @@ -51,6 +53,18 @@ deps = pytest psycopg2==2.8.* docker = pg14 +[testenv:psycopg29-only] +base_python = python3.12 +deps = + pytest + psycopg2~=2.9 +docker = pg17 +[testenv:psycopg3-only] +base_python = python3.12 +deps = + pytest + psycopg[binary] +docker = pg17 [docker:pg13] image = postgres:13 environment= From 9ffbc8613d52306468e68d87cd42b8ba687ca3a4 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Wed, 24 Dec 2025 23:37:41 +0200 Subject: [PATCH 23/24] test: update docker compose test --- docker-compose.yml | 3 +-- docker/Dockerfile | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 2dc8272..671b1af 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,3 @@ -version: '2' services: pgcopy: build: @@ -15,7 +14,7 @@ services: - pgsql pgsql: - image: pgvector/pgvector:pg12 + image: pgvector/pgvector:pg18 environment: - POSTGRES_DB=pgcopy_test - POSTGRES_USER=postgres diff --git a/docker/Dockerfile b/docker/Dockerfile index c6d24c4..cada991 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.10-slim-buster +FROM python:3.14-slim ARG DEBIAN_FRONTEND=noninteractive @@ -7,13 +7,13 @@ RUN \ apt-get install -y --no-install-recommends \ gcc \ libpq-dev \ - netcat \ + netcat-traditional \ python3 \ python3-dev \ python3-pip \ python3-setuptools -RUN pip3 install pytest==8.1.1 +RUN pip3 install pytest psycopg COPY ./ /opt/install WORKDIR /opt/install RUN pip3 install . From b7faad9cb3b80f69d1451e6d43d2a502d2382ce9 Mon Sep 17 00:00:00 2001 From: Aryeh Leib Taurog Date: Wed, 24 Dec 2025 23:44:29 +0200 Subject: [PATCH 24/24] docs: list supported db adaptors --- README.rst | 17 +++++++++++++++-- docs/installation.rst | 17 +++++++++-------- docs/testing.rst | 5 +++++ pgcopy/copy.py | 8 +++++++- 4 files changed, 36 insertions(+), 11 deletions(-) diff --git a/README.rst b/README.rst index 1bc4b99..8398663 100644 --- a/README.rst +++ b/README.rst @@ -26,6 +26,7 @@ PostgreSQL with `binary copy`_. Features --------- +* Support for multiple db adaptors * Support for many data types * Support for multi-dimensional array types * Support for schema and schema search path @@ -42,7 +43,7 @@ Quickstart from datetime import datetime from pgcopy import CopyManager - import psycopg2 + import psycopg cols = ('id', 'timestamp', 'location', 'temperature') now = datetime.now() records = [ @@ -50,7 +51,7 @@ Quickstart (1, now, 'New York', 75.6), (2, now, 'Moscow', 54.3), ] - conn = psycopg2.connect(database='weather_db') + conn = psycopg.connect(database='weather_db') mgr = CopyManager(conn, 'measurements_table', cols) mgr.copy(records) @@ -59,6 +60,14 @@ Quickstart .. home-end +Supported Adaptors +------------------- + +* psycopg2_ +* psycopg_ +* pg8000_ +* PyGreSQL_ + Supported datatypes ------------------- @@ -96,6 +105,10 @@ Documentation .. _binary copy: http://www.postgresql.org/docs/9.3/static/sql-copy.html .. _psycopg2: https://pypi.org/project/psycopg2/ +.. _psycopg: https://pypi.org/project/psycopg/ +.. _pg8000: https://pypi.org/project/pg8000/ +.. _PyGreSQL: https://pypi.org/project/PyGreSQL/ + .. _pytz: https://pypi.org/project/pytz/ .. _pytest: https://pypi.org/project/pytest/ .. _Tox: https://tox.readthedocs.io/en/latest/ diff --git a/docs/installation.rst b/docs/installation.rst index 4f9b878..8c3e462 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -7,15 +7,14 @@ To install:: Dependencies """""""""""" -pgcopy requires pytz_ and the psycopg2_ db adapter. -pytest_ is required to run the tests. +pgcopy requires Python3, pytz_, and a db adaptor. The supported adaptors are: -Due to technical problems with binary distributions, `psycopg2 versions -2.8 and later`_ have separate packages for binary install. This complicates -installation in some situations, as it requires the dev tools to build psycopg2. +* psycopg2_ +* psycopg_ +* pg8000_ +* PyGreSQL_ -If you do not want to build psycopg2 for each installation, the recommended -approach is to create a psycopg2 wheel for distribution to production machines +pytest_ and one of psycopg2_ or psycopg_ is required to run the tests. Compatibility """"""""""""" @@ -29,7 +28,9 @@ PostgreSQL versions 13 -- 18, as well as `Aurora DSQL`_ Please upgrade to Python 3. .. _psycopg2: https://pypi.org/project/psycopg2/ +.. _psycopg: https://pypi.org/project/psycopg/ +.. _pg8000: https://pypi.org/project/pg8000/ +.. _PyGreSQL: https://pypi.org/project/PyGreSQL/ .. _pytz: https://pypi.org/project/pytz/ .. _pytest: https://pypi.org/project/pytest/ -.. _psycopg2 versions 2.8 and later: https://www.psycopg.org/docs/install#change-in-binary-packages-between-psycopg-2-7-and-2-8 .. _Aurora DSQL: https://docs.aws.amazon.com/aurora-dsql/latest/userguide/what-is-aurora-dsql.html diff --git a/docs/testing.rst b/docs/testing.rst index 95606c8..9fd4c5b 100644 --- a/docs/testing.rst +++ b/docs/testing.rst @@ -15,6 +15,9 @@ database: * ``POSTGRES_PASSWORD`` +One of psycopg2_ or psycopg_ is required to run the tests. The test suite +will automatically discover all supported db adaptors. + For more thorough testing, tox_ with tox-docker_ will run tests on python versions 3.9 -- 3.14 and postgresql versions 13 -- 18:: @@ -37,3 +40,5 @@ boto3 must be installed and ``POSTGRES_HOST`` set to the dsql endpoint. .. _pytest: https://pypi.org/project/pytest/ .. _tox: https://tox.wiki .. _tox-docker: https://tox-docker.readthedocs.io +.. _psycopg2: https://pypi.org/project/psycopg2/ +.. _psycopg: https://pypi.org/project/psycopg/ diff --git a/pgcopy/copy.py b/pgcopy/copy.py index a483bd5..52797c7 100644 --- a/pgcopy/copy.py +++ b/pgcopy/copy.py @@ -236,8 +236,14 @@ class CopyManager(object): Inspects the database on instantiation for the column types. + supported adaptors: + + * psycopg2 + * psycopg + * pg8000 + * PyGreSQL + :param conn: a database connection - :type conn: psycopg2 connection :param table: the table name. Schema may be specified using dot notation: ``schema.table``. :type table: str