Add MiniLM-{L6 & L12}-v2 sentence transformer#89
Conversation
Implements the all-MiniLM-L12-v2 model using Burn's built-in TransformerEncoder and burn-store for weight loading from safetensors. - Load config from HuggingFace's config.json via serde - Key remapping from HuggingFace BERT to Burn TransformerEncoder - Mean pooling for sentence embeddings - Example with HuggingFace download and cosine similarity
Reformatted code in loader.rs, model.rs, and pooling.rs for improved readability and consistency. Adjusted import order and indentation, and expanded some array initializations for clarity in tests. No functional changes were made.
Results measured on Apple M3 Max showing performance comparison across all supported backends.
There was a problem hiding this comment.
Pull request overview
Introduces a new minilm-burn crate implementing the all-MiniLM-L12-v2 sentence-transformer model on top of Burn, with support for multiple backends, pretrained weight loading from Hugging Face, and documentation/examples/benchmarks.
Changes:
- Add MiniLM-specific embedding, encoder, pooling, and normalization modules plus a
MiniLmModelconfiguration and forward pass. - Implement HF Hub-based weight and tokenizer loading, along with a
pretrainedconvenience API, examples, and benchmarks across backends. - Add integration tests against Python
sentence-transformersoutputs and update repo-level documentation/README to list the new model.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
minilm-burn/src/embedding.rs |
Defines MiniLmEmbeddingsConfig and MiniLmEmbeddings (word/position/token-type embeddings + layer norm + dropout) matching the MiniLM/BERT-style embedding stack. |
minilm-burn/src/model.rs |
Adds MiniLmConfig, MiniLmModel, and MiniLmOutput, wiring Burn’s TransformerEncoder to MiniLM’s config and attention mask semantics. |
minilm-burn/src/pooling.rs |
Implements mean_pooling and normalize_l2 utilities plus a unit test for mean pooling on the ndarray backend. |
minilm-burn/src/loader.rs |
Introduces LoadError, HF safetensor key remapping and loading, HF Hub download utilities, config loading, and MiniLmModel::pretrained. |
minilm-burn/src/lib.rs |
Exposes the MiniLM public API and adds crate-level documentation and a usage example. |
minilm-burn/tests/integration_test.rs |
Adds ndarray-based integration tests that compare MiniLM Rust embeddings and cosine similarities against Python sentence-transformers references. |
minilm-burn/scripts/generate_reference.py |
Script to generate reference embeddings and cosine similarities from Python sentence-transformers for use in integration tests. |
minilm-burn/scripts/debug_embeddings.py |
Small helper script to inspect raw MiniLM embeddings and norms in Python for debugging. |
minilm-burn/examples/inference.rs |
Demonstrates end-to-end inference with the pretrained MiniLM model, tokenization, pooling, and cosine similarity computation on the ndarray backend. |
minilm-burn/benches/inference.rs |
Adds Criterion benchmarks for forward passes, batching, full pipeline, and pooling/normalization across multiple backends. |
minilm-burn/README.md |
Documents the new crate’s usage, features, testing strategy, and benchmark results. |
minilm-burn/Cargo.toml |
Declares the new crate, its features (including multi-backend and pretrained support), and dependencies (Burn, burn-store, tokenizers, hf-hub, tokio, etc.). |
README.md |
Updates the root repository overview and tables to include the MiniLM model and its subcrate, and switches to reference-style links. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Fix doc example to use MiniLmModel::pretrained (not MiniLmConfig) - Update HfModelFiles doc to reflect struct with 3 fields - Fix generate_reference.py to use normalize_embeddings=True
- Use dirs::cache_dir() for platform-appropriate default location - Allow custom cache path via pretrained(device, Some(path)) - Downloads to ~/.cache/burn-models/ (Linux) or ~/Library/Caches/burn-models/ (macOS)
- Remove hardcoded hidden_size (384), derive from tensor dims - Add normalize_l2 to example (matches sentence-transformers default) - Remove debug_embeddings.py script
PyTorchToBurnAdapter handles weight→gamma and bias→beta automatically.
- Add MiniLmVariant enum (L6, L12) for model selection - L6: 6 layers, faster inference - L12: 12 layers, better quality (default) - Update pretrained() to accept variant parameter
L6 is ~2x faster than L12 across all backends: - ndarray: 53ms vs 105ms - wgpu: 18ms vs 35ms - tch-cpu: 14ms vs 27ms
Replaces use of `equal_elem(0)` with comparison to a zeros tensor for creating the padding mask. This ensures compatibility with tensor operations and device placement.
Refactored lines where MiniLmModel is loaded to improve code readability by reducing line length and aligning with Rust formatting conventions. No functional changes were made.
There was a problem hiding this comment.
Looks like the CI is failing on burn-import downstream dep 😅
/edit: whoops, thought this was the same as the one in the burn-onnx model checks just for comparison, but it's actually L12 not L6. Removed part of my comment.
minilm-burn/README.md
Outdated
| cargo bench --features tch-cpu | ||
| ``` | ||
|
|
||
| Results are saved to `target/criterion/` for comparison across backends. |
There was a problem hiding this comment.
I wasn't aware of criterion, it seems to generate pretty detailed reports with a bunch of tables and graphs!
That's pretty cool. Is it entirely useful to have the HTML report in this case though? If so, then there should be instructions to visualize the reports.
There was a problem hiding this comment.
Good call! Added instructions to view the criterion HTML report (open target/criterion/report/index.html). It generates detailed comparison tables and graphs that are useful for tracking performance across backends and between runs.
Marks 'hf-hub' and 'tokio' as optional dependencies and includes them in the 'pretrained' feature. This allows for more flexible builds and reduces unnecessary dependencies when the 'pretrained' feature is not enabled.
Eliminated the optional tokio dependency from Cargo.toml and updated the loader to use the synchronous Hugging Face API instead of the async version. This simplifies the codebase by removing async requirements for model downloading.
Updated MiniLmConfig::load_from_hf to return LoadError instead of std::io::Error, ensuring consistent error handling throughout the loader and model modules.
- Add tokenize_batch() to deduplicate tokenization logic across example, tests, and benchmarks - Make MiniLmEmbeddingsConfig pub(crate) since it's internal-only - Add compile_error! when multiple backend features are enabled - Simplify benchmark backend cfg guards
|
@laggui ready |
Adds
minilm-burncrate implementing the all-MiniLM-L12-v2 and all-MiniLM-L6-v2 sentence transformer model.Features
MiniLmModel::pretrained(&device)config.jsonvia serdeUsage
Benchmarks (Apple M3 Max)
Testing
cargo test --features ndarray