Skip to content

Commit 9f1ba95

Browse files
committed
Added support for halfvec binary format for pgx - #24
1 parent 499c8e8 commit 9f1ba95

File tree

5 files changed

+59
-10
lines changed

5 files changed

+59
-10
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## 0.3.0 (unreleased)
22

33
- Added distance functions for Ent
4+
- Added support for `halfvec` binary format for pgx
45
- Dropped support for Go < 1.23
56

67
## 0.2.3 (2025-01-15)

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ require (
3838
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
3939
github.com/vmihailenco/tagparser v0.1.2 // indirect
4040
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
41+
github.com/x448/float16 v0.8.4 // indirect
4142
github.com/zclconf/go-cty v1.16.2 // indirect
4243
github.com/zclconf/go-cty-yaml v1.1.0 // indirect
4344
golang.org/x/crypto v0.36.0 // indirect

halfvec.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ func (v *HalfVector) Scan(src interface{}) (err error) {
7171
return v.Parse(string(src))
7272
case string:
7373
return v.Parse(src)
74+
case []float32:
75+
v.vec = src
76+
return nil
7477
default:
7578
return fmt.Errorf("unsupported data type: %T", src)
7679
}

pgx/halfvec.go

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,24 @@ package pgx
22

33
import (
44
"database/sql/driver"
5+
"encoding/binary"
56
"fmt"
7+
"slices"
68

79
"github.com/jackc/pgx/v5"
810
"github.com/jackc/pgx/v5/pgtype"
911
"github.com/pgvector/pgvector-go"
12+
"github.com/x448/float16"
1013
)
1114

1215
type HalfVectorCodec struct{}
1316

1417
func (HalfVectorCodec) FormatSupported(format int16) bool {
15-
return format == pgx.TextFormatCode
18+
return format == pgx.BinaryFormatCode || format == pgx.TextFormatCode
1619
}
1720

1821
func (HalfVectorCodec) PreferredFormat() int16 {
19-
return pgx.TextFormatCode
22+
return pgx.BinaryFormatCode
2023
}
2124

2225
func (HalfVectorCodec) PlanEncode(m *pgtype.Map, oid uint32, format int16, value any) pgtype.EncodePlan {
@@ -25,35 +28,76 @@ func (HalfVectorCodec) PlanEncode(m *pgtype.Map, oid uint32, format int16, value
2528
return nil
2629
}
2730

28-
if format == pgx.TextFormatCode {
31+
switch format {
32+
case pgx.BinaryFormatCode:
33+
return encodePlanHalfVectorCodecBinary{}
34+
case pgx.TextFormatCode:
2935
return encodePlanHalfVectorCodecText{}
3036
}
3137

3238
return nil
3339
}
3440

41+
type encodePlanHalfVectorCodecBinary struct{}
42+
43+
func (encodePlanHalfVectorCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) {
44+
v := value.(pgvector.HalfVector)
45+
vec := v.Slice()
46+
dim := len(vec)
47+
buf = slices.Grow(buf, 4+2*dim)
48+
buf = binary.BigEndian.AppendUint16(buf, uint16(dim))
49+
buf = binary.BigEndian.AppendUint16(buf, 0)
50+
for _, v := range vec {
51+
buf = binary.BigEndian.AppendUint16(buf, float16.Fromfloat32(v).Bits())
52+
}
53+
return buf, nil
54+
}
55+
3556
type encodePlanHalfVectorCodecText struct{}
3657

3758
func (encodePlanHalfVectorCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
3859
v := value.(pgvector.HalfVector)
3960
return v.EncodeText(buf)
4061
}
4162

42-
type scanPlanHalfVectorCodecText struct{}
43-
4463
func (HalfVectorCodec) PlanScan(m *pgtype.Map, oid uint32, format int16, target any) pgtype.ScanPlan {
4564
_, ok := target.(*pgvector.HalfVector)
4665
if !ok {
4766
return nil
4867
}
4968

50-
if format == pgx.TextFormatCode {
69+
switch format {
70+
case pgx.BinaryFormatCode:
71+
return scanPlanHalfVectorCodecBinary{}
72+
case pgx.TextFormatCode:
5173
return scanPlanHalfVectorCodecText{}
5274
}
5375

5476
return nil
5577
}
5678

79+
type scanPlanHalfVectorCodecBinary struct{}
80+
81+
func (scanPlanHalfVectorCodecBinary) Scan(src []byte, dst any) error {
82+
v := (dst).(*pgvector.HalfVector)
83+
buf := src
84+
dim := int(binary.BigEndian.Uint16(buf[0:2]))
85+
unused := binary.BigEndian.Uint16(buf[2:4])
86+
if unused != 0 {
87+
return fmt.Errorf("expected unused to be 0")
88+
}
89+
90+
vec := make([]float32, 0, dim)
91+
offset := 4
92+
for i := 0; i < dim; i++ {
93+
vec = append(vec, float16.Frombits(binary.BigEndian.Uint16(buf[offset:offset+2])).Float32())
94+
offset += 2
95+
}
96+
return v.Scan(vec)
97+
}
98+
99+
type scanPlanHalfVectorCodecText struct{}
100+
57101
func (scanPlanHalfVectorCodecText) Scan(src []byte, dst any) error {
58102
v := (dst).(*pgvector.HalfVector)
59103
return v.Scan(src)

pgx_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ func TestPgx(t *testing.T) {
161161
_, err = conn.CopyFrom(
162162
ctx,
163163
pgx.Identifier{"pgx_items"},
164-
[]string{"embedding", "binary_embedding", "sparse_embedding"},
164+
[]string{"embedding", "half_embedding", "binary_embedding", "sparse_embedding"},
165165
pgx.CopyFromSlice(1, func(i int) ([]any, error) {
166-
return []interface{}{"[1,2,3]", "101", "{1:1,2:2,3:3}/3"}, nil
166+
return []interface{}{"[1,2,3]", "[1,2,3]", "101", "{1:1,2:2,3:3}/3"}, nil
167167
}),
168168
)
169169
if err != nil {
@@ -186,9 +186,9 @@ func TestPgx(t *testing.T) {
186186
_, err = pool.CopyFrom(
187187
ctx,
188188
pgx.Identifier{"pgx_items"},
189-
[]string{"embedding", "binary_embedding", "sparse_embedding"},
189+
[]string{"embedding", "half_embedding", "binary_embedding", "sparse_embedding"},
190190
pgx.CopyFromSlice(1, func(i int) ([]any, error) {
191-
return []interface{}{"[1,2,3]", "101", "{1:1,2:2,3:3}/3"}, nil
191+
return []interface{}{"[1,2,3]", "[1,2,3]", "101", "{1:1,2:2,3:3}/3"}, nil
192192
}),
193193
)
194194
if err != nil {

0 commit comments

Comments
 (0)