Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -604,10 +604,10 @@ grain_train_files: ''
grain_eval_files: ''
grain_train_mixture_config_path: '' # Path to a JSON file specifying the mixture weights for Grain training data.
grain_file_type: 'arrayrecord' # arrayrecord or parquet
grain_packing_type: 'first_fit' # 'first_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html
grain_packing_type: 'first_fit' # 'first_fit', 'best_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html
grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html
grain_per_worker_buffer_size: 1
# num_threads and prefetch_buffer_size are per-worker per-dataset.
# num_threads and prefetch_buffer_size are per-worker per-dataset.
# When using array_records, they are used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions)
# The default value matches that in the Grain package. If mixing multiple data sources, consider lowering these values to reduce memory usage.
# When using parquet, grain_num_threads is the number of files to read and interleave in parallel
Expand Down
4 changes: 2 additions & 2 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,9 +874,9 @@ class DatasetGeneral(BaseModel):
True,
description="Whether to pack multiple short examples into a single sequence.",
)
grain_packing_type: Literal["first_fit", "concat_then_split"] = Field(
grain_packing_type: Literal["first_fit", "best_fit", "concat_then_split"] = Field(
"first_fit",
description="Packing type when using Grain pipeline. 'first_fit' or 'concat_then_split'.",
description="Packing type when using Grain pipeline. 'first_fit', 'best_fit' or 'concat_then_split'.",
)
max_segments_per_seq: int = Field(
32,
Expand Down
4 changes: 3 additions & 1 deletion src/MaxText/input_pipeline/_grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import jax

from grain.experimental import pick_performance_config
from grain.experimental import BestFitPackIterDataset, pick_performance_config
import grain.python as grain

from MaxText.utils import gcs_utils
Expand Down Expand Up @@ -246,6 +246,8 @@ def pretrain_preprocessing_pipeline(
dataset = grain.experimental.FirstFitPackIterDataset(
dataset, length_struct=length_struct, num_packing_bins=batch_size
)
elif config.grain_packing_type == "best_fit":
dataset = BestFitPackIterDataset(dataset, length_struct=length_struct, num_packing_bins=batch_size)
elif config.grain_packing_type == "concat_then_split":
if config.add_bos and hasattr(tokenizer_model, "bos_id"):
dataset = grain.experimental.ConcatThenSplitIterDataset(
Expand Down
34 changes: 34 additions & 0 deletions tests/grain_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,40 @@ def test_batch_determinism(self):
super().test_batch_determinism()


class GrainArrayRecordBestFitPackingTest(GrainArrayRecordProcessingTest):
"""Test grain data processing with best_fit packing strategy."""

def setUp(self):
super().setUp()
temp_dir = tempfile.gettempdir()
self.config = pyconfig.initialize(
[sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
per_device_batch_size=1,
run_name="test",
mesh_axes=["data"],
logical_axis_rules=[["batch", "data"]],
data_sharding=["data"],
base_output_directory="gs://max-experiments/",
dataset_type="grain",
grain_train_files=os.path.join(
temp_dir, "gcsfuse", "array-record", "c4", "en", "3.0.1", "c4-train.array_record*"
),
grain_packing_type="best_fit", # Use best_fit packing
tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"),
enable_checkpointing=False,
)
self.mesh_shape_1d = (len(jax.devices()),)
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
self.process_indices = input_pipeline_interface.get_process_loading_real_data(
self.config.data_sharding,
self.config.global_batch_size_to_load,
self.config.global_batch_size_to_train_on,
self.config.max_target_length,
self.mesh,
)
self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)


class GrainParquetProcessingTest(unittest.TestCase):

@classmethod
Expand Down