Skip to content

Commit 34d12ce

Browse files
authored
feat: Add vocabulary quantization support (#30)
1 parent 3fe7e77 commit 34d12ce

File tree

11 files changed

+315
-62
lines changed

11 files changed

+315
-62
lines changed

src/model.rs

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ use tokenizers::Tokenizer;
1212
pub struct StaticModel {
1313
tokenizer: Tokenizer,
1414
embeddings: Array2<f32>,
15+
weights: Option<Vec<f32>>,
16+
token_mapping: Option<Vec<usize>>,
1517
normalize: bool,
1618
median_token_length: usize,
1719
unk_token_id: Option<usize>,
@@ -115,9 +117,48 @@ impl StaticModel {
115117
};
116118
let embeddings = Array2::from_shape_vec((rows, cols), floats).context("failed to build embeddings array")?;
117119

120+
// Load optional weights for vocabulary quantization
121+
let weights = match safet.tensor("weights") {
122+
Ok(t) => {
123+
let raw = t.data();
124+
let v: Vec<f32> = match t.dtype() {
125+
Dtype::F64 => raw
126+
.chunks_exact(8)
127+
.map(|b| f64::from_le_bytes(b.try_into().unwrap()) as f32)
128+
.collect(),
129+
Dtype::F32 => raw
130+
.chunks_exact(4)
131+
.map(|b| f32::from_le_bytes(b.try_into().unwrap()))
132+
.collect(),
133+
Dtype::F16 => raw
134+
.chunks_exact(2)
135+
.map(|b| half::f16::from_le_bytes(b.try_into().unwrap()).to_f32())
136+
.collect(),
137+
other => return Err(anyhow!("unsupported weights dtype: {:?}", other)),
138+
};
139+
Some(v)
140+
}
141+
Err(_) => None,
142+
};
143+
144+
// Load optional token mapping for vocabulary quantization
145+
let token_mapping = match safet.tensor("mapping") {
146+
Ok(t) => {
147+
let raw = t.data();
148+
let v: Vec<usize> = raw
149+
.chunks_exact(4)
150+
.map(|b| i32::from_le_bytes(b.try_into().unwrap()) as usize)
151+
.collect();
152+
Some(v)
153+
}
154+
Err(_) => None,
155+
};
156+
118157
Ok(Self {
119158
tokenizer,
120159
embeddings,
160+
weights,
161+
token_mapping,
121162
normalize,
122163
median_token_length,
123164
unk_token_id: Some(unk_token_id),
@@ -202,18 +243,46 @@ impl StaticModel {
202243

203244
/// Mean-pool a single token-ID list into a vector
204245
fn pool_ids(&self, ids: Vec<u32>) -> Vec<f32> {
205-
let mut sum = vec![0.0; self.embeddings.ncols()];
246+
let dim = self.embeddings.ncols();
247+
let mut sum = vec![0.0; dim];
248+
let mut cnt = 0usize;
249+
206250
for &id in &ids {
207-
let row = self.embeddings.row(id as usize);
251+
let tok = id as usize;
252+
253+
// Remap: row = token_mapping[id] or id
254+
let row_idx = if let Some(m) = &self.token_mapping {
255+
*m.get(tok).unwrap_or(&tok)
256+
} else {
257+
tok
258+
};
259+
260+
// Scale by per-token weight if present
261+
let scale = if let Some(w) = &self.weights {
262+
*w.get(tok).unwrap_or(&1.0)
263+
} else {
264+
1.0
265+
};
266+
267+
let row = self.embeddings.row(row_idx);
208268
for (i, &v) in row.iter().enumerate() {
209-
sum[i] += v;
269+
sum[i] += v * scale;
210270
}
271+
cnt += 1;
211272
}
212-
let cnt = ids.len().max(1) as f32;
213-
sum.iter_mut().for_each(|x| *x /= cnt);
273+
274+
// Mean pool the embeddings
275+
let denom = (cnt.max(1)) as f32;
276+
for x in &mut sum {
277+
*x /= denom;
278+
}
279+
280+
// Normalize the embeddings if required
214281
if self.normalize {
215282
let norm = sum.iter().map(|&v| v * v).sum::<f32>().sqrt().max(1e-12);
216-
sum.iter_mut().for_each(|x| *x /= norm);
283+
for x in &mut sum {
284+
*x /= norm;
285+
}
217286
}
218287
sum
219288
}

tests/common.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#![allow(dead_code)]
12
use model2vec_rs::model::StaticModel;
23

34
/// Load the small float32 test model from fixtures
@@ -10,3 +11,24 @@ pub fn load_test_model() -> StaticModel {
1011
)
1112
.expect("Failed to load test model")
1213
}
14+
15+
/// Load the vocab quantized test model from fixtures
16+
pub fn load_test_model_vocab_quantized() -> StaticModel {
17+
StaticModel::from_pretrained(
18+
"tests/fixtures/test-model-vocab-quantized",
19+
None, // token
20+
None, // normalize
21+
None, // subfolder
22+
)
23+
.expect("Failed to load test model")
24+
}
25+
26+
pub fn encode_with_model(path: &str) -> Vec<f32> {
27+
// Helper function to load the model and encode "hello world"
28+
let model = StaticModel::from_pretrained(path, None, None, None)
29+
.unwrap_or_else(|e| panic!("Failed to load model at {path}: {e}"));
30+
31+
let out = model.encode(&["hello world".to_string()]);
32+
assert_eq!(out.len(), 1);
33+
out.into_iter().next().unwrap()
34+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[[-0.363402567303566, -0.17280622671667234, 0.39940570500107386, -0.16813011070789555, -0.15618451933846883, -0.14994650725286524, 0.006071081245258223, -0.16688048919255372, 0.22185219763399205, -0.02887293996697886, -0.17287425174176316, -0.01464316366136059, 0.16364637740934518, 0.15876414737891392, 0.06036445094799073, 0.15604592511625437, -0.0671308839546444, -0.016413190151500695, 0.016156947144592284, -0.04046410877612431, 0.08342219180115457, -0.06072982382315607, -0.15155935894530448, -0.27756653365043565, 0.04183386122272067, -0.02478048814648766, 0.048693007647196467, 0.15564136567656622, 0.03729875053535759, -0.06892603188806953, 0.08513432033392887, 0.0036654278831274112, -0.017677908666363845, 0.062035159999304555, -0.1394435606629564, 0.05264278960819571, -0.10000422994390393, 0.162456462739632, 0.0026303158188036926, -0.010224468015697916, -0.12629957405039433, -0.08506841545219175, -0.06720500777509077, -0.04443293593977252, 0.01816271214883152, 0.11269895366859049, 0.15572718186207016, -0.12838458617894438, 0.020126459971623472, -0.16689367919078762, 0.1038076137507656, 0.005876202291780198, 0.11467950137199819, -0.06360069640738063, 0.12898717602987858, 0.06665970239323335, 0.1263998072107107, 0.054322006298590964, 0.02275680905863399, -0.09242075684142392, 0.0003214892909238989, 0.06269664923701938, 0.007532826486481935, 0.006629162182434642]]
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
---
2+
base_model: BAAI/bge-base-en-v1.5
3+
language:
4+
- en
5+
library_name: model2vec
6+
license: mit
7+
model_name: test-model-vocab-quantized
8+
tags:
9+
- embeddings
10+
- static-embeddings
11+
- sentence-transformers
12+
---
13+
14+
# test-model-vocab-quantized Model Card
15+
16+
This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of the BAAI/bge-base-en-v1.5(https://huggingface.co/BAAI/bge-base-en-v1.5) Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical. Model2Vec models are the smallest, fastest, and most performant static embedders available. The distilled models are up to 50 times smaller and 500 times faster than traditional Sentence Transformers.
17+
18+
19+
## Installation
20+
21+
Install model2vec using pip:
22+
```
23+
pip install model2vec
24+
```
25+
26+
## Usage
27+
28+
### Using Model2Vec
29+
30+
The [Model2Vec library](https://github.com/MinishLab/model2vec) is the fastest and most lightweight way to run Model2Vec models.
31+
32+
Load this model using the `from_pretrained` method:
33+
```python
34+
from model2vec import StaticModel
35+
36+
# Load a pretrained Model2Vec model
37+
model = StaticModel.from_pretrained("test-model-vocab-quantized")
38+
39+
# Compute text embeddings
40+
embeddings = model.encode(["Example sentence"])
41+
```
42+
43+
### Using Sentence Transformers
44+
45+
You can also use the [Sentence Transformers library](https://github.com/UKPLab/sentence-transformers) to load and use the model:
46+
47+
```python
48+
from sentence_transformers import SentenceTransformer
49+
50+
# Load a pretrained Sentence Transformer model
51+
model = SentenceTransformer("test-model-vocab-quantized")
52+
53+
# Compute text embeddings
54+
embeddings = model.encode(["Example sentence"])
55+
```
56+
57+
### Distilling a Model2Vec model
58+
59+
You can distill a Model2Vec model from a Sentence Transformer model using the `distill` method. First, install the `distill` extra with `pip install model2vec[distill]`. Then, run the following code:
60+
61+
```python
62+
from model2vec.distill import distill
63+
64+
# Distill a Sentence Transformer model, in this case the BAAI/bge-base-en-v1.5 model
65+
m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=256)
66+
67+
# Save the model
68+
m2v_model.save_pretrained("m2v_model")
69+
```
70+
71+
## How it works
72+
73+
Model2vec creates a small, fast, and powerful model that outperforms other static embedding models by a large margin on all tasks we could find, while being much faster to create than traditional static embedding models such as GloVe. Best of all, you don't need any data to distill a model using Model2Vec.
74+
75+
It works by passing a vocabulary through a sentence transformer model, then reducing the dimensionality of the resulting embeddings using PCA, and finally weighting the embeddings using [SIF weighting](https://openreview.net/pdf?id=SyK00v5xx). During inference, we simply take the mean of all token embeddings occurring in a sentence.
76+
77+
## Additional Resources
78+
79+
- [Model2Vec Repo](https://github.com/MinishLab/model2vec)
80+
- [Model2Vec Base Models](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e)
81+
- [Model2Vec Results](https://github.com/MinishLab/model2vec/tree/main/results)
82+
- [Model2Vec Tutorials](https://github.com/MinishLab/model2vec/tree/main/tutorials)
83+
- [Website](https://minishlab.github.io/)
84+
85+
86+
## Library Authors
87+
88+
Model2Vec was developed by the [Minish Lab](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled).
89+
90+
## Citation
91+
92+
Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) if you use this model in your work.
93+
```
94+
@article{minishlab2024model2vec,
95+
author = {Tulkens, Stephan and {van Dongen}, Thomas},
96+
title = {Model2Vec: Fast State-of-the-Art Static Embeddings},
97+
year = {2024},
98+
url = {https://github.com/MinishLab/model2vec}
99+
}
100+
```
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"model_type": "model2vec",
3+
"architectures": [
4+
"StaticModel"
5+
],
6+
"tokenizer_name": "BAAI/bge-base-en-v1.5",
7+
"apply_pca": 64,
8+
"apply_zipf": null,
9+
"sif_coefficient": 0.0001,
10+
"hidden_dim": 64,
11+
"seq_length": 1000000,
12+
"normalize": true
13+
}
378 KB
Binary file not shown.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[
2+
{
3+
"idx": 0,
4+
"name": "0",
5+
"path": ".",
6+
"type": "sentence_transformers.models.StaticEmbedding"
7+
},
8+
{
9+
"idx": 1,
10+
"name": "1",
11+
"path": "1_Normalize",
12+
"type": "sentence_transformers.models.Normalize"
13+
}
14+
]

tests/fixtures/test-model-vocab-quantized/tokenizer.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

tests/test_model.rs

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,6 @@
11
mod common;
2-
use approx::assert_relative_eq;
32
use common::load_test_model;
43
use model2vec_rs::model::StaticModel;
5-
use std::fs;
6-
7-
#[test]
8-
fn test_encode_matches_python_model2vec() {
9-
// Load the test model
10-
let model = load_test_model();
11-
12-
// Define the short and long text inputs
13-
let long_text = vec!["hello"; 1000].join(" ");
14-
let short_text = "hello world".to_string();
15-
let cases = vec![
16-
("tests/fixtures/embeddings_short.json", vec![short_text]),
17-
("tests/fixtures/embeddings_long.json", vec![long_text]),
18-
];
19-
20-
for (fixture_path, inputs) in cases {
21-
// Read and parse the Python‐generated embedding fixture
22-
let fixture =
23-
fs::read_to_string(fixture_path).unwrap_or_else(|_| panic!("Fixture not found: {}", fixture_path));
24-
let expected: Vec<Vec<f32>> = serde_json::from_str(&fixture).expect("Failed to parse fixture");
25-
26-
// Encode with the Rust model
27-
let output = model.encode(&inputs);
28-
29-
// Sanity checks
30-
assert_eq!(
31-
output.len(),
32-
expected.len(),
33-
"number of sentences mismatch for {}",
34-
fixture_path
35-
);
36-
assert_eq!(
37-
output[0].len(),
38-
expected[0].len(),
39-
"vector dimensionality mismatch for {}",
40-
fixture_path
41-
);
42-
43-
// Element‐wise comparison
44-
for (o, e) in output[0].iter().zip(&expected[0]) {
45-
assert_relative_eq!(o, e, max_relative = 1e-5);
46-
}
47-
}
48-
}
494

505
/// Test that encoding an empty input slice yields an empty output
516
#[test]
Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
1+
mod common;
12
use approx::assert_relative_eq;
2-
use model2vec_rs::model::StaticModel;
3-
4-
fn encode_with_model(path: &str) -> Vec<f32> {
5-
// Helper function to load the model and encode "hello world"
6-
let model = StaticModel::from_pretrained(path, None, None, None)
7-
.unwrap_or_else(|e| panic!("Failed to load model at {path}: {e}"));
8-
9-
let out = model.encode(&["hello world".to_string()]);
10-
assert_eq!(out.len(), 1);
11-
out.into_iter().next().unwrap()
12-
}
3+
use common::encode_with_model;
134

145
#[test]
156
fn quantized_models_match_float32() {

0 commit comments

Comments
 (0)