Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,9 @@ compiled_trainstep_file: "" # Name of saved serialized compiled train_step, e.g.
compile_topology: '' # Target hardware version, e.g. 'v5e-256'
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

# MaxText Estimator configs
write_estimator_result: False

decode_sampling_strategy: "greedy" # decode_sampling_strategy should be one of greedy, weighted, nucleus, topk, or composite(top_k -> top_p -> weighted temperature)
decode_sampling_nucleus_p: -1 # set if you're doing nucleus / top-p
decode_sampling_top_k: 0 # set if you're doing top-k
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,7 @@ class AOT(BaseModel):
compiled_trainstep_file: PathStr = Field("", description="Name of saved serialized compiled train_step.")
compile_topology: str = Field("", description="Target hardware version, e.g. 'v5e-256'.")
compile_topology_num_slices: int = Field(-1, description="Number of target slices.")
write_estimator_result: bool = Field(False, description="Write estimator.py results in a separate file.")


class DevelopmentAndDebugging(BaseModel):
Expand Down
79 changes: 62 additions & 17 deletions src/MaxText/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,34 @@ def next_policy(policy: dict) -> dict[str, str] | None:
return None


def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64) -> int:
def find_pdb_scalar(config):
"""Calculates the scaling factor to normalize the Per-Device Batch (PDB) size.

In distributed training, the batch size is divided across various mesh axes.
When using non-batch-based sharding (like Tensor Parallelism), the raw
per-device batch size can become a fractional value.

This function identifies those non-batch axes (e.g., 'tensor') and calculates
a multiplier. This scalar represents the value by which a fractional per-device
batch size must be multiplied to result in an integer value, ensuring
compatibility with memory and compute estimation logic.

Args:
config: The configuration object containing 'mesh_axes' and the
corresponding 'ici_{axis}_parallelism' values.

Returns:
float: The aggregate parallelism degree of all non-data/non-FSDP axes,
serving as the integer-normalization constant for the PDB.
"""
pdb_scalar = 1.0
for mesh_axis in config.mesh_axes:
if mesh_axis not in ("data", "fsdp", "fsdp_transpose", "expert", "stage"):
pdb_scalar *= getattr(config, f"ici_{mesh_axis}_parallelism")
return pdb_scalar


def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64, pdb_scalar=1.0) -> int:
"""
Finds the largest possible per_device_batch_size (pdb) that does not cause an OOM error.

Expand All @@ -181,26 +208,29 @@ def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64) -> int:
"""
print(f"Starting binary search for the largest batch size between {min_pdb} and {max_pdb}.")

if pdb_scalar == 0.0:
raise ValueError("pdb_scalar cannot be value zero.")

if is_oom(base_argv, policy, min_pdb):
print(f"OOM at minimum batch size {min_pdb}.")
return min_pdb - 1
return min_pdb - 1 / pdb_scalar
if not is_oom(base_argv, policy, max_pdb):
print(f"No OOM at maximum batch size {max_pdb}.")
return max_pdb

low, high, result = min_pdb, max_pdb, min_pdb
low, high, result = min_pdb * pdb_scalar, max_pdb * pdb_scalar, min_pdb * pdb_scalar
while low <= high:
mid = (low + high) // 2
if mid < min_pdb:
low = mid + 1
continue

if not is_oom(base_argv, policy, mid):
if not is_oom(base_argv, policy, mid / pdb_scalar):
result = mid
low = mid + 1
else:
high = mid - 1
return result
return result / pdb_scalar


def is_oom(base_argv, policy: dict, pdb: int) -> bool:
Expand Down Expand Up @@ -294,6 +324,7 @@ def search(
base_argv,
init_policy: dict = None,
max_pdb: int = 256,
pdb_scalar: float = 1.0,
) -> list[tuple[int, dict]]:
"""
Performs the core search algorithm to find the Pareto frontier points.
Expand All @@ -308,11 +339,13 @@ def search(
"""
output_lst = []
policy = build_full_device_policy(tensor_names) if init_policy is None else init_policy
pdb = 1
pdb = 1 / pdb_scalar
while policy is not None:
pdb = largest_batch_size(base_argv, policy, min_pdb=pdb, max_pdb=max_pdb)
pdb = largest_batch_size(base_argv, policy, min_pdb=pdb, max_pdb=max_pdb, pdb_scalar=pdb_scalar)
if pdb > 0:
output_lst.append((pdb, policy))
else:
break
policy = next_policy(policy)
return output_lst

Expand Down Expand Up @@ -432,6 +465,7 @@ def main(argv_list: Sequence[str]) -> None:
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
config = pyconfig.initialize(base_argv)
train_compile.validate_config(config)
pdb_scalar = find_pdb_scalar(config)

# Get the prioritized list of tensors to try rematerializing
tensor_names = generate_priority_list(config, provided_tensor_names)
Expand All @@ -451,25 +485,36 @@ def main(argv_list: Sequence[str]) -> None:
# MODE 2: No batch size. Search for both batch size and policy.
print("No batch size provided. Searching for max batch size and policies...")
# First, find the absolute max batch size that fits *even with full remat*
max_pdb = largest_batch_size(base_argv, full_remat_policy, min_pdb=1)
max_pdb = largest_batch_size(base_argv, full_remat_policy, min_pdb=1 / pdb_scalar, pdb_scalar=pdb_scalar)
suggested_list = [(max_pdb, full_remat_policy)]

# Now, search for combinations, starting from no-remat up to max_pdb
suggested_list = search(tensor_names, base_argv, init_policy=full_device_policy, max_pdb=max_pdb)
suggested_list.extend(
search(tensor_names, base_argv, init_policy=full_device_policy, max_pdb=max_pdb, pdb_scalar=pdb_scalar)
)

end_time = time.time()
print(f"\nSearch completed in {end_time - start_time:.2f} seconds.")

output_filename = "remat_commands_from_estimator.txt"
print(f"Writing {len(suggested_list)} suggested command(s) to {output_filename}...")

with open(output_filename, "w", encoding="utf-8") as f:
for pdb_result, policy_result in suggested_list:
# Build the full, runnable command string
final_argv = build_argv(base_argv[1:], policy_result, pdb_result)
command = "python -m MaxText.train " + " ".join(final_argv)
# Only open the file and print the status if the config allows writing
if config.write_estimator_result:
print(f"Writing {len(suggested_list)} suggested command(s) to {output_filename}...")

f.write(command + "\n")
print(f" - Found valid combo: pdb={pdb_result}, policy={policy_result}")
with open(output_filename, "w", encoding="utf-8") as f:
for pdb_result, policy_result in suggested_list:
# Build the full, runnable command string
final_argv = build_argv(base_argv[1:], policy_result, pdb_result)
command = "python -m MaxText.train " + " ".join(final_argv)

f.write(command + "\n")
print(f" - Found valid combo: pdb={pdb_result}, policy={policy_result}")

print("Done.")
else:
for pdb_result, policy_result in suggested_list:
print(f" - Found valid combo (not saved to file): pdb={pdb_result}, policy={policy_result}")

print("Done.")

Expand Down
1 change: 1 addition & 0 deletions src/MaxText/train_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def is_oom(argv: Sequence[str]) -> bool:
except Exception as e:
# return true if OOM error happens
# OOM error looks like
# Check failed: entries[i] <= std::numeric_limits<uint32_t>::max()
# jax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Allocation ...
# jax.errors.JaxRuntimeError: INTERNAL: RET_CHECK failure ...
message = str(e).lower()
Expand Down
Loading