Skip to content

Comments

Add ALBERT masked language model#92

Merged
laggui merged 12 commits intotracel-ai:mainfrom
antimora:albert-burn
Feb 4, 2026
Merged

Add ALBERT masked language model#92
laggui merged 12 commits intotracel-ai:mainfrom
antimora:albert-burn

Conversation

@antimora
Copy link
Collaborator

@antimora antimora commented Jan 30, 2026

Depends on https//github.com/tracel-ai/burn/pull/4410 fixes

Summary

  • Add ALBERT (A Lite BERT) masked language model implementation using the Burn deep learning framework
  • Support all four v2 variants (base, large, xlarge, xxlarge) with pretrained weights loaded from HuggingFace
  • Model configurations are read from HuggingFace's config.json, so variant-specific architectures are not hardcoded
  • Inference results show high accuracy — all variants correctly predict "paris" as the top or near-top answer for the masked sentence "The capital of france is [MASK]."

Benchmark Results

ALBERT BaseV2 inference (forward pass), 10 samples, with GPU warmup:

Backend Fastest Median Mean
NdArray (CPU) 67.54 ms 68.55 ms 69.32 ms
WGPU (GPU) 15.80 ms 51.68 ms 62.71 ms

WGPU fastest is ~4.3x faster than NdArray CPU. Run with:

cargo bench --bench inference --features "pretrained,ndarray,wgpu"

Test plan

  • Run cargo test -p albert-burn to verify integration tests pass
  • Run inference example: cargo run -p albert-burn --example inference --release --features pretrained
  • Verify pretrained weights load correctly for each variant
  • Run benchmarks: cargo bench --bench inference --features "pretrained,ndarray,wgpu"

Implements albert-burn crate with:
- Factorized embedding parameterization (vocab→128→768 projection)
- Cross-layer parameter sharing (single TransformerEncoderLayer applied 12x)
- Masked LM head with weight-tied decoder (reuses word embeddings)
- HuggingFace safetensors weight loading with key remapping
- Fill-mask inference example (albert/albert-base-v2)
- Fix MLM decoder bias: load predictions.decoder.bias (zeros) instead of
  predictions.bias (stale non-zero values) to match HF's weight tying behavior
- Move embedding projection from AlbertEmbeddings to AlbertEncoder to match
  HF's architecture where embedding_hidden_mapping_in lives in the encoder
- Point burn dependencies to upstream main (TransformerEncoderLayer is now public)
- Add integration tests comparing logits and top-5 predictions against Python HF
- Add Python reference generation script
Switch from custom GELU implementation to Burn's ActivationConfig::GeluApproximate
and layer_norm_eps support. Add comprehensive integration tests across 3 sentences
verifying logits, top-5 predictions, statistics, and per-position L2 norms at 5e-4
relative tolerance.
Support BaseV2, LargeV2, XLargeV2, and XXLargeV2 variants.
The pretrained() method now takes a variant parameter (defaults to BaseV2).
The example now accepts a CLI argument (base, large, xlarge, xxlarge).
README shows fill-mask results across all four variants, demonstrating
quality improvement with model size.
Benchmarks ALBERT BaseV2 forward pass on NdArray (CPU) and WGPU (GPU)
with warmup for GPU backends to avoid measuring shader compilation.
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds an ALBERT (A Lite BERT) masked language model implementation using the Burn deep learning framework. The implementation supports all four v2 variants (base, large, xlarge, xxlarge) with pretrained weights loaded from HuggingFace and includes comprehensive integration tests, benchmarks, and examples.

Changes:

  • Implemented ALBERT model with factorized embedding parameterization and cross-layer parameter sharing
  • Added pretrained weight loading from HuggingFace for all four ALBERT v2 variants
  • Included integration tests validating outputs against Python HuggingFace reference, inference example, and benchmarks across multiple backends

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
albert-burn/Cargo.toml Package configuration with dependencies and feature flags for backend selection
albert-burn/README.md Comprehensive documentation with usage examples, feature descriptions, and benchmark results
albert-burn/src/lib.rs Public API surface with module exports
albert-burn/src/model.rs Core ALBERT model structures including config, base model, and masked LM head
albert-burn/src/encoder.rs ALBERT encoder with cross-layer parameter sharing implementation
albert-burn/src/embedding.rs Factorized embeddings with word, position, and token type components
albert-burn/src/loader.rs Weight loading and HuggingFace model download utilities with key remapping
albert-burn/src/tokenize.rs Batch tokenization utility for preparing input tensors
albert-burn/tests/integration_test.rs Integration tests comparing outputs with Python HuggingFace reference
albert-burn/examples/inference.rs Command-line inference example for fill-mask predictions
albert-burn/benches/inference.rs Performance benchmarks across multiple backends (ndarray, wgpu, cuda, tch)
albert-burn/scripts/generate_reference.py Python script for generating reference values from HuggingFace
README.md Updated root README to include ALBERT in the model collection

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@antimora antimora mentioned this pull request Feb 1, 2026
2 tasks
@antimora
Copy link
Collaborator Author

antimora commented Feb 3, 2026

@laggui ready for your review.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple of comments

Comment on lines 29 to 50
/// # Key Mappings (HuggingFace ALBERT → Burn)
///
/// Embedding keys:
/// - `albert.embeddings.*` → `albert.embeddings.*`
/// - `albert.encoder.embedding_hidden_mapping_in` → `albert.embeddings.projection`
/// - `albert.embeddings.LayerNorm` → `albert.embeddings.layer_norm`
///
/// Encoder keys (shared layer):
/// - `albert.encoder.albert_layer_groups.0.albert_layers.0.attention.query` → `albert.encoder.layer.mha.query`
/// - `albert.encoder.albert_layer_groups.0.albert_layers.0.attention.key` → `albert.encoder.layer.mha.key`
/// - `albert.encoder.albert_layer_groups.0.albert_layers.0.attention.value` → `albert.encoder.layer.mha.value`
/// - `albert.encoder.albert_layer_groups.0.albert_layers.0.attention.dense` → `albert.encoder.layer.mha.output`
/// - `albert.encoder.albert_layer_groups.0.albert_layers.0.attention.LayerNorm` → `albert.encoder.layer.norm_1`
/// - `albert.encoder.albert_layer_groups.0.albert_layers.0.ffn` → `albert.encoder.layer.pwff.linear_inner`
/// - `albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output` → `albert.encoder.layer.pwff.linear_outer`
/// - `albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm` → `albert.encoder.layer.norm_2`
///
/// MLM head keys:
/// - `predictions.dense` → `mlm_dense`
/// - `predictions.LayerNorm` → `mlm_layer_norm`
/// - `predictions.decoder` → `mlm_decoder`
/// - `predictions.bias` → `mlm_bias`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure the keys remapping should be repeated in the docs. This is more of an implementation detail. The mappings are defined with comments in the function, I think it should be sufficient.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in ba475f3. Removed the redundant key mapping docs from the function docstring.

Run the fill-mask inference example:

```bash
cargo run --example inference --features ndarray --release
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't work, even for other backends e.g.

cargo run --example inference --features cuda --release

yields

error: target `inference` in package `albert-burn` requires the features: `pretrained`, `ndarray`
Consider enabling them by passing, e.g., `--features="pretrained ndarray"`

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in ba475f3. Updated to use --features "pretrained,ndarray".

abs_diff / scale
}

fn assert_close(actual: f32, expected: f32, label: &str) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably could have used the TensorData assert_approx_eq but fine for an integration test I guess

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Kept the custom impl since it provides more descriptive error messages with labels for each assertion, but TensorData::assert_approx_eq would work too.

@antimora antimora requested a review from laggui February 3, 2026 21:15
@laggui laggui merged commit 55bb0bb into tracel-ai:main Feb 4, 2026
2 checks passed
@laggui
Copy link
Member

laggui commented Feb 4, 2026

@antimora not sure why this broke main CI when merging

https://github.com/tracel-ai/models/actions/runs/21674714690/job/62491679306

   Compiling burn-store v0.20.1
error[E0308]: mismatched types
  --> /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/burn-store-0.20.1/src/pytorch/lazy_data.rs:90:25
   |
87 |             file_list.push((name, offset, compressed_size));
   |             ---------      ------------------------------- this argument has type `(std::string::String, std::option::Option<u64>, u64)`...
   |             |
   |             ... which causes `file_list` to have type `std::vec::Vec<(std::string::String, std::option::Option<u64>, u64)>`
...
90 |         Ok(Self { path, file_list })
   |                         ^^^^^^^^^ expected `Vec<(String, u64, u64)>`, found `Vec<(String, Option<u64>, u64)>`
   |
   = note: expected struct `std::vec::Vec<(std::string::String, u64, _)>`
              found struct `std::vec::Vec<(std::string::String, std::option::Option<u64>, _)>`

   Compiling burn v0.20.1
For more information about this error, try `rustc --explain E0308`.

/edit: ahhh shoot looks like there is a breaking change in zip 7.3.0 which was just released... even though it was a minor version. file.data_start() now appears to return an Option<u64>.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants