diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 605817c33..2fbfc922f 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -760,6 +760,9 @@ compiled_trainstep_file: "" # Name of saved serialized compiled train_step, e.g. compile_topology: '' # Target hardware version, e.g. 'v5e-256' compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. +# MaxText Estimator configs +write_estimator_result: False + decode_sampling_strategy: "greedy" # decode_sampling_strategy should be one of greedy, weighted, nucleus, topk, or composite(top_k -> top_p -> weighted temperature) decode_sampling_nucleus_p: -1 # set if you're doing nucleus / top-p decode_sampling_top_k: 0 # set if you're doing top-k diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index b9f243a07..4ba27a5e9 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -1178,6 +1178,7 @@ class AOT(BaseModel): compiled_trainstep_file: PathStr = Field("", description="Name of saved serialized compiled train_step.") compile_topology: str = Field("", description="Target hardware version, e.g. 'v5e-256'.") compile_topology_num_slices: int = Field(-1, description="Number of target slices.") + write_estimator_result: bool = Field(False, description="Write estimator.py results in a separate file.") class DevelopmentAndDebugging(BaseModel): diff --git a/src/MaxText/estimator.py b/src/MaxText/estimator.py index 8cdb3c0d0..e3439b5ca 100644 --- a/src/MaxText/estimator.py +++ b/src/MaxText/estimator.py @@ -164,7 +164,34 @@ def next_policy(policy: dict) -> dict[str, str] | None: return None -def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64) -> int: +def find_pdb_scalar(config): + """Calculates the scaling factor to normalize the Per-Device Batch (PDB) size. + + In distributed training, the batch size is divided across various mesh axes. + When using non-batch-based sharding (like Tensor Parallelism), the raw + per-device batch size can become a fractional value. + + This function identifies those non-batch axes (e.g., 'tensor') and calculates + a multiplier. This scalar represents the value by which a fractional per-device + batch size must be multiplied to result in an integer value, ensuring + compatibility with memory and compute estimation logic. + + Args: + config: The configuration object containing 'mesh_axes' and the + corresponding 'ici_{axis}_parallelism' values. + + Returns: + float: The aggregate parallelism degree of all non-data/non-FSDP axes, + serving as the integer-normalization constant for the PDB. + """ + pdb_scalar = 1.0 + for mesh_axis in config.mesh_axes: + if mesh_axis not in ("data", "fsdp", "fsdp_transpose", "expert", "stage"): + pdb_scalar *= getattr(config, f"ici_{mesh_axis}_parallelism") + return pdb_scalar + + +def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64, pdb_scalar=1.0) -> int: """ Finds the largest possible per_device_batch_size (pdb) that does not cause an OOM error. @@ -181,26 +208,29 @@ def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64) -> int: """ print(f"Starting binary search for the largest batch size between {min_pdb} and {max_pdb}.") + if pdb_scalar == 0.0: + raise ValueError("pdb_scalar cannot be value zero.") + if is_oom(base_argv, policy, min_pdb): print(f"OOM at minimum batch size {min_pdb}.") - return min_pdb - 1 + return min_pdb - 1 / pdb_scalar if not is_oom(base_argv, policy, max_pdb): print(f"No OOM at maximum batch size {max_pdb}.") return max_pdb - low, high, result = min_pdb, max_pdb, min_pdb + low, high, result = min_pdb * pdb_scalar, max_pdb * pdb_scalar, min_pdb * pdb_scalar while low <= high: mid = (low + high) // 2 if mid < min_pdb: low = mid + 1 continue - if not is_oom(base_argv, policy, mid): + if not is_oom(base_argv, policy, mid / pdb_scalar): result = mid low = mid + 1 else: high = mid - 1 - return result + return result / pdb_scalar def is_oom(base_argv, policy: dict, pdb: int) -> bool: @@ -294,6 +324,7 @@ def search( base_argv, init_policy: dict = None, max_pdb: int = 256, + pdb_scalar: float = 1.0, ) -> list[tuple[int, dict]]: """ Performs the core search algorithm to find the Pareto frontier points. @@ -308,11 +339,13 @@ def search( """ output_lst = [] policy = build_full_device_policy(tensor_names) if init_policy is None else init_policy - pdb = 1 + pdb = 1 / pdb_scalar while policy is not None: - pdb = largest_batch_size(base_argv, policy, min_pdb=pdb, max_pdb=max_pdb) + pdb = largest_batch_size(base_argv, policy, min_pdb=pdb, max_pdb=max_pdb, pdb_scalar=pdb_scalar) if pdb > 0: output_lst.append((pdb, policy)) + else: + break policy = next_policy(policy) return output_lst @@ -432,6 +465,7 @@ def main(argv_list: Sequence[str]) -> None: with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull): config = pyconfig.initialize(base_argv) train_compile.validate_config(config) + pdb_scalar = find_pdb_scalar(config) # Get the prioritized list of tensors to try rematerializing tensor_names = generate_priority_list(config, provided_tensor_names) @@ -451,25 +485,36 @@ def main(argv_list: Sequence[str]) -> None: # MODE 2: No batch size. Search for both batch size and policy. print("No batch size provided. Searching for max batch size and policies...") # First, find the absolute max batch size that fits *even with full remat* - max_pdb = largest_batch_size(base_argv, full_remat_policy, min_pdb=1) + max_pdb = largest_batch_size(base_argv, full_remat_policy, min_pdb=1 / pdb_scalar, pdb_scalar=pdb_scalar) + suggested_list = [(max_pdb, full_remat_policy)] # Now, search for combinations, starting from no-remat up to max_pdb - suggested_list = search(tensor_names, base_argv, init_policy=full_device_policy, max_pdb=max_pdb) + suggested_list.extend( + search(tensor_names, base_argv, init_policy=full_device_policy, max_pdb=max_pdb, pdb_scalar=pdb_scalar) + ) end_time = time.time() print(f"\nSearch completed in {end_time - start_time:.2f} seconds.") output_filename = "remat_commands_from_estimator.txt" - print(f"Writing {len(suggested_list)} suggested command(s) to {output_filename}...") - with open(output_filename, "w", encoding="utf-8") as f: - for pdb_result, policy_result in suggested_list: - # Build the full, runnable command string - final_argv = build_argv(base_argv[1:], policy_result, pdb_result) - command = "python -m MaxText.train " + " ".join(final_argv) + # Only open the file and print the status if the config allows writing + if config.write_estimator_result: + print(f"Writing {len(suggested_list)} suggested command(s) to {output_filename}...") - f.write(command + "\n") - print(f" - Found valid combo: pdb={pdb_result}, policy={policy_result}") + with open(output_filename, "w", encoding="utf-8") as f: + for pdb_result, policy_result in suggested_list: + # Build the full, runnable command string + final_argv = build_argv(base_argv[1:], policy_result, pdb_result) + command = "python -m MaxText.train " + " ".join(final_argv) + + f.write(command + "\n") + print(f" - Found valid combo: pdb={pdb_result}, policy={policy_result}") + + print("Done.") + else: + for pdb_result, policy_result in suggested_list: + print(f" - Found valid combo (not saved to file): pdb={pdb_result}, policy={policy_result}") print("Done.") diff --git a/src/MaxText/train_compile.py b/src/MaxText/train_compile.py index e55ca0201..22f5598fa 100644 --- a/src/MaxText/train_compile.py +++ b/src/MaxText/train_compile.py @@ -186,6 +186,7 @@ def is_oom(argv: Sequence[str]) -> bool: except Exception as e: # return true if OOM error happens # OOM error looks like + # Check failed: entries[i] <= std::numeric_limits::max() # jax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Allocation ... # jax.errors.JaxRuntimeError: INTERNAL: RET_CHECK failure ... message = str(e).lower()