From a5b27060aeecef65cf4bf4c56ada553c2076c7d8 Mon Sep 17 00:00:00 2001 From: Chris Arderne Date: Wed, 24 Dec 2025 13:52:58 +0000 Subject: [PATCH] add pgvector support --- docs/conftest.py | 2 +- docs/extensions/vector.md | 110 ++++++++++++++ mkdocs.yml | 3 + src/embar/column/pg.py | 29 ++++ src/embar/constraint.py | 6 +- src/embar/model.py | 13 +- src/embar/query/clause_base.py | 25 ++++ src/embar/query/conflict.py | 2 +- src/embar/query/delete.py | 24 ++-- src/embar/query/having.py | 4 +- src/embar/query/join.py | 18 +-- src/embar/query/order_by.py | 255 ++++++++++++++++++--------------- src/embar/query/select.py | 29 ++-- src/embar/query/update.py | 10 +- src/embar/query/vector.py | 70 +++++++++ src/embar/query/where.py | 209 ++++++++++++--------------- tests/test_vector.py | 124 ++++++++++++++++ 17 files changed, 654 insertions(+), 279 deletions(-) create mode 100644 docs/extensions/vector.md create mode 100644 src/embar/query/clause_base.py create mode 100644 src/embar/query/vector.py create mode 100644 tests/test_vector.py diff --git a/docs/conftest.py b/docs/conftest.py index b5990c8..db80fd1 100644 --- a/docs/conftest.py +++ b/docs/conftest.py @@ -9,7 +9,7 @@ def postgres_container_raw(request: pytest.FixtureRequest): """Session-scoped postgres container for docs tests.""" try: - with PostgresContainer("postgres:18-alpine3.22", port=25432) as postgres: + with PostgresContainer("pgvector/pgvector:0.8.1-pg18-trixie", port=25432) as postgres: request.addfinalizer(postgres.stop) yield postgres except Exception as e: diff --git a/docs/extensions/vector.md b/docs/extensions/vector.md new file mode 100644 index 0000000..8e57437 --- /dev/null +++ b/docs/extensions/vector.md @@ -0,0 +1,110 @@ +# Vector + +Embar supports [pgvector](https://github.com/pgvector/pgvector), the open-source vector similarity search extension for PostgreSQL. + +Before using vector columns, you must install and activate the extension: + +```sql +CREATE EXTENSION vector; +``` + +## Creating a Vector Column + +Use `Vector` to store embeddings with a fixed dimension: + +```{.python fixture:postgres_container} +import asyncio +import psycopg + +from embar.column.common import Integer +from embar.column.pg import Vector +from embar.db.pg import AsyncPgDb +from embar.table import Table + +class Document(Table): + id: Integer = Integer() + embedding: Vector = Vector(3) # 3-dimensional vector + +async def get_db(): + database_url = "postgres://pg:pw@localhost:25432/db" + conn = await psycopg.AsyncConnection.connect(database_url) + await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") + db = AsyncPgDb(conn) + await db.migrate([Document]) + return db + +async def setup(): + db = await get_db() + # Insert some documents with embeddings + await db.insert(Document).values(Document(id=1, embedding=[1.0, 0.0, 0.0])) + await db.insert(Document).values(Document(id=2, embedding=[0.0, 1.0, 0.0])) + await db.insert(Document).values(Document(id=3, embedding=[0.0, 0.0, 1.0])) + +asyncio.run(setup()) +``` + +## L2 Distance + +Use `L2Distance` for Euclidean distance searches. This uses the `<->` operator. + +### Order By L2 Distance + +Find documents ordered by distance to a query vector: + +```{.python continuation} +from embar.query.vector import L2Distance + +async def order_by_l2(): + db = await get_db() + query_vector = [1.0, 0.5, 0.0] + docs = await ( + db.select(Document.all()) + .from_(Document) + .order_by(L2Distance(Document.embedding, query_vector)) + ) + print([d.id for d in docs]) + +asyncio.run(order_by_l2()) +``` + +### Filter By L2 Distance + +Find documents within a distance threshold: + +```{.python continuation} +from embar.query.where import Lt + +async def filter_by_l2(): + db = await get_db() + query_vector = [1.0, 0.0, 0.0] + docs = await ( + db.select(Document.all()) + .from_(Document) + .where(Lt(L2Distance(Document.embedding, query_vector), 0.5)) + ) + print([d.id for d in docs]) + +asyncio.run(filter_by_l2()) +``` + +## Cosine Distance + +Use `CosineDistance` for cosine similarity searches. This uses the `<=>` operator. + +### Order By Cosine Distance + +```{.python continuation} +from embar.query.vector import CosineDistance + +async def order_by_cosine(): + db = await get_db() + query_vector = [1.0, 0.5, 0.0] + docs = await ( + db.select(Document.all()) + .from_(Document) + .order_by(CosineDistance(Document.embedding, query_vector)) + ) + print([d.id for d in docs]) + +asyncio.run(order_by_cosine()) +``` diff --git a/mkdocs.yml b/mkdocs.yml index e623e40..ffc1fe0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -48,6 +48,9 @@ nav: - Transactions: query/transactions.md - Raw SQL: query/raw-sql.md + - Extensions: + - Vector: extensions/vector.md + - Integration: - Migrations: migrations.md - Types: types.md diff --git a/src/embar/column/pg.py b/src/embar/column/pg.py index e5f7566..015f527 100644 --- a/src/embar/column/pg.py +++ b/src/embar/column/pg.py @@ -32,6 +32,7 @@ "Timestamp", "Varchar", "EnumCol", + "Vector", ] @@ -344,3 +345,31 @@ def __init__( self._sql_type = pg_enum.name super().__init__(name=name, default=default, primary=primary, not_null=not_null) + + +# Extension: pgvector +# Should also support `halfvec` and `bit` +class Vector(Column[list[float]]): + """ + Vector column using [pgvector](https://github.com/pgvector/pgvector). + + This assumes the extension is already installed and activated with + CREATE EXTENSION vector; + """ + + _sql_type: str = "VECTOR" + _py_type: Type = list[float] + + def __init__( + self, + length: int, + name: str | None = None, + default: list[float] | None = None, + primary: bool = False, + not_null: bool = False, + ): + """ + Create a new Vector instance. + """ + self._extra_args: tuple[int] | tuple[int, int] | None = (length,) + super().__init__(name=name, default=default, primary=primary, not_null=not_null) diff --git a/src/embar/constraint.py b/src/embar/constraint.py index 0d7c64b..3b43c94 100644 --- a/src/embar/constraint.py +++ b/src/embar/constraint.py @@ -6,8 +6,8 @@ from embar.column.base import ColumnBase from embar.constraint_base import Constraint from embar.custom_types import PyType +from embar.query.clause_base import ClauseBase from embar.query.query import QuerySingle -from embar.query.where import WhereClause class Index: @@ -70,7 +70,7 @@ class IndexReady(Constraint): unique: bool name: str columns: tuple[Callable[[], ColumnBase], ...] - _where_clause: Callable[[], WhereClause] | None = None + _where_clause: Callable[[], ClauseBase] | None = None def __init__(self, name: str, unique: bool, *columns: Callable[[], ColumnBase]): """ @@ -80,7 +80,7 @@ def __init__(self, name: str, unique: bool, *columns: Callable[[], ColumnBase]): self.unique = unique self.columns = columns - def where(self, where_clause: Callable[[], WhereClause]) -> Self: + def where(self, where_clause: Callable[[], ClauseBase]) -> Self: """ Add a WHERE clause to create a partial index. """ diff --git a/src/embar/model.py b/src/embar/model.py index 585245f..f9dd2f3 100644 --- a/src/embar/model.py +++ b/src/embar/model.py @@ -1,3 +1,4 @@ +import json from typing import ( Annotated, Any, @@ -8,7 +9,7 @@ get_type_hints, ) -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel, BeforeValidator, Field, create_model from embar.column.base import ColumnBase from embar.db.base import DbType @@ -160,6 +161,10 @@ class MyTable(Table): ... fields_dict: dict[str, Any] = {} for field_name, column in cls._fields.items(): # pyright:ignore[reportPrivateUsage] field_type = column.info.py_type + + if column.info.col_type == "VECTOR": + field_type = Annotated[field_type, BeforeValidator(_parse_json_list)] + fields_dict[field_name] = ( Annotated[field_type, column], Field(default_factory=lambda a=column: column.info.fqn()), @@ -185,3 +190,9 @@ def upgrade_model_nested_fields[B: BaseModel](model: type[B]) -> type[B]: new_class.model_rebuild() return new_class + + +def _parse_json_list(v: Any): + if isinstance(v, str): + return json.loads(v) + return v diff --git a/src/embar/query/clause_base.py b/src/embar/query/clause_base.py new file mode 100644 index 0000000..6cf8dcf --- /dev/null +++ b/src/embar/query/clause_base.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod +from typing import Callable + +from embar.query.query import QuerySingle + +# Where clauses get passed a get_count() function that returns a monotonically +# increasing integer. This allows each SQL binding parameter to get a unique +# name like `%(eq_id_2)s` in psycopg format. +type GetCount = Callable[[], int] + + +class ClauseBase(ABC): + """ + ABC for ORDER BY and WHERE clauses. + + Not all use the get_count() directly (those with no bindings) + but their children might. + """ + + @abstractmethod + def sql(self, get_count: GetCount) -> QuerySingle: + """ + Generate the SQL fragment for this clause. + """ + ... diff --git a/src/embar/query/conflict.py b/src/embar/query/conflict.py index 5f84ff2..43de24d 100644 --- a/src/embar/query/conflict.py +++ b/src/embar/query/conflict.py @@ -2,8 +2,8 @@ from typing import override from embar.custom_types import PyType +from embar.query.clause_base import GetCount from embar.query.query import QuerySingle -from embar.query.where import GetCount # require at least one element in tuple TupleAtLeastOne = tuple[str, *tuple[str, ...]] diff --git a/src/embar/query/delete.py b/src/embar/query/delete.py index eebb5af..0c69f20 100644 --- a/src/embar/query/delete.py +++ b/src/embar/query/delete.py @@ -11,9 +11,9 @@ from embar.model import ( generate_model, ) -from embar.query.order_by import Asc, BareColumn, Desc, OrderBy, OrderByClause, RawSqlOrder +from embar.query.clause_base import ClauseBase +from embar.query.order_by import Asc, BareColumn, Desc, OrderBy, RawSqlOrder from embar.query.query import QuerySingle -from embar.query.where import WhereClause from embar.sql import Sql from embar.table import Table @@ -30,7 +30,7 @@ class DeleteQueryReady[T: Table, Db: AllDbBase]: table: type[T] _db: Db - _where_clause: WhereClause | None = None + _where_clause: ClauseBase | None = None _order_clause: OrderBy | None = None _limit_value: int | None = None @@ -44,7 +44,7 @@ def __init__(self, table: type[T], db: Db): def returning(self) -> DeleteQueryReturning[T, Db]: return DeleteQueryReturning(self.table, self._db, self._where_clause, self._order_clause, self._limit_value) - def where(self, where_clause: WhereClause) -> Self: + def where(self, where_clause: ClauseBase) -> Self: """ Add a WHERE clause to the query. """ @@ -79,7 +79,7 @@ class User(Table): ``` """ # Convert each clause to an OrderByClause - order_clauses: list[OrderByClause] = [] + order_clauses: list[ClauseBase] = [] for clause in clauses: if isinstance(clause, (Asc, Desc)): order_clauses.append(clause) @@ -170,8 +170,9 @@ def get_count() -> int: params = {**params, **where_data.params} if self._order_clause is not None: - order_by_sql = self._order_clause.sql() - sql += f"\nORDER BY {order_by_sql}" + order_by_query = self._order_clause.sql(get_count) + sql += f"\nORDER BY {order_by_query.sql}" + params = {**params, **order_by_query.params} if self._limit_value is not None: sql += f"\nLIMIT {self._limit_value}" @@ -189,7 +190,7 @@ class DeleteQueryReturning[T: Table, Db: AllDbBase]: table: type[T] _db: Db - _where_clause: WhereClause | None = None + _where_clause: ClauseBase | None = None _order_clause: OrderBy | None = None _limit_value: int | None = None @@ -197,7 +198,7 @@ def __init__( self, table: type[T], db: Db, - where_clause: WhereClause | None, + where_clause: ClauseBase | None, order_clause: OrderBy | None, limit_value: int | None, ): @@ -287,8 +288,9 @@ def get_count() -> int: params = {**params, **where_data.params} if self._order_clause is not None: - order_by_sql = self._order_clause.sql() - sql += f"\nORDER BY {order_by_sql}" + order_by_query = self._order_clause.sql(get_count) + sql += f"\nORDER BY {order_by_query.sql}" + params = {**params, **order_by_query.params} if self._limit_value is not None: sql += f"\nLIMIT {self._limit_value}" diff --git a/src/embar/query/having.py b/src/embar/query/having.py index f138040..0d3f5c2 100644 --- a/src/embar/query/having.py +++ b/src/embar/query/having.py @@ -2,7 +2,7 @@ from dataclasses import dataclass -from embar.query.where import WhereClause +from embar.query.clause_base import ClauseBase @dataclass @@ -32,4 +32,4 @@ class Having: ``` """ - clause: WhereClause + clause: ClauseBase diff --git a/src/embar/query/join.py b/src/embar/query/join.py index 652704e..c9882ae 100644 --- a/src/embar/query/join.py +++ b/src/embar/query/join.py @@ -3,8 +3,8 @@ from abc import ABC, abstractmethod from typing import override +from embar.query.clause_base import ClauseBase, GetCount from embar.query.query import QuerySingle -from embar.query.where import GetCount, WhereClause from embar.table import Table @@ -27,9 +27,9 @@ class LeftJoin(JoinClause): """ table: type[Table] - on: WhereClause + on: ClauseBase - def __init__(self, table: type[Table], on: WhereClause): + def __init__(self, table: type[Table], on: ClauseBase): """ Create a new LeftJoin instance. """ @@ -53,9 +53,9 @@ class RightJoin(JoinClause): """ table: type[Table] - on: WhereClause + on: ClauseBase - def __init__(self, table: type[Table], on: WhereClause): + def __init__(self, table: type[Table], on: ClauseBase): """ Create a new RightJoin instance. """ @@ -79,9 +79,9 @@ class InnerJoin(JoinClause): """ table: type[Table] - on: WhereClause + on: ClauseBase - def __init__(self, table: type[Table], on: WhereClause): + def __init__(self, table: type[Table], on: ClauseBase): """ Create a new InnerJoin instance. """ @@ -105,9 +105,9 @@ class FullJoin(JoinClause): """ table: type[Table] - on: WhereClause + on: ClauseBase - def __init__(self, table: type[Table], on: WhereClause): + def __init__(self, table: type[Table], on: ClauseBase): """ Create a new FullJoin instance. """ diff --git a/src/embar/query/order_by.py b/src/embar/query/order_by.py index 14d6a3e..8c48e9a 100644 --- a/src/embar/query/order_by.py +++ b/src/embar/query/order_by.py @@ -1,32 +1,121 @@ """Order by clause for sorting query results.""" -from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Literal, override from embar.column.base import ColumnBase +from embar.custom_types import PyType +from embar.query.clause_base import ClauseBase, GetCount +from embar.query.query import QuerySingle from embar.sql import Sql type NullsOrdering = Literal["first", "last"] -class OrderByClause(ABC): +@dataclass +class OrderBy: """ - Base class for ORDER BY clause components. + Represents an ORDER BY clause for sorting query results. + + ```python + from embar.query.order_by import OrderBy, Asc, Desc, BareColumn + from embar.column.base import ColumnBase, ColumnInfo + + col1 = ColumnBase() + col1.info = ColumnInfo( + _table_name=lambda: "users", + name="age", + col_type="INTEGER", + py_type=int, + primary=False, + not_null=False + ) + + col2 = ColumnBase() + col2.info = ColumnInfo( + _table_name=lambda: "users", + name="name", + col_type="TEXT", + py_type=str, + primary=False, + not_null=False + ) - An ORDER BY clause can contain: - - A bare column (defaults to ASC) - - An Asc(column) or Desc(column) wrapper with optional nulls handling - - Raw SQL via Sql(t"...") + order = OrderBy(( + Desc(col1), + Asc(col2, nulls="first"), + )) + sql = order.sql(lambda: 0) + print(sql) + assert sql.sql == '"users"."age" DESC, "users"."name" ASC NULLS FIRST' + ``` """ - @abstractmethod - def sql(self) -> str: - """Generate the SQL fragment for this ORDER BY component.""" - ... + clauses: tuple[ClauseBase, ...] + + def sql(self, get_count: GetCount) -> QuerySingle: + """ + Generate the full ORDER BY SQL clause. + + ```python + from embar.query.order_by import OrderBy, Asc, BareColumn + from embar.column.base import ColumnBase, ColumnInfo + + col1 = ColumnBase() + col1.info = ColumnInfo( + _table_name=lambda: "users", + name="id", + col_type="INTEGER", + py_type=int, + primary=False, + not_null=False + ) + + col2 = ColumnBase() + col2.info = ColumnInfo( + _table_name=lambda: "users", + name="name", + col_type="TEXT", + py_type=str, + primary=False, + not_null=False + ) + + order = OrderBy((BareColumn(col1), Asc(col2))) + sql = order.sql(lambda: 0) + assert sql.sql == '"users"."id", "users"."name" ASC' + ``` + """ + queries = [clause.sql(get_count) for clause in self.clauses] + params = {k: v for d in queries for k, v in d.params.items()} -class Asc(OrderByClause): + sql = ", ".join(q.sql for q in queries) + return QuerySingle(sql=sql, params=params) + + +def _asc_or_desc_sql( + clause: ColumnBase | ClauseBase, + nulls: NullsOrdering | None, + asc: bool, + get_count: GetCount, +) -> QuerySingle: + """Generate the SQL fragment.""" + params: dict[str, PyType] = {} + direction = "ASC" if asc else "DESC" + if isinstance(clause, ColumnBase): + sql = f"{clause.info.fqn()} {direction}" + else: + query = clause.sql(get_count) + sql = f"{query.sql} {direction}" + params = query.params + + if nulls is not None: + return QuerySingle(sql=f"{sql} NULLS {nulls.upper()}", params=params) + return QuerySingle(sql=sql, params=params) + + +class Asc(ClauseBase): """ Represents an ascending sort order for a column. @@ -44,15 +133,15 @@ class Asc(OrderByClause): not_null=False ) asc = Asc(col, nulls="last") - sql = asc.sql() - assert sql == '"users"."age" ASC NULLS LAST' + sql = asc.sql(lambda: 0) + assert sql.sql == '"users"."age" ASC NULLS LAST' ``` """ - col: ColumnBase + clause: ColumnBase | ClauseBase nulls: NullsOrdering | None - def __init__(self, col: ColumnBase, nulls: NullsOrdering | None = None): + def __init__(self, clause: ColumnBase | ClauseBase, nulls: NullsOrdering | None = None): """ Create an ascending sort order. @@ -60,19 +149,21 @@ def __init__(self, col: ColumnBase, nulls: NullsOrdering | None = None): col: The column to sort by nulls: Optional nulls ordering ("first" or "last") """ - self.col = col + self.clause = clause self.nulls = nulls @override - def sql(self) -> str: + def sql(self, get_count: GetCount) -> QuerySingle: """Generate the SQL fragment.""" - base = f"{self.col.info.fqn()} ASC" - if self.nulls is not None: - return f"{base} NULLS {self.nulls.upper()}" - return base + return _asc_or_desc_sql( + self.clause, + self.nulls, + True, + get_count, + ) -class Desc(OrderByClause): +class Desc(ClauseBase): """ Represents a descending sort order for a column. @@ -90,15 +181,15 @@ class Desc(OrderByClause): not_null=False ) desc = Desc(col, nulls="first") - sql = desc.sql() - assert sql == '"users"."age" DESC NULLS FIRST' + sql = desc.sql(lambda: 0) + assert sql.sql == '"users"."age" DESC NULLS FIRST' ``` """ - col: ColumnBase + clause: ColumnBase | ClauseBase nulls: NullsOrdering | None - def __init__(self, col: ColumnBase, nulls: NullsOrdering | None = None): + def __init__(self, clause: ColumnBase | ClauseBase, nulls: NullsOrdering | None = None): """ Create a descending sort order. @@ -106,19 +197,21 @@ def __init__(self, col: ColumnBase, nulls: NullsOrdering | None = None): col: The column to sort by nulls: Optional nulls ordering ("first" or "last") """ - self.col = col + self.clause = clause self.nulls = nulls @override - def sql(self) -> str: + def sql(self, get_count: GetCount) -> QuerySingle: """Generate the SQL fragment.""" - base = f"{self.col.info.fqn()} DESC" - if self.nulls is not None: - return f"{base} NULLS {self.nulls.upper()}" - return base + return _asc_or_desc_sql( + self.clause, + self.nulls, + False, + get_count, + ) -class BareColumn(OrderByClause): +class BareColumn(ClauseBase): """ Represents a bare column reference (defaults to ASC). @@ -138,8 +231,8 @@ class BareColumn(OrderByClause): not_null=False ) bare = BareColumn(col) - sql = bare.sql() - assert sql == '"users"."id"' + sql = bare.sql(lambda: 0) + assert sql.sql == '"users"."id"' ``` """ @@ -150,12 +243,12 @@ def __init__(self, col: ColumnBase): self.col = col @override - def sql(self) -> str: + def sql(self, get_count: GetCount) -> QuerySingle: """Generate the SQL fragment (just the column FQN).""" - return self.col.info.fqn() + return QuerySingle(sql=self.col.info.fqn(), params=None) -class RawSqlOrder(OrderByClause): +class RawSqlOrder(ClauseBase): """ Represents raw SQL in an ORDER BY clause. @@ -169,8 +262,8 @@ class User(Table): id: Integer = Integer() raw = RawSqlOrder(Sql(t"{User.id} DESC")) - sql = raw.sql() - assert sql == '"user"."id" DESC' + sql = raw.sql(lambda: 0) + assert sql.sql == '"user"."id" DESC' ``` """ @@ -181,82 +274,6 @@ def __init__(self, sql_obj: Sql): self.sql_obj = sql_obj @override - def sql(self) -> str: + def sql(self, get_count: GetCount) -> QuerySingle: """Generate the SQL fragment.""" - return self.sql_obj.sql() - - -@dataclass -class OrderBy: - """ - Represents an ORDER BY clause for sorting query results. - - ```python - from embar.query.order_by import OrderBy, Asc, Desc, BareColumn - from embar.column.base import ColumnBase, ColumnInfo - - col1 = ColumnBase() - col1.info = ColumnInfo( - _table_name=lambda: "users", - name="age", - col_type="INTEGER", - py_type=int, - primary=False, - not_null=False - ) - - col2 = ColumnBase() - col2.info = ColumnInfo( - _table_name=lambda: "users", - name="name", - col_type="TEXT", - py_type=str, - primary=False, - not_null=False - ) - - order = OrderBy(( - Desc(col1), - Asc(col2, nulls="first"), - )) - sql = order.sql() - assert sql == '"users"."age" DESC, "users"."name" ASC NULLS FIRST' - ``` - """ - - clauses: tuple[OrderByClause, ...] - - def sql(self) -> str: - """ - Generate the full ORDER BY SQL clause. - - ```python - from embar.query.order_by import OrderBy, Asc, BareColumn - from embar.column.base import ColumnBase, ColumnInfo - - col1 = ColumnBase() - col1.info = ColumnInfo( - _table_name=lambda: "users", - name="id", - col_type="INTEGER", - py_type=int, - primary=False, - not_null=False - ) - - col2 = ColumnBase() - col2.info = ColumnInfo( - _table_name=lambda: "users", - name="name", - col_type="TEXT", - py_type=str, - primary=False, - not_null=False - ) - - order = OrderBy((BareColumn(col1), Asc(col2))) - sql = order.sql() - assert sql == '"users"."id", "users"."name" ASC' - ``` - """ - return ", ".join(clause.sql() for clause in self.clauses) + return QuerySingle(sql=self.sql_obj.sql(), params=None) diff --git a/src/embar/query/select.py b/src/embar/query/select.py index 41c70ba..3583012 100644 --- a/src/embar/query/select.py +++ b/src/embar/query/select.py @@ -15,12 +15,12 @@ to_sql_columns, upgrade_model_nested_fields, ) +from embar.query.clause_base import ClauseBase from embar.query.group_by import GroupBy from embar.query.having import Having from embar.query.join import CrossJoin, FullJoin, InnerJoin, JoinClause, LeftJoin, RightJoin -from embar.query.order_by import Asc, BareColumn, Desc, OrderBy, OrderByClause, RawSqlOrder +from embar.query.order_by import Asc, BareColumn, Desc, OrderBy, RawSqlOrder from embar.query.query import QuerySingle -from embar.query.where import WhereClause from embar.sql import Sql from embar.table import Table @@ -108,7 +108,7 @@ class SelectQueryReady[M: BaseModel, T: Table, Db: AllDbBase]: _distinct: bool _joins: list[JoinClause] - _where_clause: WhereClause | None = None + _where_clause: ClauseBase | None = None _group_clause: GroupBy | None = None _having_clause: Having | None = None _order_clause: OrderBy | None = None @@ -125,28 +125,28 @@ def __init__(self, model: type[M], table: type[T], db: Db, distinct: bool): self._distinct = distinct self._joins = [] - def left_join(self, table: type[Table], on: WhereClause) -> Self: + def left_join(self, table: type[Table], on: ClauseBase) -> Self: """ Add a LEFT JOIN clause to the query. """ self._joins.append(LeftJoin(table, on)) return self - def right_join(self, table: type[Table], on: WhereClause) -> Self: + def right_join(self, table: type[Table], on: ClauseBase) -> Self: """ Add a RIGHT JOIN clause to the query. """ self._joins.append(RightJoin(table, on)) return self - def inner_join(self, table: type[Table], on: WhereClause) -> Self: + def inner_join(self, table: type[Table], on: ClauseBase) -> Self: """ Add an INNER JOIN clause to the query. """ self._joins.append(InnerJoin(table, on)) return self - def full_join(self, table: type[Table], on: WhereClause) -> Self: + def full_join(self, table: type[Table], on: ClauseBase) -> Self: """ Add a FULL OUTER JOIN clause to the query. """ @@ -160,7 +160,7 @@ def cross_join(self, table: type[Table]) -> Self: self._joins.append(CrossJoin(table)) return self - def where(self, where_clause: WhereClause) -> Self: + def where(self, where_clause: ClauseBase) -> Self: """ Add a WHERE clause to the query. """ @@ -174,7 +174,7 @@ def group_by(self, *cols: ColumnBase) -> Self: self._group_clause = GroupBy(cols) return self - def having(self, clause: WhereClause) -> Self: + def having(self, clause: ClauseBase) -> Self: """ Add a HAVING clause to filter grouped/aggregated results. @@ -204,7 +204,7 @@ class User(Table): self._having_clause = Having(clause) return self - def order_by(self, *clauses: ColumnBase | Asc | Desc | Sql) -> Self: + def order_by(self, *clauses: ColumnBase | Asc | Desc | ClauseBase | Sql) -> Self: """ Add an ORDER BY clause to sort query results. @@ -248,12 +248,14 @@ class User(Table): ``` """ # Convert each clause to an OrderByClause - order_clauses: list[OrderByClause] = [] + order_clauses: list[ClauseBase] = [] for clause in clauses: if isinstance(clause, (Asc, Desc)): order_clauses.append(clause) elif isinstance(clause, Sql): order_clauses.append(RawSqlOrder(clause)) + elif isinstance(clause, ClauseBase): + order_clauses.append(clause) else: order_clauses.append(BareColumn(clause)) @@ -419,8 +421,9 @@ def get_count() -> int: params = {**params, **having_data.params} if self._order_clause is not None: - order_by_sql = self._order_clause.sql() - sql += f"\nORDER BY {order_by_sql}" + order_by_query = self._order_clause.sql(get_count) + sql += f"\nORDER BY {order_by_query.sql}" + params = {**params, **order_by_query.params} if self._limit_value is not None: sql += f"\nLIMIT {self._limit_value}" diff --git a/src/embar/query/update.py b/src/embar/query/update.py index 04b9236..13592b0 100644 --- a/src/embar/query/update.py +++ b/src/embar/query/update.py @@ -7,8 +7,8 @@ from embar.db.base import AllDbBase, AsyncDbBase, DbBase from embar.model import generate_model +from embar.query.clause_base import ClauseBase from embar.query.query import QuerySingle -from embar.query.where import WhereClause from embar.table import Table @@ -53,7 +53,7 @@ class UpdateQueryReady[T: Table, Db: AllDbBase]: table: type[T] _db: Db data: Mapping[str, Any] - _where_clause: WhereClause | None = None + _where_clause: ClauseBase | None = None def __init__(self, table: type[T], db: Db, data: Mapping[str, Any]): """ @@ -63,7 +63,7 @@ def __init__(self, table: type[T], db: Db, data: Mapping[str, Any]): self._db = db self.data = data - def where(self, where_clause: WhereClause) -> Self: + def where(self, where_clause: ClauseBase) -> Self: """ Add a WHERE clause to limit which rows are updated. """ @@ -151,9 +151,9 @@ class UpdateQueryReturning[T: Table, Db: AllDbBase]: table: type[T] _db: Db data: Mapping[str, Any] - _where_clause: WhereClause | None = None + _where_clause: ClauseBase | None = None - def __init__(self, table: type[T], db: Db, data: Mapping[str, Any], where_clause: WhereClause | None): + def __init__(self, table: type[T], db: Db, data: Mapping[str, Any], where_clause: ClauseBase | None): """ Create a new UpdateQueryReturning instance. """ diff --git a/src/embar/query/vector.py b/src/embar/query/vector.py new file mode 100644 index 0000000..2c835ed --- /dev/null +++ b/src/embar/query/vector.py @@ -0,0 +1,70 @@ +""" +Code specific to the pgvector extension. + +The base vector column is still defined in embar.column.pg +""" + +from typing import override + +from embar.column.base import ColumnInfo +from embar.column.common import Column +from embar.query.clause_base import ClauseBase, GetCount +from embar.query.query import QuerySingle + + +class L2Distance(ClauseBase): + """ + Get the L2 Distance using pgvector. + + Assumes pgvector extension is installed and activated. + + Creates a query like col_a <-> '[1,2,3]' or col_a <-> col_b. + """ + + left: ColumnInfo + right: list[float] | ColumnInfo + + def __init__(self, left: Column[list[float]], right: list[float] | Column[list[float]]): + self.left = left.info + self.right = right.info if isinstance(right, Column) else right + + @override + def sql(self, get_count: GetCount) -> QuerySingle: + count = get_count() + name = f"l2distance_{self.left.name}_{count}" + if isinstance(self.right, ColumnInfo): + return QuerySingle(sql=f"{self.left.fqn()} <-> {self.right.fqn()}") + + # pgvector expects an argument of the form '[1,2,3]' + stringified = str(self.right).replace(" ", "") + + return QuerySingle(sql=f"{self.left.fqn()} <-> %({name})s", params={name: stringified}) + + +class CosineDistance(ClauseBase): + """ + Get the Cosine Distance using pgvector. + + Assumes pgvector extension is installed and activated. + + Creates a query like col_a <=> '[1,2,3]' or col_a <=> col_b. + """ + + left: ColumnInfo + right: list[float] | ColumnInfo + + def __init__(self, left: Column[list[float]], right: list[float] | Column[list[float]]): + self.left = left.info + self.right = right.info if isinstance(right, Column) else right + + @override + def sql(self, get_count: GetCount) -> QuerySingle: + count = get_count() + name = f"l2distance_{self.left.name}_{count}" + if isinstance(self.right, ColumnInfo): + return QuerySingle(sql=f"{self.left.fqn()} <=> {self.right.fqn()}") + + # pgvector expects an argument of the form '[1,2,3]' + stringified = str(self.right).replace(" ", "") + + return QuerySingle(sql=f"{self.left.fqn()} <=> %({name})s", params={name: stringified}) diff --git a/src/embar/query/where.py b/src/embar/query/where.py index a60e1b2..c5b6ca3 100644 --- a/src/embar/query/where.py +++ b/src/embar/query/where.py @@ -1,178 +1,151 @@ """Where clauses for filtering queries.""" -from abc import ABC, abstractmethod -from typing import Any, Callable, Protocol, override +from typing import Any, Protocol, override from embar.column.base import ColumnInfo from embar.column.common import Column from embar.custom_types import PyType +from embar.query.clause_base import ClauseBase, GetCount from embar.query.query import QuerySingle -# Where clauses get passed a get_count() function that returns a monotonically -# increasing integer. This allows each SQL binding parameter to get a unique -# name like `%(eq_id_2)s` in psycopg format. -type GetCount = Callable[[], int] +def _gen_comparison_sql( + left: ColumnInfo | ClauseBase, + right: ColumnInfo | PyType, + operator: str, + name_root: str, + get_count: GetCount, +) -> QuerySingle: + """Generate SQL for binary comparison operators.""" + name = "vals" + params: dict[str, PyType] = {} + + if isinstance(left, ColumnInfo): + left_sql = left.fqn() + name = left.name + else: + left_result = left.sql(get_count) + left_sql = left_result.sql + params.update(left_result.params) + + if isinstance(right, ColumnInfo): + right_sql = right.fqn() + else: + count = get_count() + param_name = f"{name_root}_{name}_{count}" + right_sql = f"%({param_name})s" + params[param_name] = right -class WhereClause(ABC): - """ - ABC for Where clauses. - - Not all use the get_count() directly (those with no bindings) - but their children might. - """ - - @abstractmethod - def sql(self, get_count: GetCount) -> QuerySingle: - """ - Generate the SQL for this where clause. - """ - ... + return QuerySingle(sql=f"{left_sql} {operator} {right_sql}", params=params) # Comparison operators -class Eq[T: PyType](WhereClause): +class Eq[T: PyType](ClauseBase): """ Checks if a column value is equal to another column or a passed param. Right now the left must always be a column, maybe that must be loosened. """ - left: ColumnInfo - right: PyType | ColumnInfo + left: ColumnInfo | ClauseBase + right: ColumnInfo | PyType - def __init__(self, left: Column[T], right: T | Column[T]): - self.left = left.info + def __init__(self, left: Column[T] | ClauseBase, right: Column[T] | T): + self.left = left.info if isinstance(left, Column) else left self.right = right.info if isinstance(right, Column) else right @override def sql(self, get_count: GetCount) -> QuerySingle: - count = get_count() - name = f"eq_{self.left.name}_{count}" - - if isinstance(self.right, ColumnInfo): - return QuerySingle(sql=f"{self.left.fqn()} = {self.right.fqn()}") + return _gen_comparison_sql(self.left, self.right, "=", "eq", get_count) - return QuerySingle(sql=f"{self.left.fqn()} = %({name})s", params={name: self.right}) - -class Ne[T: PyType](WhereClause): +class Ne[T: PyType](ClauseBase): """ Checks if a column value is not equal to another column or a passed param. """ - left: ColumnInfo - right: PyType | ColumnInfo + left: ColumnInfo | ClauseBase + right: ColumnInfo | PyType - def __init__(self, left: Column[T], right: T | Column[T]): - self.left = left.info + def __init__(self, left: Column[T] | ClauseBase, right: Column[T] | T): + self.left = left.info if isinstance(left, Column) else left self.right = right.info if isinstance(right, Column) else right @override def sql(self, get_count: GetCount) -> QuerySingle: - count = get_count() - name = f"ne_{self.left.name}_{count}" + return _gen_comparison_sql(self.left, self.right, "!=", "ne", get_count) - if isinstance(self.right, ColumnInfo): - return QuerySingle(sql=f"{self.left.fqn()} != {self.right.fqn()}") - - return QuerySingle(sql=f"{self.left.fqn()} != %({name})s", params={name: self.right}) - -class Gt[T: PyType](WhereClause): +class Gt[T: PyType](ClauseBase): """ Checks if a column value is greater than another column or a passed param. """ - left: ColumnInfo - right: PyType | ColumnInfo + left: ColumnInfo | ClauseBase + right: ColumnInfo | PyType - def __init__(self, left: Column[T], right: T | Column[T]): - self.left = left.info + def __init__(self, left: Column[T] | ClauseBase, right: Column[T] | T): + self.left = left.info if isinstance(left, Column) else left self.right = right.info if isinstance(right, Column) else right @override def sql(self, get_count: GetCount) -> QuerySingle: - count = get_count() - name = f"gt_{self.left.name}_{count}" - - if isinstance(self.right, ColumnInfo): - return QuerySingle(sql=f"{self.left.fqn()} > {self.right.fqn()}") - - return QuerySingle(sql=f"{self.left.fqn()} > %({name})s", params={name: self.right}) + return _gen_comparison_sql(self.left, self.right, ">", "gt", get_count) -class Gte[T: PyType](WhereClause): +class Gte[T: PyType](ClauseBase): """ Checks if a column value is greater than or equal to another column or a passed param. """ - left: ColumnInfo - right: PyType | ColumnInfo + left: ColumnInfo | ClauseBase + right: ColumnInfo | PyType - def __init__(self, left: Column[T], right: T | Column[T]): - self.left = left.info + def __init__(self, left: Column[T] | ClauseBase, right: Column[T] | T): + self.left = left.info if isinstance(left, Column) else left self.right = right.info if isinstance(right, Column) else right @override def sql(self, get_count: GetCount) -> QuerySingle: - count = get_count() - name = f"gte_{self.left.name}_{count}" - - if isinstance(self.right, ColumnInfo): - return QuerySingle(sql=f"{self.left.fqn()} >= {self.right.fqn()}") - - return QuerySingle(sql=f"{self.left.fqn()} >= %({name})s", params={name: self.right}) + return _gen_comparison_sql(self.left, self.right, ">=", "gte", get_count) -class Lt[T: PyType](WhereClause): +class Lt[T: PyType](ClauseBase): """ Checks if a column value is less than another column or a passed param. """ - left: ColumnInfo - right: PyType | ColumnInfo + left: ColumnInfo | ClauseBase + right: ColumnInfo | PyType - def __init__(self, left: Column[T], right: T | Column[T]): - self.left = left.info + def __init__(self, left: Column[T] | ClauseBase, right: Column[T] | T): + self.left = left.info if isinstance(left, Column) else left self.right = right.info if isinstance(right, Column) else right @override def sql(self, get_count: GetCount) -> QuerySingle: - count = get_count() - name = f"lt_{self.left.name}_{count}" - - if isinstance(self.right, ColumnInfo): - return QuerySingle(sql=f"{self.left.fqn()} < {self.right.fqn()}") + return _gen_comparison_sql(self.left, self.right, "<", "lt", get_count) - return QuerySingle(sql=f"{self.left.fqn()} < %({name})s", params={name: self.right}) - -class Lte[T: PyType](WhereClause): +class Lte[T: PyType](ClauseBase): """ Checks if a column value is less than or equal to another column or a passed param. """ - left: ColumnInfo - right: PyType | ColumnInfo + left: ColumnInfo | ClauseBase + right: ColumnInfo | PyType - def __init__(self, left: Column[T], right: T | Column[T]): - self.left = left.info + def __init__(self, left: Column[T] | ClauseBase, right: Column[T] | T): + self.left = left.info if isinstance(left, Column) else left self.right = right.info if isinstance(right, Column) else right @override def sql(self, get_count: GetCount) -> QuerySingle: - count = get_count() - name = f"lte_{self.left.name}_{count}" - - if isinstance(self.right, ColumnInfo): - return QuerySingle(sql=f"{self.left.fqn()} <= {self.right.fqn()}") - - return QuerySingle(sql=f"{self.left.fqn()} <= %({name})s", params={name: self.right}) + return _gen_comparison_sql(self.left, self.right, "<=", "lte", get_count) # String matching operators -class Like[T: PyType](WhereClause): +class Like[T: PyType](ClauseBase): left: ColumnInfo right: PyType | ColumnInfo @@ -190,7 +163,7 @@ def sql(self, get_count: GetCount) -> QuerySingle: return QuerySingle(sql=f"{self.left.fqn()} LIKE %({name})s", params={name: self.right}) -class Ilike[T: PyType](WhereClause): +class Ilike[T: PyType](ClauseBase): """ Case-insensitive LIKE pattern matching. """ @@ -212,7 +185,7 @@ def sql(self, get_count: GetCount) -> QuerySingle: return QuerySingle(sql=f"{self.left.fqn()} ILIKE %({name})s", params={name: self.right}) -class NotLike[T: PyType](WhereClause): +class NotLike[T: PyType](ClauseBase): """ Negated LIKE pattern matching. """ @@ -235,7 +208,7 @@ def sql(self, get_count: GetCount) -> QuerySingle: # Null checks -class IsNull(WhereClause): +class IsNull(ClauseBase): """ Checks if a column value is NULL. """ @@ -250,7 +223,7 @@ def sql(self, get_count: GetCount) -> QuerySingle: return QuerySingle(sql=f"{self.column.fqn()} IS NULL") -class IsNotNull(WhereClause): +class IsNotNull(ClauseBase): """ Checks if a column value is NOT NULL. """ @@ -266,7 +239,7 @@ def sql(self, get_count: GetCount) -> QuerySingle: # Array/list operations -class InArray[T: PyType](WhereClause): +class InArray[T: PyType](ClauseBase): """ Checks if a column value is in a list of values. """ @@ -285,7 +258,7 @@ def sql(self, get_count: GetCount) -> QuerySingle: return QuerySingle(sql=f"{self.column.fqn()} = ANY(%({name})s)", params={name: self.values}) -class NotInArray[T: PyType](WhereClause): +class NotInArray[T: PyType](ClauseBase): """ Checks if a column value is not in a list of values. """ @@ -307,7 +280,7 @@ def sql(self, get_count: GetCount) -> QuerySingle: # Range operations -class Between[T: PyType](WhereClause): +class Between[T: PyType](ClauseBase): """ Checks if a column value is between two values (inclusive). """ @@ -332,7 +305,7 @@ def sql(self, get_count: GetCount) -> QuerySingle: ) -class NotBetween[T: PyType](WhereClause): +class NotBetween[T: PyType](ClauseBase): """ Checks if a column value is not between two values (inclusive). """ @@ -362,7 +335,7 @@ class SqlAble(Protocol): def sql(self) -> QuerySingle: ... -class Exists(WhereClause): +class Exists(ClauseBase): """ Check if a subquery result exists. """ @@ -378,7 +351,7 @@ def sql(self, get_count: GetCount) -> QuerySingle: return QuerySingle(f"EXISTS ({query.sql})", query.params) -class NotExists(WhereClause): +class NotExists(ClauseBase): """ Check if a subquery result does not exist. """ @@ -395,14 +368,14 @@ def sql(self, get_count: GetCount) -> QuerySingle: # Logical operators -class Not(WhereClause): +class Not(ClauseBase): """ Negates a where clause. """ - clause: WhereClause + clause: ClauseBase - def __init__(self, clause: WhereClause): + def __init__(self, clause: ClauseBase): self.clause = clause @override @@ -411,11 +384,15 @@ def sql(self, get_count: GetCount) -> QuerySingle: return QuerySingle(sql=f"NOT ({inner.sql})", params=inner.params) -class And(WhereClause): - left: WhereClause - right: WhereClause +class And(ClauseBase): + """ + AND two clauses. + """ + + left: ClauseBase + right: ClauseBase - def __init__(self, left: WhereClause, right: WhereClause): + def __init__(self, left: ClauseBase, right: ClauseBase): self.left = left self.right = right @@ -428,11 +405,15 @@ def sql(self, get_count: GetCount) -> QuerySingle: return QuerySingle(sql=sql, params=params) -class Or(WhereClause): - left: WhereClause - right: WhereClause +class Or(ClauseBase): + """ + OR two clauses. + """ + + left: ClauseBase + right: ClauseBase - def __init__(self, left: WhereClause, right: WhereClause): + def __init__(self, left: ClauseBase, right: ClauseBase): self.left = left self.right = right diff --git a/tests/test_vector.py b/tests/test_vector.py new file mode 100644 index 0000000..b80efdb --- /dev/null +++ b/tests/test_vector.py @@ -0,0 +1,124 @@ +"""Tests for pgvector support (L2Distance).""" + +from typing import Annotated + +from pydantic import BaseModel + +from embar.column.common import Integer +from embar.column.pg import Vector +from embar.config import EmbarConfig +from embar.db.pg import PgDb +from embar.query.vector import CosineDistance, L2Distance +from embar.query.where import Gt, Lt +from embar.table import Table + + +class Embedding(Table): + embar_config: EmbarConfig = EmbarConfig(table_name="embeddings") + + id: Integer = Integer(primary=True) + vec_a: Vector = Vector(3) + vec_b: Vector = Vector(3) + + +def test_order_by_l2distance_with_literal(db_dummy: PgDb): + """Test ORDER BY with L2Distance using a literal vector.""" + + class EmbeddingSel(BaseModel): + id: Annotated[int, Embedding.id] + + # fmt: off + query = ( + db_dummy.select(EmbeddingSel) + .from_(Embedding) + .order_by(L2Distance(Embedding.vec_a, [1.0, 2.0, 3.0])) + ) + # fmt: on + + sql_result = query.sql() + assert "ORDER BY" in sql_result.sql + assert "<->" in sql_result.sql + assert '"embeddings"."vec_a"' in sql_result.sql + + +def test_order_by_l2distance_with_column(db_dummy: PgDb): + """Test ORDER BY with L2Distance comparing two vector columns.""" + + class EmbeddingSel(BaseModel): + id: Annotated[int, Embedding.id] + + # fmt: off + query = ( + db_dummy.select(EmbeddingSel) + .from_(Embedding) + .order_by(L2Distance(Embedding.vec_a, Embedding.vec_b)) + ) + # fmt: on + + sql_result = query.sql() + assert "ORDER BY" in sql_result.sql + assert "<->" in sql_result.sql + assert '"embeddings"."vec_a"' in sql_result.sql + assert '"embeddings"."vec_b"' in sql_result.sql + + +def test_where_l2distance_with_lt(db_dummy: PgDb): + """Test WHERE clause filtering by L2Distance < threshold.""" + + class EmbeddingSel(BaseModel): + id: Annotated[int, Embedding.id] + + # fmt: off + query = ( + db_dummy.select(EmbeddingSel) + .from_(Embedding) + .where(Lt(L2Distance(Embedding.vec_a, [1.0, 2.0, 3.0]), 0.5)) + ) + # fmt: on + + sql_result = query.sql() + assert "WHERE" in sql_result.sql + assert "<->" in sql_result.sql + assert "<" in sql_result.sql + assert '"embeddings"."vec_a"' in sql_result.sql + + +def test_where_l2distance_with_gt(db_dummy: PgDb): + """Test WHERE clause filtering by L2Distance > threshold.""" + + class EmbeddingSel(BaseModel): + id: Annotated[int, Embedding.id] + + # fmt: off + query = ( + db_dummy.select(EmbeddingSel) + .from_(Embedding) + .where(Gt(L2Distance(Embedding.vec_a, [1.0, 2.0, 3.0]), 0.5)) + ) + # fmt: on + + sql_result = query.sql() + assert "WHERE" in sql_result.sql + assert "<->" in sql_result.sql + assert ">" in sql_result.sql + assert '"embeddings"."vec_a"' in sql_result.sql + + +def test_order_by_cosine_distance_with_literal(db_dummy: PgDb): + """Test ORDER BY with CosineDistance using a literal vector.""" + + class EmbeddingSel(BaseModel): + id: Annotated[int, Embedding.id] + + # fmt: off + query = ( + db_dummy.select(EmbeddingSel) + .from_(Embedding) + .order_by(CosineDistance(Embedding.vec_a, [1.0, 2.0, 3.0])) + ) + # fmt: on + + sql_result = query.sql() + assert "ORDER BY" in sql_result.sql + assert "<=>" in sql_result.sql + assert '"embeddings"."vec_a"' in sql_result.sql