diff --git a/README.rst b/README.rst index 1bc4b99..8398663 100644 --- a/README.rst +++ b/README.rst @@ -26,6 +26,7 @@ PostgreSQL with `binary copy`_. Features --------- +* Support for multiple db adaptors * Support for many data types * Support for multi-dimensional array types * Support for schema and schema search path @@ -42,7 +43,7 @@ Quickstart from datetime import datetime from pgcopy import CopyManager - import psycopg2 + import psycopg cols = ('id', 'timestamp', 'location', 'temperature') now = datetime.now() records = [ @@ -50,7 +51,7 @@ Quickstart (1, now, 'New York', 75.6), (2, now, 'Moscow', 54.3), ] - conn = psycopg2.connect(database='weather_db') + conn = psycopg.connect(database='weather_db') mgr = CopyManager(conn, 'measurements_table', cols) mgr.copy(records) @@ -59,6 +60,14 @@ Quickstart .. home-end +Supported Adaptors +------------------- + +* psycopg2_ +* psycopg_ +* pg8000_ +* PyGreSQL_ + Supported datatypes ------------------- @@ -96,6 +105,10 @@ Documentation .. _binary copy: http://www.postgresql.org/docs/9.3/static/sql-copy.html .. _psycopg2: https://pypi.org/project/psycopg2/ +.. _psycopg: https://pypi.org/project/psycopg/ +.. _pg8000: https://pypi.org/project/pg8000/ +.. _PyGreSQL: https://pypi.org/project/PyGreSQL/ + .. _pytz: https://pypi.org/project/pytz/ .. _pytest: https://pypi.org/project/pytest/ .. _Tox: https://tox.readthedocs.io/en/latest/ diff --git a/docker-compose.yml b/docker-compose.yml index 2dc8272..671b1af 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,3 @@ -version: '2' services: pgcopy: build: @@ -15,7 +14,7 @@ services: - pgsql pgsql: - image: pgvector/pgvector:pg12 + image: pgvector/pgvector:pg18 environment: - POSTGRES_DB=pgcopy_test - POSTGRES_USER=postgres diff --git a/docker/Dockerfile b/docker/Dockerfile index c6d24c4..cada991 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.10-slim-buster +FROM python:3.14-slim ARG DEBIAN_FRONTEND=noninteractive @@ -7,13 +7,13 @@ RUN \ apt-get install -y --no-install-recommends \ gcc \ libpq-dev \ - netcat \ + netcat-traditional \ python3 \ python3-dev \ python3-pip \ python3-setuptools -RUN pip3 install pytest==8.1.1 +RUN pip3 install pytest psycopg COPY ./ /opt/install WORKDIR /opt/install RUN pip3 install . diff --git a/docs/installation.rst b/docs/installation.rst index 4f9b878..8c3e462 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -7,15 +7,14 @@ To install:: Dependencies """""""""""" -pgcopy requires pytz_ and the psycopg2_ db adapter. -pytest_ is required to run the tests. +pgcopy requires Python3, pytz_, and a db adaptor. The supported adaptors are: -Due to technical problems with binary distributions, `psycopg2 versions -2.8 and later`_ have separate packages for binary install. This complicates -installation in some situations, as it requires the dev tools to build psycopg2. +* psycopg2_ +* psycopg_ +* pg8000_ +* PyGreSQL_ -If you do not want to build psycopg2 for each installation, the recommended -approach is to create a psycopg2 wheel for distribution to production machines +pytest_ and one of psycopg2_ or psycopg_ is required to run the tests. Compatibility """"""""""""" @@ -29,7 +28,9 @@ PostgreSQL versions 13 -- 18, as well as `Aurora DSQL`_ Please upgrade to Python 3. .. _psycopg2: https://pypi.org/project/psycopg2/ +.. _psycopg: https://pypi.org/project/psycopg/ +.. _pg8000: https://pypi.org/project/pg8000/ +.. _PyGreSQL: https://pypi.org/project/PyGreSQL/ .. _pytz: https://pypi.org/project/pytz/ .. _pytest: https://pypi.org/project/pytest/ -.. _psycopg2 versions 2.8 and later: https://www.psycopg.org/docs/install#change-in-binary-packages-between-psycopg-2-7-and-2-8 .. _Aurora DSQL: https://docs.aws.amazon.com/aurora-dsql/latest/userguide/what-is-aurora-dsql.html diff --git a/docs/testing.rst b/docs/testing.rst index 95606c8..9fd4c5b 100644 --- a/docs/testing.rst +++ b/docs/testing.rst @@ -15,6 +15,9 @@ database: * ``POSTGRES_PASSWORD`` +One of psycopg2_ or psycopg_ is required to run the tests. The test suite +will automatically discover all supported db adaptors. + For more thorough testing, tox_ with tox-docker_ will run tests on python versions 3.9 -- 3.14 and postgresql versions 13 -- 18:: @@ -37,3 +40,5 @@ boto3 must be installed and ``POSTGRES_HOST`` set to the dsql endpoint. .. _pytest: https://pypi.org/project/pytest/ .. _tox: https://tox.wiki .. _tox-docker: https://tox-docker.readthedocs.io +.. _psycopg2: https://pypi.org/project/psycopg2/ +.. _psycopg: https://pypi.org/project/psycopg/ diff --git a/pgcopy/backend.py b/pgcopy/backend.py new file mode 100644 index 0000000..c90d3bd --- /dev/null +++ b/pgcopy/backend.py @@ -0,0 +1,254 @@ +"psycopg backends" +import codecs +import collections +import contextlib +import importlib +import os + +from .errors import UnsupportedConnectionError +from .thread import RaisingThread + + +def for_connection(conn): + sources = [cls.__module__.split(".")[0] for cls in conn.__class__.mro()] + if "psycopg2" in sources: + return Psycopg2Backend(conn) + if "psycopg" in sources: + return Psycopg3Backend(conn) + if "pgdb" in sources: + return PyGreSQLBackend(conn) + if "pg8000" in sources: + return Pg8000Backend(conn) + message = f"{conn.__class__.__name__} is not a supported connection type" + raise UnsupportedConnectionError(message) + + +def copy_sql(schema, table, columns): + column_list = '", "'.join(columns) + cmd = 'COPY "{0}"."{1}" ("{2}") FROM STDIN WITH BINARY' + return cmd.format(schema, table, column_list) + + +class Psycopg2Backend: + def __init__(self, conn): + self.conn = conn + self.adaptor = importlib.import_module("psycopg2") + self.adaptor.extras = importlib.import_module("psycopg2.extras") + + def get_encoding(self): + return self.adaptor.extensions.encodings[self.conn.encoding] + + def namedtuple_cursor(self): + factory = self.adaptor.extras.NamedTupleCursor + return self.conn.cursor(cursor_factory=factory) + + def copy(self, schema, table, columns, fobject_factory): + sql = copy_sql(schema, table, columns) + return Psycopg2Copy(self.conn, sql, fobject_factory) + + def threading_copy(self, schema, table, columns): + sql = copy_sql(schema, table, columns) + return Psycopg2ThreadingCopy(self.conn, sql) + + +class Psycopg2Copy: + def __init__(self, conn, sql, fobject_factory): + self.conn = conn + self.sql = sql + self.datastream = fobject_factory() + + def __enter__(self): + return self.datastream + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self.datastream.seek(0) + self.copystream() + self.datastream.close() + + def copystream(self): + with self.conn.cursor() as cur: + cur.copy_expert(self.sql, self.datastream) + + +class Psycopg2ThreadingCopy: + def __init__(self, conn, sql): + self.conn = conn + self.sql = sql + r_fd, w_fd = os.pipe() + self.rstream = os.fdopen(r_fd, "rb") + self.wstream = os.fdopen(w_fd, "wb") + + def __enter__(self): + self.copy_thread = RaisingThread(target=self.copystream) + self.copy_thread.start() + return self.wstream + + def __exit__(self, exc_type, exc_val, exc_tb): + self.wstream.close() + self.copy_thread.join() + + def copystream(self): + with self.conn.cursor() as cur: + cur.copy_expert(self.sql, self.rstream) + + +class Psycopg3Backend: + def __init__(self, conn): + self.conn = conn + self.adaptor = importlib.import_module("psycopg") + + def get_encoding(self): + return self.conn.info.encoding + + def namedtuple_cursor(self): + factory = self.adaptor.rows.namedtuple_row + return self.conn.cursor(row_factory=factory) + + @contextlib.contextmanager + def copy(self, schema, table, columns, _): + sql = copy_sql(schema, table, columns) + with self.conn.cursor() as cur: + with cur.copy(sql) as copy: + yield copy + + @contextlib.contextmanager + def threading_copy(self, schema, table, columns): + sql = copy_sql(schema, table, columns) + with self.conn.cursor() as cur: + with cur.copy(sql) as copy: + yield copy + + +class PyGreSQLBackend: + def __init__(self, conn): + self.conn = conn + + def get_encoding(self): + with self.conn.cursor() as cur: + cur.execute("SHOW client_encoding") + row = cur.fetchone() + return codecs.lookup(row.client_encoding).name + + def namedtuple_cursor(self): + return self.conn.cursor() + + def copy(self, schema, table, columns, fobject_factory): + return PyGreSQLCopy(self.conn, schema, table, columns, fobject_factory) + + +class PyGreSQLCopy: + def __init__(self, conn, schema, table, columns, fobject_factory): + self.conn = conn + self.table = f"{schema}.{table}" + self.columns = columns + self.datastream = fobject_factory() + + def __enter__(self): + return self.datastream + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self.datastream.seek(0) + self.copystream() + self.datastream.close() + + def copystream(self): + with self.conn.cursor() as cur: + cur.copy_from( + self.datastream, + self.table, + format="binary", + columns=self.columns, + ) + + +class PyGreSQLThreadingCopy: + def __init__(self, conn, schema, table, columns): + self.conn = conn + self.table = f"{schema}.{table}" + self.columns = columns + r_fd, w_fd = os.pipe() + self.rstream = os.fdopen(r_fd, "rb") + self.wstream = os.fdopen(w_fd, "wb") + + def __enter__(self): + self.copy_thread = RaisingThread(target=self.copystream) + self.copy_thread.start() + return self.wstream + + def __exit__(self, exc_type, exc_val, exc_tb): + self.wstream.close() + self.copy_thread.join() + + def copystream(self): + with self.conn.cursor() as cur: + cur.copy_from( + self.rstream, + self.table, + format="binary", + columns=self.columns, + ) + + +class Pg8000Backend: + NamedTupleCursor = None + + def __init__(self, conn): + self.conn = conn + + def get_encoding(self): + with contextlib.closing(self.namedtuple_cursor()) as cur: + cur.execute("SHOW client_encoding") + row = cur.fetchone() + return codecs.lookup(row.client_encoding).name + + def namedtuple_cursor(self): + if not Pg8000Backend.NamedTupleCursor: + cur = self.conn.cursor() + Cursor = cur.__class__ + cur.close() + + class NamedTupleCursor(Cursor): + def __next__(self): + val = super().__next__() + context = self._context + if context is None: + return val # raised an error already + rowclass = getattr(context, "_pgcopy_row_class", None) + if not rowclass: + columns = context.columns + if columns is None or len(columns) == 0: + return val # probably also raised an error + column_names = [col["name"] for col in columns] + rowclass = collections.namedtuple("Row", column_names) + context._pgcopy_row_class = rowclass + return rowclass(*val) + + Pg8000Backend.NamedTupleCursor = NamedTupleCursor + return Pg8000Backend.NamedTupleCursor(self.conn) + + def copy(self, schema, table, columns, fobject_factory): + sql = copy_sql(schema, table, columns) + return Pg8000Copy(self.conn, sql, fobject_factory) + + +class Pg8000Copy: + def __init__(self, conn, sql, fobject_factory): + self.conn = conn + self.sql = sql + self.datastream = fobject_factory() + + def __enter__(self): + return self.datastream + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self.datastream.seek(0) + self.copystream() + self.datastream.close() + + def copystream(self): + cur = self.conn.cursor() + cur.execute(self.sql, stream=self.datastream) + cur.close() diff --git a/pgcopy/copy.py b/pgcopy/copy.py index 17ff060..52797c7 100644 --- a/pgcopy/copy.py +++ b/pgcopy/copy.py @@ -11,9 +11,7 @@ except ImportError: pass -from psycopg2.extensions import encodings - -from . import errors, inspect, util +from . import backend, errors, inspect, util from .thread import RaisingThread __all__ = ["CopyManager"] @@ -238,8 +236,14 @@ class CopyManager(object): Inspects the database on instantiation for the column types. + supported adaptors: + + * psycopg2 + * psycopg + * pg8000 + * PyGreSQL + :param conn: a database connection - :type conn: psycopg2 connection :param table: the table name. Schema may be specified using dot notation: ``schema.table``. :type table: str @@ -257,7 +261,11 @@ def __init__(self, conn, table, cols): **type_formatters, **self.type_formatters, } - self.conn = conn + self.backend = backend.for_connection(conn) + self.implements_threading_copy = hasattr(self.backend, "threading_copy") + if not self.implements_threading_copy: + self.threading_copy = self.copy + if "." in table: self.schema, self.table = table.split(".", 1) else: @@ -267,8 +275,8 @@ def __init__(self, conn, table, cols): def compile(self): self.formatters = [] - type_dict = inspect.get_types(self.conn, self.schema, self.table) - encoding = encodings[self.conn.encoding] + type_dict = inspect.get_types(self.backend, self.schema, self.table) + encoding = self.backend.get_encoding() for column in self.cols: att = type_dict.get(column) if att is None: @@ -314,11 +322,9 @@ def copy(self, data, fobject_factory=tempfile.TemporaryFile): ``ValueError`` is raised if a null value is provided for a column with non-null constraint. """ - datastream = fobject_factory() - self.writestream(data, datastream) - datastream.seek(0) - self.copystream(datastream) - datastream.close() + self._copy( + data, self.backend.copy(self.schema, self.table, self.cols, fobject_factory) + ) def threading_copy(self, data): """ @@ -327,14 +333,18 @@ def threading_copy(self, data): :param data: the data to be inserted :type data: iterable of iterables """ - r_fd, w_fd = os.pipe() - rstream = os.fdopen(r_fd, "rb") - wstream = os.fdopen(w_fd, "wb") - copy_thread = RaisingThread(target=self.copystream, args=(rstream,)) - copy_thread.start() - self.writestream(data, wstream) - wstream.close() - copy_thread.join() + self._copy( + data, self.backend.threading_copy(self.schema, self.table, self.cols) + ) + + def _copy(self, data, copy): + try: + with copy as datastream: + self.writestream(data, datastream) + except Exception as e: + templ = "error doing binary copy into {0}.{1}:\n{2}" + e.message = templ.format(self.schema, self.table, e) + raise e def writestream(self, data, datastream): datastream.write(BINCOPY_HEADER) @@ -348,15 +358,3 @@ def writestream(self, data, datastream): rdat.extend(d) datastream.write(struct.pack("".join(fmt), *rdat)) datastream.write(BINCOPY_TRAILER) - - def copystream(self, datastream): - columns = '", "'.join(self.cols) - cmd = 'COPY "{0}"."{1}" ("{2}") FROM STDIN WITH BINARY' - sql = cmd.format(self.schema, self.table, columns) - cursor = self.conn.cursor() - try: - cursor.copy_expert(sql, datastream) - except Exception as e: - templ = "error doing binary copy into {0}.{1}:\n{2}" - e.message = templ.format(self.schema, self.table, e) - raise e diff --git a/pgcopy/errors/__init__.py b/pgcopy/errors/__init__.py index fc09d33..90ead89 100644 --- a/pgcopy/errors/__init__.py +++ b/pgcopy/errors/__init__.py @@ -9,3 +9,7 @@ from .py2 import raise_from else: from .py3 import raise_from + + +class UnsupportedConnectionError(TypeError): + "connection type not supported" diff --git a/pgcopy/inspect.py b/pgcopy/inspect.py index eca2195..c57c1fe 100644 --- a/pgcopy/inspect.py +++ b/pgcopy/inspect.py @@ -1,7 +1,7 @@ -from psycopg2.extras import NamedTupleCursor +"inspect column types" -def get_types(conn, schema, table): +def get_types(backend, schema, table): # for arrays: # typname has '_' prefix # attndims > 0 @@ -24,6 +24,6 @@ def get_types(conn, schema, table): WHERE n.nspname = %s and relname = %s and attnum > 0 ORDER BY c.relname, a.attnum; """ - cursor = conn.cursor(cursor_factory=NamedTupleCursor) + cursor = backend.namedtuple_cursor() cursor.execute(query, (schema, table)) return {r.attname: r for r in cursor} diff --git a/pgcopy/util.py b/pgcopy/util.py index ff29566..3191a6f 100644 --- a/pgcopy/util.py +++ b/pgcopy/util.py @@ -3,7 +3,6 @@ import string from datetime import datetime, time -from psycopg2 import sql from pytz import UTC @@ -34,14 +33,13 @@ def array_iter(arr): def get_schema(conn, table): cur = conn.cursor() - quoted_table = sql.Identifier(table).as_string(cur) query = """ SELECT n.nspname, c.relname FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE c.oid = %s::regclass """ - cur.execute(query, (quoted_table,)) + cur.execute(query, (f'"{table}"',)) return cur.fetchone()[0] diff --git a/setup.py b/setup.py index 3aee002..ffb6c74 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def get_version(package_name, default="0.1"): license="MIT", url="https://pgcopy.readthedocs.io/en/latest/", packages=["pgcopy", "pgcopy.errors", "pgcopy.contrib"], - install_requires=["psycopg2", "pytz"], + install_requires=["pytz"], classifiers=[ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.9", diff --git a/tests/adaptor.py b/tests/adaptor.py new file mode 100644 index 0000000..96ff106 --- /dev/null +++ b/tests/adaptor.py @@ -0,0 +1,131 @@ +import argparse +import codecs +import contextlib +import importlib +import os +import sys + +from . import db_connection + + +def available_adaptors(): + adaptors = [Psycopg2, Psycopg3, PyGreSQL, Pg8000] + return [a for a in adaptors if a.load()] + + +class Adaptor: + module_names: list[str] + pgcode_attribute = "sqlstate" + m: argparse.Namespace + + @classmethod + def load(cls): + try: + modules = [importlib.import_module(m) for m in cls.module_names] + moddict = dict(zip(map(modname, cls.module_names), modules)) + cls.m = argparse.Namespace(**moddict) + return True + except ModuleNotFoundError: + return False + + @classmethod + def get_pgcode(cls, err): + return getattr(err, cls.pgcode_attribute, None) + + +def modname(modpath): + return modpath.rsplit(".", 1)[-1] + + +class Psycopg2(Adaptor): + module_names = ["psycopg2", "psycopg2.extras"] + pgcode_attribute = "pgcode" + + def __init__(self, connection_params, client_encoding): + self.conn = self.m.psycopg2.connect( + connection_factory=self.m.extras.LoggingConnection, + **connection_params, + ) + self.conn.initialize(sys.stderr) + self.conn.autocommit = False + self.conn.set_client_encoding(client_encoding) + self.unsupported_type = self.m.psycopg2.errors.UndefinedObject + self.integrity_error = self.m.psycopg2.errors.IntegrityError + + @staticmethod + def supports_encoding(encoding): + return True + + +class Psycopg3(Adaptor): + module_names = ["psycopg"] + + def __init__(self, connection_params, client_encoding): + self.conn = self.m.psycopg.connect(**connection_params) + self.conn.autocommit = False + self.conn.execute(f"SET client_encoding='{client_encoding}'") + self.unsupported_type = self.m.psycopg.errors.UndefinedObject + self.integrity_error = self.m.psycopg.errors.IntegrityError + + @staticmethod + def supports_encoding(encoding): + return True + + +class PyGreSQL(Adaptor): + module_names = ["pgdb"] + + def __init__(self, connection_params, client_encoding): + self.conn = self.m.pgdb.connect(**connection_params) + self.conn.autocommit = False + self.conn.execute(f"SET client_encoding='{client_encoding}'") + self.unsupported_type = self.conn.ProgrammingError + self.integrity_error = self.conn.IntegrityError + + @staticmethod + def supports_encoding(encoding): + try: + codecs.lookup(encoding) + return True + except LookupError: + return False + + +class Pg8000(Adaptor): + module_names = ["pg8000.dbapi", "pg8000.exceptions"] + + def __init__(self, connection_params, client_encoding): + params = self.get_connection_parameters(connection_params) + self.conn = self.m.dbapi.connect(**params) + self.conn.autocommit = False + with contextlib.closing(self.conn.cursor()) as cur: + cur.execute(f"SET client_encoding='{client_encoding}'") + self.unsupported_type = self.m.exceptions.DatabaseError + self.integrity_error = self.m.exceptions.DatabaseError + + def get_connection_parameters(self, connection_params): + with db_connection.conninfo(connection_params) as conninfo: + parameters = { + "user": conninfo.user, + "database": conninfo.dbname, + } + host = conninfo.host + if host.startswith("/"): + sock = f"{host}/.s.PGSQL.{conninfo.port}" + if os.path.exists(sock): + parameters["unix_sock"] = sock + return parameters + parameters["host"] = "localhost" + else: + parameters["host"] = host + parameters["port"] = conninfo.port + parameters["password"] = conninfo.password + return parameters + + @staticmethod + def supports_encoding(encoding): + return encoding.upper() == "UTF8" + + @classmethod + def get_pgcode(cls, err): + return err.args[0]["C"] diff --git a/tests/conftest.py b/tests/conftest.py index ec3020b..23599fc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,128 +1,90 @@ -import sys +import contextlib -import psycopg2 import pytest -from psycopg2.extras import LoggingConnection from . import db_connection +from .adaptor import available_adaptors from .db import TemporaryTable -connection_params = db_connection.get_connection_params() +# pylint: disable=redefined-outer-name,consider-using-f-string @pytest.fixture(scope="session") def db(): - drop = create_db() + drop = db_connection.create_db() yield if drop: - try: - drop_db() - except psycopg2.OperationalError: - pass - - -def connect(**kwargs): - kw = connection_params.copy() - kw.update(kwargs) - conn = psycopg2.connect(connection_factory=LoggingConnection, **kw) - conn.initialize(sys.stderr) - return conn - - -def create_db(): - "connect to test db" - try: - connect().close() - return False - except psycopg2.OperationalError as exc: - nosuch_db = 'database "%s" does not exist' % connection_params["dbname"] - if nosuch_db in str(exc): - try: - master = connect(dbname="postgres") - master.rollback() - master.autocommit = True - cursor = master.cursor() - cursor.execute("CREATE DATABASE %s" % connection_params["dbname"]) - cursor.close() - master.close() - except psycopg2.Error as exc: - message = ( - "Unable to connect to or create test db " - + connection_params["dbname"] - + ".\nThe error is: %s" % exc - ) - raise RuntimeError(message) - return True - - -def drop_db(): - "Drop test db" - master = connect(dbname="postgres") - master.rollback() - master.autocommit = True - cursor = master.cursor() - cursor.execute("DROP DATABASE %s" % connection_params["dbname"]) - cursor.close() - master.close() + db_connection.drop_db() @pytest.fixture -def conn(request, db): +def client_encoding(request): + return getattr(request, "param", "UTF8") + + +@pytest.fixture(params=available_adaptors()) +def adaptor(request, db, client_encoding): + if not request.param.supports_encoding(client_encoding): + pytest.skip("Unsupported encoding for {request.param}") + adaptor = request.param(db_connection.connection_params, client_encoding) inst = request.instance if isinstance(inst, TemporaryTable): if db_connection.IS_DSQL: - yield from dsql_table(request, inst) + yield from dsql_table(adaptor, inst) else: - yield from temporary_table(request, inst) + yield from temporary_table(adaptor, inst) else: - yield from no_table(request) + yield from no_table(adaptor) -def dsql_table(request, inst): - for conn in temporary_table(request, inst): +def dsql_table(adaptor, inst): + for adaptor in temporary_table(adaptor, inst): + conn = adaptor.conn conn.commit() yield conn conn.commit() if drop_sql := inst.drop_sql(): - with conn.cursor() as cur: + with contextlib.closing(conn.cursor()) as cur: cur.execute(drop_sql) conn.commit() -def temporary_table(request, inst): - for conn in no_table(request): - for extension in inst.extensions: - try: - with conn.cursor() as cur: - cur.execute("CREATE EXTENSION {}".format(extension)) - conn.commit() - except ( - psycopg2.errors.DuplicateObject, - psycopg2.errors.UndefinedFile, # postgres <= 14 - psycopg2.errors.FeatureNotSupported, # postgres >= 15 - ): - conn.rollback() +def temporary_table(adaptor, inst): + for adaptor in no_table(adaptor): + if inst.extensions: + db_connection.create_extensions(inst.extensions) + + conn = adaptor.conn try: - with conn.cursor() as cur: + with contextlib.closing(conn.cursor()) as cur: cur.execute(inst.create_sql(inst.tempschema)) - except ( - psycopg2.errors.FeatureNotSupported, - psycopg2.errors.UndefinedObject, - ) as e: + except adaptor.unsupported_type as e: + pgcode = adaptor.get_pgcode(e) conn.rollback() - pytest.skip("Unsupported datatype") - yield conn + if pgcode == "42704": + pytest.skip("Unsupported datatype") + else: + raise + yield adaptor -def no_table(request): - conn = connect() - conn.autocommit = False - conn.set_client_encoding(getattr(request, "param", "UTF8")) - yield conn +def no_table(adaptor): + conn = adaptor.conn + yield adaptor conn.rollback() conn.close() +@pytest.fixture +def conn(adaptor): + return adaptor.conn + + +@pytest.fixture +def integrity_error(adaptor): + return adaptor.integrity_error + + @pytest.fixture def cursor(conn): cur = conn.cursor() diff --git a/tests/db_connection.py b/tests/db_connection.py index ee40d4e..338e3fe 100644 --- a/tests/db_connection.py +++ b/tests/db_connection.py @@ -1,6 +1,12 @@ +import contextlib import os import re +try: + import psycopg +except ModuleNotFoundError: + import psycopg2 as psycopg + DB_HOST = os.getenv("POSTGRES_HOST") IS_DSQL = bool(DB_HOST and re.match(r"^\w+\.dsql\.\w\w-\w+-[1-9]\.on\.aws$", DB_HOST)) @@ -43,3 +49,77 @@ def get_port(): if search_pattern.fullmatch(name): return int(val) return int(os.getenv("POSTGRES_PORT", "5432")) + + +connection_params = get_connection_params() + + +def connect(**kwargs): + kw = connection_params.copy() + kw.update(kwargs) + conn = psycopg.connect(**kw) + return conn + + +def create_db(): + "connect to test db" + try: + connect().close() + return False + except psycopg.OperationalError as exc: + dbname = connection_params["dbname"] + nosuch_db = 'database "%s" does not exist' % dbname + if nosuch_db in str(exc): + try: + master = connect(dbname="postgres") + master.rollback() + master.autocommit = True + cursor = master.cursor() + cursor.execute("CREATE DATABASE %s" % dbname) + cursor.close() + master.close() + except psycopg.Error as exc: + message = ( + "Unable to connect to or create test db %s.\nThe error is: %s" + % (dbname, exc) + ) + raise RuntimeError(message) + return True + + +def drop_db(): + "Drop test db" + try: + master = connect(dbname="postgres") + master.rollback() + master.autocommit = True + cursor = master.cursor() + cursor.execute("DROP DATABASE %s" % connection_params["dbname"]) + cursor.close() + master.close() + except psycopg.OperationalError: + pass + + +def create_extensions(extensions): + # always use psycopg2 connection to create extensions if necessary + conn = connect() + for extension in extensions: + try: + with conn.cursor() as cur: + cur.execute("CREATE EXTENSION {}".format(extension)) + conn.commit() + except ( + psycopg.errors.DuplicateObject, + psycopg.errors.UndefinedFile, # postgres <= 14 + psycopg.errors.FeatureNotSupported, # postgres >= 15 + ): + conn.rollback() + conn.close() + + +@contextlib.contextmanager +def conninfo(connection_params): + conn = psycopg.connect(**connection_params) + yield conn.info + conn.close() diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index 81221a1..67e1056 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals +import contextlib import decimal import json import sys @@ -8,21 +9,21 @@ import pytest +# pylint: disable=consider-using-f-string + if sys.version_info < (3,): memoryview = buffer -import psycopg2.extensions from pgcopy import CopyManager, util from . import db -def test_connection_encoding(conn): - assert conn.encoding == "UTF8" - - def test_db_encoding(conn): - assert conn.info.parameter_status("server_encoding") == "UTF8" + with contextlib.closing(conn.cursor()) as cur: + cur.execute("SHOW server_encoding") + res = cur.fetchone() + assert res[0] == "UTF8" class TypeMixin(db.TemporaryTable): @@ -69,9 +70,10 @@ class TestEncoding(TypeMixin): datatypes = ["varchar(12)"] data = [("database",), ("מוסד נתונים",)] - @pytest.mark.parametrize("conn", ["UTF8", "ISO_8859_8", "WIN1255"], indirect=True) + @pytest.mark.parametrize( + "client_encoding", ["UTF8", "ISO_8859_8", "WIN1255"], indirect=True + ) def test_type(self, conn, cursor, schema_table, data): - psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, cursor) super(TestEncoding, self).test_type(conn, cursor, schema_table, data) @@ -240,8 +242,7 @@ class TestBytea(TypeMixin): ] def cast(self, v): - assert isinstance(v, memoryview) - return bytes(v) + return bytes(v) if isinstance(v, memoryview) else v class TestTime(TypeMixin): @@ -295,7 +296,7 @@ class TestUUID(TypeMixin): ] def cast(self, v): - return uuid.UUID(v) + return uuid.UUID(v) if isinstance(v, str) else v def expected(self, rec): return (self.cast(v) if isinstance(v, str) else v for v in rec) diff --git a/tests/test_errors.py b/tests/test_errors.py index 9d280fe..b8ff750 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,3 +1,4 @@ +import pgcopy.errors import pytest from pgcopy import CopyManager @@ -42,3 +43,13 @@ def test_dropped_col(self, conn, cursor, schema): msg = '"{}" is not a column of table "{}"."{}"' with pytest.raises(ValueError, match=msg.format(col, schema, self.table)): CopyManager(conn, self.table, self.cols) + + +def test_unsupported_connection(): + message = "FakeConnection is not a supported connection type" + with pytest.raises(pgcopy.errors.UnsupportedConnectionError, match=message): + CopyManager(FakeConnection(), "fake_table", ["id", "name"]) + + +class FakeConnection: + encoding = "utf8" diff --git a/tests/test_replace.py b/tests/test_replace.py index 603596b..6a80e71 100644 --- a/tests/test_replace.py +++ b/tests/test_replace.py @@ -1,17 +1,19 @@ import contextlib -import psycopg2 import pytest from pgcopy import Replace, util from . import db, db_connection - @pytest.mark.skipif(db_connection.IS_DSQL, reason="tests not supported on dsql") class TemporaryTable(db.TemporaryTable): id_col = False +def tuplist(iterable): + return [tuple(row) for row in iterable] + + class TestRenameReplace(TemporaryTable): datatypes = ["integer"] mixed_case = False @@ -25,11 +27,11 @@ def test_rename_replace(self, conn, cursor, schema): with util.RenameReplace(conn, self.table, xform) as temp: cursor.executemany(sql.format(temp), [(36,), (72,)]) cursor.execute("SELECT * FROM {}".format(self.table)) - assert list(cursor) == [(36,), (72,)] + assert tuplist(cursor) == [(36,), (72,)] cursor.execute("SELECT * FROM v") - assert list(cursor) == [(37,), (73,)] + assert tuplist(cursor) == [(37,), (73,)] cursor.execute("SELECT * FROM {}".format(self.table + "_old")) - assert list(cursor) == [(1,), (2,)] + assert tuplist(cursor) == [(1,), (2,)] class TestReplaceFallbackSchema(TemporaryTable): @@ -47,7 +49,7 @@ def test_fallback_schema_honors_search_path( with Replace(conn, self.table) as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM {}".format(schema_table)) - assert list(cursor) == [(1,)] + assert tuplist(cursor) == [(1,)] class TestReplaceDefault(TemporaryTable): @@ -67,20 +69,23 @@ def test_replace_with_default(self, conn, cursor, schema_table): with Replace(conn, schema_table) as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM {}".format(schema_table)) - assert list(cursor) == [(1, 3)] + assert tuplist(cursor) == [(1, 3)] -@contextlib.contextmanager -def replace_raises(conn, table, exc=psycopg2.IntegrityError): - """ - Wrap Replace context manager and assert - exception is thrown on context exit - """ - r = Replace(conn, table) - yield r.__enter__() - with pytest.raises(exc): - r.__exit__(None, None, None) +@pytest.fixture +def replace_raises(conn, schema_table, integrity_error): + @contextlib.contextmanager + def _replace_raises(): + """ + Wrap Replace context manager and assert + exception is thrown on context exit + """ + r = Replace(conn, schema_table) + yield r.__enter__() + with pytest.raises(integrity_error): + r.__exit__(None, None, None) + return _replace_raises class TestReplaceNotNull(TemporaryTable): mixed_case = False @@ -90,15 +95,15 @@ class TestReplaceNotNull(TemporaryTable): "integer NOT NULL", ] - def test_replace_not_null(self, conn, cursor, schema_table): + def test_replace_not_null(self, replace_raises, cursor): """ Not-null constraint is added on exit """ sql = 'INSERT INTO {} ("a") VALUES (%s)' - with replace_raises(conn, schema_table) as temp: + with replace_raises() as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM {}".format(temp)) - assert list(cursor) == [(1, None)] + assert tuplist(cursor) == [(1, None)] class TestReplaceConstraint(TemporaryTable): @@ -108,12 +113,12 @@ class TestReplaceConstraint(TemporaryTable): "integer CHECK (a > 5)", ] - def test_replace_constraint(self, conn, cursor, schema_table): + def test_replace_constraint(self, replace_raises, cursor): sql = 'INSERT INTO {} ("a") VALUES (%s)' - with replace_raises(conn, schema_table) as temp: + with replace_raises() as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM {}".format(temp)) - assert list(cursor) == [(1,)] + assert tuplist(cursor) == [(1,)] class TestReplaceNamedConstraint(TemporaryTable): @@ -137,16 +142,16 @@ class TestReplaceUniqueIndex(TemporaryTable): "integer UNIQUE", ] - def test_replace_unique_index(self, conn, cursor, schema_table): + def test_replace_unique_index(self, replace_raises, cursor): """ Not-null constraint is added on exit """ sql = 'INSERT INTO {} ("a") VALUES (%s)' - with replace_raises(conn, schema_table) as temp: + with replace_raises() as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM {}".format(temp)) - assert list(cursor) == [(1,), (1,)] + assert tuplist(cursor) == [(1,), (1,)] class TestReplaceView(TemporaryTable): @@ -160,7 +165,7 @@ def test_replace_with_view(self, conn, cursor, schema_table): with Replace(conn, schema_table) as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM v") - assert list(cursor) == [(2,)] + assert tuplist(cursor) == [(2,)] class TestReplaceViewMultiSchema(TemporaryTable): @@ -175,7 +180,7 @@ def test_replace_view_in_different_schema(self, conn, cursor, schema_table): with Replace(conn, schema_table) as temp: cursor.execute(sql.format(temp), (1,)) cursor.execute("SELECT * FROM ns.v") - assert list(cursor) == [(2,)] + assert tuplist(cursor) == [(2,)] class TestReplaceTrigger(TemporaryTable): @@ -208,7 +213,7 @@ def test_replace_with_trigger(self, conn, cursor, schema_table): cursor.execute(sql.format(temp), (1, 1)) cursor.execute(sql.format(schema_table), (2, 1)) cursor.execute("SELECT * FROM {}".format(schema_table)) - assert list(cursor) == [(1, 1), (2, 8)] + assert tuplist(cursor) == [(1, 1), (2, 8)] class TestReplaceSequence(TemporaryTable): @@ -227,4 +232,4 @@ def test_replace_with_sequence(self, conn, cursor, schema_table): cursor.execute(sql.format(schema_table), (40,)) cursor.execute(sql.format(schema_table), (40,)) cursor.execute("SELECT * FROM {}".format(schema_table)) - assert list(cursor) == [(30, 3), (40, 5)] + assert tuplist(cursor) == [(30, 3), (40, 5)] diff --git a/tests/test_threading_copy.py b/tests/test_threading_copy.py index db229df..6e3b4da 100644 --- a/tests/test_threading_copy.py +++ b/tests/test_threading_copy.py @@ -1,6 +1,5 @@ import pytest from pgcopy import CopyManager -from psycopg2.errors import BadCopyFileFormat from . import test_datatypes @@ -13,6 +12,8 @@ class TestThreadingCopy(test_datatypes.TypeMixin): def test_threading_copy(self, conn, cursor, schema_table, data): mgr = CopyManager(conn, self.table, self.cols) + if not mgr.implements_threading_copy: + pytest.skip("threading_copy not implemented") mgr.threading_copy(data) select_list = ",".join(self.cols) cursor.execute(self.select_sql(schema_table)) @@ -21,11 +22,15 @@ def test_threading_copy(self, conn, cursor, schema_table, data): def test_threading_copy_error(self, conn, cursor): data = [{}] mgr = CopyManager(conn, self.table, self.cols) - with pytest.raises(BadCopyFileFormat): + if not mgr.implements_threading_copy: + pytest.skip("threading_copy not implemented") + with pytest.raises(conn.DataError): mgr.threading_copy(data) def test_threading_copy_generator(self, conn, cursor, schema_table, data): mgr = CopyManager(conn, self.table, self.cols) + if not mgr.implements_threading_copy: + pytest.skip("threading_copy not implemented") mgr.threading_copy(iter(data)) select_list = ",".join(self.cols) cursor.execute(self.select_sql(schema_table)) @@ -34,6 +39,8 @@ def test_threading_copy_generator(self, conn, cursor, schema_table, data): def test_threading_copy_empty_generator(self, conn, cursor, schema_table): data = [] mgr = CopyManager(conn, self.table, self.cols) + if not mgr.implements_threading_copy: + pytest.skip("threading_copy not implemented") mgr.threading_copy(iter(data)) select_list = ",".join(self.cols) cursor.execute(self.select_sql(schema_table)) diff --git a/tox.ini b/tox.ini index 699e4fa..a53cd57 100644 --- a/tox.ini +++ b/tox.ini @@ -4,18 +4,22 @@ envlist = py312-pg{13,14,15,16,17,18} vector psycopg28 + psycopg29-only + psycopg3-only [testenv] deps = pytest psycopg2~=2.9 -commands = python -m pytest tests/ + psycopg[binary] + pg8000 + PyGreSQL +commands = python -m pytest tests/ --tb=native docker = pg16 setenv = POSTGRES_DB=pgcopy_tox_test POSTGRES_HOST=localhost POSTGRES_USER=postgres POSTGRES_PASSWORD=postgres -[testenv:python2.7] [testenv:py312-pg13] docker = pg13 [testenv:py312-pg14] @@ -37,6 +41,9 @@ deps = pytest pytest-cov psycopg2~=2.9 + psycopg[binary] + pg8000 + PyGreSQL commands = pytest --cov-report=term --cov-report=lcov:coverage.lcov --cov=pgcopy/ tests/ docker = pgvector @@ -46,6 +53,18 @@ deps = pytest psycopg2==2.8.* docker = pg14 +[testenv:psycopg29-only] +base_python = python3.12 +deps = + pytest + psycopg2~=2.9 +docker = pg17 +[testenv:psycopg3-only] +base_python = python3.12 +deps = + pytest + psycopg[binary] +docker = pg17 [docker:pg13] image = postgres:13 environment=