Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .devcontainer/recipes/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ torchdata
torchmetrics
tqdm
transformer_engine
transformers<5.0
transformers
typer
wandb
zstandard
Expand Down
2 changes: 1 addition & 1 deletion bionemo-recipes/models/amplify/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"pytest",
"torch==2.6.0a0+ecf3bae40a.nv25.01",
"transformer_engine[pytorch]",
"transformers<5.0",
"transformers<5.0", # TODO(BIO-143): update AMPLIFY to support Transformers v5
"xformers",
]

Expand Down
2 changes: 1 addition & 1 deletion bionemo-recipes/models/esm2/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
"torch",
"torchao!=0.14.0",
"transformer_engine[pytorch]",
"transformers<5.0",
"transformers",
]


Expand Down
2 changes: 1 addition & 1 deletion bionemo-recipes/models/esm2/src/esm/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module:
],
)

output_model.tie_weights()
output_model.post_init()

# Note: contact_head parameters are not preserved in TE models
# They are lost during HF -> TE conversion and cannot be recovered
Expand Down
38 changes: 30 additions & 8 deletions bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Adapted from `modeling_esm.py` in huggingface/transformers.
"""

from typing import Literal, Optional, Unpack
from typing import ClassVar, Literal, Optional, Unpack

# TODO: put import guard around transformer_engine here, with an informative error message around
# installation and the nvidia docker container.
Expand Down Expand Up @@ -256,10 +256,34 @@ def init_empty_weights(self):
# Meta-device init seems to break weight tying, so we re-tie the weights here.
self.tie_weights()

@classmethod
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
"""Override the default get_init_context method to allow for fp8 model initialization."""
return []
def _init_weights(self, module):
"""Initialize module weights.

We only use this method for standard pytorch modules, TE modules handle their own weight initialization through
`init_method` parameters and the `reset_parameters` method.
"""
if module.__module__.startswith("transformer_engine.pytorch"):
# Notably, we need to avoid calling the parent method for TE modules, since the default _init_weights will
# assume any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking
# `LayerNormLinear` and `LayerNormMLP` modules that use `weight` for the linear layer and
# `layer_norm_weight` for the layer norm. Instead, we call `reset_parameters` if the module has it and the
# weights are not in fp8. We still need to figure out why this raises an error if we're using
# `quantized_model_init`.
if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False):
module.reset_parameters()
return

super()._init_weights(module)

def state_dict(self, *args, **kwargs):
"""Override state_dict to filter out TransformerEngine's _extra_state keys.

TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading.
These are filtered out to ensure checkpoints can be loaded with from_pretrained().
"""
state_dict = super().state_dict(*args, **kwargs)
# Filter out _extra_state keys which are TransformerEngine-specific and not loadable
return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")}


class NVEsmModel(NVEsmPreTrainedModel):
Expand Down Expand Up @@ -367,7 +391,7 @@ def forward(
class NVEsmForMaskedLM(NVEsmPreTrainedModel):
"""NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling."""

_tied_weights_keys = ("lm_head.decoder.weight",)
_tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"}

def __init__(self, config: NVEsmConfig):
"""Initialize a NVEsmForMaskedLM.
Expand All @@ -386,7 +410,6 @@ def __init__(self, config: NVEsmConfig):
self.esm = NVEsmModel(config, add_pooling_layer=False)
self.lm_head = NVEsmLMHead(config)

self.init_weights()
self.post_init()

def get_output_embeddings(self):
Expand Down Expand Up @@ -614,7 +637,6 @@ def __init__(self, config):
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
)

self.init_weights()
self.post_init()

def forward(
Expand Down
11 changes: 11 additions & 0 deletions bionemo-recipes/models/esm2/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import importlib
import os
import socket

import pytest
import transformer_engine.pytorch
Expand All @@ -32,6 +33,16 @@
os.environ["TRITON_LIBCUDA_PATH"] = "/usr/local/cuda/lib64"


@pytest.fixture
def unused_tcp_port():
"""Find and return an unused TCP port for torchrun rendezvous."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
return port


@pytest.fixture(autouse=True)
def use_te_debug(monkeypatch):
monkeypatch.setenv("NVTE_DEBUG", "1")
Expand Down
5 changes: 5 additions & 0 deletions bionemo-recipes/models/esm2/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def test_convert_te_to_hf_roundtrip():
torch.testing.assert_close(original_state_dict[key], converted_state_dict[key], atol=1e-5, rtol=1e-5)


def test_load_from_converted_checkpoint(te_model_checkpoint):
from esm.modeling_esm_te import NVEsmForMaskedLM
NVEsmForMaskedLM.from_pretrained(te_model_checkpoint)


def test_qkv_unpacking():
"""Test that QKV unpacking works correctly."""
from esm.convert import convert_esm_hf_to_te, convert_esm_te_to_hf
Expand Down
21 changes: 14 additions & 7 deletions bionemo-recipes/models/esm2/tests/test_distributed_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def requires_fp8(func):
"strategy", ["ddp", "fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2999"))]
)
@requires_fp8
def test_single_process_attaches_correct_fp8_recipe(strategy):
def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port):
cmd = [
"torchrun",
"--nproc_per_node=1",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
os.path.relpath(__file__),
"--strategy",
strategy,
Expand All @@ -66,10 +68,12 @@ def test_single_process_attaches_correct_fp8_recipe(strategy):
)
@requires_fp8
@requires_multi_gpu
def test_multi_process_fp8_recipes_are_synced(strategy):
def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port):
cmd = [
"torchrun",
"--nproc_per_node=2",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
os.path.relpath(__file__),
"--strategy",
strategy,
Expand Down Expand Up @@ -207,11 +211,14 @@ def is_main_process(self) -> bool:

outputs.loss.backward()

fp8_extra_states = {key: val for key, val in model.state_dict().items() if key.endswith("_extra_state")}

# For some reason, this one doesn't get an fp8 recipe? It's the only te.LayerNorm.
key = filter(lambda x: x.endswith("encoder.emb_layer_norm_after._extra_state"), fp8_extra_states.keys())
fp8_extra_states.pop(next(key))
# Access FP8 extra states directly from modules instead of state_dict()
# since state_dict() now filters them out for HuggingFace compatibility
fp8_extra_states = {}
for name, module in model.named_modules():
if hasattr(module, "_extra_state") and callable(module._extra_state):
extra_state = module._extra_state()
if extra_state is not None and len(extra_state) > 0:
fp8_extra_states[f"{name}._extra_state"] = extra_state

# lm_head.dense and lm_head.decoder are BF16, not FP8, so exclude them from FP8 checks
fp8_extra_states = {key: val for key, val in fp8_extra_states.items() if "lm_head." not in key}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@
],
)
@pytest.mark.parametrize("backend", ["te", "eager"])
def test_ddp_vs_fsdp_single_gpu(strategy, backend):
def test_ddp_vs_fsdp_single_gpu(strategy, backend, unused_tcp_port):
cmd = [
"torchrun",
"--nproc_per_node=1",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
os.path.relpath(__file__),
"--strategy",
strategy,
Expand All @@ -69,10 +71,12 @@ def test_ddp_vs_fsdp_single_gpu(strategy, backend):
@requires_multi_gpu
@pytest.mark.parametrize("strategy", ["fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2726"))])
@pytest.mark.parametrize("backend", ["te", "eager"])
def test_ddp_vs_fsdp_multi_gpu(strategy, backend):
def test_ddp_vs_fsdp_multi_gpu(strategy, backend, unused_tcp_port):
cmd = [
"torchrun",
"--nproc_per_node=2",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
os.path.relpath(__file__),
"--strategy",
strategy,
Expand Down Expand Up @@ -160,20 +164,28 @@ def is_main_process(self) -> bool:
return self.rank == 0

def run_forward_backward(use_te: bool, strategy: Strategy, input_data: dict, dist_config: DistributedConfig):
# Set seed for reproducible model initialization across strategies
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

device_mesh = init_device_mesh(
"cuda",
mesh_shape=(dist_config.world_size,),
mesh_dim_names=("dp",),
mesh_shape=(dist_config.world_size, 1),
mesh_dim_names=("dp", "tp"), # mfsdp requires us to give a tp mesh dimension.
)

device = f"cuda:{dist_config.local_rank}"

if use_te:
model = AutoModelForMaskedLM.from_pretrained(
"nvidia/esm2_t6_8M_UR50D",
# Import local model classes to avoid using outdated code from HF Hub
from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM

config = NVEsmConfig.from_pretrained(
"facebook/esm2_t6_8M_UR50D",
dtype=torch.bfloat16,
trust_remote_code=True,
revision="c731040f",
)
model = NVEsmForMaskedLM(config)
transformer_layers = model.esm.encoder.layers
else:
model = AutoModelForMaskedLM.from_pretrained(
Expand Down
24 changes: 13 additions & 11 deletions bionemo-recipes/models/esm2/tests/test_meta_device_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@
)


def msg(x):
return f"Mismatch in module {name}: {x}"


def verify_model_parameters_initialized_correctly(
model: NVEsmForMaskedLM, atol=1e-3, rtol=1e-4, should_be_fp8: bool = False
):
Expand All @@ -57,6 +53,10 @@ def verify_model_parameters_initialized_correctly(
assert str(parameter.device).startswith("cuda"), f"Parameter {name} is not on the cuda device"

for name, module in model.named_modules():

def msg(x):
return f"Mismatch in module {name}: {x}"

if isinstance(module, torch.nn.Embedding):
torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg)
torch.testing.assert_close(
Expand Down Expand Up @@ -118,8 +118,12 @@ def verify_model_parameters_initialized_correctly(
torch.testing.assert_close(module.inv_freq, expected_inv_freq, msg=msg)


def verify_pretrained_model_sanity(model: NVEsmForTokenClassification, atol=1e-3, rtol=1e-4):
def verify_pretrained_model_sanity(model: NVEsmForTokenClassification, atol=1e-2, rtol=1e-3):
for name, p in model.named_parameters():

def msg(x):
return f"Mismatch in parameter {name}: {x}"

assert p.numel() > 0, f"{name} is empty"
assert torch.isfinite(p).all(), f"{name} has NaN/Inf"

Expand Down Expand Up @@ -187,14 +191,12 @@ def test_meta_fp8_init(fp8_recipe):


def test_model_for_token_classification_init(te_model_checkpoint):
config = NVEsmConfig.from_pretrained(te_model_checkpoint, trust_remote_code=True)

set_seed(42)
model = NVEsmForTokenClassification.from_pretrained(
te_model_checkpoint, config=config, dtype=torch.bfloat16, trust_remote_code=True
)
model.to("cuda")

config = NVEsmConfig.from_pretrained(te_model_checkpoint)
model = NVEsmForTokenClassification.from_pretrained(te_model_checkpoint, config=config, dtype=torch.bfloat16)
# model.classifier.reset_parameters()
model.to("cuda")
verify_pretrained_model_sanity(model)


Expand Down
15 changes: 13 additions & 2 deletions bionemo-recipes/models/llama3/modeling_llama_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

from collections import OrderedDict
from typing import Unpack
from typing import ClassVar, Unpack

import torch
import torch.nn as nn
Expand Down Expand Up @@ -88,6 +88,17 @@ def _init_weights(self, module):

super()._init_weights(module)

def state_dict(self, *args, **kwargs):
"""Override state_dict to filter out TransformerEngine's _extra_state keys.

TransformerEngine layers add _extra_state attributes that are not compatible with
standard PyTorch/HuggingFace model loading. These are filtered out to ensure
checkpoints can be loaded with from_pretrained().
"""
state_dict = super().state_dict(*args, **kwargs)
# Filter out _extra_state keys which are TransformerEngine-specific and not loadable
return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")}


class NVLlamaModel(NVLlamaPreTrainedModel):
"""Llama3 model implemented in Transformer Engine."""
Expand Down Expand Up @@ -260,7 +271,7 @@ def forward(
class NVLlamaForCausalLM(NVLlamaPreTrainedModel, transformers.GenerationMixin):
"""Llama3 model with causal language head."""

_tied_weights_keys = ("lm_head.weight",)
_tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.weight": "model.embed_tokens.weight"}

def __init__(self, config):
"""Initialize the NVLlamaForCausalLM model."""
Expand Down
2 changes: 1 addition & 1 deletion bionemo-recipes/models/llama3/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ lm-eval # For testing
torch
torchao!=0.14.0
transformer_engine[pytorch]
transformers<5.0
transformers
11 changes: 11 additions & 0 deletions bionemo-recipes/models/llama3/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import socket
import sys
from pathlib import Path

Expand All @@ -25,6 +26,16 @@
sys.path.append(Path(__file__).parent.as_posix())


@pytest.fixture
def unused_tcp_port():
"""Find and return an unused TCP port for torchrun rendezvous."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
return port


@pytest.fixture(scope="session")
def recipe_path() -> Path:
"""Return the root directory of the recipe."""
Expand Down
Loading
Loading