Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def postgres_container_raw(request: pytest.FixtureRequest):
"""Session-scoped postgres container for docs tests."""
try:
with PostgresContainer("postgres:18-alpine3.22", port=25432) as postgres:
with PostgresContainer("pgvector/pgvector:0.8.1-pg18-trixie", port=25432) as postgres:
request.addfinalizer(postgres.stop)
yield postgres
except Exception as e:
Expand Down
110 changes: 110 additions & 0 deletions docs/extensions/vector.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Vector

Embar supports [pgvector](https://github.com/pgvector/pgvector), the open-source vector similarity search extension for PostgreSQL.

Before using vector columns, you must install and activate the extension:

```sql
CREATE EXTENSION vector;
```

## Creating a Vector Column

Use `Vector` to store embeddings with a fixed dimension:

```{.python fixture:postgres_container}
import asyncio
import psycopg

from embar.column.common import Integer
from embar.column.pg import Vector
from embar.db.pg import AsyncPgDb
from embar.table import Table

class Document(Table):
id: Integer = Integer()
embedding: Vector = Vector(3) # 3-dimensional vector

async def get_db():
database_url = "postgres://pg:pw@localhost:25432/db"
conn = await psycopg.AsyncConnection.connect(database_url)
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
db = AsyncPgDb(conn)
await db.migrate([Document])
return db

async def setup():
db = await get_db()
# Insert some documents with embeddings
await db.insert(Document).values(Document(id=1, embedding=[1.0, 0.0, 0.0]))
await db.insert(Document).values(Document(id=2, embedding=[0.0, 1.0, 0.0]))
await db.insert(Document).values(Document(id=3, embedding=[0.0, 0.0, 1.0]))

asyncio.run(setup())
```

## L2 Distance

Use `L2Distance` for Euclidean distance searches. This uses the `<->` operator.

### Order By L2 Distance

Find documents ordered by distance to a query vector:

```{.python continuation}
from embar.query.vector import L2Distance

async def order_by_l2():
db = await get_db()
query_vector = [1.0, 0.5, 0.0]
docs = await (
db.select(Document.all())
.from_(Document)
.order_by(L2Distance(Document.embedding, query_vector))
)
print([d.id for d in docs])

asyncio.run(order_by_l2())
```

### Filter By L2 Distance

Find documents within a distance threshold:

```{.python continuation}
from embar.query.where import Lt

async def filter_by_l2():
db = await get_db()
query_vector = [1.0, 0.0, 0.0]
docs = await (
db.select(Document.all())
.from_(Document)
.where(Lt(L2Distance(Document.embedding, query_vector), 0.5))
)
print([d.id for d in docs])

asyncio.run(filter_by_l2())
```

## Cosine Distance

Use `CosineDistance` for cosine similarity searches. This uses the `<=>` operator.

### Order By Cosine Distance

```{.python continuation}
from embar.query.vector import CosineDistance

async def order_by_cosine():
db = await get_db()
query_vector = [1.0, 0.5, 0.0]
docs = await (
db.select(Document.all())
.from_(Document)
.order_by(CosineDistance(Document.embedding, query_vector))
)
print([d.id for d in docs])

asyncio.run(order_by_cosine())
```
3 changes: 3 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ nav:
- Transactions: query/transactions.md
- Raw SQL: query/raw-sql.md

- Extensions:
- Vector: extensions/vector.md

- Integration:
- Migrations: migrations.md
- Types: types.md
Expand Down
29 changes: 29 additions & 0 deletions src/embar/column/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"Timestamp",
"Varchar",
"EnumCol",
"Vector",
]


Expand Down Expand Up @@ -344,3 +345,31 @@ def __init__(
self._sql_type = pg_enum.name

super().__init__(name=name, default=default, primary=primary, not_null=not_null)


# Extension: pgvector
# Should also support `halfvec` and `bit`
class Vector(Column[list[float]]):
"""
Vector column using [pgvector](https://github.com/pgvector/pgvector).

This assumes the extension is already installed and activated with
CREATE EXTENSION vector;
"""

_sql_type: str = "VECTOR"
_py_type: Type = list[float]

def __init__(
self,
length: int,
name: str | None = None,
default: list[float] | None = None,
primary: bool = False,
not_null: bool = False,
):
"""
Create a new Vector instance.
"""
self._extra_args: tuple[int] | tuple[int, int] | None = (length,)
super().__init__(name=name, default=default, primary=primary, not_null=not_null)
6 changes: 3 additions & 3 deletions src/embar/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from embar.column.base import ColumnBase
from embar.constraint_base import Constraint
from embar.custom_types import PyType
from embar.query.clause_base import ClauseBase
from embar.query.query import QuerySingle
from embar.query.where import WhereClause


class Index:
Expand Down Expand Up @@ -70,7 +70,7 @@ class IndexReady(Constraint):
unique: bool
name: str
columns: tuple[Callable[[], ColumnBase], ...]
_where_clause: Callable[[], WhereClause] | None = None
_where_clause: Callable[[], ClauseBase] | None = None

def __init__(self, name: str, unique: bool, *columns: Callable[[], ColumnBase]):
"""
Expand All @@ -80,7 +80,7 @@ def __init__(self, name: str, unique: bool, *columns: Callable[[], ColumnBase]):
self.unique = unique
self.columns = columns

def where(self, where_clause: Callable[[], WhereClause]) -> Self:
def where(self, where_clause: Callable[[], ClauseBase]) -> Self:
"""
Add a WHERE clause to create a partial index.
"""
Expand Down
13 changes: 12 additions & 1 deletion src/embar/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import (
Annotated,
Any,
Expand All @@ -8,7 +9,7 @@
get_type_hints,
)

from pydantic import BaseModel, Field, create_model
from pydantic import BaseModel, BeforeValidator, Field, create_model

from embar.column.base import ColumnBase
from embar.db.base import DbType
Expand Down Expand Up @@ -160,6 +161,10 @@ class MyTable(Table): ...
fields_dict: dict[str, Any] = {}
for field_name, column in cls._fields.items(): # pyright:ignore[reportPrivateUsage]
field_type = column.info.py_type

if column.info.col_type == "VECTOR":
field_type = Annotated[field_type, BeforeValidator(_parse_json_list)]

fields_dict[field_name] = (
Annotated[field_type, column],
Field(default_factory=lambda a=column: column.info.fqn()),
Expand All @@ -185,3 +190,9 @@ def upgrade_model_nested_fields[B: BaseModel](model: type[B]) -> type[B]:
new_class.model_rebuild()

return new_class


def _parse_json_list(v: Any):
if isinstance(v, str):
return json.loads(v)
return v
25 changes: 25 additions & 0 deletions src/embar/query/clause_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
from typing import Callable

from embar.query.query import QuerySingle

# Where clauses get passed a get_count() function that returns a monotonically
# increasing integer. This allows each SQL binding parameter to get a unique
# name like `%(eq_id_2)s` in psycopg format.
type GetCount = Callable[[], int]


class ClauseBase(ABC):
"""
ABC for ORDER BY and WHERE clauses.

Not all use the get_count() directly (those with no bindings)
but their children might.
"""

@abstractmethod
def sql(self, get_count: GetCount) -> QuerySingle:
"""
Generate the SQL fragment for this clause.
"""
...
2 changes: 1 addition & 1 deletion src/embar/query/conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import override

from embar.custom_types import PyType
from embar.query.clause_base import GetCount
from embar.query.query import QuerySingle
from embar.query.where import GetCount

# require at least one element in tuple
TupleAtLeastOne = tuple[str, *tuple[str, ...]]
Expand Down
24 changes: 13 additions & 11 deletions src/embar/query/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from embar.model import (
generate_model,
)
from embar.query.order_by import Asc, BareColumn, Desc, OrderBy, OrderByClause, RawSqlOrder
from embar.query.clause_base import ClauseBase
from embar.query.order_by import Asc, BareColumn, Desc, OrderBy, RawSqlOrder
from embar.query.query import QuerySingle
from embar.query.where import WhereClause
from embar.sql import Sql
from embar.table import Table

Expand All @@ -30,7 +30,7 @@ class DeleteQueryReady[T: Table, Db: AllDbBase]:
table: type[T]
_db: Db

_where_clause: WhereClause | None = None
_where_clause: ClauseBase | None = None
_order_clause: OrderBy | None = None
_limit_value: int | None = None

Expand All @@ -44,7 +44,7 @@ def __init__(self, table: type[T], db: Db):
def returning(self) -> DeleteQueryReturning[T, Db]:
return DeleteQueryReturning(self.table, self._db, self._where_clause, self._order_clause, self._limit_value)

def where(self, where_clause: WhereClause) -> Self:
def where(self, where_clause: ClauseBase) -> Self:
"""
Add a WHERE clause to the query.
"""
Expand Down Expand Up @@ -79,7 +79,7 @@ class User(Table):
```
"""
# Convert each clause to an OrderByClause
order_clauses: list[OrderByClause] = []
order_clauses: list[ClauseBase] = []
for clause in clauses:
if isinstance(clause, (Asc, Desc)):
order_clauses.append(clause)
Expand Down Expand Up @@ -170,8 +170,9 @@ def get_count() -> int:
params = {**params, **where_data.params}

if self._order_clause is not None:
order_by_sql = self._order_clause.sql()
sql += f"\nORDER BY {order_by_sql}"
order_by_query = self._order_clause.sql(get_count)
sql += f"\nORDER BY {order_by_query.sql}"
params = {**params, **order_by_query.params}

if self._limit_value is not None:
sql += f"\nLIMIT {self._limit_value}"
Expand All @@ -189,15 +190,15 @@ class DeleteQueryReturning[T: Table, Db: AllDbBase]:
table: type[T]
_db: Db

_where_clause: WhereClause | None = None
_where_clause: ClauseBase | None = None
_order_clause: OrderBy | None = None
_limit_value: int | None = None

def __init__(
self,
table: type[T],
db: Db,
where_clause: WhereClause | None,
where_clause: ClauseBase | None,
order_clause: OrderBy | None,
limit_value: int | None,
):
Expand Down Expand Up @@ -287,8 +288,9 @@ def get_count() -> int:
params = {**params, **where_data.params}

if self._order_clause is not None:
order_by_sql = self._order_clause.sql()
sql += f"\nORDER BY {order_by_sql}"
order_by_query = self._order_clause.sql(get_count)
sql += f"\nORDER BY {order_by_query.sql}"
params = {**params, **order_by_query.params}

if self._limit_value is not None:
sql += f"\nLIMIT {self._limit_value}"
Expand Down
4 changes: 2 additions & 2 deletions src/embar/query/having.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass

from embar.query.where import WhereClause
from embar.query.clause_base import ClauseBase


@dataclass
Expand Down Expand Up @@ -32,4 +32,4 @@ class Having:
```
"""

clause: WhereClause
clause: ClauseBase
Loading