diff --git a/src/MaxText/sequence_packing.py b/src/MaxText/sequence_packing.py index 74e5c09ea..88d46c873 100644 --- a/src/MaxText/sequence_packing.py +++ b/src/MaxText/sequence_packing.py @@ -20,7 +20,11 @@ def pack_dataset( - dataset: tf.data.Dataset, key2length: int | dict[str, int], pad_id: int, keys: None | list[str] = None + dataset: tf.data.Dataset, + key2length: int | dict[str, int], + pad_id: int, + keys: None | list[str] = None, + packing_batch_size: int = 256, ) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. @@ -54,6 +58,7 @@ def pack_dataset( dataset: a tf.data.Dataset key2length: an integer, or a dict from feature-key to integer keys: a list of strings (e.g. ["inputs", "targets"]) + packing_batch_size: batch size for an intermediate batching during packing operation (default 256). Returns: a tf.data.Dataset """ @@ -80,7 +85,10 @@ def pack_dataset( dataset = dataset.map(lambda x: {k: x[k][: key2length[k]] for k in keys}, num_parallel_calls=AUTOTUNE) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. - batch_size = max(key2length.values()) + # batch_size = max(key2length.values()) + # The above implementation creates numerous intermediate batch (=max_target_length, 8192 by default), potentially cause OOM + # Switch to a fixed size here. + batch_size = packing_batch_size # We pad with a negative value instead of the default 0 because 0 is a # valid token for some tokenizers for e.g., representing unknown value dataset = dataset.padded_batch(batch_size, padded_shapes={k: [-1] for k in keys}, padding_values=-1)