-
Notifications
You must be signed in to change notification settings - Fork 443
Description
Feature or Model Request
What problem are you trying to solve?
Current sequence packing strategies in MaxText (specifically first_fit) prioritize placement speed but do not necessarily optimize for the densest possible packing. This results in unnecessary padding tokens within batches, which consumes compute resources without contributing to model learning.
Why is this problem important?
Minimizing padding is crucial for training efficiency. "Best Fit" packing searches for the tightest fit for sequences, significantly reducing the ratio of padding tokens compared to the current "First Fit" default.
Benchmarks demonstrate that Best Fit can reduce padding by up to 27.6% (at 1024 bins) compared to first_fit, directly translating to higher data throughput and more efficient training per step with negligible overhead.
Describe your requested feature or solution
I request the integration of grain.experimental.BestFitPackIterDataset into the MaxText _grain_data_processing.py pipeline.
The feature has recently been merged into upstream Grain (Pull Request #1028).
We should enable this via a configuration flag (e.g., config.grain_packing_type = "best_fit"). The implementation logic would be similar to the existing setup but utilizing the new class:
elif config.grain_packing_type == "best_fit":
dataset = grain.experimental.BestFitPackIterDataset(
dataset, length_struct=length_struct, num_packing_bins=batch_size
)Benchmark Results
We benchmarked the packing algorithms using a mock dataset of 20,000 variable-length sequences (random lengths up to 1024 tokens) across different batch sizes (bins). The results confirm that BestFit significantly reduces total padding tokens compared to FirstFit with negligible processing overhead.
| Batch Size (Bins) | Algorithm | Time (s) | Total Padding Tokens | Padding Reduction (vs FirstFit) |
|---|---|---|---|---|
| 30 | FirstFitPackedBatch |
3.01 | 2,646,804 | - |
BestFitPackedBatch |
3.20 | 2,382,612 | ▼ 9.98% | |
| 128 | FirstFitPackedBatch |
2.98 | 1,723,156 | - |
BestFitPackedBatch |
3.26 | 1,411,860 | ▼ 18.07% | |
| 1024 | FirstFitPackedBatch |
3.11 | 926,484 | - |
BestFitPackedBatch |
3.44 | 670,484 | ▼ 27.63% |
Additional context or examples
- Upstream Grain PR: original PR, copybara PR