Add ALBERT masked language model#92
Conversation
Implements albert-burn crate with: - Factorized embedding parameterization (vocab→128→768 projection) - Cross-layer parameter sharing (single TransformerEncoderLayer applied 12x) - Masked LM head with weight-tied decoder (reuses word embeddings) - HuggingFace safetensors weight loading with key remapping - Fill-mask inference example (albert/albert-base-v2)
- Fix MLM decoder bias: load predictions.decoder.bias (zeros) instead of predictions.bias (stale non-zero values) to match HF's weight tying behavior - Move embedding projection from AlbertEmbeddings to AlbertEncoder to match HF's architecture where embedding_hidden_mapping_in lives in the encoder - Point burn dependencies to upstream main (TransformerEncoderLayer is now public) - Add integration tests comparing logits and top-5 predictions against Python HF - Add Python reference generation script
Switch from custom GELU implementation to Burn's ActivationConfig::GeluApproximate and layer_norm_eps support. Add comprehensive integration tests across 3 sentences verifying logits, top-5 predictions, statistics, and per-position L2 norms at 5e-4 relative tolerance.
Support BaseV2, LargeV2, XLargeV2, and XXLargeV2 variants. The pretrained() method now takes a variant parameter (defaults to BaseV2).
The example now accepts a CLI argument (base, large, xlarge, xxlarge). README shows fill-mask results across all four variants, demonstrating quality improvement with model size.
Benchmarks ALBERT BaseV2 forward pass on NdArray (CPU) and WGPU (GPU) with warmup for GPU backends to avoid measuring shader compilation.
There was a problem hiding this comment.
Pull request overview
This pull request adds an ALBERT (A Lite BERT) masked language model implementation using the Burn deep learning framework. The implementation supports all four v2 variants (base, large, xlarge, xxlarge) with pretrained weights loaded from HuggingFace and includes comprehensive integration tests, benchmarks, and examples.
Changes:
- Implemented ALBERT model with factorized embedding parameterization and cross-layer parameter sharing
- Added pretrained weight loading from HuggingFace for all four ALBERT v2 variants
- Included integration tests validating outputs against Python HuggingFace reference, inference example, and benchmarks across multiple backends
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| albert-burn/Cargo.toml | Package configuration with dependencies and feature flags for backend selection |
| albert-burn/README.md | Comprehensive documentation with usage examples, feature descriptions, and benchmark results |
| albert-burn/src/lib.rs | Public API surface with module exports |
| albert-burn/src/model.rs | Core ALBERT model structures including config, base model, and masked LM head |
| albert-burn/src/encoder.rs | ALBERT encoder with cross-layer parameter sharing implementation |
| albert-burn/src/embedding.rs | Factorized embeddings with word, position, and token type components |
| albert-burn/src/loader.rs | Weight loading and HuggingFace model download utilities with key remapping |
| albert-burn/src/tokenize.rs | Batch tokenization utility for preparing input tensors |
| albert-burn/tests/integration_test.rs | Integration tests comparing outputs with Python HuggingFace reference |
| albert-burn/examples/inference.rs | Command-line inference example for fill-mask predictions |
| albert-burn/benches/inference.rs | Performance benchmarks across multiple backends (ndarray, wgpu, cuda, tch) |
| albert-burn/scripts/generate_reference.py | Python script for generating reference values from HuggingFace |
| README.md | Updated root README to include ALBERT in the model collection |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@laggui ready for your review. |
albert-burn/src/loader.rs
Outdated
| /// # Key Mappings (HuggingFace ALBERT → Burn) | ||
| /// | ||
| /// Embedding keys: | ||
| /// - `albert.embeddings.*` → `albert.embeddings.*` | ||
| /// - `albert.encoder.embedding_hidden_mapping_in` → `albert.embeddings.projection` | ||
| /// - `albert.embeddings.LayerNorm` → `albert.embeddings.layer_norm` | ||
| /// | ||
| /// Encoder keys (shared layer): | ||
| /// - `albert.encoder.albert_layer_groups.0.albert_layers.0.attention.query` → `albert.encoder.layer.mha.query` | ||
| /// - `albert.encoder.albert_layer_groups.0.albert_layers.0.attention.key` → `albert.encoder.layer.mha.key` | ||
| /// - `albert.encoder.albert_layer_groups.0.albert_layers.0.attention.value` → `albert.encoder.layer.mha.value` | ||
| /// - `albert.encoder.albert_layer_groups.0.albert_layers.0.attention.dense` → `albert.encoder.layer.mha.output` | ||
| /// - `albert.encoder.albert_layer_groups.0.albert_layers.0.attention.LayerNorm` → `albert.encoder.layer.norm_1` | ||
| /// - `albert.encoder.albert_layer_groups.0.albert_layers.0.ffn` → `albert.encoder.layer.pwff.linear_inner` | ||
| /// - `albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output` → `albert.encoder.layer.pwff.linear_outer` | ||
| /// - `albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm` → `albert.encoder.layer.norm_2` | ||
| /// | ||
| /// MLM head keys: | ||
| /// - `predictions.dense` → `mlm_dense` | ||
| /// - `predictions.LayerNorm` → `mlm_layer_norm` | ||
| /// - `predictions.decoder` → `mlm_decoder` | ||
| /// - `predictions.bias` → `mlm_bias` |
There was a problem hiding this comment.
Not sure the keys remapping should be repeated in the docs. This is more of an implementation detail. The mappings are defined with comments in the function, I think it should be sufficient.
There was a problem hiding this comment.
Fixed in ba475f3. Removed the redundant key mapping docs from the function docstring.
albert-burn/README.md
Outdated
| Run the fill-mask inference example: | ||
|
|
||
| ```bash | ||
| cargo run --example inference --features ndarray --release |
There was a problem hiding this comment.
That doesn't work, even for other backends e.g.
cargo run --example inference --features cuda --release
yields
error: target `inference` in package `albert-burn` requires the features: `pretrained`, `ndarray`
Consider enabling them by passing, e.g., `--features="pretrained ndarray"`
There was a problem hiding this comment.
Fixed in ba475f3. Updated to use --features "pretrained,ndarray".
| abs_diff / scale | ||
| } | ||
|
|
||
| fn assert_close(actual: f32, expected: f32, label: &str) { |
There was a problem hiding this comment.
Probably could have used the TensorData assert_approx_eq but fine for an integration test I guess
There was a problem hiding this comment.
Agreed. Kept the custom impl since it provides more descriptive error messages with labels for each assertion, but TensorData::assert_approx_eq would work too.
|
@antimora not sure why this broke main CI when merging https://github.com/tracel-ai/models/actions/runs/21674714690/job/62491679306 /edit: ahhh shoot looks like there is a breaking change in zip 7.3.0 which was just released... even though it was a minor version. |
Depends on https//github.com/tracel-ai/burn/pull/4410 fixes
Summary
config.json, so variant-specific architectures are not hardcodedBenchmark Results
ALBERT BaseV2 inference (forward pass), 10 samples, with GPU warmup:
WGPU fastest is ~4.3x faster than NdArray CPU. Run with:
cargo bench --bench inference --features "pretrained,ndarray,wgpu"Test plan
cargo test -p albert-burnto verify integration tests passcargo run -p albert-burn --example inference --release --features pretrainedcargo bench --bench inference --features "pretrained,ndarray,wgpu"