From bceea6b066ad9b8840ff67c05297bb09e8e22d34 Mon Sep 17 00:00:00 2001 From: Savitha Srinivasan Date: Sun, 25 Jan 2026 09:43:19 +0000 Subject: [PATCH 1/6] fp32 changes Signed-off-by: Savitha Srinivasan --- .../hydra_config/defaults.yaml | 5 +++ .../llama3_native_te/tests/test_train.py | 31 ++++++++++++++ .../recipes/llama3_native_te/train_fsdp2.py | 40 ++++++++++++++++--- 3 files changed, 70 insertions(+), 6 deletions(-) diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index cba8836de3..e53281db90 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -44,6 +44,11 @@ fp8_config: fp8_model_init_kwargs: enabled: false # If this is set to true, fp8_config.enabled must also be set to true. +# FP32 Master Weights +# When enabled, model is initialized in FP32 and MixedPrecisionPolicy casts to BF16 for compute. +# This matches Megatron's main_params_dtype=torch.float32 for better numerical stability. +use_fp32_master_weights: false + # Optimizer config adamw_kwargs: lr: 3e-3 diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py index a33fc27f87..3eb7c5b759 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py @@ -502,3 +502,34 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path): assert fp8_log_dir.exists() assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_logs" / "nvdlfw_inspect_globalrank-0.log").exists() assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists() + + +def test_train_fsdp2_fp32_master_weights(tmp_path, recipe_path): + """Test FSDP2 training with FP32 master weights and BF16 compute. + + This test validates that the MixedPrecisionPolicy correctly: + - Stores master weights in FP32 + - Casts to BF16 for forward/backward compute + - Accumulates gradients in FP32 + - Training still converges properly + """ + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.resume_from_checkpoint=false", + "use_fp32_master_weights=true", + "use_torch_compile=false", # Disable compile to simplify debugging + "fp8_config.enabled=false", # Disable FP8 to isolate FP32 master weights + "config_kwargs.attn_input_format=bshd", + ], + ) + + final_loss = main_fsdp2(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + # FP32 master weights should achieve same or better convergence + assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0" diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 10a28a27cf..636896c2bc 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -25,7 +25,7 @@ import transformer_engine.pytorch from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import fully_shard +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard from torch.optim import AdamW from transformer_engine.common.recipe import Format from transformers.models.llama.configuration_llama import LlamaConfig @@ -83,8 +83,17 @@ def main(args: DictConfig) -> float | None: config_class = LlamaConfig model_class = LlamaForCausalLM + # Determine dtype for model initialization + # When use_fp32_master_weights=True, we create the model in FP32 and use MixedPrecisionPolicy + # to cast to BF16 for forward/backward. This matches Megatron's main_params_dtype=torch.float32 + use_fp32_master_weights = getattr(args, "use_fp32_master_weights", False) + model_dtype = torch.float32 if use_fp32_master_weights else torch.bfloat16 + + if use_fp32_master_weights: + logger.info("FP32 master weights enabled: model init in FP32, compute in BF16") + # Create an empty Llama3 model with a causal language model head, e.g. "meta-llama/Meta-Llama-3-8B". - config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) + config = config_class.from_pretrained(args.config_name_or_path, dtype=model_dtype, **args.config_kwargs) # Optionally use transformer engine to initialize only fp8 versions of weights by setting # `fp8_config.fp8_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8 @@ -97,11 +106,25 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) + # Create MixedPrecisionPolicy for FSDP when using FP32 master weights + # This casts FP32 master weights to BF16 for forward/backward, then back to FP32 for optimizer + mp_policy = None + if use_fp32_master_weights: + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward compute + reduce_dtype=torch.float32, # Accumulate gradients in FP32 for precision + output_dtype=torch.bfloat16, # Output activations in BF16 + cast_forward_inputs=True, # Cast inputs to param_dtype (BF16) + ) + logger.info( + "MixedPrecisionPolicy: param_dtype=bf16, reduce_dtype=fp32, output_dtype=bf16, cast_forward_inputs=True" + ) + # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. # Each decoder layer should be individually sharded before sharding the full model. for layer in model.model.layers: - fully_shard(layer, mesh=device_mesh["dp"]) - fully_shard(model, mesh=device_mesh["dp"]) + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) # If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters. if args.use_meta_device and isinstance(model, NVLlamaForCausalLM): @@ -157,6 +180,10 @@ def main(args: DictConfig) -> float | None: logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}") step = start_step micro_step = 0 # Gradient accumulation step counter + + # Create autocast context for FP32 master weights (casts compute to BF16) + autocast_ctx = torch.autocast("cuda", dtype=torch.bfloat16) if use_fp32_master_weights else nullcontext() + while step < args.num_train_steps: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 @@ -164,8 +191,9 @@ def main(args: DictConfig) -> float | None: micro_step += 1 # Forward pass with mixed precision. - with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe): - outputs = model(**batch) + with autocast_ctx: + with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe): + outputs = model(**batch) # Backward pass - scale loss by grad_acc_steps for proper gradient averaging loss = outputs.loss / args.grad_acc_steps From a0da5c52c3bd39f5832c3e83968537fcfcc112dc Mon Sep 17 00:00:00 2001 From: Savitha Srinivasan Date: Mon, 26 Jan 2026 02:34:39 +0000 Subject: [PATCH 2/6] another fp32 test with fp8 enabled --- .../llama3_native_te/tests/test_train.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py index 3eb7c5b759..1b6f9fc21e 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py @@ -533,3 +533,38 @@ def test_train_fsdp2_fp32_master_weights(tmp_path, recipe_path): # FP32 master weights should achieve same or better convergence assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0" + + +def test_train_fsdp2_fp32_master_weights_with_fp8(tmp_path, recipe_path): + """Test FSDP2 training with FP32 master weights and FP8 compute. + + This test validates the full precision hierarchy: + - FP32 master weights (stored by optimizer via MixedPrecisionPolicy) + - BF16 intermediate compute (via torch.autocast) + - FP8 compute for eligible TE layers (via te.fp8_autocast) + + This matches Megatron's approach where you "use the FP32 weights to get + a really precise quantization for FP8" (per Cory's Slack comment). + """ + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.resume_from_checkpoint=false", + "use_fp32_master_weights=true", + "use_torch_compile=false", # Disable compile to simplify debugging + "fp8_config.enabled=true", # Enable FP8 with FP32 master weights + "config_kwargs.attn_input_format=bshd", + # FP8 requires last dim divisible by 16 + "+dataset.pad_sequences_to_be_divisible_by=16", + ], + ) + + final_loss = main_fsdp2(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + # FP32 master weights + FP8 should achieve good convergence + assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0" From 9f9d459e4ba243eb8c4b2d0d024bde84facb211c Mon Sep 17 00:00:00 2001 From: Savitha Srinivasan Date: Mon, 26 Jan 2026 04:36:43 +0000 Subject: [PATCH 3/6] Remove extra bf16 autocast Signed-off-by: Savitha Srinivasan --- .../recipes/llama3_native_te/train_fsdp2.py | 58 ++++++++++++++----- 1 file changed, 45 insertions(+), 13 deletions(-) diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 636896c2bc..3fb3f3c067 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -106,9 +106,11 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) - # Create MixedPrecisionPolicy for FSDP when using FP32 master weights - # This casts FP32 master weights to BF16 for forward/backward, then back to FP32 for optimizer - mp_policy = None + if not isinstance(model, torch.nn.Module): + raise TypeError("Expected model_class(config) to return a torch.nn.Module.") + + # Create MixedPrecisionPolicy for FSDP when using FP32 master weights. + # This casts FP32 master weights to BF16 for forward/backward, then back to FP32 for optimizer. if use_fp32_master_weights: mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward compute @@ -120,11 +122,17 @@ def main(args: DictConfig) -> float | None: "MixedPrecisionPolicy: param_dtype=bf16, reduce_dtype=fp32, output_dtype=bf16, cast_forward_inputs=True" ) - # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. - # Each decoder layer should be individually sharded before sharding the full model. - for layer in model.model.layers: - fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) - fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) + # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. + # Each decoder layer should be individually sharded before sharding the full model. + for layer in model.model.layers: + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) + else: + # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. + # Each decoder layer should be individually sharded before sharding the full model. + for layer in model.model.layers: + fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(model, mesh=device_mesh["dp"]) # If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters. if args.use_meta_device and isinstance(model, NVLlamaForCausalLM): @@ -139,6 +147,11 @@ def main(args: DictConfig) -> float | None: if args.fp8_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) + # Log initial parameter dtype before optimizer setup for debugging mixed precision behavior. + first_param = next(model.parameters(), None) + if first_param is not None: + logger.info("Model param dtype before optimizer: %s", first_param.dtype) + # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) @@ -162,7 +175,7 @@ def main(args: DictConfig) -> float | None: scheduler=scheduler, ckpt_path=ckpt_path, dist_config=dist_config, - dataloader=train_dataloader, + dataloader=train_dataloader, # type: ignore[arg-type] process_group=device_mesh.get_group("dp"), ) logger.info(f"Checkpoint loaded, resuming from step {start_step}, epoch {epoch}") @@ -181,9 +194,17 @@ def main(args: DictConfig) -> float | None: step = start_step micro_step = 0 # Gradient accumulation step counter - # Create autocast context for FP32 master weights (casts compute to BF16) - autocast_ctx = torch.autocast("cuda", dtype=torch.bfloat16) if use_fp32_master_weights else nullcontext() + # Create autocast context for FP32 master weights (casts compute to BF16). + # Allow override via config for debugging (default: enabled when use_fp32_master_weights=True). + use_autocast = getattr(args, "use_autocast", use_fp32_master_weights) + autocast_ctx = torch.autocast("cuda", dtype=torch.bfloat16) if use_autocast else nullcontext() + if use_fp32_master_weights: + logger.info("FP32 master weights: use_autocast=%s", use_autocast) + if train_dataloader is None: + raise RuntimeError("Expected train_dataloader to be initialized before training.") + + logged_dtypes = False while step < args.num_train_steps: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 @@ -199,6 +220,17 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss / args.grad_acc_steps loss.backward() + if not logged_dtypes: + grad_param = next((p for p in model.parameters() if p.grad is not None), None) + if grad_param is not None and grad_param.grad is not None: + logger.info( + "Dtypes after first backward: param=%s grad=%s loss=%s", + grad_param.dtype, + grad_param.grad.dtype, + outputs.loss.dtype, + ) + logged_dtypes = True + # Log microbatch step data for accumulation metrics perf_logger.log_micro_step(batch=batch, outputs=outputs) @@ -229,7 +261,7 @@ def main(args: DictConfig) -> float | None: step=step, epoch=epoch, dist_config=dist_config, - dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None, + dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None, # type: ignore[arg-type] process_group=device_mesh.get_group("dp"), max_checkpoints=args.checkpoint.max_checkpoints, async_save=args.checkpoint.async_save, @@ -241,7 +273,7 @@ def main(args: DictConfig) -> float | None: # Dataloader exhausted, incrementing epoch epoch += 1 - dataset_or_sampler.set_epoch(epoch) + dataset_or_sampler.set_epoch(epoch) # type: ignore[attr-defined] # Save final model to a .safetensors file. if args.checkpoint.save_final_model and ckpt_path: From 9d300c542756d721771f350614bfd4c3cf0a0d53 Mon Sep 17 00:00:00 2001 From: Savitha Srinivasan Date: Mon, 26 Jan 2026 20:11:07 +0000 Subject: [PATCH 4/6] Revert "Remove extra bf16 autocast" This reverts commit 9f9d459e4ba243eb8c4b2d0d024bde84facb211c. --- .../recipes/llama3_native_te/train_fsdp2.py | 58 +++++-------------- 1 file changed, 13 insertions(+), 45 deletions(-) diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 3fb3f3c067..636896c2bc 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -106,11 +106,9 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) - if not isinstance(model, torch.nn.Module): - raise TypeError("Expected model_class(config) to return a torch.nn.Module.") - - # Create MixedPrecisionPolicy for FSDP when using FP32 master weights. - # This casts FP32 master weights to BF16 for forward/backward, then back to FP32 for optimizer. + # Create MixedPrecisionPolicy for FSDP when using FP32 master weights + # This casts FP32 master weights to BF16 for forward/backward, then back to FP32 for optimizer + mp_policy = None if use_fp32_master_weights: mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward compute @@ -122,17 +120,11 @@ def main(args: DictConfig) -> float | None: "MixedPrecisionPolicy: param_dtype=bf16, reduce_dtype=fp32, output_dtype=bf16, cast_forward_inputs=True" ) - # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. - # Each decoder layer should be individually sharded before sharding the full model. - for layer in model.model.layers: - fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) - fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) - else: - # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. - # Each decoder layer should be individually sharded before sharding the full model. - for layer in model.model.layers: - fully_shard(layer, mesh=device_mesh["dp"]) - fully_shard(model, mesh=device_mesh["dp"]) + # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. + # Each decoder layer should be individually sharded before sharding the full model. + for layer in model.model.layers: + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) # If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters. if args.use_meta_device and isinstance(model, NVLlamaForCausalLM): @@ -147,11 +139,6 @@ def main(args: DictConfig) -> float | None: if args.fp8_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) - # Log initial parameter dtype before optimizer setup for debugging mixed precision behavior. - first_param = next(model.parameters(), None) - if first_param is not None: - logger.info("Model param dtype before optimizer: %s", first_param.dtype) - # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) @@ -175,7 +162,7 @@ def main(args: DictConfig) -> float | None: scheduler=scheduler, ckpt_path=ckpt_path, dist_config=dist_config, - dataloader=train_dataloader, # type: ignore[arg-type] + dataloader=train_dataloader, process_group=device_mesh.get_group("dp"), ) logger.info(f"Checkpoint loaded, resuming from step {start_step}, epoch {epoch}") @@ -194,17 +181,9 @@ def main(args: DictConfig) -> float | None: step = start_step micro_step = 0 # Gradient accumulation step counter - # Create autocast context for FP32 master weights (casts compute to BF16). - # Allow override via config for debugging (default: enabled when use_fp32_master_weights=True). - use_autocast = getattr(args, "use_autocast", use_fp32_master_weights) - autocast_ctx = torch.autocast("cuda", dtype=torch.bfloat16) if use_autocast else nullcontext() - if use_fp32_master_weights: - logger.info("FP32 master weights: use_autocast=%s", use_autocast) + # Create autocast context for FP32 master weights (casts compute to BF16) + autocast_ctx = torch.autocast("cuda", dtype=torch.bfloat16) if use_fp32_master_weights else nullcontext() - if train_dataloader is None: - raise RuntimeError("Expected train_dataloader to be initialized before training.") - - logged_dtypes = False while step < args.num_train_steps: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 @@ -220,17 +199,6 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss / args.grad_acc_steps loss.backward() - if not logged_dtypes: - grad_param = next((p for p in model.parameters() if p.grad is not None), None) - if grad_param is not None and grad_param.grad is not None: - logger.info( - "Dtypes after first backward: param=%s grad=%s loss=%s", - grad_param.dtype, - grad_param.grad.dtype, - outputs.loss.dtype, - ) - logged_dtypes = True - # Log microbatch step data for accumulation metrics perf_logger.log_micro_step(batch=batch, outputs=outputs) @@ -261,7 +229,7 @@ def main(args: DictConfig) -> float | None: step=step, epoch=epoch, dist_config=dist_config, - dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None, # type: ignore[arg-type] + dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None, process_group=device_mesh.get_group("dp"), max_checkpoints=args.checkpoint.max_checkpoints, async_save=args.checkpoint.async_save, @@ -273,7 +241,7 @@ def main(args: DictConfig) -> float | None: # Dataloader exhausted, incrementing epoch epoch += 1 - dataset_or_sampler.set_epoch(epoch) # type: ignore[attr-defined] + dataset_or_sampler.set_epoch(epoch) # Save final model to a .safetensors file. if args.checkpoint.save_final_model and ckpt_path: From b657ba1dcb94cebb5f85cdba2ef5e4ff350a237e Mon Sep 17 00:00:00 2001 From: Savitha Srinivasan Date: Mon, 26 Jan 2026 20:31:23 +0000 Subject: [PATCH 5/6] remove cast inptus=true Signed-off-by: Savitha Srinivasan --- bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 636896c2bc..23f264e43b 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -114,10 +114,10 @@ def main(args: DictConfig) -> float | None: param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward compute reduce_dtype=torch.float32, # Accumulate gradients in FP32 for precision output_dtype=torch.bfloat16, # Output activations in BF16 - cast_forward_inputs=True, # Cast inputs to param_dtype (BF16) + cast_forward_inputs=False, # Do not cast inputs to param_dtype (BF16) ) logger.info( - "MixedPrecisionPolicy: param_dtype=bf16, reduce_dtype=fp32, output_dtype=bf16, cast_forward_inputs=True" + "MixedPrecisionPolicy: param_dtype=bf16, reduce_dtype=fp32, output_dtype=bf16, cast_forward_inputs=False" ) # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. From 4844d74d18caf40b6902f3a94eb9751ea3339fca Mon Sep 17 00:00:00 2001 From: savitha-eng Date: Wed, 28 Jan 2026 01:13:52 +0000 Subject: [PATCH 6/6] remove output casting Signed-off-by: savitha-eng --- bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 23f264e43b..544c0371e2 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -113,7 +113,6 @@ def main(args: DictConfig) -> float | None: mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward compute reduce_dtype=torch.float32, # Accumulate gradients in FP32 for precision - output_dtype=torch.bfloat16, # Output activations in BF16 cast_forward_inputs=False, # Do not cast inputs to param_dtype (BF16) ) logger.info(