-
Notifications
You must be signed in to change notification settings - Fork 58
Add MiniLM-{L6 & L12}-v2 sentence transformer #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
2a476e4
Add MiniLM-L12-v2 sentence transformer implementation
antimora 7e16e35
Add README for minilm-burn
antimora d230517
Simplify API: pretrained() returns model and tokenizer path
antimora 45bef2c
Refactor formatting and imports for consistency
antimora b92f8d9
Add minilm-burn to models list
antimora f6bd618
Add integration tests comparing with Python sentence-transformers
antimora ecdb08b
Fix fmt
antimora 1927929
Add benchmarks for inference performance
antimora a635c87
Add benchmark results for ndarray, wgpu, and tch-cpu backends
antimora a7faad7
Fix documentation issues from PR review
antimora 242c096
Add configurable cache directory for model downloads
antimora 5d4adbf
Clean up code review issues
antimora 3bd5698
Simplify key remappings for LayerNorm
antimora 8e7b9e1
Add support for all-MiniLM-L6-v2 variant
antimora ce9d4bb
Add integration test for L6 variant
antimora b6190cf
Add L6 vs L12 variant benchmarks
antimora b36ed41
Fix attention mask padding conversion in MiniLmModel
antimora a9b0bb1
Reformat model loading for improved readability
antimora 65a47ec
Merge remote-tracking branch 'upstream/main' into minilm-burn
antimora 8ae8224
Update README.md
antimora 6e7a171
Clarify backend features are for benchmarks only
antimora d6fd570
Add cuda backend to benchmark examples
antimora 36c368a
Add instructions to view criterion HTML report
antimora 03446b4
Make hf-hub and tokio optional dependencies
antimora fc5f579
Remove tokio dependency and switch to sync HF API
antimora 6f0cf18
Refactor config loading to use LoadError for errors
antimora e3f3c02
Update embedding.rs
antimora 5ad70c6
Extract shared tokenize_batch helper and tighten public API
antimora 6c888cb
Replace libm::sqrt with std f64::sqrt and remove libm dependency
antimora File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| [package] | ||
| authors = ["Dilshod Tadjibaev (@antimora)"] | ||
| license = "MIT OR Apache-2.0" | ||
| name = "minilm-burn" | ||
| version = "0.1.0" | ||
| edition = "2024" | ||
| description = "MiniLM sentence transformer with Burn" | ||
|
|
||
| [features] | ||
| default = ["pretrained"] | ||
| pretrained = ["burn/network", "dep:dirs", "dep:hf-hub"] | ||
|
|
||
| # Backend selection | ||
| ndarray = ["burn/ndarray"] | ||
| tch-cpu = ["burn/tch"] | ||
| tch-gpu = ["burn/tch"] | ||
| wgpu = ["burn/wgpu"] | ||
| cuda = ["burn/cuda"] | ||
|
|
||
| [dependencies] | ||
| burn = { version = "0.20.0", default-features = false, features = ["std"] } | ||
| burn-store = { version = "0.20.0", features = ["std", "safetensors"] } | ||
|
|
||
| # Tokenizer | ||
| tokenizers = { version = "0.19.1", default-features = false, features = ["onig"] } | ||
|
|
||
| # HuggingFace model download | ||
| hf-hub = { version = "0.4.3", optional = true } | ||
| dirs = { version = "6.0.0", optional = true } | ||
|
|
||
| # Serialization | ||
| serde = { version = "1.0", default-features = false, features = ["derive", "alloc"] } | ||
| serde_json = "1.0" | ||
|
|
||
|
|
||
|
|
||
| [dev-dependencies] | ||
| clap = { version = "4.5", features = ["derive"] } | ||
| criterion = "0.5" | ||
|
|
||
| [[example]] | ||
| name = "inference" | ||
| required-features = ["pretrained", "ndarray"] | ||
|
|
||
| [[bench]] | ||
| name = "inference" | ||
| harness = false | ||
| required-features = ["pretrained"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| # MiniLM-Burn | ||
|
|
||
| MiniLM sentence transformer implementation in Rust using [Burn](https://github.com/tracel-ai/burn). | ||
|
|
||
| Supports two model variants from HuggingFace: | ||
| - [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) - 6 layers, faster | ||
| - [all-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2) - 12 layers, better quality (default) | ||
|
|
||
| ## Usage | ||
|
|
||
| ```rust | ||
| use burn::backend::ndarray::NdArray; | ||
| use minilm_burn::{mean_pooling, MiniLmModel}; | ||
|
|
||
| type B = NdArray<f32>; | ||
|
|
||
| fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
| let device = Default::default(); | ||
|
|
||
| // Load pretrained model and tokenizer (downloads from HuggingFace) | ||
| // Use MiniLmVariant::L6 for faster inference, L12 for better quality | ||
| let (model, tokenizer) = MiniLmModel::<B>::pretrained(&device, Default::default(), None)?; | ||
|
|
||
| // Tokenize and run inference | ||
| let output = model.forward(input_ids, attention_mask.clone(), None); | ||
| let embeddings = mean_pooling(output.hidden_states, attention_mask); | ||
|
|
||
| Ok(()) | ||
| } | ||
| ``` | ||
|
|
||
| ## Features | ||
|
|
||
| - `pretrained` - Enables model download utilities (default) | ||
| - `ndarray` - NdArray backend (required for inference example and tests) | ||
|
|
||
| Backend features for benchmarks: | ||
| - `wgpu` - WebGPU backend | ||
| - `cuda` - CUDA backend | ||
| - `tch-cpu` - LibTorch CPU backend | ||
| - `tch-gpu` - LibTorch GPU backend | ||
|
|
||
| ## Example | ||
|
|
||
| Run the inference example: | ||
|
|
||
| ```bash | ||
| cargo run --example inference --features ndarray --release | ||
| ``` | ||
|
|
||
| ## Testing | ||
|
|
||
| Unit tests: | ||
|
|
||
| ```bash | ||
| cargo test --features ndarray | ||
| ``` | ||
|
|
||
| Integration tests (requires model download): | ||
|
|
||
| ```bash | ||
| cargo test --features ndarray -- --ignored | ||
| ``` | ||
|
|
||
| ## Benchmarks | ||
|
|
||
| Run for each backend: | ||
|
|
||
| ```bash | ||
| cargo bench --features ndarray | ||
| cargo bench --features wgpu | ||
| cargo bench --features cuda | ||
| cargo bench --features tch-cpu | ||
| ``` | ||
|
|
||
| Results are saved to `target/criterion/` for comparison across backends. View the HTML report: | ||
|
|
||
| ```bash | ||
| open target/criterion/report/index.html | ||
| ``` | ||
|
|
||
| ### Results (Apple M3 Max) | ||
|
|
||
| **L6 vs L12 (single sentence):** | ||
|
|
||
| | Variant | ndarray | wgpu | tch-cpu | | ||
| | ------- | ------- | ----- | ------- | | ||
| | L6 | 53 ms | 18 ms | 14 ms | | ||
| | L12 | 105 ms | 35 ms | 27 ms | | ||
|
|
||
| **L12 batch scaling:** | ||
|
|
||
| | Batch size | ndarray | wgpu | tch-cpu | | ||
| | ---------- | ------- | ----- | ------- | | ||
| | 1 | 102 ms | 35 ms | 26 ms | | ||
| | 4 | 387 ms | 39 ms | 49 ms | | ||
| | 8 | 774 ms | 44 ms | 77 ms | | ||
| | 16 | 1.54 s | 73 ms | 130 ms | | ||
|
|
||
| ## License | ||
|
|
||
| MIT OR Apache-2.0 | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.