Skip to content

Comments

Add MiniLM-{L6 & L12}-v2 sentence transformer#89

Merged
laggui merged 29 commits intotracel-ai:mainfrom
antimora:minilm-burn
Jan 29, 2026
Merged

Add MiniLM-{L6 & L12}-v2 sentence transformer#89
laggui merged 29 commits intotracel-ai:mainfrom
antimora:minilm-burn

Conversation

@antimora
Copy link
Collaborator

@antimora antimora commented Jan 23, 2026

Adds minilm-burn crate implementing the all-MiniLM-L12-v2 and all-MiniLM-L6-v2 sentence transformer model.

Features

  • Load pretrained weights from HuggingFace with simple API: MiniLmModel::pretrained(&device)
  • Mean pooling and L2 normalization for sentence embeddings
  • Multi-backend support: ndarray, wgpu, tch-cpu, tch-gpu, cuda
  • Config loaded from HuggingFace's config.json via serde

Usage

let (model, tokenizer) = MiniLmModel::<B>::pretrained(&device)?;
let output = model.forward(input_ids, attention_mask.clone(), None);
let embeddings = mean_pooling(output.hidden_states, attention_mask);
let embeddings = normalize_l2(embeddings);

Benchmarks (Apple M3 Max)

Benchmark ndarray wgpu tch-cpu
forward (batch=1) 102 ms 35 ms 26 ms
forward (batch=16) 1.54 s 73 ms 130 ms

Testing

  • Unit tests: cargo test --features ndarray
  • Integration tests verify outputs match Python sentence-transformers within 1e-4 tolerance

Implements the all-MiniLM-L12-v2 model using Burn's built-in
TransformerEncoder and burn-store for weight loading from safetensors.

- Load config from HuggingFace's config.json via serde
- Key remapping from HuggingFace BERT to Burn TransformerEncoder
- Mean pooling for sentence embeddings
- Example with HuggingFace download and cosine similarity
Reformatted code in loader.rs, model.rs, and pooling.rs for improved readability and consistency. Adjusted import order and indentation, and expanded some array initializations for clarity in tests. No functional changes were made.
Results measured on Apple M3 Max showing performance comparison
across all supported backends.
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Introduces a new minilm-burn crate implementing the all-MiniLM-L12-v2 sentence-transformer model on top of Burn, with support for multiple backends, pretrained weight loading from Hugging Face, and documentation/examples/benchmarks.

Changes:

  • Add MiniLM-specific embedding, encoder, pooling, and normalization modules plus a MiniLmModel configuration and forward pass.
  • Implement HF Hub-based weight and tokenizer loading, along with a pretrained convenience API, examples, and benchmarks across backends.
  • Add integration tests against Python sentence-transformers outputs and update repo-level documentation/README to list the new model.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
minilm-burn/src/embedding.rs Defines MiniLmEmbeddingsConfig and MiniLmEmbeddings (word/position/token-type embeddings + layer norm + dropout) matching the MiniLM/BERT-style embedding stack.
minilm-burn/src/model.rs Adds MiniLmConfig, MiniLmModel, and MiniLmOutput, wiring Burn’s TransformerEncoder to MiniLM’s config and attention mask semantics.
minilm-burn/src/pooling.rs Implements mean_pooling and normalize_l2 utilities plus a unit test for mean pooling on the ndarray backend.
minilm-burn/src/loader.rs Introduces LoadError, HF safetensor key remapping and loading, HF Hub download utilities, config loading, and MiniLmModel::pretrained.
minilm-burn/src/lib.rs Exposes the MiniLM public API and adds crate-level documentation and a usage example.
minilm-burn/tests/integration_test.rs Adds ndarray-based integration tests that compare MiniLM Rust embeddings and cosine similarities against Python sentence-transformers references.
minilm-burn/scripts/generate_reference.py Script to generate reference embeddings and cosine similarities from Python sentence-transformers for use in integration tests.
minilm-burn/scripts/debug_embeddings.py Small helper script to inspect raw MiniLM embeddings and norms in Python for debugging.
minilm-burn/examples/inference.rs Demonstrates end-to-end inference with the pretrained MiniLM model, tokenization, pooling, and cosine similarity computation on the ndarray backend.
minilm-burn/benches/inference.rs Adds Criterion benchmarks for forward passes, batching, full pipeline, and pooling/normalization across multiple backends.
minilm-burn/README.md Documents the new crate’s usage, features, testing strategy, and benchmark results.
minilm-burn/Cargo.toml Declares the new crate, its features (including multi-backend and pretrained support), and dependencies (Burn, burn-store, tokenizers, hf-hub, tokio, etc.).
README.md Updates the root repository overview and tables to include the MiniLM model and its subcrate, and switches to reference-style links.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

- Fix doc example to use MiniLmModel::pretrained (not MiniLmConfig)
- Update HfModelFiles doc to reflect struct with 3 fields
- Fix generate_reference.py to use normalize_embeddings=True
- Use dirs::cache_dir() for platform-appropriate default location
- Allow custom cache path via pretrained(device, Some(path))
- Downloads to ~/.cache/burn-models/ (Linux) or ~/Library/Caches/burn-models/ (macOS)
- Remove hardcoded hidden_size (384), derive from tensor dims
- Add normalize_l2 to example (matches sentence-transformers default)
- Remove debug_embeddings.py script
PyTorchToBurnAdapter handles weight→gamma and bias→beta automatically.
- Add MiniLmVariant enum (L6, L12) for model selection
- L6: 6 layers, faster inference
- L12: 12 layers, better quality (default)
- Update pretrained() to accept variant parameter
L6 is ~2x faster than L12 across all backends:
- ndarray: 53ms vs 105ms
- wgpu: 18ms vs 35ms
- tch-cpu: 14ms vs 27ms
Replaces use of `equal_elem(0)` with comparison to a zeros tensor for creating the padding mask. This ensures compatibility with tensor operations and device placement.
Refactored lines where MiniLmModel is loaded to improve code readability by reducing line length and aligning with Rust formatting conventions. No functional changes were made.
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the CI is failing on burn-import downstream dep 😅

/edit: whoops, thought this was the same as the one in the burn-onnx model checks just for comparison, but it's actually L12 not L6. Removed part of my comment.

cargo bench --features tch-cpu
```

Results are saved to `target/criterion/` for comparison across backends.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't aware of criterion, it seems to generate pretty detailed reports with a bunch of tables and graphs!

That's pretty cool. Is it entirely useful to have the HTML report in this case though? If so, then there should be instructions to visualize the reports.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call! Added instructions to view the criterion HTML report (open target/criterion/report/index.html). It generates detailed comparison tables and graphs that are useful for tracking performance across backends and between runs.

@antimora antimora requested a review from laggui January 29, 2026 16:01
Marks 'hf-hub' and 'tokio' as optional dependencies and includes them in the 'pretrained' feature. This allows for more flexible builds and reduces unnecessary dependencies when the 'pretrained' feature is not enabled.
Eliminated the optional tokio dependency from Cargo.toml and updated the loader to use the synchronous Hugging Face API instead of the async version. This simplifies the codebase by removing async requirements for model downloading.
Updated MiniLmConfig::load_from_hf to return LoadError instead of std::io::Error, ensuring consistent error handling throughout the loader and model modules.
- Add tokenize_batch() to deduplicate tokenization logic across
  example, tests, and benchmarks
- Make MiniLmEmbeddingsConfig pub(crate) since it's internal-only
- Add compile_error! when multiple backend features are enabled
- Simplify benchmark backend cfg guards
@antimora
Copy link
Collaborator Author

@laggui ready

@antimora antimora changed the title Add MiniLM-L12-v2 sentence transformer Add MiniLM-{L6 & L12}-v2 sentence transformer Jan 29, 2026
@laggui laggui merged commit ac6ff32 into tracel-ai:main Jan 29, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants