Skip to content

Commit 0dd0c82

Browse files
committed
Added support for Postgres arrays for pgx - closes #20
1 parent 1ec55f1 commit 0dd0c82

File tree

3 files changed

+38
-4
lines changed

3 files changed

+38
-4
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.2.3 (unreleased)
2+
3+
- Added support for Postgres arrays for pgx
4+
15
## 0.2.2 (2024-08-10)
26

37
- Added support for `CopyFrom` with `string` values

pgx/register.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@ import (
1010

1111
func RegisterTypes(ctx context.Context, conn *pgx.Conn) error {
1212
var vectorOid *uint32
13+
var vectorArrayOid *uint32
1314
var halfvecOid *uint32
15+
var halfvecArrayOid *uint32
1416
var sparsevecOid *uint32
15-
err := conn.QueryRow(ctx, "SELECT to_regtype('vector')::oid, to_regtype('halfvec')::oid, to_regtype('sparsevec')::oid").Scan(&vectorOid, &halfvecOid, &sparsevecOid)
17+
var sparsevecArrayOid *uint32
18+
err := conn.QueryRow(ctx, "SELECT to_regtype('vector')::oid, to_regtype('_vector')::oid, to_regtype('halfvec')::oid, to_regtype('_halfvec')::oid, to_regtype('sparsevec')::oid, to_regtype('_sparsevec')::oid").Scan(&vectorOid, &vectorArrayOid, &halfvecOid, &halfvecArrayOid, &sparsevecOid, &sparsevecArrayOid)
1619
if err != nil {
1720
return err
1821
}
@@ -22,14 +25,20 @@ func RegisterTypes(ctx context.Context, conn *pgx.Conn) error {
2225
}
2326

2427
tm := conn.TypeMap()
25-
tm.RegisterType(&pgtype.Type{Name: "vector", OID: *vectorOid, Codec: &VectorCodec{}})
28+
vectorType := pgtype.Type{Name: "vector", OID: *vectorOid, Codec: &VectorCodec{}}
29+
tm.RegisterType(&vectorType)
30+
tm.RegisterType(&pgtype.Type{Name: "_vector", OID: *vectorArrayOid, Codec: &pgtype.ArrayCodec{ElementType: &vectorType}})
2631

2732
if halfvecOid != nil {
28-
tm.RegisterType(&pgtype.Type{Name: "halfvec", OID: *halfvecOid, Codec: &HalfVectorCodec{}})
33+
halfvecType := pgtype.Type{Name: "halfvec", OID: *halfvecOid, Codec: &HalfVectorCodec{}}
34+
tm.RegisterType(&halfvecType)
35+
tm.RegisterType(&pgtype.Type{Name: "_halfvec", OID: *halfvecArrayOid, Codec: &pgtype.ArrayCodec{ElementType: &halfvecType}})
2936
}
3037

3138
if sparsevecOid != nil {
32-
tm.RegisterType(&pgtype.Type{Name: "sparsevec", OID: *sparsevecOid, Codec: &SparseVectorCodec{}})
39+
sparsevecType := pgtype.Type{Name: "sparsevec", OID: *sparsevecOid, Codec: &SparseVectorCodec{}}
40+
tm.RegisterType(&sparsevecType)
41+
tm.RegisterType(&pgtype.Type{Name: "_sparsevec", OID: *sparsevecArrayOid, Codec: &pgtype.ArrayCodec{ElementType: &sparsevecType}})
3342
}
3443

3544
return nil

pgx_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,25 @@ func TestPgx(t *testing.T) {
169169
if err != nil {
170170
panic(err)
171171
}
172+
173+
embeddings := []pgvector.Vector{pgvector.NewVector([]float32{1, 2, 3}), pgvector.NewVector([]float32{4, 5, 6})}
174+
halfEmbeddings := []pgvector.HalfVector{pgvector.NewHalfVector([]float32{1, 2, 3}), pgvector.NewHalfVector([]float32{4, 5, 6})}
175+
sparseEmbeddings := []pgvector.SparseVector{pgvector.NewSparseVector([]float32{1, 2, 3}), pgvector.NewSparseVector([]float32{4, 5, 6})}
176+
row = conn.QueryRow(ctx, "SELECT $1::vector[], $2::halfvec[], $3::sparsevec[]", embeddings, halfEmbeddings, sparseEmbeddings)
177+
var scanEmbeddings []pgvector.Vector
178+
var scanHalfEmbeddings []pgvector.HalfVector
179+
var scanSparseEmbeddings []pgvector.SparseVector
180+
err = row.Scan(&scanEmbeddings, &scanHalfEmbeddings, &scanSparseEmbeddings)
181+
if err != nil {
182+
panic(err)
183+
}
184+
if !reflect.DeepEqual(scanEmbeddings, embeddings) {
185+
t.Error()
186+
}
187+
if !reflect.DeepEqual(scanHalfEmbeddings, halfEmbeddings) {
188+
t.Error()
189+
}
190+
if !reflect.DeepEqual(scanSparseEmbeddings, sparseEmbeddings) {
191+
t.Error()
192+
}
172193
}

0 commit comments

Comments
 (0)