Skip to content

Conversation

@bzantium
Copy link
Collaborator

@bzantium bzantium commented Dec 24, 2025

Description

This PR integrates the "Best Fit" packing algorithm (grain.experimental.BestFitPackIterDataset) into the Grain input pipeline in MaxText. Users can now select this strategy by setting grain_packing_type="best_fit" in their configuration.

Previously, users were limited to first_fit (the default) or concat_then_split. While first_fit is efficient for placement speed, it often leaves suboptimal gaps in packed sequences, resulting in unnecessary padding tokens.

Why this change is being made:
Minimizing padding tokens is critical for training efficiency. By finding the tightest fit for variable-length sequences, we can significantly increase the effective data throughput per training step.

The problem being solved:
The default first_fit algorithm prioritizes placement speed but does not optimize for density. Benchmarks demonstrate that best_fit can reduce padding by up to 27.6% (at batch size 1024) compared to first_fit, with negligible processing overhead. This allows more real data to be processed within the same compute budget.

Specific implementation:

  • Updated MaxText/configs/base.yml and MaxText/configs/types.py to include "best_fit" as a valid option for grain_packing_type.
  • Modified MaxText/input_pipeline/_grain_data_processing.py to conditionally instantiate grain.experimental.BestFitPackIterDataset when the config option is set.
  • Added dedicated unit tests (GrainArrayRecordBestFitPackingTest and GrainParquetBestFitPackingTest) in tests/grain_data_processing_test.py to ensure the new packing strategy works correctly for both ArrayRecord and Parquet formats.

References:
This feature leverages the upstream implementation merged in Grain PR #1028.

FIXES: #2884

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

I have tested this change via new unit tests and local verification.

1. Unit Tests
Added GrainArrayRecordBestFitPackingTest and GrainParquetBestFitPackingTest to tests/grain_data_processing_test.py.

  • Command: python3 tests/grain_data_processing_test.py
  • Result: Tests passed, confirming that the pipeline initializes and processes batches correctly under the new packing mode.

2. Benchmark
(Referenced in the issue) Verified that best_fit produces significantly lower padding counts compared to first_fit on mock variable-length data.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Dec 24, 2025

Codecov Report

❌ Patch coverage is 0% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...c/MaxText/input_pipeline/_grain_data_processing.py 0.00% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

…dding

Signed-off-by: bzantium <ryumin93@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement Best Fit Packing Algorithm in Grain Pipeline for Reduced Padding

1 participant