From 7a7ed1e2e44343c53675638876c4925d85e96650 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 18 Feb 2026 10:20:27 -0800 Subject: [PATCH 01/16] Add MBridge distillation support for AnyModel checkpoints - Add distill_anymodel.py: Knowledge distillation script for AnyModel checkpoints - Add import_anymodel_to_mbridge.py: Import script to convert HF AnyModel to MBridge format - Update base.py: Simplify HeterogeneousBridgeMixin for AnyModel support --- .../mbridge_distillation/distill_anymodel.py | 295 ++++++++++++++++++ .../import_anymodel_to_mbridge.py | 92 ++++++ .../torch/puzzletron/export/mbridge/base.py | 116 +++++++ 3 files changed, 503 insertions(+) create mode 100644 examples/puzzletron/mbridge_distillation/distill_anymodel.py create mode 100644 examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py create mode 100644 modelopt/torch/puzzletron/export/mbridge/base.py diff --git a/examples/puzzletron/mbridge_distillation/distill_anymodel.py b/examples/puzzletron/mbridge_distillation/distill_anymodel.py new file mode 100644 index 000000000..5d6e3eb9c --- /dev/null +++ b/examples/puzzletron/mbridge_distillation/distill_anymodel.py @@ -0,0 +1,295 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Knowledge Distillation Script for AnyModel Checkpoints. + +This script performs knowledge distillation between student and teacher models +that have been converted to Megatron-Bridge format using import_anymodel_to_mbridge.py. + +The distillation uses KL-Divergence loss between student and teacher logits +with temperature scaling (standard knowledge distillation from Hinton et al., 2015). + +Usage: + cd /workspace/Model-Optimizer + + # TODO: remove this once Megatron-Bridge is installed in the environment + export PYTHONPATH="/workspace/Megatron-Bridge/src:/workspace/Model-Optimizer:${PYTHONPATH}" + + # Basic usage (uses model's max seq_length, which may be very large) + torchrun --nproc_per_node=1 examples/puzzletron/mbridge_distillation/distill_anymodel.py \ + --student-mbridge-ckpt /path/to/student/iter_0000000 \ + --teacher-mbridge-ckpt /path/to/teacher/iter_0000000 \ + --data-path /path/to/tokenized/dataset \ + --output-dir ./distilled_output + + # Recommended: Override sequence length and other training params for faster training + torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_anymodel.py \ + --student-mbridge-ckpt /path/to/student/iter_0000000 \ + --teacher-mbridge-ckpt /path/to/teacher/iter_0000000 \ + --data-path /path/to/tokenized/dataset \ + --output-dir ./distilled_output \ + dataset.sequence_length=8192 \ + model.tensor_model_parallel_size=8 \ + model.teacher.tensor_model_parallel_size=8 \ + train.global_batch_size=4 \ + train.micro_batch_size=1 \ + train.train_iters=5000 \ + logger.log_interval=1 +""" + +import argparse +import logging +import os +import sys + +import torch +from megatron.bridge.models.distillation_provider import convert_to_distillation_provider +from megatron.bridge.training.checkpointing import get_checkpoint_run_config_filename +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + DistributedDataParallelConfig, + DistributedInitConfig, + GPTDatasetConfig, + LoggerConfig, + OptimizerConfig, + RerunStateMachineConfig, + RNGConfig, + SchedulerConfig, + TrainingConfig, + ValidationConfig, +) +from megatron.bridge.training.distill import distill +from megatron.bridge.training.model_load_save import load_model_config +from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig +from megatron.bridge.training.tokenizers.config import TokenizerConfig +from megatron.bridge.training.utils.omegaconf_utils import ( + apply_overrides, + create_omegaconf_dict_config, + parse_hydra_overrides, +) +from megatron.bridge.utils.common_utils import get_rank_safe +from omegaconf import OmegaConf + +# Import GenericHeterogeneousProvider so it can be instantiated when loading +# checkpoint configs that reference it (e.g., run_config.yaml with +# _target_: modelopt.torch.puzzletron.export.mbridge.base.GenericHeterogeneousProvider) +import modelopt.torch.puzzletron.export.mbridge # noqa: F401 + +logger: logging.Logger = logging.getLogger(__name__) + + +def create_distillation_config() -> ModelOptDistillConfig: + """Create KD config with output layer distillation only.""" + return ModelOptDistillConfig( + logit_layers=["output_layer", "output_layer"], + intermediate_layer_pairs=[], + skip_lm_loss=True, + kd_loss_scale=1.0, + logit_kl_temperature=1.0, + ) + + +def create_base_config( + student_model_provider, + data_path: str, + student_ckpt: str, + output_dir: str, + use_bf16: bool, + use_fp16: bool, +) -> ConfigContainer: + """Create base ConfigContainer with defaults.""" + return ConfigContainer( + model=student_model_provider, + train=TrainingConfig(global_batch_size=1, micro_batch_size=1, train_iters=100), + optimizer=OptimizerConfig( + optimizer="adam", + lr=1e-4, + min_lr=1e-5, + weight_decay=0.01, + bf16=use_bf16, + fp16=use_fp16, + ), + scheduler=SchedulerConfig( + lr_decay_style="linear", + lr_warmup_iters=0, + start_weight_decay=0.01, + end_weight_decay=0.01, + weight_decay_incr_style="constant", + ), + dataset=GPTDatasetConfig( + random_seed=1234, + blend=[[data_path], [1.0]], + split="9999,8,2", + seq_length=student_model_provider.seq_length, + reset_attention_mask=False, + reset_position_ids=False, + eod_mask_loss=False, + dataloader_type="single", + ), + checkpoint=CheckpointConfig(load=student_ckpt, save=output_dir), + logger=LoggerConfig(), + tokenizer=TokenizerConfig(tokenizer_type="HuggingFaceTokenizer", tokenizer_model=None), + validation=ValidationConfig(eval_interval=500, eval_iters=100), + ddp=DistributedDataParallelConfig(grad_reduce_in_fp32=True), + dist=DistributedInitConfig(), + rng=RNGConfig(), + rerun_state_machine=RerunStateMachineConfig(), + ) + + +def merge_checkpoint_configs( + cfg: ConfigContainer, checkpoint_path: str, use_bf16: bool, use_fp16: bool +) -> None: + """Merge tokenizer and optimizer configs from checkpoint if available.""" + try: + run_config_path = get_checkpoint_run_config_filename(checkpoint_path) + checkpoint_cfg = ConfigContainer.from_yaml(run_config_path) + if checkpoint_cfg.tokenizer is not None: + cfg.tokenizer = checkpoint_cfg.tokenizer + if checkpoint_cfg.optimizer is not None: + for key, value in checkpoint_cfg.optimizer.__dict__.items(): + if ( + value is not None + and hasattr(cfg.optimizer, key) + and key not in ("bf16", "fp16") + ): + setattr(cfg.optimizer, key, value) + # Ensure bf16/fp16 are set correctly based on model dtype + cfg.optimizer.bf16 = use_bf16 + cfg.optimizer.fp16 = use_fp16 + except Exception as e: + logger.warning(f"Could not load additional configs from checkpoint: {e}") + + +def parse_cli_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Knowledge distillation with AnyModel checkpoints", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--student-mbridge-ckpt", + type=str, + required=True, + help="Path to student checkpoint in MBridge format (must be iter_XXXXXXX directory).", + ) + parser.add_argument( + "--teacher-mbridge-ckpt", + type=str, + required=True, + help="Path to teacher checkpoint in MBridge format (must be iter_XXXXXXX directory).", + ) + parser.add_argument( + "--data-path", + type=str, + required=True, + help="Path to tokenized dataset (without .bin extension).", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./distilled_output", + help="Output directory for distilled checkpoint.", + ) + parser.add_argument( + "--config-file", + type=str, + default=None, + help="Path to YAML OmegaConf override file (optional).", + ) + parser.add_argument("--debug", action="store_true", help="Enable debug logging") + + args, cli_overrides = parser.parse_known_args() + return args, cli_overrides + + +def main() -> None: + """Main distillation function.""" + args, cli_overrides = parse_cli_args() + + if args.debug: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + logger.info("Megatron-Bridge Knowledge Distillation Script (AnyModel)") + logger.info("=" * 70) + + # Load model configs from checkpoints + logger.info("Loading model configs from MBridge checkpoints...") + student_model_provider, _ = load_model_config(args.student_mbridge_ckpt) + teacher_model_provider, _ = load_model_config(args.teacher_mbridge_ckpt) + + # Detect model dtype for optimizer config + model_params_dtype = getattr(student_model_provider, "params_dtype", torch.float32) + use_bf16 = model_params_dtype == torch.bfloat16 + use_fp16 = model_params_dtype == torch.float16 + + # Create base config with defaults + cfg = create_base_config( + student_model_provider, + args.data_path, + args.student_mbridge_ckpt, + args.output_dir, + use_bf16, + use_fp16, + ) + + # Merge tokenizer and optimizer from checkpoint if available + merge_checkpoint_configs(cfg, args.student_mbridge_ckpt, use_bf16, use_fp16) + + # Create distillation config and convert to DistillationProvider + kd_config = create_distillation_config() + cfg.model = convert_to_distillation_provider( + student_provider=student_model_provider, + teacher_provider=teacher_model_provider, + kd_config=kd_config, + ) + + # Apply YAML and CLI overrides + merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg) + if args.config_file: + if not os.path.exists(args.config_file): + logger.error(f"Override YAML file not found: {args.config_file}") + sys.exit(1) + yaml_overrides_omega = OmegaConf.load(args.config_file) + merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega) + if cli_overrides: + merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides) + apply_overrides(cfg, OmegaConf.to_container(merged_omega_conf, resolve=True), excluded_fields) + + # Sync model seq_length with dataset sequence_length if they differ + if ( + hasattr(cfg.dataset, "sequence_length") + and cfg.dataset.sequence_length != cfg.model.seq_length + ): + cfg.model.seq_length = cfg.dataset.sequence_length + if hasattr(cfg.model, "teacher") and cfg.model.teacher is not None: + cfg.model.teacher.seq_length = cfg.dataset.sequence_length + + if get_rank_safe() == 0: + logger.info("--- Final Configuration ---") + cfg.print_yaml() + + # Run distillation + logger.info("Starting distillation training...") + distill(cfg) + + logger.info(f"✓ Distillation complete! Checkpoints saved to: {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py b/examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py new file mode 100644 index 000000000..1127d228d --- /dev/null +++ b/examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Import AnyModel checkpoint to Megatron-Bridge format. + +This script converts a HuggingFace AnyModel checkpoint to Megatron-Bridge format, +similar to NeMo's convert_hf_to_nemo. + +Usage: + cd /workspace/Model-Optimizer + + # TODO: remove this once Megatron-Bridge is installed in the environment + export PYTHONPATH="/workspace/Megatron-Bridge/src:/workspace/Model-Optimizer:${PYTHONPATH}" + + torchrun --nproc_per_node=1 examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py \ + --input-ckpt-path /path/to/anymodel/checkpoint \ + --output-ckpt-path /path/to/save/mbridge/checkpoint +""" + +import argparse +from pathlib import Path + +from megatron.bridge import AutoBridge + +# Import all heterogeneous bridges to register them +# This will override homogeneous bridges (e.g., LlamaBridge, Qwen3Bridge) with +# heterogeneous versions (PuzzletronLlamaAnyModelBridge, PuzzletronQwen3AnyModelBridge) +# that support block_configs for AnyModel checkpoints. +import modelopt.torch.puzzletron.export.mbridge # noqa: F401 + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Convert AnyModel checkpoint to Megatron-Bridge format", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--input-ckpt-path", + type=str, + required=True, + help="Path to input AnyModel checkpoint (HuggingFace format)", + ) + parser.add_argument( + "--output-ckpt-path", + type=str, + required=True, + help="Path to save output Megatron-Bridge checkpoint", + ) + return parser.parse_args() + + +def main() -> None: + """Main function to import HF AnyModel and save as Megatron checkpoint.""" + args = parse_args() + + input_path = Path(args.input_ckpt_path) + output_path = Path(args.output_ckpt_path) + + print(f"Importing AnyModel checkpoint from: {input_path}") + print(f"Saving Megatron-Bridge checkpoint to: {output_path}") + print() + + # Create output directory if it doesn't exist + output_path.mkdir(parents=True, exist_ok=True) + + # Import and save as Megatron checkpoint + AutoBridge.import_ckpt( + hf_model_id=str(input_path), + megatron_path=str(output_path), + trust_remote_code=True, + ) + + print(f"\n✓ Successfully saved Megatron-Bridge checkpoint to: {output_path}") + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/puzzletron/export/mbridge/base.py b/modelopt/torch/puzzletron/export/mbridge/base.py new file mode 100644 index 000000000..eff9f0b73 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/base.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mixin class for bridges that support heterogeneous layer architectures. + +This module provides a mixin class for converting models with block_configs +(heterogeneous layer configurations) to Megatron-Core format via Megatron-Bridge. +""" + +import dataclasses +import json +from collections.abc import Callable +from dataclasses import dataclass + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.bridge.models.transformer_config import HeterogeneousTransformerConfig +from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( + get_gpt_heterogeneous_layer_spec, +) +from megatron.core.transformer.spec_utils import ModuleSpec + + +def heterogeneous_layer_spec(config) -> ModuleSpec: + """Get GPT heterogeneous layer spec using Transformer Engine.""" + return get_gpt_heterogeneous_layer_spec(config, use_te=True) + + +@dataclass +class GenericHeterogeneousProvider(GPTModelProvider, HeterogeneousTransformerConfig): + """Generic provider for AnyModel checkpoints with block_configs.""" + + # Heterogeneous configuration fields + heterogeneous_layers_config_path: str | None = None + heterogeneous_layers_config_encoded_json: str = "" + transformer_layer_spec: ModuleSpec | Callable = heterogeneous_layer_spec + + def __getattr__(self, name: str): + """Handle missing attributes for OmegaConf compatibility. + + OmegaConf conversion tries to access per_block_parameters which may not + be initialized when loading from YAML. Return empty list as fallback. + """ + if name == "per_block_parameters": + return [] + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + +class HeterogeneousBridgeMixin: + """Mixin for bridges supporting heterogeneous layer architectures (block_configs). + + Must be used with multiple inheritance alongside a model-specific bridge. + Example: class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge) + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider: + """Convert HF AnyModel config to Megatron GPTModelProvider.""" + + parent_provider = super().provider_bridge(hf_pretrained) # type: ignore[misc] + + provider_kwargs = dataclasses.asdict(parent_provider) + + provider_kwargs["heterogeneous_layers_config_encoded_json"] = ( + self._build_heterogeneous_config_json(hf_pretrained.config) + ) + return GenericHeterogeneousProvider(**provider_kwargs) + + @classmethod + def megatron_to_hf_config(cls, provider: GPTModelProvider) -> dict: + raise NotImplementedError( + "megatron_to_hf_config() not yet implemented for AnyModel bridges. " + "AnyModel bridges require special handling for heterogeneous layer configurations." + ) + + def _build_heterogeneous_config_json(self, hf_config) -> str: + """Build heterogeneous layers config JSON from HF config.""" + + hf_config_dict = json.loads(hf_config.to_json_string()) + + mcore_block_configs = [ + self._convert_block_config(block) for block in hf_config_dict["block_configs"] + ] + return json.dumps({"block_configs": mcore_block_configs}, ensure_ascii=False) + + def _convert_block_config(self, block: dict) -> dict: + """Convert a single block config from HF format to MCore format.""" + return { + "attention": self._convert_attention_config(block["attention"]), + "ffn": self._convert_ffn_config(block["ffn"]), + } + + def _convert_attention_config(self, attention_config: dict) -> dict: + """Convert attention config from HF format to MCore format.""" + attention_config = attention_config.copy() + attention_config["num_query_groups"] = attention_config.pop("num_key_value_heads") + return attention_config + + def _convert_ffn_config(self, ffn_config: dict) -> dict: + """Convert FFN/MLP config from HF format to MCore format.""" + ffn_config = ffn_config.copy() + ffn_config["ffn_hidden_size"] = ffn_config.pop("intermediate_size") + return ffn_config From 562f46b750069de9aa78f31fff232f2399309f39 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Wed, 18 Feb 2026 10:24:05 -0800 Subject: [PATCH 02/16] Add missing files from modelopt/torch/puzzletron/export/mbridge - Add __init__.py: Module initialization - Add llama.py: Llama bridge implementation - Add qwen3.py: Qwen3 bridge implementation --- .../puzzletron/export/mbridge/__init__.py | 35 +++++++++++++++++ .../torch/puzzletron/export/mbridge/llama.py | 38 +++++++++++++++++++ .../torch/puzzletron/export/mbridge/qwen3.py | 38 +++++++++++++++++++ 3 files changed, 111 insertions(+) create mode 100644 modelopt/torch/puzzletron/export/mbridge/__init__.py create mode 100644 modelopt/torch/puzzletron/export/mbridge/llama.py create mode 100644 modelopt/torch/puzzletron/export/mbridge/qwen3.py diff --git a/modelopt/torch/puzzletron/export/mbridge/__init__.py b/modelopt/torch/puzzletron/export/mbridge/__init__.py new file mode 100644 index 000000000..471e68984 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/__init__.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron-Bridge adapters for Puzzletron AnyModel checkpoints. + +This module provides bridges for converting Puzzletron AnyModel checkpoints +(heterogeneous layer architectures) to Megatron-Core format via Megatron-Bridge. +""" + +# Import to register bridges (side effect) +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin +from modelopt.torch.puzzletron.export.mbridge.llama import ( # noqa: F401 + PuzzletronLlamaAnyModelBridge, +) +from modelopt.torch.puzzletron.export.mbridge.qwen3 import ( # noqa: F401 + PuzzletronQwen3AnyModelBridge, +) + +__all__ = [ + "HeterogeneousBridgeMixin", + "PuzzletronLlamaAnyModelBridge", + "PuzzletronQwen3AnyModelBridge", +] diff --git a/modelopt/torch/puzzletron/export/mbridge/llama.py b/modelopt/torch/puzzletron/export/mbridge/llama.py new file mode 100644 index 000000000..d193a02e9 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/llama.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Bridge for Puzzletron Llama-based AnyModel heterogeneous checkpoints.""" + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.llama.llama_bridge import LlamaBridge +from megatron.core.models.gpt.gpt_model import GPTModel +from transformers import LlamaForCausalLM + +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin + + +@MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel, model_type="llama") +class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge): + """ + Megatron Bridge for Puzzletron Llama-based AnyModel checkpoints. + + Extends LlamaBridge with support for heterogeneous layer architectures (block_configs). + All Llama-specific settings are inherited from LlamaBridge. + """ + + # provider_bridge() is inherited from HeterogeneousBridgeMixin + # It automatically reuses LlamaBridge.provider_bridge() and adds heterogeneous config + # mapping_registry() is inherited from LlamaBridge diff --git a/modelopt/torch/puzzletron/export/mbridge/qwen3.py b/modelopt/torch/puzzletron/export/mbridge/qwen3.py new file mode 100644 index 000000000..664e87e08 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/qwen3.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Bridge for Puzzletron Qwen3-based AnyModel heterogeneous checkpoints.""" + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.qwen.qwen3_bridge import Qwen3Bridge +from megatron.core.models.gpt.gpt_model import GPTModel +from transformers import Qwen3ForCausalLM + +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin + + +@MegatronModelBridge.register_bridge(source=Qwen3ForCausalLM, target=GPTModel, model_type="qwen3") +class PuzzletronQwen3AnyModelBridge(HeterogeneousBridgeMixin, Qwen3Bridge): + """ + Megatron Bridge for Puzzletron Qwen3-based AnyModel checkpoints. + + Extends Qwen3Bridge with support for heterogeneous layer architectures (block_configs). + All Qwen3-specific settings are inherited from Qwen3Bridge. + """ + + # provider_bridge() is inherited from HeterogeneousBridgeMixin + # It automatically reuses Qwen3Bridge.provider_bridge() and adds heterogeneous config + # mapping_registry() is inherited from Qwen3Bridge From d244ca7f3ad4babc20be4c735167f73cd3925093 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 19 Feb 2026 01:08:05 -0800 Subject: [PATCH 03/16] A tutorial on mbridge distillation for puzzletron/any_model Signed-off-by: Daniel Korzekwa --- examples/puzzletron/README.md | 16 +--- .../puzzletron/mbridge_distillation/README.md | 93 +++++++++++++++++++ 2 files changed, 95 insertions(+), 14 deletions(-) create mode 100644 examples/puzzletron/mbridge_distillation/README.md diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md index 48f64d3c4..426186d93 100644 --- a/examples/puzzletron/README.md +++ b/examples/puzzletron/README.md @@ -275,21 +275,9 @@ vllm bench throughput --model path/to/model --input-len 2000 --output-len 100 -- ## Knowledge Distillation -To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one. For this, we will use [NeMo framework](https://github.com/NVIDIA-NeMo/NeMo) with the [nemo:25.07](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=25.07) container. +To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one. -First, convert the HF model to NeMo format: - -```bash -python -m nemo_export/convert_hf_to_nemo --input-ckpt-path path/to/HF-model --output-ckpt-path path/to/save/model-nemo -``` - -Now you can utilize all the training features available in NeMo, including distillation. Please refer to the [NeMo distillation documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html). - -[Optional] Once distillation is complete, you can convert the distilled model back to the HuggingFace format. - -```bash -python -m nemo_export/convert_nemo_to_hf --input-ckpt-path path/to/nemo-model --output-ckpt-path path/to/save/model-HF -``` +See [mbridge_distillation/README.md](./mbridge_distillation/README.md) for instructions on using Megatron-Bridge for knowledge distillation. ## Advanced Usage diff --git a/examples/puzzletron/mbridge_distillation/README.md b/examples/puzzletron/mbridge_distillation/README.md new file mode 100644 index 000000000..faaf4e0ba --- /dev/null +++ b/examples/puzzletron/mbridge_distillation/README.md @@ -0,0 +1,93 @@ +# Knowledge Distillation with Megatron-Bridge + +This guide shows how to perform knowledge distillation on Puzzletron-compressed AnyModel checkpoints using Megatron-Bridge. + +## Overview + +1. Set up the environment with Megatron-Bridge +2. Convert AnyModel checkpoints (student and teacher) to Megatron-Bridge format +3. Run knowledge distillation training + +## Setup + +> **Temporary Setup:** This manual Megatron-Bridge setup is required temporarily until the NeMo docker container includes Megatron-Bridge by default. Once the container is updated, this setup step will no longer be necessary. + +**Note:** Set `$WORKSPACE` to your project root directory before running these commands: + +```bash +export WORKSPACE=/path/to/your/project +``` + +1. **Clone Megatron-Bridge:** + + Clone [Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) into your workspace: + + ```bash + cd $WORKSPACE + git clone https://github.com/NVIDIA-NeMo/Megatron-Bridge.git + ``` + +2. **Initialize Megatron-Bridge submodules:** + + ```bash + cd $WORKSPACE/Megatron-Bridge + git submodule init + git submodule update + ``` + +3. **Start Docker container with mounts:** + + Use the [NeMo 25.11 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=25.11): + + ```bash + docker run --gpus all -it --rm \ + -v $WORKSPACE:/workspace \ + -v $WORKSPACE/Megatron-Bridge/3rdparty/Megatron-LM:/opt/megatron-lm \ + nvcr.io/nvidia/nemo:25.11 \ + /bin/bash + ``` + + **Note:** The mount `/opt/megatron-lm` is required because Megatron-Bridge depends on the Megatron-LM submodule. + +4. **Set up the environment inside the container:** + + ```bash + export PYTHONPATH="/workspace/Megatron-Bridge/src:/workspace/Model-Optimizer:${PYTHONPATH}" + ``` + +## Step 1: Convert Checkpoints to Megatron-Bridge Format + +Convert both student and teacher checkpoints: + +```bash +# Convert student checkpoint +torchrun --nproc_per_node=1 examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py \ + --input-ckpt-path /path/to/student/anymodel/checkpoint \ + --output-ckpt-path /path/to/student/mbridge/checkpoint + +# Convert teacher checkpoint +torchrun --nproc_per_node=1 examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py \ + --input-ckpt-path /path/to/teacher/anymodel/checkpoint \ + --output-ckpt-path /path/to/teacher/mbridge/checkpoint +``` + +## Step 2: Run Knowledge Distillation + +Run distillation with tokenized dataset: + +```bash +torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_anymodel.py \ + --student-mbridge-ckpt /path/to/student/mbridge/checkpoint/iter_0000000 \ + --teacher-mbridge-ckpt /path/to/teacher/mbridge/checkpoint/iter_0000000 \ + --data-path /path/to/tokenized/dataset \ + --output-dir ./distilled_output \ + dataset.sequence_length=8192 \ + model.tensor_model_parallel_size=8 \ + model.teacher.tensor_model_parallel_size=8 \ + train.global_batch_size=4 \ + train.micro_batch_size=1 \ + train.train_iters=5000 \ + logger.log_interval=1 +``` + +The distilled checkpoint will be saved to `--output-dir`. From 018b20845739f7ecbcc30a439862bf3066491816 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 19 Feb 2026 01:20:18 -0800 Subject: [PATCH 04/16] Update distillation readme Signed-off-by: Daniel Korzekwa --- examples/puzzletron/mbridge_distillation/README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/puzzletron/mbridge_distillation/README.md b/examples/puzzletron/mbridge_distillation/README.md index faaf4e0ba..6726089fa 100644 --- a/examples/puzzletron/mbridge_distillation/README.md +++ b/examples/puzzletron/mbridge_distillation/README.md @@ -10,7 +10,7 @@ This guide shows how to perform knowledge distillation on Puzzletron-compressed ## Setup -> **Temporary Setup:** This manual Megatron-Bridge setup is required temporarily until the NeMo docker container includes Megatron-Bridge by default. Once the container is updated, this setup step will no longer be necessary. +> **Temporary Setup:** The NeMo docker container includes Megatron-Bridge (main branch), but Puzzletron requires a specific version/branch of Megatron-Bridge that is not included by default. This manual setup is required to use the Puzzletron-compatible version. Once the container includes the required version, this setup step will no longer be necessary. **Note:** Set `$WORKSPACE` to your project root directory before running these commands: @@ -20,11 +20,13 @@ export WORKSPACE=/path/to/your/project 1. **Clone Megatron-Bridge:** - Clone [Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) into your workspace: + Clone [Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) and checkout the specific commit required for Puzzletron: ```bash cd $WORKSPACE git clone https://github.com/NVIDIA-NeMo/Megatron-Bridge.git + cd Megatron-Bridge + git checkout 960a718cb8989676b258e107d538642717e22e39 ``` 2. **Initialize Megatron-Bridge submodules:** From 5f30fa91fa6edac4562551484e08e01c0bcca5d0 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 19 Feb 2026 08:59:24 -0800 Subject: [PATCH 05/16] Improve mbridge tutorial for anymodel Signed-off-by: Daniel Korzekwa --- .../puzzletron/mbridge_distillation/README.md | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/examples/puzzletron/mbridge_distillation/README.md b/examples/puzzletron/mbridge_distillation/README.md index 6726089fa..1ae84e65f 100644 --- a/examples/puzzletron/mbridge_distillation/README.md +++ b/examples/puzzletron/mbridge_distillation/README.md @@ -57,6 +57,57 @@ export WORKSPACE=/path/to/your/project export PYTHONPATH="/workspace/Megatron-Bridge/src:/workspace/Model-Optimizer:${PYTHONPATH}" ``` +## Dataset Preparation + +This section describes how to prepare datasets for knowledge distillation. We provide examples using a toy dataset (WikiText-103) for illustration purposes, and note how to adapt the process for production datasets like Nemotron-Post-Training-Dataset-v2. + +> **Note:** For actual knowledge distillation, use a larger, more representative dataset like [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2). + +### Step 1: Download Dataset + +First, download the dataset and save it in JSONL format. For WikiText-103, you can use the following script: + +```python +# download_hf_wikitext_dataset.py +import json +import os +from datasets import load_dataset + +DATA_PATH = "path/to/hf_datasets/wikitext-103-v1" +# Load the WikiText-103 dataset +dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") + +# Define the destination folder +os.makedirs(DATA_PATH, exist_ok=True) + +# Save splits to JSONL files +with open(f"{DATA_PATH}/wikitext-train.jsonl", "w") as file: + file.writelines(json.dumps(item) + "\n" for item in dataset) + +print(f"Raw dataset saved to {DATA_PATH}/wikitext-train.jsonl") +``` + +### Step 2: Tokenize Dataset + +Next, tokenize the JSONL dataset using the tokenizer from your model. This converts the text data into token IDs that can be used for training: + +```python +# tokenize_wikitext_dataset.py +from modelopt.torch.utils.plugins import megatron_preprocess_data + +DATA_PATH = "path/to/hf_datasets/wikitext-103-v1" +HF_MODEL_NAME_OR_PATH = "path/to/your/model/checkpoint" + +megatron_preprocess_data( + input_path=f"{DATA_PATH}/wikitext-train.jsonl", + output_dir=DATA_PATH, + tokenizer_name_or_path=HF_MODEL_NAME_OR_PATH, + json_keys=["text"], + workers=32, + log_interval=100000, +) +``` + ## Step 1: Convert Checkpoints to Megatron-Bridge Format Convert both student and teacher checkpoints: From 5f73765d5c9be6992b6614db89d1df0da42723bf Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Fri, 20 Feb 2026 11:03:00 -0800 Subject: [PATCH 06/16] Fixing distillation for heterogenous models (call self.teacher.finalize() on DistillationProvider.provide() Signed-off-by: Daniel Korzekwa --- .../export/mbridge/distillation_provider.py | 173 ++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 modelopt/torch/puzzletron/export/mbridge/distillation_provider.py diff --git a/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py b/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py new file mode 100644 index 000000000..d9d7f2266 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Upstream this fix to Megatron-Bridge and remove this local copy. + +import logging +from dataclasses import dataclass, fields +from typing import TYPE_CHECKING, Any, Optional + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider +from megatron.bridge.models.transformer_config import TransformerConfig +from megatron.core.models.gpt import GPTModel as MCoreGPTModel + +import modelopt.torch.distill as mtd +import modelopt.torch.distill.plugins.megatron as mtd_mcore + +if TYPE_CHECKING: + from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig + + +logger = logging.getLogger(__name__) + + +@dataclass +class DistillationProvider(TransformerConfig): + """Provider for Megatron Core GPT models in distillation mode. + + Please use `convert_to_distillation_provider()` to create an instance of this class. + """ + + teacher: Optional[GPTModelProvider | MambaModelProvider] = None + kd_config: Optional["ModelOptDistillConfig"] = None + + def __init__(self, *args, **kwargs): + raise NotImplementedError( + "Use `convert_to_distillation_provider()` to create an instance of this class." + ) + + def __post_init__(self): + assert getattr(self, "teacher", None) is not None, "Teacher model must be provided." + + shared_attrs = [ + "tensor_model_parallel_size", + "pipeline_model_parallel_size", + "context_parallel_size", + "seq_length", + "pipeline_dtype", + ] + for attr in shared_attrs: + if getattr(self, attr) != getattr(self.teacher, attr): + raise ValueError(f"Student and teacher providers must have the same {attr}.") + + # Logits are overwritten in-place when TE cross-entropy loss is enabled, so switch it back to native version. + self.cross_entropy_fusion_impl = "native" + + # Hack to dynamically subclass other providers and still use their methods + self._super_class = self.__class__.__bases__[0] + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: + """Configure and instantiate a ModelOpt DistillationModel based on this configuration. + + Args: + pre_process: Whether to include pre-processing in the model, defaults to first pipeline stage + post_process: Whether to include post-processing in the model, defaults to last pipeline stage + vp_stage: Virtual pipeline stage + + Returns: + MCoreGPTModel: Configured ModelOpt DistillationModel instance + """ + if vp_stage is not None: + raise ValueError("ModelOpt KD currently does not support virtual-pipeline parallel.") + + assert self.teacher is not None, "Teacher model must be provided." + student_model = self._super_class.provide(self, pre_process, post_process, vp_stage) # type: ignore[attr-defined] + + # Finalize teacher provider before creating model (required for heterogeneous models). + # + # per_block_parameters is an attribute of HeterogeneousTransformerConfig (defined in + # MCoreHeterogeneousTransformerConfig, heterogeneous_config.py:197). It's created during + # provider creation (bridge.to_megatron_provider()), but finalize() ensures they're consistent + # with current parallelism settings and distributed context. Student model creation (above) + # initializes parallel_state (process groups, TP/PP config), which weight loading/scatter + # requires. During teacher model creation, get_config_for_layer() is called (transformer_block.py:341) + # for each layer, which uses per_block_parameters and current tensor_model_parallel_size to + # determine layer architecture. Without finalize() in this context, architecture expectations + # don't match checkpoint weights, causing: + # ValueError: ProcessGroupNCCL::scatter: invalid tensor size at index 0 + # (expected (2880, 4096), got (3584, 4096)) + # + # Note: This explanation needs to be confirmed yet. + self.teacher.finalize() + + # Hack to get teacher's pre-wrap hooks called to potentially load HF weights + teacher_model = self.teacher.provide_distributed_model( + wrap_with_ddp=False, mixed_precision_wrapper=None + )[0] + + kd_cfg = mtd_mcore.setup_distillation_config( + self.kd_config, student_model.config, teacher_model.config + ) + modelopt_cfg = { + "teacher_model": teacher_model, + "criterion": kd_cfg.criterion, + "loss_balancer": kd_cfg.loss_balancer, + } + kd_model = mtd.convert(student_model, mode=[("kd_loss", modelopt_cfg)]) + mtd_mcore.adjust_distillation_model_for_mcore(kd_model, kd_cfg) + + return kd_model + + def to_cfg_dict(self) -> dict[str, Any]: + """Custom method to save equivalent to the original provider class. + + Used by `_ConfigContainerBase` to serialize the main `ConfigContainer` to YAML. + There is no need to restore a `DistillationProvider` from the run config file, as + it can always be re-converted using the original student provider. + + Returns: + Dictionary representation of this provider class + """ + from megatron.bridge.training.utils.config_utils import _ConfigContainerBase + + result = {"_target_": f"{self._super_class.__module__}.{self._super_class.__qualname__}"} + for field in fields(self): + if field.name.startswith("_") or field.name in ["teacher", "kd_config"]: + continue + result[field.name] = _ConfigContainerBase._convert_value_to_dict( + getattr(self, field.name) + ) + return result + + def __setattr__(self, name, value): + super().__setattr__(name, value) + # Mirror to teacher if it has that attribute + if hasattr(self.teacher, name): + setattr(self.teacher, name, value) + + +def convert_to_distillation_provider( + student_provider: GPTModelProvider | MambaModelProvider, + teacher_provider: GPTModelProvider | MambaModelProvider, + kd_config: Optional["ModelOptDistillConfig"] = None, +) -> "DistillationProvider": + """Convert a given model provider to a DistillationProvider.""" + + assert isinstance(student_provider, (GPTModelProvider, MambaModelProvider)), ( + "Student provider must be a subclass of GPTModelProvider or MambaModelProvider." + ) + assert isinstance(teacher_provider, (GPTModelProvider, MambaModelProvider)), ( + "Teacher provider must be a subclass of GPTModelProvider or MambaModelProvider." + ) + + DistillationProvider.__bases__ = (type(student_provider),) + student_provider.__class__ = DistillationProvider + + student_provider.teacher = teacher_provider + student_provider.kd_config = kd_config + student_provider.__post_init__() + + return student_provider From 99400f39a247cc990c658718c09af67ea50d3604 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Sun, 22 Feb 2026 23:24:31 -0800 Subject: [PATCH 07/16] Add original keval distill script. Signed-off-by: Daniel Korzekwa --- .../mbridge_distillation/distill_hf_keval.py | 250 ++++++++++++++++++ 1 file changed, 250 insertions(+) create mode 100644 examples/puzzletron/mbridge_distillation/distill_hf_keval.py diff --git a/examples/puzzletron/mbridge_distillation/distill_hf_keval.py b/examples/puzzletron/mbridge_distillation/distill_hf_keval.py new file mode 100644 index 000000000..31f1cfc71 --- /dev/null +++ b/examples/puzzletron/mbridge_distillation/distill_hf_keval.py @@ -0,0 +1,250 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Distillation script for Megatron-Bridge. + +Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model +to `/checkpoints` in megatron distributed checkpoint format. + +See `README.md` in this directory for example usage and data preparation instructions. +""" + +import argparse +import os + +import torch +from megatron.bridge import AutoBridge +from megatron.bridge.models.distillation_provider import convert_to_distillation_provider +from megatron.bridge.recipes.utils.optimizer_utils import ( + distributed_fused_adam_with_cosine_annealing, +) +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + GPTDatasetConfig, + LoggerConfig, + MockGPTDatasetConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.distill import distill +from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig +from megatron.core.datasets.utils import get_blend_from_list +from megatron.core.distributed import DistributedDataParallelConfig + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.utils import print_rank_0 + +SEED = 1234 + + +def get_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.") + # Model arguments (accepts HuggingFace input only at the moment) + parser.add_argument( + "--student_hf_path", + type=str, + required=True, + help="HuggingFace model name or path for the student (e.g. Qwen/Qwen3-0.6B)", + ) + parser.add_argument( + "--teacher_hf_path", + type=str, + required=True, + help="HuggingFace model name or path for the teacher (e.g. Qwen/Qwen3-8B)", + ) + # Parallelism arguments + parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") + # Dataset arguments + parser.add_argument( + "--data_paths", + nargs="+", + help="List of tokenized data paths to load from (weight1 path1 weight2 path2 ...)", + ) + parser.add_argument( + "--split", type=str, default="99,1,0", help="Train,Val,Test ratios to split data" + ) + parser.add_argument( + "--data_path_to_cache", type=str, default=None, help="Path to cache the dataset indices" + ) + parser.add_argument( + "--use_mock_data", action="store_true", help="Use mock data instead of --data_paths" + ) + # Training & Eval arguments + parser.add_argument( + "--output_dir", type=str, required=True, help="Folder for logging and checkpoint saving" + ) + parser.add_argument( + "--seq_length", + type=int, + default=4096, + help="Number of tokens per input sample. Use 8192 if your dataset has longer sequences.", + ) + parser.add_argument("--mbs", type=int, default=1, help="Micro-batch Size") + parser.add_argument("--gbs", type=int, default=768, help="Global Batch Size") + parser.add_argument( + "--train_iters", type=int, required=True, help="Number of training iterations" + ) + parser.add_argument("--lr", type=float, default=1e-4, help="Peak learning rate") + parser.add_argument("--min_lr", type=float, default=1e-5, help="Minimum learning rate") + parser.add_argument("--lr_warmup_iters", type=int, default=50, help="Number of LR warmup steps") + parser.add_argument( + "--eval_interval", type=int, default=100, help="Validate + checkpoint every steps" + ) + parser.add_argument( + "--eval_iters", type=int, default=32, help="Number of batches per validation stage" + ) + # Logging arguments + parser.add_argument("--log_interval", type=int, default=10, help="Write to log every steps") + parser.add_argument( + "--wandb_project", type=str, help="Wandb project name (required to enable Wandb logging)" + ) + parser.add_argument("--wandb_entity", type=str, help="Wandb entity name (optional)") + parser.add_argument("--wandb_exp_name", type=str, help="Wandb experiment name (optional)") + args = parser.parse_args() + + # Sanity checks + if not args.use_mock_data and not args.data_paths: + raise ValueError("Must provide either --data_paths or set --use_mock_data.") + + print_rank_0("\n==================== Arguments ====================") + for k, v in args.__dict__.items(): + print_rank_0(f"{k:<35} {v}") + print_rank_0("===================================================\n") + + return args + + +def main(args: argparse.Namespace): + checkpoint_dir = os.path.join(args.output_dir, "checkpoints") + tensorboard_dir = os.path.join(args.output_dir, "tb_logs") + + # Build student and teacher model providers + def _build_model_provider(hf_path): + bridge = AutoBridge.from_hf_pretrained(hf_path) + provider = bridge.to_megatron_provider(load_weights=True) + + # Override parallelism / training settings + provider.tensor_model_parallel_size = args.tp_size + provider.pipeline_model_parallel_size = args.pp_size + provider.context_parallel_size = 1 + provider.sequence_parallel = args.tp_size > 1 + provider.seq_length = args.seq_length + provider.pipeline_dtype = torch.bfloat16 + return provider + + # TODO: Support megatron-ckpt as an alternative to HF checkpoints (e.g. /path/to/ckpt/iter_0000000) + # Still requires an HF model name or path to build provider correctly + student_provider = _build_model_provider(args.student_hf_path) + teacher_provider = _build_model_provider(args.teacher_hf_path) + + # Wrap into DistillationProvider + kd_config = ModelOptDistillConfig() + distill_provider = convert_to_distillation_provider( + student_provider, teacher_provider, kd_config + ) + + # Build optimizer and scheduler + optimizer_config, scheduler_config = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=args.lr_warmup_iters, + max_lr=args.lr, + min_lr=args.min_lr, + adam_beta2=0.98, + ) + + # Build dataset config + dataset_kwargs = { + "seq_length": args.seq_length, + "path_to_cache": args.data_path_to_cache, + "random_seed": SEED, + "reset_attention_mask": False, + "reset_position_ids": False, + "eod_mask_loss": False, + "num_dataset_builder_threads": 1, + "data_sharding": True, + "dataloader_type": "single", + "skip_getting_attention_mask_from_dataset": True, + } + if args.use_mock_data: + dataset_config = MockGPTDatasetConfig(**dataset_kwargs) + else: + # Convert flat CLI list (e.g. ["1.0", "/path/data"]) to Megatron blend format + blend = get_blend_from_list(args.data_paths) + dataset_config = GPTDatasetConfig(blend=blend, split=args.split, **dataset_kwargs) + + # Assemble ConfigContainer and run distillation + config = ConfigContainer( + model=distill_provider, + train=TrainingConfig( + train_iters=args.train_iters, + eval_interval=args.eval_interval, + eval_iters=args.eval_iters, + global_batch_size=args.gbs, + micro_batch_size=args.mbs, + manual_gc=True, + manual_gc_interval=100, + ), + # TODO: Replace validation args in train with validation config in nemo:26.04 + # validation=ValidationConfig(eval_interval=args.eval_interval, eval_iters=args.eval_iters), + optimizer=optimizer_config, + scheduler=scheduler_config, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ), + dataset=dataset_config, + logger=LoggerConfig( + log_interval=args.log_interval, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + # Weights & Biases logging + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, # optional + wandb_exp_name=args.wandb_exp_name, + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", vocab_size=distill_provider.vocab_size + ), + checkpoint=CheckpointConfig( + save_interval=args.eval_interval, + save=checkpoint_dir, + load=checkpoint_dir, # Resume from this directory (if exists) + most_recent_k=3, # Keeps 3 most recent checkpoints (not metric-based) + ckpt_format="torch_dist", + async_save=True, + fully_parallel_save=True, + ), + rng=RNGConfig(seed=SEED), + mixed_precision="bf16_mixed", + ) + + print_rank_0("\nStarting distillation...") + distill(config) + print_rank_0(f"\nDistillation done! Saved checkpoint to {checkpoint_dir}\n") + + +if __name__ == "__main__": + dist.setup() + args = get_args() + try: + main(args) + finally: + dist.cleanup() From f2b65ec45d79aca02f669da170a9523b585479ce Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Sun, 22 Feb 2026 23:53:18 -0800 Subject: [PATCH 08/16] Replace null tokenizer with a teacher tokenizer Signed-off-by: Daniel Korzekwa --- .../puzzletron/mbridge_distillation/distill_hf_keval.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/puzzletron/mbridge_distillation/distill_hf_keval.py b/examples/puzzletron/mbridge_distillation/distill_hf_keval.py index 31f1cfc71..8b127c472 100644 --- a/examples/puzzletron/mbridge_distillation/distill_hf_keval.py +++ b/examples/puzzletron/mbridge_distillation/distill_hf_keval.py @@ -221,7 +221,11 @@ def _build_model_provider(hf_path): wandb_exp_name=args.wandb_exp_name, ), tokenizer=TokenizerConfig( - tokenizer_type="NullTokenizer", vocab_size=distill_provider.vocab_size + tokenizer_type="HuggingFaceTokenizer", + # Use teacher tokenizer as the source of knowledge; fallback to student if teacher unavailable + # In distillation, both models should use the same tokenizer to process the same input + tokenizer_model=args.teacher_hf_path, + vocab_size=distill_provider.vocab_size, ), checkpoint=CheckpointConfig( save_interval=args.eval_interval, From e385b72cc9cbcf214981b460e048e2307731e5ab Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 23 Feb 2026 00:01:03 -0800 Subject: [PATCH 09/16] Ensure exception is printed (and not lost during distributed run) Signed-off-by: Daniel Korzekwa --- .../mbridge_distillation/distill_hf_keval.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/examples/puzzletron/mbridge_distillation/distill_hf_keval.py b/examples/puzzletron/mbridge_distillation/distill_hf_keval.py index 8b127c472..a4750ead7 100644 --- a/examples/puzzletron/mbridge_distillation/distill_hf_keval.py +++ b/examples/puzzletron/mbridge_distillation/distill_hf_keval.py @@ -22,6 +22,7 @@ import argparse import os +import traceback import torch from megatron.bridge import AutoBridge @@ -221,9 +222,13 @@ def _build_model_provider(hf_path): wandb_exp_name=args.wandb_exp_name, ), tokenizer=TokenizerConfig( + # TODO This replaced tokenizer_type="NullTokenizer" + # Why NullTokenizer is not working with container nvidian+nemo+26.02.rc5 and why was it + # used in the first place? tokenizer_type="HuggingFaceTokenizer", - # Use teacher tokenizer as the source of knowledge; fallback to student if teacher unavailable - # In distillation, both models should use the same tokenizer to process the same input + # Use teacher tokenizer as the source of knowledge; + # In distillation, both student and teacher models should use the same tokenizer to + # process the same input tokenizer_model=args.teacher_hf_path, vocab_size=distill_provider.vocab_size, ), @@ -250,5 +255,9 @@ def _build_model_provider(hf_path): args = get_args() try: main(args) + except Exception as e: + print_rank_0(f"✗ MAIN FAILED: {type(e).__name__}: {e}") + print_rank_0(f"Traceback:\n{traceback.format_exc()}") + raise finally: dist.cleanup() From fbcf2a68b68712eab52c4e6f57c09facd5805015 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 23 Feb 2026 00:58:55 -0800 Subject: [PATCH 10/16] Add anymodel support to mbridge distillation. Signed-off-by: Daniel Korzekwa --- .../mbridge_distillation/distill_hf_keval.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/examples/puzzletron/mbridge_distillation/distill_hf_keval.py b/examples/puzzletron/mbridge_distillation/distill_hf_keval.py index a4750ead7..e00a043ed 100644 --- a/examples/puzzletron/mbridge_distillation/distill_hf_keval.py +++ b/examples/puzzletron/mbridge_distillation/distill_hf_keval.py @@ -24,9 +24,9 @@ import os import traceback +import megatron.bridge.models.distillation_provider import torch from megatron.bridge import AutoBridge -from megatron.bridge.models.distillation_provider import convert_to_distillation_provider from megatron.bridge.recipes.utils.optimizer_utils import ( distributed_fused_adam_with_cosine_annealing, ) @@ -45,9 +45,29 @@ from megatron.core.datasets.utils import get_blend_from_list from megatron.core.distributed import DistributedDataParallelConfig +# Import heterogeneous bridges BEFORE AutoBridge.from_hf_pretrained() is called to ensure +# registration takes precedence. The @MegatronModelBridge.register_bridge decorator registers +# bridges when the module is imported. If both LlamaBridge and PuzzletronLlamaAnyModelBridge +# register for the same source (LlamaForCausalLM), the dispatch system uses the last registration. +# +# Note: Currently, bridges are also registered when distillation_provider is imported +# below (via mbridge/__init__.py), but this import will be needed once DistillationProvider +# is upstreamed to Megatron-Bridge and we no longer import from modelopt.torch.puzzletron. +import modelopt.torch.puzzletron.export.mbridge # noqa: F401 import modelopt.torch.utils.distributed as dist + +# Use local copy of distillation_provider with fix for heterogeneous models +# TODO: Remove this local copy once fix is upstreamed to Megatron-Bridge +from modelopt.torch.puzzletron.export.mbridge.distillation_provider import ( + DistillationProvider, + convert_to_distillation_provider, +) from modelopt.torch.utils import print_rank_0 +# Patch upstream module so isinstance checks in distill() work with our local DistillationProvider +# This must come after all imports since it modifies an imported module +megatron.bridge.models.distillation_provider.DistillationProvider = DistillationProvider + SEED = 1234 From 315e1c0d6f04a0dbb69d8fd22972967719278d81 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 23 Feb 2026 01:26:27 -0800 Subject: [PATCH 11/16] Improve error handling Signed-off-by: Daniel Korzekwa --- .../mbridge_distillation/distill_hf_keval.py | 9 ++++++--- modelopt/torch/puzzletron/export/mbridge/base.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/examples/puzzletron/mbridge_distillation/distill_hf_keval.py b/examples/puzzletron/mbridge_distillation/distill_hf_keval.py index e00a043ed..2d9cd1ec0 100644 --- a/examples/puzzletron/mbridge_distillation/distill_hf_keval.py +++ b/examples/puzzletron/mbridge_distillation/distill_hf_keval.py @@ -40,7 +40,6 @@ TokenizerConfig, TrainingConfig, ) -from megatron.bridge.training.distill import distill from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig from megatron.core.datasets.utils import get_blend_from_list from megatron.core.distributed import DistributedDataParallelConfig @@ -64,10 +63,14 @@ ) from modelopt.torch.utils import print_rank_0 -# Patch upstream module so isinstance checks in distill() work with our local DistillationProvider -# This must come after all imports since it modifies an imported module +# Patch upstream module BEFORE importing distill() so isinstance checks work with our local DistillationProvider +# This must happen before distill() is imported because distill.py imports DistillationProvider at module load time megatron.bridge.models.distillation_provider.DistillationProvider = DistillationProvider +# Import distill() AFTER patching so it uses the patched DistillationProvider + +from megatron.bridge.training.distill import distill # noqa: E402 + SEED = 1234 diff --git a/modelopt/torch/puzzletron/export/mbridge/base.py b/modelopt/torch/puzzletron/export/mbridge/base.py index eff9f0b73..c21b2da47 100644 --- a/modelopt/torch/puzzletron/export/mbridge/base.py +++ b/modelopt/torch/puzzletron/export/mbridge/base.py @@ -52,11 +52,16 @@ class GenericHeterogeneousProvider(GPTModelProvider, HeterogeneousTransformerCon def __getattr__(self, name: str): """Handle missing attributes for OmegaConf compatibility. - OmegaConf conversion tries to access per_block_parameters which may not - be initialized when loading from YAML. Return empty list as fallback. + Returns empty list for per_block_parameters if not yet initialized (before finalize()). + This allows OmegaConf to serialize/deserialize configs without errors. Actual usage + should call finalize() first to set per_block_parameters as a real attribute. """ if name == "per_block_parameters": - return [] + # Return existing attribute if set, otherwise [] for OmegaConf compatibility + try: + return object.__getattribute__(self, name) + except AttributeError: + return [] raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") From 9a98e88ce151e33dddc1f8f8a59b45173312604b Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 23 Feb 2026 01:53:38 -0800 Subject: [PATCH 12/16] Improve mbridge distillation readme Signed-off-by: Daniel Korzekwa --- .../puzzletron/mbridge_distillation/README.md | 52 ++++--------------- 1 file changed, 11 insertions(+), 41 deletions(-) diff --git a/examples/puzzletron/mbridge_distillation/README.md b/examples/puzzletron/mbridge_distillation/README.md index 1ae84e65f..5dd64084d 100644 --- a/examples/puzzletron/mbridge_distillation/README.md +++ b/examples/puzzletron/mbridge_distillation/README.md @@ -10,58 +10,28 @@ This guide shows how to perform knowledge distillation on Puzzletron-compressed ## Setup -> **Temporary Setup:** The NeMo docker container includes Megatron-Bridge (main branch), but Puzzletron requires a specific version/branch of Megatron-Bridge that is not included by default. This manual setup is required to use the Puzzletron-compatible version. Once the container includes the required version, this setup step will no longer be necessary. +**Start Docker container:** -**Note:** Set `$WORKSPACE` to your project root directory before running these commands: +Use the [NeMo 26.02 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=26.02): ```bash -export WORKSPACE=/path/to/your/project +docker run --gpus all -it --rm \ + -v /path/to/your/project:/workspace \ + nvcr.io/nvidia/nemo:26.02 \ + /bin/bash ``` -1. **Clone Megatron-Bridge:** +**Set up the environment inside the container:** - Clone [Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) and checkout the specific commit required for Puzzletron: - - ```bash - cd $WORKSPACE - git clone https://github.com/NVIDIA-NeMo/Megatron-Bridge.git - cd Megatron-Bridge - git checkout 960a718cb8989676b258e107d538642717e22e39 - ``` - -2. **Initialize Megatron-Bridge submodules:** - - ```bash - cd $WORKSPACE/Megatron-Bridge - git submodule init - git submodule update - ``` - -3. **Start Docker container with mounts:** - - Use the [NeMo 25.11 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=25.11): - - ```bash - docker run --gpus all -it --rm \ - -v $WORKSPACE:/workspace \ - -v $WORKSPACE/Megatron-Bridge/3rdparty/Megatron-LM:/opt/megatron-lm \ - nvcr.io/nvidia/nemo:25.11 \ - /bin/bash - ``` - - **Note:** The mount `/opt/megatron-lm` is required because Megatron-Bridge depends on the Megatron-LM submodule. - -4. **Set up the environment inside the container:** - - ```bash - export PYTHONPATH="/workspace/Megatron-Bridge/src:/workspace/Model-Optimizer:${PYTHONPATH}" - ``` +```bash +export PYTHONPATH="/workspace/Model-Optimizer:${PYTHONPATH}" +``` ## Dataset Preparation This section describes how to prepare datasets for knowledge distillation. We provide examples using a toy dataset (WikiText-103) for illustration purposes, and note how to adapt the process for production datasets like Nemotron-Post-Training-Dataset-v2. -> **Note:** For actual knowledge distillation, use a larger, more representative dataset like [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2). +> **Note:** The WikiText-103 dataset is a small toy dataset used here only for illustration. For actual knowledge distillation, use a larger, more representative dataset like [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2). ### Step 1: Download Dataset From d5997cd673c7f231ee59b6973cac96329b1be5a6 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 23 Feb 2026 02:08:05 -0800 Subject: [PATCH 13/16] Update mbridge distillation readme. Signed-off-by: Daniel Korzekwa --- .../puzzletron/mbridge_distillation/README.md | 55 ++++++++----------- 1 file changed, 23 insertions(+), 32 deletions(-) diff --git a/examples/puzzletron/mbridge_distillation/README.md b/examples/puzzletron/mbridge_distillation/README.md index 5dd64084d..b485ed5a2 100644 --- a/examples/puzzletron/mbridge_distillation/README.md +++ b/examples/puzzletron/mbridge_distillation/README.md @@ -5,8 +5,8 @@ This guide shows how to perform knowledge distillation on Puzzletron-compressed ## Overview 1. Set up the environment with Megatron-Bridge -2. Convert AnyModel checkpoints (student and teacher) to Megatron-Bridge format -3. Run knowledge distillation training +2. Prepare tokenized dataset +3. Run knowledge distillation training directly from HuggingFace checkpoints ## Setup @@ -78,39 +78,30 @@ megatron_preprocess_data( ) ``` -## Step 1: Convert Checkpoints to Megatron-Bridge Format +## Step 1: Run Knowledge Distillation -Convert both student and teacher checkpoints: +Run distillation directly from HuggingFace checkpoints (student and teacher) with tokenized dataset: ```bash -# Convert student checkpoint -torchrun --nproc_per_node=1 examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py \ - --input-ckpt-path /path/to/student/anymodel/checkpoint \ - --output-ckpt-path /path/to/student/mbridge/checkpoint - -# Convert teacher checkpoint -torchrun --nproc_per_node=1 examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py \ - --input-ckpt-path /path/to/teacher/anymodel/checkpoint \ - --output-ckpt-path /path/to/teacher/mbridge/checkpoint +torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf_keval.py \ + --student_hf_path /path/to/student/huggingface/checkpoint \ + --teacher_hf_path /path/to/teacher/huggingface/checkpoint \ + --data_paths 1.0 /path/to/tokenized/dataset \ + --output_dir /workspace/mbridge_distillation/distilled_student \ + --seq_length 4096 \ + --tp_size 8 \ + --pp_size 1 \ + --mbs 1 \ + --gbs 4 \ + --train_iters 100 \ + --lr 0.0001 \ + --min_lr 1e-05 \ + --lr_warmup_iters 10 \ + --eval_interval 10 \ + --eval_iters 10 \ + --log_interval 1 ``` -## Step 2: Run Knowledge Distillation +The distilled checkpoint will be saved to `--output_dir`. -Run distillation with tokenized dataset: - -```bash -torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_anymodel.py \ - --student-mbridge-ckpt /path/to/student/mbridge/checkpoint/iter_0000000 \ - --teacher-mbridge-ckpt /path/to/teacher/mbridge/checkpoint/iter_0000000 \ - --data-path /path/to/tokenized/dataset \ - --output-dir ./distilled_output \ - dataset.sequence_length=8192 \ - model.tensor_model_parallel_size=8 \ - model.teacher.tensor_model_parallel_size=8 \ - train.global_batch_size=4 \ - train.micro_batch_size=1 \ - train.train_iters=5000 \ - logger.log_interval=1 -``` - -The distilled checkpoint will be saved to `--output-dir`. +**Note:** The script automatically converts HuggingFace checkpoints to Megatron-Bridge format on-the-fly, so no separate import step is needed. From 2a611c6f168a84aeec32fc1789677f5b712000c3 Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 23 Feb 2026 02:36:03 -0800 Subject: [PATCH 14/16] delete old code Signed-off-by: Daniel Korzekwa --- .../mbridge_distillation/distill_anymodel.py | 295 ------------------ .../import_anymodel_to_mbridge.py | 92 ------ 2 files changed, 387 deletions(-) delete mode 100644 examples/puzzletron/mbridge_distillation/distill_anymodel.py delete mode 100644 examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py diff --git a/examples/puzzletron/mbridge_distillation/distill_anymodel.py b/examples/puzzletron/mbridge_distillation/distill_anymodel.py deleted file mode 100644 index 5d6e3eb9c..000000000 --- a/examples/puzzletron/mbridge_distillation/distill_anymodel.py +++ /dev/null @@ -1,295 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Knowledge Distillation Script for AnyModel Checkpoints. - -This script performs knowledge distillation between student and teacher models -that have been converted to Megatron-Bridge format using import_anymodel_to_mbridge.py. - -The distillation uses KL-Divergence loss between student and teacher logits -with temperature scaling (standard knowledge distillation from Hinton et al., 2015). - -Usage: - cd /workspace/Model-Optimizer - - # TODO: remove this once Megatron-Bridge is installed in the environment - export PYTHONPATH="/workspace/Megatron-Bridge/src:/workspace/Model-Optimizer:${PYTHONPATH}" - - # Basic usage (uses model's max seq_length, which may be very large) - torchrun --nproc_per_node=1 examples/puzzletron/mbridge_distillation/distill_anymodel.py \ - --student-mbridge-ckpt /path/to/student/iter_0000000 \ - --teacher-mbridge-ckpt /path/to/teacher/iter_0000000 \ - --data-path /path/to/tokenized/dataset \ - --output-dir ./distilled_output - - # Recommended: Override sequence length and other training params for faster training - torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_anymodel.py \ - --student-mbridge-ckpt /path/to/student/iter_0000000 \ - --teacher-mbridge-ckpt /path/to/teacher/iter_0000000 \ - --data-path /path/to/tokenized/dataset \ - --output-dir ./distilled_output \ - dataset.sequence_length=8192 \ - model.tensor_model_parallel_size=8 \ - model.teacher.tensor_model_parallel_size=8 \ - train.global_batch_size=4 \ - train.micro_batch_size=1 \ - train.train_iters=5000 \ - logger.log_interval=1 -""" - -import argparse -import logging -import os -import sys - -import torch -from megatron.bridge.models.distillation_provider import convert_to_distillation_provider -from megatron.bridge.training.checkpointing import get_checkpoint_run_config_filename -from megatron.bridge.training.config import ( - CheckpointConfig, - ConfigContainer, - DistributedDataParallelConfig, - DistributedInitConfig, - GPTDatasetConfig, - LoggerConfig, - OptimizerConfig, - RerunStateMachineConfig, - RNGConfig, - SchedulerConfig, - TrainingConfig, - ValidationConfig, -) -from megatron.bridge.training.distill import distill -from megatron.bridge.training.model_load_save import load_model_config -from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig -from megatron.bridge.training.tokenizers.config import TokenizerConfig -from megatron.bridge.training.utils.omegaconf_utils import ( - apply_overrides, - create_omegaconf_dict_config, - parse_hydra_overrides, -) -from megatron.bridge.utils.common_utils import get_rank_safe -from omegaconf import OmegaConf - -# Import GenericHeterogeneousProvider so it can be instantiated when loading -# checkpoint configs that reference it (e.g., run_config.yaml with -# _target_: modelopt.torch.puzzletron.export.mbridge.base.GenericHeterogeneousProvider) -import modelopt.torch.puzzletron.export.mbridge # noqa: F401 - -logger: logging.Logger = logging.getLogger(__name__) - - -def create_distillation_config() -> ModelOptDistillConfig: - """Create KD config with output layer distillation only.""" - return ModelOptDistillConfig( - logit_layers=["output_layer", "output_layer"], - intermediate_layer_pairs=[], - skip_lm_loss=True, - kd_loss_scale=1.0, - logit_kl_temperature=1.0, - ) - - -def create_base_config( - student_model_provider, - data_path: str, - student_ckpt: str, - output_dir: str, - use_bf16: bool, - use_fp16: bool, -) -> ConfigContainer: - """Create base ConfigContainer with defaults.""" - return ConfigContainer( - model=student_model_provider, - train=TrainingConfig(global_batch_size=1, micro_batch_size=1, train_iters=100), - optimizer=OptimizerConfig( - optimizer="adam", - lr=1e-4, - min_lr=1e-5, - weight_decay=0.01, - bf16=use_bf16, - fp16=use_fp16, - ), - scheduler=SchedulerConfig( - lr_decay_style="linear", - lr_warmup_iters=0, - start_weight_decay=0.01, - end_weight_decay=0.01, - weight_decay_incr_style="constant", - ), - dataset=GPTDatasetConfig( - random_seed=1234, - blend=[[data_path], [1.0]], - split="9999,8,2", - seq_length=student_model_provider.seq_length, - reset_attention_mask=False, - reset_position_ids=False, - eod_mask_loss=False, - dataloader_type="single", - ), - checkpoint=CheckpointConfig(load=student_ckpt, save=output_dir), - logger=LoggerConfig(), - tokenizer=TokenizerConfig(tokenizer_type="HuggingFaceTokenizer", tokenizer_model=None), - validation=ValidationConfig(eval_interval=500, eval_iters=100), - ddp=DistributedDataParallelConfig(grad_reduce_in_fp32=True), - dist=DistributedInitConfig(), - rng=RNGConfig(), - rerun_state_machine=RerunStateMachineConfig(), - ) - - -def merge_checkpoint_configs( - cfg: ConfigContainer, checkpoint_path: str, use_bf16: bool, use_fp16: bool -) -> None: - """Merge tokenizer and optimizer configs from checkpoint if available.""" - try: - run_config_path = get_checkpoint_run_config_filename(checkpoint_path) - checkpoint_cfg = ConfigContainer.from_yaml(run_config_path) - if checkpoint_cfg.tokenizer is not None: - cfg.tokenizer = checkpoint_cfg.tokenizer - if checkpoint_cfg.optimizer is not None: - for key, value in checkpoint_cfg.optimizer.__dict__.items(): - if ( - value is not None - and hasattr(cfg.optimizer, key) - and key not in ("bf16", "fp16") - ): - setattr(cfg.optimizer, key, value) - # Ensure bf16/fp16 are set correctly based on model dtype - cfg.optimizer.bf16 = use_bf16 - cfg.optimizer.fp16 = use_fp16 - except Exception as e: - logger.warning(f"Could not load additional configs from checkpoint: {e}") - - -def parse_cli_args() -> tuple[argparse.Namespace, list[str]]: - """Parse command line arguments.""" - parser = argparse.ArgumentParser( - description="Knowledge distillation with AnyModel checkpoints", - formatter_class=argparse.RawTextHelpFormatter, - ) - parser.add_argument( - "--student-mbridge-ckpt", - type=str, - required=True, - help="Path to student checkpoint in MBridge format (must be iter_XXXXXXX directory).", - ) - parser.add_argument( - "--teacher-mbridge-ckpt", - type=str, - required=True, - help="Path to teacher checkpoint in MBridge format (must be iter_XXXXXXX directory).", - ) - parser.add_argument( - "--data-path", - type=str, - required=True, - help="Path to tokenized dataset (without .bin extension).", - ) - parser.add_argument( - "--output-dir", - type=str, - default="./distilled_output", - help="Output directory for distilled checkpoint.", - ) - parser.add_argument( - "--config-file", - type=str, - default=None, - help="Path to YAML OmegaConf override file (optional).", - ) - parser.add_argument("--debug", action="store_true", help="Enable debug logging") - - args, cli_overrides = parser.parse_known_args() - return args, cli_overrides - - -def main() -> None: - """Main distillation function.""" - args, cli_overrides = parse_cli_args() - - if args.debug: - logging.basicConfig(level=logging.DEBUG) - else: - logging.basicConfig(level=logging.INFO) - - logger.info("Megatron-Bridge Knowledge Distillation Script (AnyModel)") - logger.info("=" * 70) - - # Load model configs from checkpoints - logger.info("Loading model configs from MBridge checkpoints...") - student_model_provider, _ = load_model_config(args.student_mbridge_ckpt) - teacher_model_provider, _ = load_model_config(args.teacher_mbridge_ckpt) - - # Detect model dtype for optimizer config - model_params_dtype = getattr(student_model_provider, "params_dtype", torch.float32) - use_bf16 = model_params_dtype == torch.bfloat16 - use_fp16 = model_params_dtype == torch.float16 - - # Create base config with defaults - cfg = create_base_config( - student_model_provider, - args.data_path, - args.student_mbridge_ckpt, - args.output_dir, - use_bf16, - use_fp16, - ) - - # Merge tokenizer and optimizer from checkpoint if available - merge_checkpoint_configs(cfg, args.student_mbridge_ckpt, use_bf16, use_fp16) - - # Create distillation config and convert to DistillationProvider - kd_config = create_distillation_config() - cfg.model = convert_to_distillation_provider( - student_provider=student_model_provider, - teacher_provider=teacher_model_provider, - kd_config=kd_config, - ) - - # Apply YAML and CLI overrides - merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg) - if args.config_file: - if not os.path.exists(args.config_file): - logger.error(f"Override YAML file not found: {args.config_file}") - sys.exit(1) - yaml_overrides_omega = OmegaConf.load(args.config_file) - merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega) - if cli_overrides: - merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides) - apply_overrides(cfg, OmegaConf.to_container(merged_omega_conf, resolve=True), excluded_fields) - - # Sync model seq_length with dataset sequence_length if they differ - if ( - hasattr(cfg.dataset, "sequence_length") - and cfg.dataset.sequence_length != cfg.model.seq_length - ): - cfg.model.seq_length = cfg.dataset.sequence_length - if hasattr(cfg.model, "teacher") and cfg.model.teacher is not None: - cfg.model.teacher.seq_length = cfg.dataset.sequence_length - - if get_rank_safe() == 0: - logger.info("--- Final Configuration ---") - cfg.print_yaml() - - # Run distillation - logger.info("Starting distillation training...") - distill(cfg) - - logger.info(f"✓ Distillation complete! Checkpoints saved to: {args.output_dir}") - - -if __name__ == "__main__": - main() diff --git a/examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py b/examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py deleted file mode 100644 index 1127d228d..000000000 --- a/examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py +++ /dev/null @@ -1,92 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Import AnyModel checkpoint to Megatron-Bridge format. - -This script converts a HuggingFace AnyModel checkpoint to Megatron-Bridge format, -similar to NeMo's convert_hf_to_nemo. - -Usage: - cd /workspace/Model-Optimizer - - # TODO: remove this once Megatron-Bridge is installed in the environment - export PYTHONPATH="/workspace/Megatron-Bridge/src:/workspace/Model-Optimizer:${PYTHONPATH}" - - torchrun --nproc_per_node=1 examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py \ - --input-ckpt-path /path/to/anymodel/checkpoint \ - --output-ckpt-path /path/to/save/mbridge/checkpoint -""" - -import argparse -from pathlib import Path - -from megatron.bridge import AutoBridge - -# Import all heterogeneous bridges to register them -# This will override homogeneous bridges (e.g., LlamaBridge, Qwen3Bridge) with -# heterogeneous versions (PuzzletronLlamaAnyModelBridge, PuzzletronQwen3AnyModelBridge) -# that support block_configs for AnyModel checkpoints. -import modelopt.torch.puzzletron.export.mbridge # noqa: F401 - - -def parse_args() -> argparse.Namespace: - """Parse command line arguments.""" - parser = argparse.ArgumentParser( - description="Convert AnyModel checkpoint to Megatron-Bridge format", - formatter_class=argparse.RawTextHelpFormatter, - ) - parser.add_argument( - "--input-ckpt-path", - type=str, - required=True, - help="Path to input AnyModel checkpoint (HuggingFace format)", - ) - parser.add_argument( - "--output-ckpt-path", - type=str, - required=True, - help="Path to save output Megatron-Bridge checkpoint", - ) - return parser.parse_args() - - -def main() -> None: - """Main function to import HF AnyModel and save as Megatron checkpoint.""" - args = parse_args() - - input_path = Path(args.input_ckpt_path) - output_path = Path(args.output_ckpt_path) - - print(f"Importing AnyModel checkpoint from: {input_path}") - print(f"Saving Megatron-Bridge checkpoint to: {output_path}") - print() - - # Create output directory if it doesn't exist - output_path.mkdir(parents=True, exist_ok=True) - - # Import and save as Megatron checkpoint - AutoBridge.import_ckpt( - hf_model_id=str(input_path), - megatron_path=str(output_path), - trust_remote_code=True, - ) - - print(f"\n✓ Successfully saved Megatron-Bridge checkpoint to: {output_path}") - - -if __name__ == "__main__": - main() From 66b9fdde1c21d24535ea30572a40076b55fb449f Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 23 Feb 2026 04:32:23 -0800 Subject: [PATCH 15/16] Make mbrdige distillation to work with nvidian+nemo+26.02.rc5 Signed-off-by: Daniel Korzekwa --- .../mbridge_distillation/distill_hf_keval.py | 10 +----- .../torch/puzzletron/export/mbridge/base.py | 32 +++++++++++++++++-- .../torch/puzzletron/export/mbridge/llama.py | 2 +- .../torch/puzzletron/export/mbridge/qwen3.py | 2 +- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/examples/puzzletron/mbridge_distillation/distill_hf_keval.py b/examples/puzzletron/mbridge_distillation/distill_hf_keval.py index 2d9cd1ec0..f3986741a 100644 --- a/examples/puzzletron/mbridge_distillation/distill_hf_keval.py +++ b/examples/puzzletron/mbridge_distillation/distill_hf_keval.py @@ -245,15 +245,7 @@ def _build_model_provider(hf_path): wandb_exp_name=args.wandb_exp_name, ), tokenizer=TokenizerConfig( - # TODO This replaced tokenizer_type="NullTokenizer" - # Why NullTokenizer is not working with container nvidian+nemo+26.02.rc5 and why was it - # used in the first place? - tokenizer_type="HuggingFaceTokenizer", - # Use teacher tokenizer as the source of knowledge; - # In distillation, both student and teacher models should use the same tokenizer to - # process the same input - tokenizer_model=args.teacher_hf_path, - vocab_size=distill_provider.vocab_size, + tokenizer_type="NullTokenizer", vocab_size=distill_provider.vocab_size ), checkpoint=CheckpointConfig( save_interval=args.eval_interval, diff --git a/modelopt/torch/puzzletron/export/mbridge/base.py b/modelopt/torch/puzzletron/export/mbridge/base.py index c21b2da47..54280400d 100644 --- a/modelopt/torch/puzzletron/export/mbridge/base.py +++ b/modelopt/torch/puzzletron/export/mbridge/base.py @@ -24,7 +24,7 @@ import dataclasses import json from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, fields from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM @@ -73,12 +73,40 @@ class HeterogeneousBridgeMixin: """ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider: - """Convert HF AnyModel config to Megatron GPTModelProvider.""" + """Convert HF AnyModel config to Megatron GPTModelProvider. + + This method: + 1. Calls the parent bridge's provider_bridge() to get a GPTModelProvider with all + model-specific settings (e.g., LlamaBridge sets normalization="RMSNorm", etc.) + 2. Converts the provider to a dict and filters to only fields accepted by + GenericHeterogeneousProvider (which inherits from GPTModelProvider, so all valid + GPTModelProvider fields are preserved) + 3. Adds heterogeneous configuration and returns GenericHeterogeneousProvider + + All parameters from the parent bridge (e.g., LlamaBridge) are maintained because + GenericHeterogeneousProvider inherits from GPTModelProvider, which includes all + the fields that the parent bridge sets. + """ parent_provider = super().provider_bridge(hf_pretrained) # type: ignore[misc] provider_kwargs = dataclasses.asdict(parent_provider) + # Filter to only fields that GenericHeterogeneousProvider accepts. + # GenericHeterogeneousProvider inherits from GPTModelProvider, so it includes all + # GPTModelProvider fields. Model-specific fields from subclasses (e.g., MistralModelProvider, + # GPTOSSModelProvider) are filtered out because GenericHeterogeneousProvider only inherits + # from GPTModelProvider, not from model-specific subclasses. + # + # Note: This logic may not work for bridges like MistralBridge or GPTOSSBridge if they + # use model-specific parameters not supported by GenericHeterogeneousProvider (e.g., + # scale_factor, yarn_rotary_scaling_factor, moe_* parameters). In such cases, create a + # model-specific heterogeneous provider that inherits from the model-specific provider. + valid_fields = {f.name for f in fields(GenericHeterogeneousProvider)} + + # Only keep kwargs that are valid fields + provider_kwargs = {k: v for k, v in provider_kwargs.items() if k in valid_fields} + provider_kwargs["heterogeneous_layers_config_encoded_json"] = ( self._build_heterogeneous_config_json(hf_pretrained.config) ) diff --git a/modelopt/torch/puzzletron/export/mbridge/llama.py b/modelopt/torch/puzzletron/export/mbridge/llama.py index d193a02e9..b80221529 100644 --- a/modelopt/torch/puzzletron/export/mbridge/llama.py +++ b/modelopt/torch/puzzletron/export/mbridge/llama.py @@ -24,7 +24,7 @@ from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin -@MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel, model_type="llama") +@MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel) class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge): """ Megatron Bridge for Puzzletron Llama-based AnyModel checkpoints. diff --git a/modelopt/torch/puzzletron/export/mbridge/qwen3.py b/modelopt/torch/puzzletron/export/mbridge/qwen3.py index 664e87e08..ace20fbf8 100644 --- a/modelopt/torch/puzzletron/export/mbridge/qwen3.py +++ b/modelopt/torch/puzzletron/export/mbridge/qwen3.py @@ -24,7 +24,7 @@ from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin -@MegatronModelBridge.register_bridge(source=Qwen3ForCausalLM, target=GPTModel, model_type="qwen3") +@MegatronModelBridge.register_bridge(source=Qwen3ForCausalLM, target=GPTModel) class PuzzletronQwen3AnyModelBridge(HeterogeneousBridgeMixin, Qwen3Bridge): """ Megatron Bridge for Puzzletron Qwen3-based AnyModel checkpoints. From 22cfa6c82abd2b4561caf79c082abeeb34badb8b Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Mon, 23 Feb 2026 04:57:35 -0800 Subject: [PATCH 16/16] rename distill_hf_keval.py to distill_hf.py Signed-off-by: Daniel Korzekwa --- examples/puzzletron/mbridge_distillation/README.md | 2 +- .../mbridge_distillation/{distill_hf_keval.py => distill_hf.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename examples/puzzletron/mbridge_distillation/{distill_hf_keval.py => distill_hf.py} (100%) diff --git a/examples/puzzletron/mbridge_distillation/README.md b/examples/puzzletron/mbridge_distillation/README.md index b485ed5a2..332dcd772 100644 --- a/examples/puzzletron/mbridge_distillation/README.md +++ b/examples/puzzletron/mbridge_distillation/README.md @@ -83,7 +83,7 @@ megatron_preprocess_data( Run distillation directly from HuggingFace checkpoints (student and teacher) with tokenized dataset: ```bash -torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf_keval.py \ +torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf.py \ --student_hf_path /path/to/student/huggingface/checkpoint \ --teacher_hf_path /path/to/teacher/huggingface/checkpoint \ --data_paths 1.0 /path/to/tokenized/dataset \ diff --git a/examples/puzzletron/mbridge_distillation/distill_hf_keval.py b/examples/puzzletron/mbridge_distillation/distill_hf.py similarity index 100% rename from examples/puzzletron/mbridge_distillation/distill_hf_keval.py rename to examples/puzzletron/mbridge_distillation/distill_hf.py