Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ fp8_config:
fp8_model_init_kwargs:
enabled: false # If this is set to true, fp8_config.enabled must also be set to true.

# FP32 Master Weights
# When enabled, model is initialized in FP32 and MixedPrecisionPolicy casts to BF16 for compute.
# This matches Megatron's main_params_dtype=torch.float32 for better numerical stability.
use_fp32_master_weights: false

# Optimizer config
adamw_kwargs:
lr: 3e-3
Expand Down
66 changes: 66 additions & 0 deletions bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,69 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path):
assert fp8_log_dir.exists()
assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_logs" / "nvdlfw_inspect_globalrank-0.log").exists()
assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists()


def test_train_fsdp2_fp32_master_weights(tmp_path, recipe_path):
"""Test FSDP2 training with FP32 master weights and BF16 compute.

This test validates that the MixedPrecisionPolicy correctly:
- Stores master weights in FP32
- Casts to BF16 for forward/backward compute
- Accumulates gradients in FP32
- Training still converges properly
"""
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
config_name="L0_sanity",
overrides=[
f"+wandb.dir={tmp_path}",
f"checkpoint.ckpt_dir={tmp_path}",
"checkpoint.resume_from_checkpoint=false",
"use_fp32_master_weights=true",
"use_torch_compile=false", # Disable compile to simplify debugging
"fp8_config.enabled=false", # Disable FP8 to isolate FP32 master weights
"config_kwargs.attn_input_format=bshd",
],
)

final_loss = main_fsdp2(sanity_config)
gc.collect()
torch.cuda.empty_cache()

# FP32 master weights should achieve same or better convergence
assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0"


def test_train_fsdp2_fp32_master_weights_with_fp8(tmp_path, recipe_path):
"""Test FSDP2 training with FP32 master weights and FP8 compute.

This test validates the full precision hierarchy:
- FP32 master weights (stored by optimizer via MixedPrecisionPolicy)
- BF16 intermediate compute (via torch.autocast)
- FP8 compute for eligible TE layers (via te.fp8_autocast)

This matches Megatron's approach where you "use the FP32 weights to get
a really precise quantization for FP8" (per Cory's Slack comment).
"""
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
config_name="L0_sanity",
overrides=[
f"+wandb.dir={tmp_path}",
f"checkpoint.ckpt_dir={tmp_path}",
"checkpoint.resume_from_checkpoint=false",
"use_fp32_master_weights=true",
"use_torch_compile=false", # Disable compile to simplify debugging
"fp8_config.enabled=true", # Enable FP8 with FP32 master weights
"config_kwargs.attn_input_format=bshd",
# FP8 requires last dim divisible by 16
"+dataset.pad_sequences_to_be_divisible_by=16",
],
)

final_loss = main_fsdp2(sanity_config)
gc.collect()
torch.cuda.empty_cache()

# FP32 master weights + FP8 should achieve good convergence
assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0"
39 changes: 33 additions & 6 deletions bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import transformer_engine.pytorch
from omegaconf import DictConfig, OmegaConf
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
from torch.optim import AdamW
from transformer_engine.common.recipe import Format
from transformers.models.llama.configuration_llama import LlamaConfig
Expand Down Expand Up @@ -83,8 +83,17 @@ def main(args: DictConfig) -> float | None:
config_class = LlamaConfig
model_class = LlamaForCausalLM

# Determine dtype for model initialization
# When use_fp32_master_weights=True, we create the model in FP32 and use MixedPrecisionPolicy
# to cast to BF16 for forward/backward. This matches Megatron's main_params_dtype=torch.float32
use_fp32_master_weights = getattr(args, "use_fp32_master_weights", False)
model_dtype = torch.float32 if use_fp32_master_weights else torch.bfloat16

if use_fp32_master_weights:
logger.info("FP32 master weights enabled: model init in FP32, compute in BF16")

# Create an empty Llama3 model with a causal language model head, e.g. "meta-llama/Meta-Llama-3-8B".
config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
config = config_class.from_pretrained(args.config_name_or_path, dtype=model_dtype, **args.config_kwargs)

# Optionally use transformer engine to initialize only fp8 versions of weights by setting
# `fp8_config.fp8_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8
Expand All @@ -97,11 +106,24 @@ def main(args: DictConfig) -> float | None:

logger.info("Initialized Model:\n%s", model)

# Create MixedPrecisionPolicy for FSDP when using FP32 master weights
# This casts FP32 master weights to BF16 for forward/backward, then back to FP32 for optimizer
mp_policy = None
if use_fp32_master_weights:
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward compute
reduce_dtype=torch.float32, # Accumulate gradients in FP32 for precision
cast_forward_inputs=False, # Do not cast inputs to param_dtype (BF16)
)
logger.info(
"MixedPrecisionPolicy: param_dtype=bf16, reduce_dtype=fp32, output_dtype=bf16, cast_forward_inputs=False"
)

# Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers.
# Each decoder layer should be individually sharded before sharding the full model.
for layer in model.model.layers:
fully_shard(layer, mesh=device_mesh["dp"])
fully_shard(model, mesh=device_mesh["dp"])
fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy)
fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy)

# If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters.
if args.use_meta_device and isinstance(model, NVLlamaForCausalLM):
Expand Down Expand Up @@ -157,15 +179,20 @@ def main(args: DictConfig) -> float | None:
logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}")
step = start_step
micro_step = 0 # Gradient accumulation step counter

# Create autocast context for FP32 master weights (casts compute to BF16)
autocast_ctx = torch.autocast("cuda", dtype=torch.bfloat16) if use_fp32_master_weights else nullcontext()

while step < args.num_train_steps:
for batch in train_dataloader:
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901

micro_step += 1

# Forward pass with mixed precision.
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
outputs = model(**batch)
with autocast_ctx:
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
outputs = model(**batch)

# Backward pass - scale loss by grad_acc_steps for proper gradient averaging
loss = outputs.loss / args.grad_acc_steps
Expand Down