Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/MaxText/sequence_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@


def pack_dataset(
dataset: tf.data.Dataset, key2length: int | dict[str, int], pad_id: int, keys: None | list[str] = None
dataset: tf.data.Dataset,
key2length: int | dict[str, int],
pad_id: int,
keys: None | list[str] = None,
packing_batch_size: int = 256,
) -> tf.data.Dataset:
"""Creates a 'packed' version of a dataset on-the-fly.
Adapted from the mesh-tf implementation.
Expand Down Expand Up @@ -54,6 +58,7 @@ def pack_dataset(
dataset: a tf.data.Dataset
key2length: an integer, or a dict from feature-key to integer
keys: a list of strings (e.g. ["inputs", "targets"])
packing_batch_size: batch size for an intermediate batching during packing operation (default 256).
Returns:
a tf.data.Dataset
"""
Expand All @@ -80,7 +85,10 @@ def pack_dataset(
dataset = dataset.map(lambda x: {k: x[k][: key2length[k]] for k in keys}, num_parallel_calls=AUTOTUNE)
# Setting batch_size=length ensures that the concatenated sequences (if they
# have length >=1) are sufficient to fill at least one packed example.
batch_size = max(key2length.values())
# batch_size = max(key2length.values())
# The above implementation creates numerous intermediate batch (=max_target_length, 8192 by default), potentially cause OOM
# Switch to a fixed size here.
batch_size = packing_batch_size
# We pad with a negative value instead of the default 0 because 0 is a
# valid token for some tokenizers for e.g., representing unknown value
dataset = dataset.padded_batch(batch_size, padded_shapes={k: [-1] for k in keys}, padding_values=-1)
Expand Down
Loading