Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2a476e4
Add MiniLM-L12-v2 sentence transformer implementation
antimora Jan 23, 2026
7e16e35
Add README for minilm-burn
antimora Jan 23, 2026
d230517
Simplify API: pretrained() returns model and tokenizer path
antimora Jan 23, 2026
45bef2c
Refactor formatting and imports for consistency
antimora Jan 23, 2026
b92f8d9
Add minilm-burn to models list
antimora Jan 23, 2026
f6bd618
Add integration tests comparing with Python sentence-transformers
antimora Jan 23, 2026
ecdb08b
Fix fmt
antimora Jan 23, 2026
1927929
Add benchmarks for inference performance
antimora Jan 23, 2026
a635c87
Add benchmark results for ndarray, wgpu, and tch-cpu backends
antimora Jan 23, 2026
a7faad7
Fix documentation issues from PR review
antimora Jan 23, 2026
242c096
Add configurable cache directory for model downloads
antimora Jan 23, 2026
5d4adbf
Clean up code review issues
antimora Jan 23, 2026
3bd5698
Simplify key remappings for LayerNorm
antimora Jan 23, 2026
8e7b9e1
Add support for all-MiniLM-L6-v2 variant
antimora Jan 24, 2026
ce9d4bb
Add integration test for L6 variant
antimora Jan 24, 2026
b6190cf
Add L6 vs L12 variant benchmarks
antimora Jan 24, 2026
b36ed41
Fix attention mask padding conversion in MiniLmModel
antimora Jan 25, 2026
a9b0bb1
Reformat model loading for improved readability
antimora Jan 25, 2026
65a47ec
Merge remote-tracking branch 'upstream/main' into minilm-burn
antimora Jan 29, 2026
8ae8224
Update README.md
antimora Jan 29, 2026
6e7a171
Clarify backend features are for benchmarks only
antimora Jan 29, 2026
d6fd570
Add cuda backend to benchmark examples
antimora Jan 29, 2026
36c368a
Add instructions to view criterion HTML report
antimora Jan 29, 2026
03446b4
Make hf-hub and tokio optional dependencies
antimora Jan 29, 2026
fc5f579
Remove tokio dependency and switch to sync HF API
antimora Jan 29, 2026
6f0cf18
Refactor config loading to use LoadError for errors
antimora Jan 29, 2026
e3f3c02
Update embedding.rs
antimora Jan 29, 2026
5ad70c6
Extract shared tokenize_batch helper and tighten public API
antimora Jan 29, 2026
6c888cb
Replace libm::sqrt with std f64::sqrt and remove libm dependency
antimora Jan 29, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 59 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,35 +1,36 @@
# 🔥 Models 🔥

Welcome to the Models repository! Here, you'll find a diverse collection of deep learning models and
examples constructed using the [Burn](https://github.com/burn-rs/burn) deep learning framework.
examples constructed using the [Burn] deep learning framework.

## Collection of Official Models

| Model | Description | Repository Link |
|-------------------------------------------------|----------------------------------------------------------|---------------------------------------|
| [Llama](https://github.com/meta-llama/llama3) | Llama 3 and TinyLlama large language models. | [llama-burn](llama-burn/) |
| [MobileNetV2](https://arxiv.org/abs/1801.04381) | A CNN model targeted at mobile devices. | [mobilenetv2-burn](mobilenetv2-burn/) |
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/) |
| [YOLOX](https://arxiv.org/abs/2107.08430) | A single-stage object detector based on the YOLO series. | [yolox-burn](yolox-burn/) |
| Model | Description | Repository |
| ------------- | ----------------------------- | ------------------------------------- |
| [Llama] | Large language models | [llama-burn](llama-burn/) |
| [MiniLM] | Sentence embeddings | [minilm-burn](minilm-burn/) |
| [MobileNetV2] | Mobile image classification | [mobilenetv2-burn](mobilenetv2-burn/) |
| [SqueezeNet] | Compact image classification | [squeezenet-burn](squeezenet-burn/) |
| [ResNet] | Residual image classification | [resnet-burn](resnet-burn/) |
| [RoBERTa] | Text encoder | [bert-burn](bert-burn/) |
| [YOLOX] | Object detection | [yolox-burn](yolox-burn/) |

## Community Contributions

Explore the curated list of models developed by the community ♥.

| Model | Description | Repository Link |
|--------------------------------------------------|-------------------------------------------------------------------|-----------------------------------------------------------------------------------|
| [Llama 2](https://arxiv.org/abs/2307.09288) | LLMs by Meta AI, ranging from 7 billion to 70 billion parameters. | [Gadersd/llama2-burn](https://github.com/Gadersd/llama2-burn) |
| [Whisper](https://arxiv.org/abs/2212.04356) | A general-purpose speech recognition model by OpenAI. | [Gadersd/whisper-burn](https://github.com/Gadersd/whisper-burn) |
| Stable Diffusion v1.4 | An image generation model developed by Stability AI. | [Gadersd/stable-diffusion-burn](https://github.com/Gadersd/stable-diffusion-burn) |
| kord (music note predictor) | A music theory model that can detect notes in short audio clips. | [twitchax/kord](https://github.com/twitchax/kord) |
| Whisper-Live | A fork of [Gadersd/whisper-burn](https://github.com/Gadersd/whisper-burn) which has been updated for Burn 13 and provides live transcription | [sudomonikers/whisper-burn](https://github.com/sudomonikers/whisper-burn) |
| [Inception V3](https://arxiv.org/abs/1512.00567) | A CNN used for calculating FID scores. | [varonroy/inception-v3-burn](https://github.com/varonroy/inception-v3-burn/) |
| [CRAFT](https://arxiv.org/abs/1904.01941) | A CNN for character-region aware text detection | [wingertge/craft-burn](https://github.com/wingertge/craft-burn) |
| [RWKV v7](https://arxiv.org/abs/2503.14456) | A large language model architecture that can be used like transformer models (parallel processing of tokens) and like RNNs (sequential generation). | [dymat/rwkv-burn](https://github.com/dymat/rwkv-burn) |
| [Single Shot Detector](https://arxiv.org/abs/1512.02325) | A trainable implementation of the Single Shot MultiBox Detector (SSD) multiclass object detection model. | [catch-twenty-two/rust-ssd](https://github.com/catch-twenty-two/rust-ssd) |
| [DeepSeek-OCR-2](https://huggingface.co/deepseek-ai/DeepSeek-OCR-2) | Unofficial pure-Rust inference (Burn) for DeepSeek-OCR-2 OCR. | [huahuadeliaoliao/DeepSeek-OCR-2-burn](https://github.com/huahuadeliaoliao/DeepSeek-OCR-2-burn) |
| Model | Description | Repository |
| --------------------- | -------------------------- | -------------------------------------- |
| [Llama 2] | Large language models | [Gadersd/llama2-burn] |
| [Whisper] | Speech recognition | [Gadersd/whisper-burn] |
| Stable Diffusion v1.4 | Image generation | [Gadersd/stable-diffusion-burn] |
| kord | Music note detection | [twitchax/kord] |
| Whisper-Live | Live speech transcription | [sudomonikers/whisper-burn] |
| [Inception V3] | Image classification | [varonroy/inception-v3-burn] |
| [CRAFT] | Text detection | [wingertge/craft-burn] |
| [RWKV v7] | Hybrid transformer/RNN LLM | [dymat/rwkv-burn] |
| [SSD] | Object detection | [catch-twenty-two/rust-ssd] |
| [DeepSeek-OCR-2] | OCR inference | [huahuadeliaoliao/DeepSeek-OCR-2-burn] |

## License Information

Expand All @@ -43,3 +44,40 @@ license information in the NOTICES.md file under the corresponding model directo

Community models linked in this repository may fall under different licenses, so please consult the
respective repositories for specific license information.

<!-- Reference Links -->

[Burn]: https://github.com/burn-rs/burn

<!-- Official Models -->

[Llama]: https://github.com/meta-llama/llama3
[MiniLM]: https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2
[MobileNetV2]: https://arxiv.org/abs/1801.04381
[SqueezeNet]: https://arxiv.org/abs/1602.07360
[ResNet]: https://arxiv.org/abs/1512.03385
[RoBERTa]: https://arxiv.org/abs/1907.11692
[YOLOX]: https://arxiv.org/abs/2107.08430

<!-- Community Models -->

[Llama 2]: https://arxiv.org/abs/2307.09288
[Whisper]: https://arxiv.org/abs/2212.04356
[Inception V3]: https://arxiv.org/abs/1512.00567
[CRAFT]: https://arxiv.org/abs/1904.01941
[RWKV v7]: https://arxiv.org/abs/2503.14456
[SSD]: https://arxiv.org/abs/1512.02325
[DeepSeek-OCR-2]: https://huggingface.co/deepseek-ai/DeepSeek-OCR-2

<!-- Community Repositories -->

[Gadersd/llama2-burn]: https://github.com/Gadersd/llama2-burn
[Gadersd/whisper-burn]: https://github.com/Gadersd/whisper-burn
[Gadersd/stable-diffusion-burn]: https://github.com/Gadersd/stable-diffusion-burn
[twitchax/kord]: https://github.com/twitchax/kord
[sudomonikers/whisper-burn]: https://github.com/sudomonikers/whisper-burn
[varonroy/inception-v3-burn]: https://github.com/varonroy/inception-v3-burn/
[wingertge/craft-burn]: https://github.com/wingertge/craft-burn
[dymat/rwkv-burn]: https://github.com/dymat/rwkv-burn
[catch-twenty-two/rust-ssd]: https://github.com/catch-twenty-two/rust-ssd
[huahuadeliaoliao/DeepSeek-OCR-2-burn]: https://github.com/huahuadeliaoliao/DeepSeek-OCR-2-burn
48 changes: 48 additions & 0 deletions minilm-burn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
[package]
authors = ["Dilshod Tadjibaev (@antimora)"]
license = "MIT OR Apache-2.0"
name = "minilm-burn"
version = "0.1.0"
edition = "2024"
description = "MiniLM sentence transformer with Burn"

[features]
default = ["pretrained"]
pretrained = ["burn/network", "dep:dirs", "dep:hf-hub"]

# Backend selection
ndarray = ["burn/ndarray"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
cuda = ["burn/cuda"]

[dependencies]
burn = { version = "0.20.0", default-features = false, features = ["std"] }
burn-store = { version = "0.20.0", features = ["std", "safetensors"] }

# Tokenizer
tokenizers = { version = "0.19.1", default-features = false, features = ["onig"] }

# HuggingFace model download
hf-hub = { version = "0.4.3", optional = true }
dirs = { version = "6.0.0", optional = true }

# Serialization
serde = { version = "1.0", default-features = false, features = ["derive", "alloc"] }
serde_json = "1.0"



[dev-dependencies]
clap = { version = "4.5", features = ["derive"] }
criterion = "0.5"

[[example]]
name = "inference"
required-features = ["pretrained", "ndarray"]

[[bench]]
name = "inference"
harness = false
required-features = ["pretrained"]
102 changes: 102 additions & 0 deletions minilm-burn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# MiniLM-Burn

MiniLM sentence transformer implementation in Rust using [Burn](https://github.com/tracel-ai/burn).

Supports two model variants from HuggingFace:
- [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) - 6 layers, faster
- [all-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2) - 12 layers, better quality (default)

## Usage

```rust
use burn::backend::ndarray::NdArray;
use minilm_burn::{mean_pooling, MiniLmModel};

type B = NdArray<f32>;

fn main() -> Result<(), Box<dyn std::error::Error>> {
let device = Default::default();

// Load pretrained model and tokenizer (downloads from HuggingFace)
// Use MiniLmVariant::L6 for faster inference, L12 for better quality
let (model, tokenizer) = MiniLmModel::<B>::pretrained(&device, Default::default(), None)?;

// Tokenize and run inference
let output = model.forward(input_ids, attention_mask.clone(), None);
let embeddings = mean_pooling(output.hidden_states, attention_mask);

Ok(())
}
```

## Features

- `pretrained` - Enables model download utilities (default)
- `ndarray` - NdArray backend (required for inference example and tests)

Backend features for benchmarks:
- `wgpu` - WebGPU backend
- `cuda` - CUDA backend
- `tch-cpu` - LibTorch CPU backend
- `tch-gpu` - LibTorch GPU backend

## Example

Run the inference example:

```bash
cargo run --example inference --features ndarray --release
```

## Testing

Unit tests:

```bash
cargo test --features ndarray
```

Integration tests (requires model download):

```bash
cargo test --features ndarray -- --ignored
```

## Benchmarks

Run for each backend:

```bash
cargo bench --features ndarray
cargo bench --features wgpu
cargo bench --features cuda
cargo bench --features tch-cpu
```

Results are saved to `target/criterion/` for comparison across backends. View the HTML report:

```bash
open target/criterion/report/index.html
```

### Results (Apple M3 Max)

**L6 vs L12 (single sentence):**

| Variant | ndarray | wgpu | tch-cpu |
| ------- | ------- | ----- | ------- |
| L6 | 53 ms | 18 ms | 14 ms |
| L12 | 105 ms | 35 ms | 27 ms |

**L12 batch scaling:**

| Batch size | ndarray | wgpu | tch-cpu |
| ---------- | ------- | ----- | ------- |
| 1 | 102 ms | 35 ms | 26 ms |
| 4 | 387 ms | 39 ms | 49 ms |
| 8 | 774 ms | 44 ms | 77 ms |
| 16 | 1.54 s | 73 ms | 130 ms |

## License

MIT OR Apache-2.0
Loading