Implement Best Fit Packing Algorithm in Grain Pipeline for Reduced Padding #2886
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR integrates the "Best Fit" packing algorithm (
grain.experimental.BestFitPackIterDataset) into the Grain input pipeline in MaxText. Users can now select this strategy by settinggrain_packing_type="best_fit"in their configuration.Previously, users were limited to
first_fit(the default) orconcat_then_split. Whilefirst_fitis efficient for placement speed, it often leaves suboptimal gaps in packed sequences, resulting in unnecessary padding tokens.Why this change is being made:
Minimizing padding tokens is critical for training efficiency. By finding the tightest fit for variable-length sequences, we can significantly increase the effective data throughput per training step.
The problem being solved:
The default
first_fitalgorithm prioritizes placement speed but does not optimize for density. Benchmarks demonstrate thatbest_fitcan reduce padding by up to 27.6% (at batch size 1024) compared tofirst_fit, with negligible processing overhead. This allows more real data to be processed within the same compute budget.Specific implementation:
MaxText/configs/base.ymlandMaxText/configs/types.pyto include"best_fit"as a valid option forgrain_packing_type.MaxText/input_pipeline/_grain_data_processing.pyto conditionally instantiategrain.experimental.BestFitPackIterDatasetwhen the config option is set.GrainArrayRecordBestFitPackingTestandGrainParquetBestFitPackingTest) intests/grain_data_processing_test.pyto ensure the new packing strategy works correctly for both ArrayRecord and Parquet formats.References:
This feature leverages the upstream implementation merged in Grain PR #1028.
FIXES: #2884
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
I have tested this change via new unit tests and local verification.
1. Unit Tests
Added
GrainArrayRecordBestFitPackingTestandGrainParquetBestFitPackingTesttotests/grain_data_processing_test.py.python3 tests/grain_data_processing_test.py2. Benchmark
(Referenced in the issue) Verified that
best_fitproduces significantly lower padding counts compared tofirst_fiton mock variable-length data.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.