diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index cba8836de3..e53281db90 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -44,6 +44,11 @@ fp8_config: fp8_model_init_kwargs: enabled: false # If this is set to true, fp8_config.enabled must also be set to true. +# FP32 Master Weights +# When enabled, model is initialized in FP32 and MixedPrecisionPolicy casts to BF16 for compute. +# This matches Megatron's main_params_dtype=torch.float32 for better numerical stability. +use_fp32_master_weights: false + # Optimizer config adamw_kwargs: lr: 3e-3 diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py index a33fc27f87..1b6f9fc21e 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py @@ -502,3 +502,69 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path): assert fp8_log_dir.exists() assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_logs" / "nvdlfw_inspect_globalrank-0.log").exists() assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists() + + +def test_train_fsdp2_fp32_master_weights(tmp_path, recipe_path): + """Test FSDP2 training with FP32 master weights and BF16 compute. + + This test validates that the MixedPrecisionPolicy correctly: + - Stores master weights in FP32 + - Casts to BF16 for forward/backward compute + - Accumulates gradients in FP32 + - Training still converges properly + """ + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.resume_from_checkpoint=false", + "use_fp32_master_weights=true", + "use_torch_compile=false", # Disable compile to simplify debugging + "fp8_config.enabled=false", # Disable FP8 to isolate FP32 master weights + "config_kwargs.attn_input_format=bshd", + ], + ) + + final_loss = main_fsdp2(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + # FP32 master weights should achieve same or better convergence + assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0" + + +def test_train_fsdp2_fp32_master_weights_with_fp8(tmp_path, recipe_path): + """Test FSDP2 training with FP32 master weights and FP8 compute. + + This test validates the full precision hierarchy: + - FP32 master weights (stored by optimizer via MixedPrecisionPolicy) + - BF16 intermediate compute (via torch.autocast) + - FP8 compute for eligible TE layers (via te.fp8_autocast) + + This matches Megatron's approach where you "use the FP32 weights to get + a really precise quantization for FP8" (per Cory's Slack comment). + """ + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.resume_from_checkpoint=false", + "use_fp32_master_weights=true", + "use_torch_compile=false", # Disable compile to simplify debugging + "fp8_config.enabled=true", # Enable FP8 with FP32 master weights + "config_kwargs.attn_input_format=bshd", + # FP8 requires last dim divisible by 16 + "+dataset.pad_sequences_to_be_divisible_by=16", + ], + ) + + final_loss = main_fsdp2(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + # FP32 master weights + FP8 should achieve good convergence + assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0" diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 10a28a27cf..544c0371e2 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -25,7 +25,7 @@ import transformer_engine.pytorch from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import fully_shard +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard from torch.optim import AdamW from transformer_engine.common.recipe import Format from transformers.models.llama.configuration_llama import LlamaConfig @@ -83,8 +83,17 @@ def main(args: DictConfig) -> float | None: config_class = LlamaConfig model_class = LlamaForCausalLM + # Determine dtype for model initialization + # When use_fp32_master_weights=True, we create the model in FP32 and use MixedPrecisionPolicy + # to cast to BF16 for forward/backward. This matches Megatron's main_params_dtype=torch.float32 + use_fp32_master_weights = getattr(args, "use_fp32_master_weights", False) + model_dtype = torch.float32 if use_fp32_master_weights else torch.bfloat16 + + if use_fp32_master_weights: + logger.info("FP32 master weights enabled: model init in FP32, compute in BF16") + # Create an empty Llama3 model with a causal language model head, e.g. "meta-llama/Meta-Llama-3-8B". - config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) + config = config_class.from_pretrained(args.config_name_or_path, dtype=model_dtype, **args.config_kwargs) # Optionally use transformer engine to initialize only fp8 versions of weights by setting # `fp8_config.fp8_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8 @@ -97,11 +106,24 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) + # Create MixedPrecisionPolicy for FSDP when using FP32 master weights + # This casts FP32 master weights to BF16 for forward/backward, then back to FP32 for optimizer + mp_policy = None + if use_fp32_master_weights: + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward compute + reduce_dtype=torch.float32, # Accumulate gradients in FP32 for precision + cast_forward_inputs=False, # Do not cast inputs to param_dtype (BF16) + ) + logger.info( + "MixedPrecisionPolicy: param_dtype=bf16, reduce_dtype=fp32, output_dtype=bf16, cast_forward_inputs=False" + ) + # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. # Each decoder layer should be individually sharded before sharding the full model. for layer in model.model.layers: - fully_shard(layer, mesh=device_mesh["dp"]) - fully_shard(model, mesh=device_mesh["dp"]) + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) # If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters. if args.use_meta_device and isinstance(model, NVLlamaForCausalLM): @@ -157,6 +179,10 @@ def main(args: DictConfig) -> float | None: logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}") step = start_step micro_step = 0 # Gradient accumulation step counter + + # Create autocast context for FP32 master weights (casts compute to BF16) + autocast_ctx = torch.autocast("cuda", dtype=torch.bfloat16) if use_fp32_master_weights else nullcontext() + while step < args.num_train_steps: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 @@ -164,8 +190,9 @@ def main(args: DictConfig) -> float | None: micro_step += 1 # Forward pass with mixed precision. - with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe): - outputs = model(**batch) + with autocast_ctx: + with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe): + outputs = model(**batch) # Backward pass - scale loss by grad_acc_steps for proper gradient averaging loss = outputs.loss / args.grad_acc_steps