diff --git a/.devcontainer/recipes/requirements.txt b/.devcontainer/recipes/requirements.txt index 585c6e3731..da0dd58c8a 100644 --- a/.devcontainer/recipes/requirements.txt +++ b/.devcontainer/recipes/requirements.txt @@ -12,7 +12,7 @@ torchdata torchmetrics tqdm transformer_engine -transformers<5.0 +transformers typer wandb zstandard diff --git a/bionemo-recipes/models/amplify/pyproject.toml b/bionemo-recipes/models/amplify/pyproject.toml index 36a936acd6..aef2a73636 100644 --- a/bionemo-recipes/models/amplify/pyproject.toml +++ b/bionemo-recipes/models/amplify/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "pytest", "torch==2.6.0a0+ecf3bae40a.nv25.01", "transformer_engine[pytorch]", - "transformers<5.0", + "transformers<5.0", # TODO(BIO-143): update AMPLIFY to support Transformers v5 "xformers", ] diff --git a/bionemo-recipes/models/esm2/pyproject.toml b/bionemo-recipes/models/esm2/pyproject.toml index dd55bcfcb2..6fe871a015 100644 --- a/bionemo-recipes/models/esm2/pyproject.toml +++ b/bionemo-recipes/models/esm2/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "torch", "torchao!=0.14.0", "transformer_engine[pytorch]", - "transformers<5.0", + "transformers", ] diff --git a/bionemo-recipes/models/esm2/src/esm/convert.py b/bionemo-recipes/models/esm2/src/esm/convert.py index 22fc1dcde1..0406e1bbb2 100644 --- a/bionemo-recipes/models/esm2/src/esm/convert.py +++ b/bionemo-recipes/models/esm2/src/esm/convert.py @@ -131,7 +131,7 @@ def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module: ], ) - output_model.tie_weights() + output_model.post_init() # Note: contact_head parameters are not preserved in TE models # They are lost during HF -> TE conversion and cannot be recovered diff --git a/bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py b/bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py index 40557a14bf..00fdf23128 100644 --- a/bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py +++ b/bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py @@ -22,7 +22,7 @@ Adapted from `modeling_esm.py` in huggingface/transformers. """ -from typing import Literal, Optional, Unpack +from typing import ClassVar, Literal, Optional, Unpack # TODO: put import guard around transformer_engine here, with an informative error message around # installation and the nvidia docker container. @@ -256,10 +256,34 @@ def init_empty_weights(self): # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() - @classmethod - def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): - """Override the default get_init_context method to allow for fp8 model initialization.""" - return [] + def _init_weights(self, module): + """Initialize module weights. + + We only use this method for standard pytorch modules, TE modules handle their own weight initialization through + `init_method` parameters and the `reset_parameters` method. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + # Notably, we need to avoid calling the parent method for TE modules, since the default _init_weights will + # assume any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking + # `LayerNormLinear` and `LayerNormMLP` modules that use `weight` for the linear layer and + # `layer_norm_weight` for the layer norm. Instead, we call `reset_parameters` if the module has it and the + # weights are not in fp8. We still need to figure out why this raises an error if we're using + # `quantized_model_init`. + if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False): + module.reset_parameters() + return + + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys. + + TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. + These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Filter out _extra_state keys which are TransformerEngine-specific and not loadable + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} class NVEsmModel(NVEsmPreTrainedModel): @@ -367,7 +391,7 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys = ("lm_head.decoder.weight",) + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -386,7 +410,6 @@ def __init__(self, config: NVEsmConfig): self.esm = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) - self.init_weights() self.post_init() def get_output_embeddings(self): @@ -614,7 +637,6 @@ def __init__(self, config): init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), ) - self.init_weights() self.post_init() def forward( diff --git a/bionemo-recipes/models/esm2/tests/conftest.py b/bionemo-recipes/models/esm2/tests/conftest.py index c27c66b8df..57aeadda84 100644 --- a/bionemo-recipes/models/esm2/tests/conftest.py +++ b/bionemo-recipes/models/esm2/tests/conftest.py @@ -15,6 +15,7 @@ import importlib import os +import socket import pytest import transformer_engine.pytorch @@ -32,6 +33,16 @@ os.environ["TRITON_LIBCUDA_PATH"] = "/usr/local/cuda/lib64" +@pytest.fixture +def unused_tcp_port(): + """Find and return an unused TCP port for torchrun rendezvous.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + @pytest.fixture(autouse=True) def use_te_debug(monkeypatch): monkeypatch.setenv("NVTE_DEBUG", "1") diff --git a/bionemo-recipes/models/esm2/tests/test_convert.py b/bionemo-recipes/models/esm2/tests/test_convert.py index 92df86b0f0..644f7fdcd6 100644 --- a/bionemo-recipes/models/esm2/tests/test_convert.py +++ b/bionemo-recipes/models/esm2/tests/test_convert.py @@ -38,6 +38,11 @@ def test_convert_te_to_hf_roundtrip(): torch.testing.assert_close(original_state_dict[key], converted_state_dict[key], atol=1e-5, rtol=1e-5) +def test_load_from_converted_checkpoint(te_model_checkpoint): + from esm.modeling_esm_te import NVEsmForMaskedLM + NVEsmForMaskedLM.from_pretrained(te_model_checkpoint) + + def test_qkv_unpacking(): """Test that QKV unpacking works correctly.""" from esm.convert import convert_esm_hf_to_te, convert_esm_te_to_hf diff --git a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py index ff9832f5dd..c0817324ac 100644 --- a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py +++ b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py @@ -38,10 +38,12 @@ def requires_fp8(func): "strategy", ["ddp", "fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2999"))] ) @requires_fp8 -def test_single_process_attaches_correct_fp8_recipe(strategy): +def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port): cmd = [ "torchrun", "--nproc_per_node=1", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", os.path.relpath(__file__), "--strategy", strategy, @@ -66,10 +68,12 @@ def test_single_process_attaches_correct_fp8_recipe(strategy): ) @requires_fp8 @requires_multi_gpu -def test_multi_process_fp8_recipes_are_synced(strategy): +def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port): cmd = [ "torchrun", "--nproc_per_node=2", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", os.path.relpath(__file__), "--strategy", strategy, @@ -207,11 +211,14 @@ def is_main_process(self) -> bool: outputs.loss.backward() - fp8_extra_states = {key: val for key, val in model.state_dict().items() if key.endswith("_extra_state")} - - # For some reason, this one doesn't get an fp8 recipe? It's the only te.LayerNorm. - key = filter(lambda x: x.endswith("encoder.emb_layer_norm_after._extra_state"), fp8_extra_states.keys()) - fp8_extra_states.pop(next(key)) + # Access FP8 extra states directly from modules instead of state_dict() + # since state_dict() now filters them out for HuggingFace compatibility + fp8_extra_states = {} + for name, module in model.named_modules(): + if hasattr(module, "_extra_state") and callable(module._extra_state): + extra_state = module._extra_state() + if extra_state is not None and len(extra_state) > 0: + fp8_extra_states[f"{name}._extra_state"] = extra_state # lm_head.dense and lm_head.decoder are BF16, not FP8, so exclude them from FP8 checks fp8_extra_states = {key: val for key, val in fp8_extra_states.items() if "lm_head." not in key} diff --git a/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py b/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py index f893ab2d4b..c208ec869a 100644 --- a/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py +++ b/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py @@ -41,10 +41,12 @@ ], ) @pytest.mark.parametrize("backend", ["te", "eager"]) -def test_ddp_vs_fsdp_single_gpu(strategy, backend): +def test_ddp_vs_fsdp_single_gpu(strategy, backend, unused_tcp_port): cmd = [ "torchrun", "--nproc_per_node=1", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", os.path.relpath(__file__), "--strategy", strategy, @@ -69,10 +71,12 @@ def test_ddp_vs_fsdp_single_gpu(strategy, backend): @requires_multi_gpu @pytest.mark.parametrize("strategy", ["fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2726"))]) @pytest.mark.parametrize("backend", ["te", "eager"]) -def test_ddp_vs_fsdp_multi_gpu(strategy, backend): +def test_ddp_vs_fsdp_multi_gpu(strategy, backend, unused_tcp_port): cmd = [ "torchrun", "--nproc_per_node=2", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", os.path.relpath(__file__), "--strategy", strategy, @@ -160,20 +164,28 @@ def is_main_process(self) -> bool: return self.rank == 0 def run_forward_backward(use_te: bool, strategy: Strategy, input_data: dict, dist_config: DistributedConfig): + # Set seed for reproducible model initialization across strategies + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + device_mesh = init_device_mesh( "cuda", - mesh_shape=(dist_config.world_size,), - mesh_dim_names=("dp",), + mesh_shape=(dist_config.world_size, 1), + mesh_dim_names=("dp", "tp"), # mfsdp requires us to give a tp mesh dimension. ) device = f"cuda:{dist_config.local_rank}" if use_te: - model = AutoModelForMaskedLM.from_pretrained( - "nvidia/esm2_t6_8M_UR50D", + # Import local model classes to avoid using outdated code from HF Hub + from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM + + config = NVEsmConfig.from_pretrained( + "facebook/esm2_t6_8M_UR50D", dtype=torch.bfloat16, - trust_remote_code=True, + revision="c731040f", ) + model = NVEsmForMaskedLM(config) transformer_layers = model.esm.encoder.layers else: model = AutoModelForMaskedLM.from_pretrained( diff --git a/bionemo-recipes/models/esm2/tests/test_meta_device_init.py b/bionemo-recipes/models/esm2/tests/test_meta_device_init.py index a8759d1ec4..82a10389f2 100644 --- a/bionemo-recipes/models/esm2/tests/test_meta_device_init.py +++ b/bionemo-recipes/models/esm2/tests/test_meta_device_init.py @@ -44,10 +44,6 @@ ) -def msg(x): - return f"Mismatch in module {name}: {x}" - - def verify_model_parameters_initialized_correctly( model: NVEsmForMaskedLM, atol=1e-3, rtol=1e-4, should_be_fp8: bool = False ): @@ -57,6 +53,10 @@ def verify_model_parameters_initialized_correctly( assert str(parameter.device).startswith("cuda"), f"Parameter {name} is not on the cuda device" for name, module in model.named_modules(): + + def msg(x): + return f"Mismatch in module {name}: {x}" + if isinstance(module, torch.nn.Embedding): torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) torch.testing.assert_close( @@ -118,8 +118,12 @@ def verify_model_parameters_initialized_correctly( torch.testing.assert_close(module.inv_freq, expected_inv_freq, msg=msg) -def verify_pretrained_model_sanity(model: NVEsmForTokenClassification, atol=1e-3, rtol=1e-4): +def verify_pretrained_model_sanity(model: NVEsmForTokenClassification, atol=1e-2, rtol=1e-3): for name, p in model.named_parameters(): + + def msg(x): + return f"Mismatch in parameter {name}: {x}" + assert p.numel() > 0, f"{name} is empty" assert torch.isfinite(p).all(), f"{name} has NaN/Inf" @@ -187,14 +191,12 @@ def test_meta_fp8_init(fp8_recipe): def test_model_for_token_classification_init(te_model_checkpoint): - config = NVEsmConfig.from_pretrained(te_model_checkpoint, trust_remote_code=True) - set_seed(42) - model = NVEsmForTokenClassification.from_pretrained( - te_model_checkpoint, config=config, dtype=torch.bfloat16, trust_remote_code=True - ) - model.to("cuda") + config = NVEsmConfig.from_pretrained(te_model_checkpoint) + model = NVEsmForTokenClassification.from_pretrained(te_model_checkpoint, config=config, dtype=torch.bfloat16) + # model.classifier.reset_parameters() + model.to("cuda") verify_pretrained_model_sanity(model) diff --git a/bionemo-recipes/models/llama3/modeling_llama_te.py b/bionemo-recipes/models/llama3/modeling_llama_te.py index 358f023e9a..8ed5351dd2 100644 --- a/bionemo-recipes/models/llama3/modeling_llama_te.py +++ b/bionemo-recipes/models/llama3/modeling_llama_te.py @@ -14,7 +14,7 @@ # limitations under the License. from collections import OrderedDict -from typing import Unpack +from typing import ClassVar, Unpack import torch import torch.nn as nn @@ -88,6 +88,17 @@ def _init_weights(self, module): super()._init_weights(module) + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys. + + TransformerEngine layers add _extra_state attributes that are not compatible with + standard PyTorch/HuggingFace model loading. These are filtered out to ensure + checkpoints can be loaded with from_pretrained(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Filter out _extra_state keys which are TransformerEngine-specific and not loadable + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + class NVLlamaModel(NVLlamaPreTrainedModel): """Llama3 model implemented in Transformer Engine.""" @@ -260,7 +271,7 @@ def forward( class NVLlamaForCausalLM(NVLlamaPreTrainedModel, transformers.GenerationMixin): """Llama3 model with causal language head.""" - _tied_weights_keys = ("lm_head.weight",) + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): """Initialize the NVLlamaForCausalLM model.""" diff --git a/bionemo-recipes/models/llama3/requirements.txt b/bionemo-recipes/models/llama3/requirements.txt index ee5e991fb5..ec6a547cb8 100644 --- a/bionemo-recipes/models/llama3/requirements.txt +++ b/bionemo-recipes/models/llama3/requirements.txt @@ -2,4 +2,4 @@ lm-eval # For testing torch torchao!=0.14.0 transformer_engine[pytorch] -transformers<5.0 +transformers diff --git a/bionemo-recipes/models/llama3/tests/conftest.py b/bionemo-recipes/models/llama3/tests/conftest.py index 98c0f2307e..5cf9deeff9 100644 --- a/bionemo-recipes/models/llama3/tests/conftest.py +++ b/bionemo-recipes/models/llama3/tests/conftest.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import socket import sys from pathlib import Path @@ -25,6 +26,16 @@ sys.path.append(Path(__file__).parent.as_posix()) +@pytest.fixture +def unused_tcp_port(): + """Find and return an unused TCP port for torchrun rendezvous.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + @pytest.fixture(scope="session") def recipe_path() -> Path: """Return the root directory of the recipe.""" diff --git a/bionemo-recipes/models/llama3/tests/test_cp_bshd.py b/bionemo-recipes/models/llama3/tests/test_cp_bshd.py index dcdb41b45c..c91f2c1451 100644 --- a/bionemo-recipes/models/llama3/tests/test_cp_bshd.py +++ b/bionemo-recipes/models/llama3/tests/test_cp_bshd.py @@ -15,10 +15,15 @@ import os import subprocess +import sys import tempfile from dataclasses import dataclass, field from pathlib import Path + +# Add parent directory to sys.path so we can import from the model package +sys.path.insert(0, str(Path(__file__).parent.parent)) + import pytest import torch from torch.distributed.device_mesh import init_device_mesh @@ -140,11 +145,13 @@ def is_main_process(self) -> bool: @skip_in_ci -def test_context_parallel_equivalence_1process(recipe_path: Path): +def test_context_parallel_equivalence_1process(recipe_path: Path, unused_tcp_port): """Test that context parallelism works with 1 process, verifying results match non-distributed run.""" cmd = [ "torchrun", "--nproc_per_node=1", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", os.path.relpath(__file__), ] result = subprocess.run( @@ -164,7 +171,7 @@ def test_context_parallel_equivalence_1process(recipe_path: Path): @skip_in_ci @requires_multi_gpu -def test_context_parallel_equivalence_2process(recipe_path: Path): +def test_context_parallel_equivalence_2process(recipe_path: Path, unused_tcp_port): """Test context parallel equivalence between 2 processes. In one instance, we run the model in non-distributed mode and in the other @@ -176,6 +183,8 @@ def test_context_parallel_equivalence_2process(recipe_path: Path): cmd = [ "torchrun", "--nproc_per_node=2", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", os.path.relpath(__file__), ] result = subprocess.run( diff --git a/bionemo-recipes/models/llama3/tests/test_cp_thd.py b/bionemo-recipes/models/llama3/tests/test_cp_thd.py index b30067caae..bc472fc2aa 100644 --- a/bionemo-recipes/models/llama3/tests/test_cp_thd.py +++ b/bionemo-recipes/models/llama3/tests/test_cp_thd.py @@ -15,10 +15,15 @@ import os import subprocess +import sys import tempfile from dataclasses import dataclass, field from pathlib import Path + +# Add parent directory to sys.path so we can import from the model package +sys.path.insert(0, str(Path(__file__).parent.parent)) + import pytest import torch from torch.distributed.device_mesh import init_device_mesh @@ -170,7 +175,7 @@ def is_main_process(self) -> bool: @skip_in_ci @requires_multi_gpu @requires_datacenter_hardware -def test_context_parallel_equivalence_2process(recipe_path: Path): +def test_context_parallel_equivalence_2process(recipe_path: Path, unused_tcp_port): """Test context parallel equivalence between 2 processes. In one instance, we run the model in non-distributed mode and in the other @@ -182,6 +187,8 @@ def test_context_parallel_equivalence_2process(recipe_path: Path): cmd = [ "torchrun", "--nproc_per_node=2", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", os.path.relpath(__file__), ] result = subprocess.run( diff --git a/bionemo-recipes/models/llama3/tests/test_meta_device_init.py b/bionemo-recipes/models/llama3/tests/test_meta_device_init.py index c9c55078d5..0d880cb093 100644 --- a/bionemo-recipes/models/llama3/tests/test_meta_device_init.py +++ b/bionemo-recipes/models/llama3/tests/test_meta_device_init.py @@ -199,10 +199,12 @@ def test_meta_fp8_init(fp8_recipe): @pytest.mark.parametrize("num_gpus", [1, pytest.param(2, marks=requires_multi_gpu)]) -def test_meta_device_init_after_fully_shard(num_gpus: int): +def test_meta_device_init_after_fully_shard(num_gpus: int, unused_tcp_port): cmd = [ "torchrun", f"--nproc_per_node={num_gpus}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", os.path.relpath(__file__), ] diff --git a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py index 40557a14bf..00fdf23128 100644 --- a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py @@ -22,7 +22,7 @@ Adapted from `modeling_esm.py` in huggingface/transformers. """ -from typing import Literal, Optional, Unpack +from typing import ClassVar, Literal, Optional, Unpack # TODO: put import guard around transformer_engine here, with an informative error message around # installation and the nvidia docker container. @@ -256,10 +256,34 @@ def init_empty_weights(self): # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() - @classmethod - def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): - """Override the default get_init_context method to allow for fp8 model initialization.""" - return [] + def _init_weights(self, module): + """Initialize module weights. + + We only use this method for standard pytorch modules, TE modules handle their own weight initialization through + `init_method` parameters and the `reset_parameters` method. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + # Notably, we need to avoid calling the parent method for TE modules, since the default _init_weights will + # assume any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking + # `LayerNormLinear` and `LayerNormMLP` modules that use `weight` for the linear layer and + # `layer_norm_weight` for the layer norm. Instead, we call `reset_parameters` if the module has it and the + # weights are not in fp8. We still need to figure out why this raises an error if we're using + # `quantized_model_init`. + if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False): + module.reset_parameters() + return + + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys. + + TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. + These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Filter out _extra_state keys which are TransformerEngine-specific and not loadable + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} class NVEsmModel(NVEsmPreTrainedModel): @@ -367,7 +391,7 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys = ("lm_head.decoder.weight",) + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -386,7 +410,6 @@ def __init__(self, config: NVEsmConfig): self.esm = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) - self.init_weights() self.post_init() def get_output_embeddings(self): @@ -614,7 +637,6 @@ def __init__(self, config): init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), ) - self.init_weights() self.post_init() def forward( diff --git a/bionemo-recipes/recipes/esm2_accelerate_te/requirements.txt b/bionemo-recipes/recipes/esm2_accelerate_te/requirements.txt index 3e9794fb52..b23f756164 100644 --- a/bionemo-recipes/recipes/esm2_accelerate_te/requirements.txt +++ b/bionemo-recipes/recipes/esm2_accelerate_te/requirements.txt @@ -5,5 +5,5 @@ hydra-core torchao!=0.14.0 torchmetrics transformer_engine[pytorch] -transformers<5.0 +transformers wandb diff --git a/bionemo-recipes/recipes/esm2_native_te/checkpoint.py b/bionemo-recipes/recipes/esm2_native_te/checkpoint.py index 9709e3e2d4..a168eeeb89 100644 --- a/bionemo-recipes/recipes/esm2_native_te/checkpoint.py +++ b/bionemo-recipes/recipes/esm2_native_te/checkpoint.py @@ -127,7 +127,7 @@ def load_checkpoint_ddp( weights_only=True, ) - model.load_state_dict(checkpoint["model"]) + model.load_state_dict(checkpoint["model"], strict=False) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) dataloader = load_dataloader(dataloader, checkpoint_path, dist_config) diff --git a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py index 40557a14bf..00fdf23128 100644 --- a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py @@ -22,7 +22,7 @@ Adapted from `modeling_esm.py` in huggingface/transformers. """ -from typing import Literal, Optional, Unpack +from typing import ClassVar, Literal, Optional, Unpack # TODO: put import guard around transformer_engine here, with an informative error message around # installation and the nvidia docker container. @@ -256,10 +256,34 @@ def init_empty_weights(self): # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() - @classmethod - def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): - """Override the default get_init_context method to allow for fp8 model initialization.""" - return [] + def _init_weights(self, module): + """Initialize module weights. + + We only use this method for standard pytorch modules, TE modules handle their own weight initialization through + `init_method` parameters and the `reset_parameters` method. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + # Notably, we need to avoid calling the parent method for TE modules, since the default _init_weights will + # assume any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking + # `LayerNormLinear` and `LayerNormMLP` modules that use `weight` for the linear layer and + # `layer_norm_weight` for the layer norm. Instead, we call `reset_parameters` if the module has it and the + # weights are not in fp8. We still need to figure out why this raises an error if we're using + # `quantized_model_init`. + if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False): + module.reset_parameters() + return + + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys. + + TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. + These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Filter out _extra_state keys which are TransformerEngine-specific and not loadable + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} class NVEsmModel(NVEsmPreTrainedModel): @@ -367,7 +391,7 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys = ("lm_head.decoder.weight",) + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -386,7 +410,6 @@ def __init__(self, config: NVEsmConfig): self.esm = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) - self.init_weights() self.post_init() def get_output_embeddings(self): @@ -614,7 +637,6 @@ def __init__(self, config): init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), ) - self.init_weights() self.post_init() def forward( diff --git a/bionemo-recipes/recipes/esm2_native_te/requirements.txt b/bionemo-recipes/recipes/esm2_native_te/requirements.txt index 438cd3dfa8..b18607fd7a 100644 --- a/bionemo-recipes/recipes/esm2_native_te/requirements.txt +++ b/bionemo-recipes/recipes/esm2_native_te/requirements.txt @@ -8,6 +8,6 @@ torchdata torchmetrics tqdm transformer_engine[pytorch] -transformers<5.0 +transformers wandb nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py index 7102fd4446..a8b8afc6af 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py @@ -69,7 +69,7 @@ def test_stop_and_go_checkpointing_and_dataloader_restoration_single_gpu(tmp_pat ) # Setup the model - config = AutoConfig.from_pretrained("nvidia/esm2_t6_8M_UR50D", trust_remote_code=True, dtype=torch.bfloat16) + config = AutoConfig.from_pretrained("example_8m_checkpoint", trust_remote_code=True, dtype=torch.bfloat16) model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) # The huggingface model has a contact head that we don't use in masked language pre-training, so we delete it to @@ -153,7 +153,7 @@ def test_stop_and_go_checkpointing_and_dataloader_restoration_single_gpu(tmp_pat torch.save(batch, f"{step10_path_reference}_batch.pt") torch.save(grads, f"{step10_path_reference}_grads.pt") # Create fresh model, optimizer, scheduler for the resume test - config = AutoConfig.from_pretrained("nvidia/esm2_t6_8M_UR50D", trust_remote_code=True, dtype=torch.bfloat16) + config = AutoConfig.from_pretrained("example_8m_checkpoint", trust_remote_code=True, dtype=torch.bfloat16) resumed_model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) try: diff --git a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py index 40557a14bf..00fdf23128 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py @@ -22,7 +22,7 @@ Adapted from `modeling_esm.py` in huggingface/transformers. """ -from typing import Literal, Optional, Unpack +from typing import ClassVar, Literal, Optional, Unpack # TODO: put import guard around transformer_engine here, with an informative error message around # installation and the nvidia docker container. @@ -256,10 +256,34 @@ def init_empty_weights(self): # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() - @classmethod - def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): - """Override the default get_init_context method to allow for fp8 model initialization.""" - return [] + def _init_weights(self, module): + """Initialize module weights. + + We only use this method for standard pytorch modules, TE modules handle their own weight initialization through + `init_method` parameters and the `reset_parameters` method. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + # Notably, we need to avoid calling the parent method for TE modules, since the default _init_weights will + # assume any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking + # `LayerNormLinear` and `LayerNormMLP` modules that use `weight` for the linear layer and + # `layer_norm_weight` for the layer norm. Instead, we call `reset_parameters` if the module has it and the + # weights are not in fp8. We still need to figure out why this raises an error if we're using + # `quantized_model_init`. + if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False): + module.reset_parameters() + return + + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys. + + TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. + These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Filter out _extra_state keys which are TransformerEngine-specific and not loadable + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} class NVEsmModel(NVEsmPreTrainedModel): @@ -367,7 +391,7 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys = ("lm_head.decoder.weight",) + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -386,7 +410,6 @@ def __init__(self, config: NVEsmConfig): self.esm = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) - self.init_weights() self.post_init() def get_output_embeddings(self): @@ -614,7 +637,6 @@ def __init__(self, config): init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), ) - self.init_weights() self.post_init() def forward( diff --git a/bionemo-recipes/recipes/esm2_peft_te/requirements.txt b/bionemo-recipes/recipes/esm2_peft_te/requirements.txt index f10a94af3f..8810ca2003 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/requirements.txt +++ b/bionemo-recipes/recipes/esm2_peft_te/requirements.txt @@ -4,5 +4,5 @@ torch torchao!=0.14.0 tqdm transformer_engine[pytorch] -transformers<5.0 +transformers wandb diff --git a/bionemo-recipes/recipes/geneformer_native_te_mfsdp_fp8/requirements.txt b/bionemo-recipes/recipes/geneformer_native_te_mfsdp_fp8/requirements.txt index 9deac3c255..f303729ec0 100644 --- a/bionemo-recipes/recipes/geneformer_native_te_mfsdp_fp8/requirements.txt +++ b/bionemo-recipes/recipes/geneformer_native_te_mfsdp_fp8/requirements.txt @@ -6,5 +6,5 @@ torch torchao!=0.14.0 tqdm transformer_engine -transformers<5.0 +transformers wandb diff --git a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py index 358f023e9a..8ed5351dd2 100644 --- a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py +++ b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py @@ -14,7 +14,7 @@ # limitations under the License. from collections import OrderedDict -from typing import Unpack +from typing import ClassVar, Unpack import torch import torch.nn as nn @@ -88,6 +88,17 @@ def _init_weights(self, module): super()._init_weights(module) + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys. + + TransformerEngine layers add _extra_state attributes that are not compatible with + standard PyTorch/HuggingFace model loading. These are filtered out to ensure + checkpoints can be loaded with from_pretrained(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Filter out _extra_state keys which are TransformerEngine-specific and not loadable + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + class NVLlamaModel(NVLlamaPreTrainedModel): """Llama3 model implemented in Transformer Engine.""" @@ -260,7 +271,7 @@ def forward( class NVLlamaForCausalLM(NVLlamaPreTrainedModel, transformers.GenerationMixin): """Llama3 model with causal language head.""" - _tied_weights_keys = ("lm_head.weight",) + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): """Initialize the NVLlamaForCausalLM model.""" diff --git a/bionemo-recipes/recipes/llama3_native_te/requirements.txt b/bionemo-recipes/recipes/llama3_native_te/requirements.txt index 8d354b6b71..073d9b39e3 100644 --- a/bionemo-recipes/recipes/llama3_native_te/requirements.txt +++ b/bionemo-recipes/recipes/llama3_native_te/requirements.txt @@ -6,7 +6,7 @@ torchdata torchmetrics tqdm transformer_engine[pytorch] -transformers<5.0 +transformers wandb zstandard nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect