From 535be220cb103503edefba4fbd42d128eacf1d65 Mon Sep 17 00:00:00 2001 From: bzantium Date: Wed, 24 Dec 2025 16:24:26 +0900 Subject: [PATCH] Implement Best Fit Packing Algorithm in Grain Pipeline for Reduced Padding Signed-off-by: bzantium --- src/MaxText/configs/base.yml | 4 +-- src/MaxText/configs/types.py | 4 +-- .../input_pipeline/_grain_data_processing.py | 4 ++- tests/grain_data_processing_test.py | 34 +++++++++++++++++++ 4 files changed, 41 insertions(+), 5 deletions(-) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 605817c33..d0c5d2647 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -604,10 +604,10 @@ grain_train_files: '' grain_eval_files: '' grain_train_mixture_config_path: '' # Path to a JSON file specifying the mixture weights for Grain training data. grain_file_type: 'arrayrecord' # arrayrecord or parquet -grain_packing_type: 'first_fit' # 'first_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html +grain_packing_type: 'first_fit' # 'first_fit', 'best_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html grain_per_worker_buffer_size: 1 -# num_threads and prefetch_buffer_size are per-worker per-dataset. +# num_threads and prefetch_buffer_size are per-worker per-dataset. # When using array_records, they are used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions) # The default value matches that in the Grain package. If mixing multiple data sources, consider lowering these values to reduce memory usage. # When using parquet, grain_num_threads is the number of files to read and interleave in parallel diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index b9f243a07..838366b16 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -874,9 +874,9 @@ class DatasetGeneral(BaseModel): True, description="Whether to pack multiple short examples into a single sequence.", ) - grain_packing_type: Literal["first_fit", "concat_then_split"] = Field( + grain_packing_type: Literal["first_fit", "best_fit", "concat_then_split"] = Field( "first_fit", - description="Packing type when using Grain pipeline. 'first_fit' or 'concat_then_split'.", + description="Packing type when using Grain pipeline. 'first_fit', 'best_fit' or 'concat_then_split'.", ) max_segments_per_seq: int = Field( 32, diff --git a/src/MaxText/input_pipeline/_grain_data_processing.py b/src/MaxText/input_pipeline/_grain_data_processing.py index 50c840603..0154d801e 100644 --- a/src/MaxText/input_pipeline/_grain_data_processing.py +++ b/src/MaxText/input_pipeline/_grain_data_processing.py @@ -23,7 +23,7 @@ import jax -from grain.experimental import pick_performance_config +from grain.experimental import BestFitPackIterDataset, pick_performance_config import grain.python as grain from MaxText.utils import gcs_utils @@ -246,6 +246,8 @@ def pretrain_preprocessing_pipeline( dataset = grain.experimental.FirstFitPackIterDataset( dataset, length_struct=length_struct, num_packing_bins=batch_size ) + elif config.grain_packing_type == "best_fit": + dataset = BestFitPackIterDataset(dataset, length_struct=length_struct, num_packing_bins=batch_size) elif config.grain_packing_type == "concat_then_split": if config.add_bos and hasattr(tokenizer_model, "bos_id"): dataset = grain.experimental.ConcatThenSplitIterDataset( diff --git a/tests/grain_data_processing_test.py b/tests/grain_data_processing_test.py index 6d5c44b50..fc4f49730 100644 --- a/tests/grain_data_processing_test.py +++ b/tests/grain_data_processing_test.py @@ -227,6 +227,40 @@ def test_batch_determinism(self): super().test_batch_determinism() +class GrainArrayRecordBestFitPackingTest(GrainArrayRecordProcessingTest): + """Test grain data processing with best_fit packing strategy.""" + + def setUp(self): + super().setUp() + temp_dir = tempfile.gettempdir() + self.config = pyconfig.initialize( + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + dataset_type="grain", + grain_train_files=os.path.join( + temp_dir, "gcsfuse", "array-record", "c4", "en", "3.0.1", "c4-train.array_record*" + ), + grain_packing_type="best_fit", # Use best_fit packing + tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"), + enable_checkpointing=False, + ) + self.mesh_shape_1d = (len(jax.devices()),) + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.process_indices = input_pipeline_interface.get_process_loading_real_data( + self.config.data_sharding, + self.config.global_batch_size_to_load, + self.config.global_batch_size_to_train_on, + self.config.max_target_length, + self.mesh, + ) + self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + + class GrainParquetProcessingTest(unittest.TestCase): @classmethod