diff --git a/docs/integration/connection-pool.md b/docs/integration/connection-pool.md new file mode 100644 index 0000000..ca87f89 --- /dev/null +++ b/docs/integration/connection-pool.md @@ -0,0 +1,59 @@ +# Connection Pooling + +Connection pooling allows you to reuse database connections across multiple operations, reducing the overhead of establishing new connections. This is particularly useful in web applications where many concurrent requests need database access. + +## Usage + +Pass a `ConnectionPool` (or `AsyncConnectionPool`) to `PgDb` instead of a raw connection: + +```{.python continuation fixture:postgres_container} +from psycopg_pool import ConnectionPool +from pydantic import BaseModel +from typing import Annotated + +from embar.column.common import Integer, Text +from embar.config import EmbarConfig +from embar.db.pg import PgDb +from embar.table import Table + + +class User(Table): + embar_config: EmbarConfig = EmbarConfig(table_name="users") + id: Integer = Integer(primary=True) + name: Text = Text() + + +# Create a connection pool +pool = ConnectionPool("postgres://pg:pw@localhost:25432/db", open=True) + +# Pass the pool to PgDb +db = PgDb(pool) + +# Run migrations +db.migrate([User]).run() + +# Insert a user +db.insert(User).values(User(id=1, name="Alice")).run() + +# Query it back +class UserRead(BaseModel): + id: Annotated[int, User.id] + name: Annotated[str, User.name] + +users = db.select(UserRead).from_(User).run() +print(users) +# [UserRead(id=1, name='Alice')] + +# Clean up +pool.close() +``` + +## Unopened Pools + +If you create a pool with `open=False`, it will be automatically opened on first use: + +```python notest +pool = ConnectionPool("postgres://...", open=False) +db = PgDb(pool) +# Pool is opened automatically when the first query runs +``` diff --git a/mkdocs.yml b/mkdocs.yml index e623e40..c4907da 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -49,6 +49,7 @@ nav: - Raw SQL: query/raw-sql.md - Integration: + - Connection pool: integration/connection-pool.md - Migrations: migrations.md - Types: types.md diff --git a/pyproject.toml b/pyproject.toml index 156ca80..0747da4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ requires-python = ">=3.14" dependencies = [ "psycopg[binary]>=3.2.11", + "psycopg-pool>=3.3.0", "pydantic>=2.12.4", ] diff --git a/src/embar/db/pg.py b/src/embar/db/pg.py index e314e35..78ed723 100644 --- a/src/embar/db/pg.py +++ b/src/embar/db/pg.py @@ -13,6 +13,7 @@ from psycopg import AsyncConnection, AsyncTransaction, Connection, Transaction from psycopg.types.json import Json +from psycopg_pool import AsyncConnectionPool, ConnectionPool from pydantic import BaseModel from embar.column.base import EnumBase @@ -28,6 +29,68 @@ from embar.table import Table +class ConnectionWrapper[C: Connection | ConnectionPool]: + conn_or_pool: C + + def __init__(self, conn_or_pool: C): + self.conn_or_pool = conn_or_pool + self._cm: AbstractContextManager[Connection] | None = None + + def __enter__(self) -> Connection: + if isinstance(self.conn_or_pool, Connection): + return self.conn_or_pool + + # Ensure pool is open (idempotent if already open) + self.conn_or_pool.open() + + self._cm = self.conn_or_pool.connection() + return self._cm.__enter__() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None, + ) -> bool | None: + if self._cm is not None: + return self._cm.__exit__(exc_type, exc_val, exc_tb) + return None + + def close(self): + self.conn_or_pool.close() + + +class AsyncConnectionWrapper[C: AsyncConnection | AsyncConnectionPool]: + conn_or_pool: C + + def __init__(self, conn_or_pool: C): + self.conn_or_pool = conn_or_pool + self._cm: AbstractAsyncContextManager[AsyncConnection] | None = None + + async def __aenter__(self) -> AsyncConnection: + if isinstance(self.conn_or_pool, AsyncConnection): + return self.conn_or_pool + + # Ensure pool is open (must be awaited for async pools) + await self.conn_or_pool.open() + + self._cm = self.conn_or_pool.connection() + return await self._cm.__aenter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None, + ) -> bool | None: + if self._cm is not None: + return await self._cm.__aexit__(exc_type, exc_val, exc_tb) + return None + + async def close(self): + await self.conn_or_pool.close() + + @final class PgDb(DbBase): """ @@ -35,21 +98,21 @@ class PgDb(DbBase): """ db_type = "postgres" - conn: Connection + conn_wrapper: ConnectionWrapper[Connection | ConnectionPool] _commit_after_execute: bool = True - def __init__(self, connection: Connection): + def __init__(self, connection_or_pool: Connection | ConnectionPool): """ Create a new PgDb instance. """ - self.conn = connection + self.conn_wrapper = ConnectionWrapper(connection_or_pool) def close(self): """ Close the database connection. """ - if self.conn: - self.conn.close() + if self.conn_wrapper: + self.conn_wrapper.close() def transaction(self) -> PgDbTransaction: """ @@ -63,9 +126,7 @@ def transaction(self) -> PgDbTransaction: ... ``` """ - db_copy = PgDb(self.conn) - db_copy._commit_after_execute = False - return PgDbTransaction(db_copy) + return PgDbTransaction(self) def select[M: BaseModel](self, model: type[M]) -> SelectQuery[M, Self]: """ @@ -122,9 +183,10 @@ def execute(self, query: QuerySingle) -> None: """ Execute a query without returning results. """ - self.conn.execute(query.sql, query.params) # pyright:ignore[reportArgumentType] - if self._commit_after_execute: - self.conn.commit() + with self.conn_wrapper as conn: + conn.execute(query.sql, query.params) # pyright:ignore[reportArgumentType] + if self._commit_after_execute: + conn.commit() @override def executemany(self, query: QueryMany): @@ -132,32 +194,34 @@ def executemany(self, query: QueryMany): Execute a query with multiple parameter sets. """ params = _jsonify_dicts(query.many_params) - with self.conn.cursor() as cur: - cur.executemany(query.sql, params) # pyright:ignore[reportArgumentType] - if self._commit_after_execute: - self.conn.commit() + with self.conn_wrapper as conn: + with conn.cursor() as cur: + cur.executemany(query.sql, params) # pyright:ignore[reportArgumentType] + if self._commit_after_execute: + conn.commit() @override def fetch(self, query: QuerySingle | QueryMany) -> list[dict[str, Any]]: """ Execute a query and return results as a list of dicts. """ - with self.conn.cursor() as cur: - if isinstance(query, QuerySingle): - cur.execute(query.sql, query.params) # pyright:ignore[reportArgumentType] - else: - cur.executemany(query.sql, query.many_params, returning=True) # pyright:ignore[reportArgumentType] - - if cur.description is None: - return [] - columns: list[str] = [desc[0] for desc in cur.description] - results: list[dict[str, Any]] = [] - for row in cur.fetchall(): - data = dict(zip(columns, row)) - results.append(data) - if self._commit_after_execute: - self.conn.commit() # Commit after SELECT - return results + with self.conn_wrapper as conn: + with conn.cursor() as cur: + if isinstance(query, QuerySingle): + cur.execute(query.sql, query.params) # pyright:ignore[reportArgumentType] + else: + cur.executemany(query.sql, query.many_params, returning=True) # pyright:ignore[reportArgumentType] + + if cur.description is None: + return [] + columns: list[str] = [desc[0] for desc in cur.description] + results: list[dict[str, Any]] = [] + for row in cur.fetchall(): + data = dict(zip(columns, row)) + results.append(data) + if self._commit_after_execute: + conn.commit() # Commit after SELECT + return results @override def truncate(self, schema: str | None = None): @@ -169,10 +233,11 @@ def truncate(self, schema: str | None = None): if tables is None: return table_names = ", ".join(tables) - with self.conn.cursor() as cursor: - cursor.execute(f"TRUNCATE TABLE {table_names} CASCADE") # pyright:ignore[reportArgumentType] - if self._commit_after_execute: - self.conn.commit() + with self.conn_wrapper as conn: + with conn.cursor() as cursor: + cursor.execute(f"TRUNCATE TABLE {table_names} CASCADE") # pyright:ignore[reportArgumentType] + if self._commit_after_execute: + conn.commit() @override def drop_tables(self, schema: str | None = None): @@ -184,20 +249,22 @@ def drop_tables(self, schema: str | None = None): if tables is None: return table_names = ", ".join(tables) - with self.conn.cursor() as cursor: - cursor.execute(f"DROP TABLE {table_names} CASCADE") # pyright:ignore[reportArgumentType] - if self._commit_after_execute: - self.conn.commit() + with self.conn_wrapper as conn: + with conn.cursor() as cursor: + cursor.execute(f"DROP TABLE {table_names} CASCADE") # pyright:ignore[reportArgumentType] + if self._commit_after_execute: + conn.commit() def _get_live_table_names(self, schema: str) -> list[str] | None: - with self.conn.cursor() as cursor: - # Get all table names from public schema - cursor.execute(f"SELECT tablename FROM pg_tables WHERE schemaname = '{schema}'") # pyright:ignore[reportArgumentType] - tables = cursor.fetchall() - if not tables: - return None - table_names = [f'"{table[0]}"' for table in tables] - return table_names + with self.conn_wrapper as conn: + with conn.cursor() as cursor: + # Get all table names from public schema + cursor.execute(f"SELECT tablename FROM pg_tables WHERE schemaname = '{schema}'") # pyright:ignore[reportArgumentType] + tables = cursor.fetchall() + if not tables: + return None + table_names = [f'"{table[0]}"' for table in tables] + return table_names class PgDbTransaction: @@ -206,15 +273,33 @@ class PgDbTransaction: """ _db: PgDb + _conn_cm: AbstractContextManager[Connection] | None = None _tx: AbstractContextManager[Transaction] | None = None def __init__(self, db: PgDb): self._db = db def __enter__(self) -> PgDb: - self._tx = self._db.conn.transaction() + pool_or_conn = self._db.conn_wrapper.conn_or_pool + + if isinstance(pool_or_conn, ConnectionPool): + # Ensure pool is open (idempotent if already open) + pool_or_conn.open() + + # Check out a dedicated connection for the transaction + self._conn_cm = pool_or_conn.connection() + conn = self._conn_cm.__enter__() + else: + conn = pool_or_conn + + # Create a PgDb that uses this single connection (no auto-commit) + tx_db = PgDb(conn) + tx_db._commit_after_execute = False # pyright: ignore[reportPrivateUsage] + self._db = tx_db + + self._tx = conn.transaction() self._tx.__enter__() - return self._db + return tx_db def __exit__( self, @@ -222,9 +307,12 @@ def __exit__( exc_val: BaseException | None, exc_tb: types.TracebackType | None, ): - if self._tx is None: - return False - return self._tx.__exit__(exc_type, exc_val, exc_tb) + result = None + if self._tx is not None: + result = self._tx.__exit__(exc_type, exc_val, exc_tb) + if self._conn_cm is not None: + self._conn_cm.__exit__(exc_type, exc_val, exc_tb) + return result @final @@ -234,21 +322,21 @@ class AsyncPgDb(AsyncDbBase): """ db_type = "postgres" - conn: AsyncConnection + conn_wrapper: AsyncConnectionWrapper[AsyncConnection | AsyncConnectionPool] _commit_after_execute: bool = True - def __init__(self, connection: AsyncConnection): + def __init__(self, connection_or_pool: AsyncConnection | AsyncConnectionPool): """ Create a new AsyncPgDb instance. """ - self.conn = connection + self.conn_wrapper = AsyncConnectionWrapper(connection_or_pool) async def close(self): """ Close the database connection. """ - if self.conn: - await self.conn.close() + if self.conn_wrapper: + await self.conn_wrapper.close() def transaction(self) -> AsyncPgDbTransaction: """ @@ -262,9 +350,7 @@ def transaction(self) -> AsyncPgDbTransaction: ... ``` """ - db_copy = AsyncPgDb(self.conn) - db_copy._commit_after_execute = False - return AsyncPgDbTransaction(db_copy) + return AsyncPgDbTransaction(self) def select[M: BaseModel](self, model: type[M]) -> SelectQuery[M, Self]: """ @@ -321,9 +407,10 @@ async def execute(self, query: QuerySingle) -> None: """ Execute a query without returning results. """ - await self.conn.execute(query.sql, query.params) # pyright:ignore[reportArgumentType] - if self._commit_after_execute: - await self.conn.commit() + async with self.conn_wrapper as conn: + await conn.execute(query.sql, query.params) # pyright:ignore[reportArgumentType] + if self._commit_after_execute: + await conn.commit() @override async def executemany(self, query: QueryMany): @@ -331,33 +418,35 @@ async def executemany(self, query: QueryMany): Execute a query with multiple parameter sets. """ params = _jsonify_dicts(query.many_params) - async with self.conn.cursor() as cur: - await cur.executemany(query.sql, params) # pyright:ignore[reportArgumentType] + async with self.conn_wrapper as conn: + async with conn.cursor() as cur: + await cur.executemany(query.sql, params) # pyright:ignore[reportArgumentType] if self._commit_after_execute: - await self.conn.commit() + await conn.commit() @override async def fetch(self, query: QuerySingle | QueryMany) -> list[dict[str, Any]]: """ Execute a query and return results as a list of dicts. """ - async with self.conn.cursor() as cur: - if isinstance(query, QuerySingle): - await cur.execute(query.sql, query.params) # pyright:ignore[reportArgumentType] - else: - await cur.executemany(query.sql, query.many_params, returning=True) # pyright:ignore[reportArgumentType] - - if cur.description is None: - return [] - columns: list[str] = [desc[0] for desc in cur.description] - results: list[dict[str, Any]] = [] - - for row in await cur.fetchall(): - data = dict(zip(columns, row)) - results.append(data) - if self._commit_after_execute: - await self.conn.commit() - return results + async with self.conn_wrapper as conn: + async with conn.cursor() as cur: + if isinstance(query, QuerySingle): + await cur.execute(query.sql, query.params) # pyright:ignore[reportArgumentType] + else: + await cur.executemany(query.sql, query.many_params, returning=True) # pyright:ignore[reportArgumentType] + + if cur.description is None: + return [] + columns: list[str] = [desc[0] for desc in cur.description] + results: list[dict[str, Any]] = [] + + for row in await cur.fetchall(): + data = dict(zip(columns, row)) + results.append(data) + if self._commit_after_execute: + await conn.commit() + return results @override async def truncate(self, schema: str | None = None): @@ -369,10 +458,11 @@ async def truncate(self, schema: str | None = None): if tables is None: return table_names = ", ".join(tables) - async with self.conn.cursor() as cursor: - await cursor.execute(f"TRUNCATE TABLE {table_names} CASCADE") # pyright:ignore[reportArgumentType] - if self._commit_after_execute: - await self.conn.commit() + async with self.conn_wrapper as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"TRUNCATE TABLE {table_names} CASCADE") # pyright:ignore[reportArgumentType] + if self._commit_after_execute: + await conn.commit() @override async def drop_tables(self, schema: str | None = None): @@ -384,20 +474,22 @@ async def drop_tables(self, schema: str | None = None): if tables is None: return table_names = ", ".join(tables) - async with self.conn.cursor() as cursor: - await cursor.execute(f"DROP TABLE {table_names} CASCADE") # pyright:ignore[reportArgumentType] - if self._commit_after_execute: - await self.conn.commit() + async with self.conn_wrapper as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"DROP TABLE {table_names} CASCADE") # pyright:ignore[reportArgumentType] + if self._commit_after_execute: + await conn.commit() async def _get_live_table_names(self, schema: str) -> list[str] | None: - async with self.conn.cursor() as cursor: - # Get all table names from public schema - await cursor.execute(f"SELECT tablename FROM pg_tables WHERE schemaname = '{schema}'") # pyright:ignore[reportArgumentType] - tables = await cursor.fetchall() - if not tables: - return None - table_names = [f'"{table[0]}"' for table in tables] - return table_names + async with self.conn_wrapper as conn: + async with conn.cursor() as cursor: + # Get all table names from public schema + await cursor.execute(f"SELECT tablename FROM pg_tables WHERE schemaname = '{schema}'") # pyright:ignore[reportArgumentType] + tables = await cursor.fetchall() + if not tables: + return None + table_names = [f'"{table[0]}"' for table in tables] + return table_names class AsyncPgDbTransaction: @@ -406,15 +498,33 @@ class AsyncPgDbTransaction: """ _db: AsyncPgDb + _conn_cm: AbstractAsyncContextManager[AsyncConnection] | None = None _tx: AbstractAsyncContextManager[AsyncTransaction] | None = None def __init__(self, db: AsyncPgDb): self._db = db async def __aenter__(self) -> AsyncPgDb: - self._tx = self._db.conn.transaction() + pool_or_conn = self._db.conn_wrapper.conn_or_pool + + if isinstance(pool_or_conn, AsyncConnectionPool): + # Ensure pool is open + await pool_or_conn.open() + + # Check out a dedicated connection for the transaction + self._conn_cm = pool_or_conn.connection() + conn = await self._conn_cm.__aenter__() + else: + conn = pool_or_conn + + # Create an AsyncPgDb that uses this single connection (no auto-commit) + tx_db = AsyncPgDb(conn) + tx_db._commit_after_execute = False # pyright: ignore[reportPrivateUsage] + self._db = tx_db + + self._tx = conn.transaction() await self._tx.__aenter__() - return self._db + return tx_db async def __aexit__( self, @@ -422,9 +532,12 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: types.TracebackType | None, ): - if self._tx is None: - return False - return await self._tx.__aexit__(exc_type, exc_val, exc_tb) + result = None + if self._tx is not None: + result = await self._tx.__aexit__(exc_type, exc_val, exc_tb) + if self._conn_cm is not None: + await self._conn_cm.__aexit__(exc_type, exc_val, exc_tb) + return result def _jsonify_dicts(params: Sequence[dict[str, Any]]) -> list[dict[str, Any]]: diff --git a/uv.lock b/uv.lock index 7808c32..d5b412f 100644 --- a/uv.lock +++ b/uv.lock @@ -141,6 +141,7 @@ name = "embar" source = { editable = "." } dependencies = [ { name = "psycopg", extra = ["binary"] }, + { name = "psycopg-pool" }, { name = "pydantic" }, ] @@ -165,6 +166,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "psycopg", extras = ["binary"], specifier = ">=3.2.11" }, + { name = "psycopg-pool", specifier = ">=3.3.0" }, { name = "pydantic", specifier = ">=2.12.4" }, ] @@ -580,6 +582,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/53/cf/10c3e95827a3ca8af332dfc471befec86e15a14dc83cee893c49a4910dad/psycopg_binary-3.2.12-cp314-cp314-win_amd64.whl", hash = "sha256:48a8e29f3e38fcf8d393b8fe460d83e39c107ad7e5e61cd3858a7569e0554a39", size = 3005787, upload-time = "2025-10-26T00:36:06.783Z" }, ] +[[package]] +name = "psycopg-pool" +version = "3.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/56/9a/9470d013d0d50af0da9c4251614aeb3c1823635cab3edc211e3839db0bcf/psycopg_pool-3.3.0.tar.gz", hash = "sha256:fa115eb2860bd88fce1717d75611f41490dec6135efb619611142b24da3f6db5", size = 31606, upload-time = "2025-12-01T11:34:33.11Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/c3/26b8a0908a9db249de3b4169692e1c7c19048a9bc41a4d3209cee7dbb758/psycopg_pool-3.3.0-py3-none-any.whl", hash = "sha256:2e44329155c410b5e8666372db44276a8b1ebd8c90f1c3026ebba40d4bc81063", size = 39995, upload-time = "2025-12-01T11:34:29.761Z" }, +] + [[package]] name = "pydantic" version = "2.12.4"