diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md index 48f64d3c4..426186d93 100644 --- a/examples/puzzletron/README.md +++ b/examples/puzzletron/README.md @@ -275,21 +275,9 @@ vllm bench throughput --model path/to/model --input-len 2000 --output-len 100 -- ## Knowledge Distillation -To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one. For this, we will use [NeMo framework](https://github.com/NVIDIA-NeMo/NeMo) with the [nemo:25.07](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=25.07) container. +To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one. -First, convert the HF model to NeMo format: - -```bash -python -m nemo_export/convert_hf_to_nemo --input-ckpt-path path/to/HF-model --output-ckpt-path path/to/save/model-nemo -``` - -Now you can utilize all the training features available in NeMo, including distillation. Please refer to the [NeMo distillation documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html). - -[Optional] Once distillation is complete, you can convert the distilled model back to the HuggingFace format. - -```bash -python -m nemo_export/convert_nemo_to_hf --input-ckpt-path path/to/nemo-model --output-ckpt-path path/to/save/model-HF -``` +See [mbridge_distillation/README.md](./mbridge_distillation/README.md) for instructions on using Megatron-Bridge for knowledge distillation. ## Advanced Usage diff --git a/examples/puzzletron/mbridge_distillation/README.md b/examples/puzzletron/mbridge_distillation/README.md new file mode 100644 index 000000000..1ae84e65f --- /dev/null +++ b/examples/puzzletron/mbridge_distillation/README.md @@ -0,0 +1,146 @@ +# Knowledge Distillation with Megatron-Bridge + +This guide shows how to perform knowledge distillation on Puzzletron-compressed AnyModel checkpoints using Megatron-Bridge. + +## Overview + +1. Set up the environment with Megatron-Bridge +2. Convert AnyModel checkpoints (student and teacher) to Megatron-Bridge format +3. Run knowledge distillation training + +## Setup + +> **Temporary Setup:** The NeMo docker container includes Megatron-Bridge (main branch), but Puzzletron requires a specific version/branch of Megatron-Bridge that is not included by default. This manual setup is required to use the Puzzletron-compatible version. Once the container includes the required version, this setup step will no longer be necessary. + +**Note:** Set `$WORKSPACE` to your project root directory before running these commands: + +```bash +export WORKSPACE=/path/to/your/project +``` + +1. **Clone Megatron-Bridge:** + + Clone [Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) and checkout the specific commit required for Puzzletron: + + ```bash + cd $WORKSPACE + git clone https://github.com/NVIDIA-NeMo/Megatron-Bridge.git + cd Megatron-Bridge + git checkout 960a718cb8989676b258e107d538642717e22e39 + ``` + +2. **Initialize Megatron-Bridge submodules:** + + ```bash + cd $WORKSPACE/Megatron-Bridge + git submodule init + git submodule update + ``` + +3. **Start Docker container with mounts:** + + Use the [NeMo 25.11 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=25.11): + + ```bash + docker run --gpus all -it --rm \ + -v $WORKSPACE:/workspace \ + -v $WORKSPACE/Megatron-Bridge/3rdparty/Megatron-LM:/opt/megatron-lm \ + nvcr.io/nvidia/nemo:25.11 \ + /bin/bash + ``` + + **Note:** The mount `/opt/megatron-lm` is required because Megatron-Bridge depends on the Megatron-LM submodule. + +4. **Set up the environment inside the container:** + + ```bash + export PYTHONPATH="/workspace/Megatron-Bridge/src:/workspace/Model-Optimizer:${PYTHONPATH}" + ``` + +## Dataset Preparation + +This section describes how to prepare datasets for knowledge distillation. We provide examples using a toy dataset (WikiText-103) for illustration purposes, and note how to adapt the process for production datasets like Nemotron-Post-Training-Dataset-v2. + +> **Note:** For actual knowledge distillation, use a larger, more representative dataset like [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2). + +### Step 1: Download Dataset + +First, download the dataset and save it in JSONL format. For WikiText-103, you can use the following script: + +```python +# download_hf_wikitext_dataset.py +import json +import os +from datasets import load_dataset + +DATA_PATH = "path/to/hf_datasets/wikitext-103-v1" +# Load the WikiText-103 dataset +dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") + +# Define the destination folder +os.makedirs(DATA_PATH, exist_ok=True) + +# Save splits to JSONL files +with open(f"{DATA_PATH}/wikitext-train.jsonl", "w") as file: + file.writelines(json.dumps(item) + "\n" for item in dataset) + +print(f"Raw dataset saved to {DATA_PATH}/wikitext-train.jsonl") +``` + +### Step 2: Tokenize Dataset + +Next, tokenize the JSONL dataset using the tokenizer from your model. This converts the text data into token IDs that can be used for training: + +```python +# tokenize_wikitext_dataset.py +from modelopt.torch.utils.plugins import megatron_preprocess_data + +DATA_PATH = "path/to/hf_datasets/wikitext-103-v1" +HF_MODEL_NAME_OR_PATH = "path/to/your/model/checkpoint" + +megatron_preprocess_data( + input_path=f"{DATA_PATH}/wikitext-train.jsonl", + output_dir=DATA_PATH, + tokenizer_name_or_path=HF_MODEL_NAME_OR_PATH, + json_keys=["text"], + workers=32, + log_interval=100000, +) +``` + +## Step 1: Convert Checkpoints to Megatron-Bridge Format + +Convert both student and teacher checkpoints: + +```bash +# Convert student checkpoint +torchrun --nproc_per_node=1 examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py \ + --input-ckpt-path /path/to/student/anymodel/checkpoint \ + --output-ckpt-path /path/to/student/mbridge/checkpoint + +# Convert teacher checkpoint +torchrun --nproc_per_node=1 examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py \ + --input-ckpt-path /path/to/teacher/anymodel/checkpoint \ + --output-ckpt-path /path/to/teacher/mbridge/checkpoint +``` + +## Step 2: Run Knowledge Distillation + +Run distillation with tokenized dataset: + +```bash +torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_anymodel.py \ + --student-mbridge-ckpt /path/to/student/mbridge/checkpoint/iter_0000000 \ + --teacher-mbridge-ckpt /path/to/teacher/mbridge/checkpoint/iter_0000000 \ + --data-path /path/to/tokenized/dataset \ + --output-dir ./distilled_output \ + dataset.sequence_length=8192 \ + model.tensor_model_parallel_size=8 \ + model.teacher.tensor_model_parallel_size=8 \ + train.global_batch_size=4 \ + train.micro_batch_size=1 \ + train.train_iters=5000 \ + logger.log_interval=1 +``` + +The distilled checkpoint will be saved to `--output-dir`. diff --git a/examples/puzzletron/mbridge_distillation/distill_anymodel.py b/examples/puzzletron/mbridge_distillation/distill_anymodel.py new file mode 100644 index 000000000..5d6e3eb9c --- /dev/null +++ b/examples/puzzletron/mbridge_distillation/distill_anymodel.py @@ -0,0 +1,295 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Knowledge Distillation Script for AnyModel Checkpoints. + +This script performs knowledge distillation between student and teacher models +that have been converted to Megatron-Bridge format using import_anymodel_to_mbridge.py. + +The distillation uses KL-Divergence loss between student and teacher logits +with temperature scaling (standard knowledge distillation from Hinton et al., 2015). + +Usage: + cd /workspace/Model-Optimizer + + # TODO: remove this once Megatron-Bridge is installed in the environment + export PYTHONPATH="/workspace/Megatron-Bridge/src:/workspace/Model-Optimizer:${PYTHONPATH}" + + # Basic usage (uses model's max seq_length, which may be very large) + torchrun --nproc_per_node=1 examples/puzzletron/mbridge_distillation/distill_anymodel.py \ + --student-mbridge-ckpt /path/to/student/iter_0000000 \ + --teacher-mbridge-ckpt /path/to/teacher/iter_0000000 \ + --data-path /path/to/tokenized/dataset \ + --output-dir ./distilled_output + + # Recommended: Override sequence length and other training params for faster training + torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_anymodel.py \ + --student-mbridge-ckpt /path/to/student/iter_0000000 \ + --teacher-mbridge-ckpt /path/to/teacher/iter_0000000 \ + --data-path /path/to/tokenized/dataset \ + --output-dir ./distilled_output \ + dataset.sequence_length=8192 \ + model.tensor_model_parallel_size=8 \ + model.teacher.tensor_model_parallel_size=8 \ + train.global_batch_size=4 \ + train.micro_batch_size=1 \ + train.train_iters=5000 \ + logger.log_interval=1 +""" + +import argparse +import logging +import os +import sys + +import torch +from megatron.bridge.models.distillation_provider import convert_to_distillation_provider +from megatron.bridge.training.checkpointing import get_checkpoint_run_config_filename +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + DistributedDataParallelConfig, + DistributedInitConfig, + GPTDatasetConfig, + LoggerConfig, + OptimizerConfig, + RerunStateMachineConfig, + RNGConfig, + SchedulerConfig, + TrainingConfig, + ValidationConfig, +) +from megatron.bridge.training.distill import distill +from megatron.bridge.training.model_load_save import load_model_config +from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig +from megatron.bridge.training.tokenizers.config import TokenizerConfig +from megatron.bridge.training.utils.omegaconf_utils import ( + apply_overrides, + create_omegaconf_dict_config, + parse_hydra_overrides, +) +from megatron.bridge.utils.common_utils import get_rank_safe +from omegaconf import OmegaConf + +# Import GenericHeterogeneousProvider so it can be instantiated when loading +# checkpoint configs that reference it (e.g., run_config.yaml with +# _target_: modelopt.torch.puzzletron.export.mbridge.base.GenericHeterogeneousProvider) +import modelopt.torch.puzzletron.export.mbridge # noqa: F401 + +logger: logging.Logger = logging.getLogger(__name__) + + +def create_distillation_config() -> ModelOptDistillConfig: + """Create KD config with output layer distillation only.""" + return ModelOptDistillConfig( + logit_layers=["output_layer", "output_layer"], + intermediate_layer_pairs=[], + skip_lm_loss=True, + kd_loss_scale=1.0, + logit_kl_temperature=1.0, + ) + + +def create_base_config( + student_model_provider, + data_path: str, + student_ckpt: str, + output_dir: str, + use_bf16: bool, + use_fp16: bool, +) -> ConfigContainer: + """Create base ConfigContainer with defaults.""" + return ConfigContainer( + model=student_model_provider, + train=TrainingConfig(global_batch_size=1, micro_batch_size=1, train_iters=100), + optimizer=OptimizerConfig( + optimizer="adam", + lr=1e-4, + min_lr=1e-5, + weight_decay=0.01, + bf16=use_bf16, + fp16=use_fp16, + ), + scheduler=SchedulerConfig( + lr_decay_style="linear", + lr_warmup_iters=0, + start_weight_decay=0.01, + end_weight_decay=0.01, + weight_decay_incr_style="constant", + ), + dataset=GPTDatasetConfig( + random_seed=1234, + blend=[[data_path], [1.0]], + split="9999,8,2", + seq_length=student_model_provider.seq_length, + reset_attention_mask=False, + reset_position_ids=False, + eod_mask_loss=False, + dataloader_type="single", + ), + checkpoint=CheckpointConfig(load=student_ckpt, save=output_dir), + logger=LoggerConfig(), + tokenizer=TokenizerConfig(tokenizer_type="HuggingFaceTokenizer", tokenizer_model=None), + validation=ValidationConfig(eval_interval=500, eval_iters=100), + ddp=DistributedDataParallelConfig(grad_reduce_in_fp32=True), + dist=DistributedInitConfig(), + rng=RNGConfig(), + rerun_state_machine=RerunStateMachineConfig(), + ) + + +def merge_checkpoint_configs( + cfg: ConfigContainer, checkpoint_path: str, use_bf16: bool, use_fp16: bool +) -> None: + """Merge tokenizer and optimizer configs from checkpoint if available.""" + try: + run_config_path = get_checkpoint_run_config_filename(checkpoint_path) + checkpoint_cfg = ConfigContainer.from_yaml(run_config_path) + if checkpoint_cfg.tokenizer is not None: + cfg.tokenizer = checkpoint_cfg.tokenizer + if checkpoint_cfg.optimizer is not None: + for key, value in checkpoint_cfg.optimizer.__dict__.items(): + if ( + value is not None + and hasattr(cfg.optimizer, key) + and key not in ("bf16", "fp16") + ): + setattr(cfg.optimizer, key, value) + # Ensure bf16/fp16 are set correctly based on model dtype + cfg.optimizer.bf16 = use_bf16 + cfg.optimizer.fp16 = use_fp16 + except Exception as e: + logger.warning(f"Could not load additional configs from checkpoint: {e}") + + +def parse_cli_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Knowledge distillation with AnyModel checkpoints", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--student-mbridge-ckpt", + type=str, + required=True, + help="Path to student checkpoint in MBridge format (must be iter_XXXXXXX directory).", + ) + parser.add_argument( + "--teacher-mbridge-ckpt", + type=str, + required=True, + help="Path to teacher checkpoint in MBridge format (must be iter_XXXXXXX directory).", + ) + parser.add_argument( + "--data-path", + type=str, + required=True, + help="Path to tokenized dataset (without .bin extension).", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./distilled_output", + help="Output directory for distilled checkpoint.", + ) + parser.add_argument( + "--config-file", + type=str, + default=None, + help="Path to YAML OmegaConf override file (optional).", + ) + parser.add_argument("--debug", action="store_true", help="Enable debug logging") + + args, cli_overrides = parser.parse_known_args() + return args, cli_overrides + + +def main() -> None: + """Main distillation function.""" + args, cli_overrides = parse_cli_args() + + if args.debug: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + logger.info("Megatron-Bridge Knowledge Distillation Script (AnyModel)") + logger.info("=" * 70) + + # Load model configs from checkpoints + logger.info("Loading model configs from MBridge checkpoints...") + student_model_provider, _ = load_model_config(args.student_mbridge_ckpt) + teacher_model_provider, _ = load_model_config(args.teacher_mbridge_ckpt) + + # Detect model dtype for optimizer config + model_params_dtype = getattr(student_model_provider, "params_dtype", torch.float32) + use_bf16 = model_params_dtype == torch.bfloat16 + use_fp16 = model_params_dtype == torch.float16 + + # Create base config with defaults + cfg = create_base_config( + student_model_provider, + args.data_path, + args.student_mbridge_ckpt, + args.output_dir, + use_bf16, + use_fp16, + ) + + # Merge tokenizer and optimizer from checkpoint if available + merge_checkpoint_configs(cfg, args.student_mbridge_ckpt, use_bf16, use_fp16) + + # Create distillation config and convert to DistillationProvider + kd_config = create_distillation_config() + cfg.model = convert_to_distillation_provider( + student_provider=student_model_provider, + teacher_provider=teacher_model_provider, + kd_config=kd_config, + ) + + # Apply YAML and CLI overrides + merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg) + if args.config_file: + if not os.path.exists(args.config_file): + logger.error(f"Override YAML file not found: {args.config_file}") + sys.exit(1) + yaml_overrides_omega = OmegaConf.load(args.config_file) + merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega) + if cli_overrides: + merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides) + apply_overrides(cfg, OmegaConf.to_container(merged_omega_conf, resolve=True), excluded_fields) + + # Sync model seq_length with dataset sequence_length if they differ + if ( + hasattr(cfg.dataset, "sequence_length") + and cfg.dataset.sequence_length != cfg.model.seq_length + ): + cfg.model.seq_length = cfg.dataset.sequence_length + if hasattr(cfg.model, "teacher") and cfg.model.teacher is not None: + cfg.model.teacher.seq_length = cfg.dataset.sequence_length + + if get_rank_safe() == 0: + logger.info("--- Final Configuration ---") + cfg.print_yaml() + + # Run distillation + logger.info("Starting distillation training...") + distill(cfg) + + logger.info(f"āœ“ Distillation complete! Checkpoints saved to: {args.output_dir}") + + +if __name__ == "__main__": + main() diff --git a/examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py b/examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py new file mode 100644 index 000000000..1127d228d --- /dev/null +++ b/examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Import AnyModel checkpoint to Megatron-Bridge format. + +This script converts a HuggingFace AnyModel checkpoint to Megatron-Bridge format, +similar to NeMo's convert_hf_to_nemo. + +Usage: + cd /workspace/Model-Optimizer + + # TODO: remove this once Megatron-Bridge is installed in the environment + export PYTHONPATH="/workspace/Megatron-Bridge/src:/workspace/Model-Optimizer:${PYTHONPATH}" + + torchrun --nproc_per_node=1 examples/puzzletron/mbridge_distillation/import_anymodel_to_mbridge.py \ + --input-ckpt-path /path/to/anymodel/checkpoint \ + --output-ckpt-path /path/to/save/mbridge/checkpoint +""" + +import argparse +from pathlib import Path + +from megatron.bridge import AutoBridge + +# Import all heterogeneous bridges to register them +# This will override homogeneous bridges (e.g., LlamaBridge, Qwen3Bridge) with +# heterogeneous versions (PuzzletronLlamaAnyModelBridge, PuzzletronQwen3AnyModelBridge) +# that support block_configs for AnyModel checkpoints. +import modelopt.torch.puzzletron.export.mbridge # noqa: F401 + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Convert AnyModel checkpoint to Megatron-Bridge format", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--input-ckpt-path", + type=str, + required=True, + help="Path to input AnyModel checkpoint (HuggingFace format)", + ) + parser.add_argument( + "--output-ckpt-path", + type=str, + required=True, + help="Path to save output Megatron-Bridge checkpoint", + ) + return parser.parse_args() + + +def main() -> None: + """Main function to import HF AnyModel and save as Megatron checkpoint.""" + args = parse_args() + + input_path = Path(args.input_ckpt_path) + output_path = Path(args.output_ckpt_path) + + print(f"Importing AnyModel checkpoint from: {input_path}") + print(f"Saving Megatron-Bridge checkpoint to: {output_path}") + print() + + # Create output directory if it doesn't exist + output_path.mkdir(parents=True, exist_ok=True) + + # Import and save as Megatron checkpoint + AutoBridge.import_ckpt( + hf_model_id=str(input_path), + megatron_path=str(output_path), + trust_remote_code=True, + ) + + print(f"\nāœ“ Successfully saved Megatron-Bridge checkpoint to: {output_path}") + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/puzzletron/export/mbridge/__init__.py b/modelopt/torch/puzzletron/export/mbridge/__init__.py new file mode 100644 index 000000000..471e68984 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/__init__.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron-Bridge adapters for Puzzletron AnyModel checkpoints. + +This module provides bridges for converting Puzzletron AnyModel checkpoints +(heterogeneous layer architectures) to Megatron-Core format via Megatron-Bridge. +""" + +# Import to register bridges (side effect) +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin +from modelopt.torch.puzzletron.export.mbridge.llama import ( # noqa: F401 + PuzzletronLlamaAnyModelBridge, +) +from modelopt.torch.puzzletron.export.mbridge.qwen3 import ( # noqa: F401 + PuzzletronQwen3AnyModelBridge, +) + +__all__ = [ + "HeterogeneousBridgeMixin", + "PuzzletronLlamaAnyModelBridge", + "PuzzletronQwen3AnyModelBridge", +] diff --git a/modelopt/torch/puzzletron/export/mbridge/base.py b/modelopt/torch/puzzletron/export/mbridge/base.py new file mode 100644 index 000000000..eff9f0b73 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/base.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mixin class for bridges that support heterogeneous layer architectures. + +This module provides a mixin class for converting models with block_configs +(heterogeneous layer configurations) to Megatron-Core format via Megatron-Bridge. +""" + +import dataclasses +import json +from collections.abc import Callable +from dataclasses import dataclass + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.bridge.models.transformer_config import HeterogeneousTransformerConfig +from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( + get_gpt_heterogeneous_layer_spec, +) +from megatron.core.transformer.spec_utils import ModuleSpec + + +def heterogeneous_layer_spec(config) -> ModuleSpec: + """Get GPT heterogeneous layer spec using Transformer Engine.""" + return get_gpt_heterogeneous_layer_spec(config, use_te=True) + + +@dataclass +class GenericHeterogeneousProvider(GPTModelProvider, HeterogeneousTransformerConfig): + """Generic provider for AnyModel checkpoints with block_configs.""" + + # Heterogeneous configuration fields + heterogeneous_layers_config_path: str | None = None + heterogeneous_layers_config_encoded_json: str = "" + transformer_layer_spec: ModuleSpec | Callable = heterogeneous_layer_spec + + def __getattr__(self, name: str): + """Handle missing attributes for OmegaConf compatibility. + + OmegaConf conversion tries to access per_block_parameters which may not + be initialized when loading from YAML. Return empty list as fallback. + """ + if name == "per_block_parameters": + return [] + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + +class HeterogeneousBridgeMixin: + """Mixin for bridges supporting heterogeneous layer architectures (block_configs). + + Must be used with multiple inheritance alongside a model-specific bridge. + Example: class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge) + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider: + """Convert HF AnyModel config to Megatron GPTModelProvider.""" + + parent_provider = super().provider_bridge(hf_pretrained) # type: ignore[misc] + + provider_kwargs = dataclasses.asdict(parent_provider) + + provider_kwargs["heterogeneous_layers_config_encoded_json"] = ( + self._build_heterogeneous_config_json(hf_pretrained.config) + ) + return GenericHeterogeneousProvider(**provider_kwargs) + + @classmethod + def megatron_to_hf_config(cls, provider: GPTModelProvider) -> dict: + raise NotImplementedError( + "megatron_to_hf_config() not yet implemented for AnyModel bridges. " + "AnyModel bridges require special handling for heterogeneous layer configurations." + ) + + def _build_heterogeneous_config_json(self, hf_config) -> str: + """Build heterogeneous layers config JSON from HF config.""" + + hf_config_dict = json.loads(hf_config.to_json_string()) + + mcore_block_configs = [ + self._convert_block_config(block) for block in hf_config_dict["block_configs"] + ] + return json.dumps({"block_configs": mcore_block_configs}, ensure_ascii=False) + + def _convert_block_config(self, block: dict) -> dict: + """Convert a single block config from HF format to MCore format.""" + return { + "attention": self._convert_attention_config(block["attention"]), + "ffn": self._convert_ffn_config(block["ffn"]), + } + + def _convert_attention_config(self, attention_config: dict) -> dict: + """Convert attention config from HF format to MCore format.""" + attention_config = attention_config.copy() + attention_config["num_query_groups"] = attention_config.pop("num_key_value_heads") + return attention_config + + def _convert_ffn_config(self, ffn_config: dict) -> dict: + """Convert FFN/MLP config from HF format to MCore format.""" + ffn_config = ffn_config.copy() + ffn_config["ffn_hidden_size"] = ffn_config.pop("intermediate_size") + return ffn_config diff --git a/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py b/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py new file mode 100644 index 000000000..d9d7f2266 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/distillation_provider.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Upstream this fix to Megatron-Bridge and remove this local copy. + +import logging +from dataclasses import dataclass, fields +from typing import TYPE_CHECKING, Any, Optional + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider +from megatron.bridge.models.transformer_config import TransformerConfig +from megatron.core.models.gpt import GPTModel as MCoreGPTModel + +import modelopt.torch.distill as mtd +import modelopt.torch.distill.plugins.megatron as mtd_mcore + +if TYPE_CHECKING: + from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig + + +logger = logging.getLogger(__name__) + + +@dataclass +class DistillationProvider(TransformerConfig): + """Provider for Megatron Core GPT models in distillation mode. + + Please use `convert_to_distillation_provider()` to create an instance of this class. + """ + + teacher: Optional[GPTModelProvider | MambaModelProvider] = None + kd_config: Optional["ModelOptDistillConfig"] = None + + def __init__(self, *args, **kwargs): + raise NotImplementedError( + "Use `convert_to_distillation_provider()` to create an instance of this class." + ) + + def __post_init__(self): + assert getattr(self, "teacher", None) is not None, "Teacher model must be provided." + + shared_attrs = [ + "tensor_model_parallel_size", + "pipeline_model_parallel_size", + "context_parallel_size", + "seq_length", + "pipeline_dtype", + ] + for attr in shared_attrs: + if getattr(self, attr) != getattr(self.teacher, attr): + raise ValueError(f"Student and teacher providers must have the same {attr}.") + + # Logits are overwritten in-place when TE cross-entropy loss is enabled, so switch it back to native version. + self.cross_entropy_fusion_impl = "native" + + # Hack to dynamically subclass other providers and still use their methods + self._super_class = self.__class__.__bases__[0] + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: + """Configure and instantiate a ModelOpt DistillationModel based on this configuration. + + Args: + pre_process: Whether to include pre-processing in the model, defaults to first pipeline stage + post_process: Whether to include post-processing in the model, defaults to last pipeline stage + vp_stage: Virtual pipeline stage + + Returns: + MCoreGPTModel: Configured ModelOpt DistillationModel instance + """ + if vp_stage is not None: + raise ValueError("ModelOpt KD currently does not support virtual-pipeline parallel.") + + assert self.teacher is not None, "Teacher model must be provided." + student_model = self._super_class.provide(self, pre_process, post_process, vp_stage) # type: ignore[attr-defined] + + # Finalize teacher provider before creating model (required for heterogeneous models). + # + # per_block_parameters is an attribute of HeterogeneousTransformerConfig (defined in + # MCoreHeterogeneousTransformerConfig, heterogeneous_config.py:197). It's created during + # provider creation (bridge.to_megatron_provider()), but finalize() ensures they're consistent + # with current parallelism settings and distributed context. Student model creation (above) + # initializes parallel_state (process groups, TP/PP config), which weight loading/scatter + # requires. During teacher model creation, get_config_for_layer() is called (transformer_block.py:341) + # for each layer, which uses per_block_parameters and current tensor_model_parallel_size to + # determine layer architecture. Without finalize() in this context, architecture expectations + # don't match checkpoint weights, causing: + # ValueError: ProcessGroupNCCL::scatter: invalid tensor size at index 0 + # (expected (2880, 4096), got (3584, 4096)) + # + # Note: This explanation needs to be confirmed yet. + self.teacher.finalize() + + # Hack to get teacher's pre-wrap hooks called to potentially load HF weights + teacher_model = self.teacher.provide_distributed_model( + wrap_with_ddp=False, mixed_precision_wrapper=None + )[0] + + kd_cfg = mtd_mcore.setup_distillation_config( + self.kd_config, student_model.config, teacher_model.config + ) + modelopt_cfg = { + "teacher_model": teacher_model, + "criterion": kd_cfg.criterion, + "loss_balancer": kd_cfg.loss_balancer, + } + kd_model = mtd.convert(student_model, mode=[("kd_loss", modelopt_cfg)]) + mtd_mcore.adjust_distillation_model_for_mcore(kd_model, kd_cfg) + + return kd_model + + def to_cfg_dict(self) -> dict[str, Any]: + """Custom method to save equivalent to the original provider class. + + Used by `_ConfigContainerBase` to serialize the main `ConfigContainer` to YAML. + There is no need to restore a `DistillationProvider` from the run config file, as + it can always be re-converted using the original student provider. + + Returns: + Dictionary representation of this provider class + """ + from megatron.bridge.training.utils.config_utils import _ConfigContainerBase + + result = {"_target_": f"{self._super_class.__module__}.{self._super_class.__qualname__}"} + for field in fields(self): + if field.name.startswith("_") or field.name in ["teacher", "kd_config"]: + continue + result[field.name] = _ConfigContainerBase._convert_value_to_dict( + getattr(self, field.name) + ) + return result + + def __setattr__(self, name, value): + super().__setattr__(name, value) + # Mirror to teacher if it has that attribute + if hasattr(self.teacher, name): + setattr(self.teacher, name, value) + + +def convert_to_distillation_provider( + student_provider: GPTModelProvider | MambaModelProvider, + teacher_provider: GPTModelProvider | MambaModelProvider, + kd_config: Optional["ModelOptDistillConfig"] = None, +) -> "DistillationProvider": + """Convert a given model provider to a DistillationProvider.""" + + assert isinstance(student_provider, (GPTModelProvider, MambaModelProvider)), ( + "Student provider must be a subclass of GPTModelProvider or MambaModelProvider." + ) + assert isinstance(teacher_provider, (GPTModelProvider, MambaModelProvider)), ( + "Teacher provider must be a subclass of GPTModelProvider or MambaModelProvider." + ) + + DistillationProvider.__bases__ = (type(student_provider),) + student_provider.__class__ = DistillationProvider + + student_provider.teacher = teacher_provider + student_provider.kd_config = kd_config + student_provider.__post_init__() + + return student_provider diff --git a/modelopt/torch/puzzletron/export/mbridge/llama.py b/modelopt/torch/puzzletron/export/mbridge/llama.py new file mode 100644 index 000000000..d193a02e9 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/llama.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Bridge for Puzzletron Llama-based AnyModel heterogeneous checkpoints.""" + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.llama.llama_bridge import LlamaBridge +from megatron.core.models.gpt.gpt_model import GPTModel +from transformers import LlamaForCausalLM + +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin + + +@MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel, model_type="llama") +class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge): + """ + Megatron Bridge for Puzzletron Llama-based AnyModel checkpoints. + + Extends LlamaBridge with support for heterogeneous layer architectures (block_configs). + All Llama-specific settings are inherited from LlamaBridge. + """ + + # provider_bridge() is inherited from HeterogeneousBridgeMixin + # It automatically reuses LlamaBridge.provider_bridge() and adds heterogeneous config + # mapping_registry() is inherited from LlamaBridge diff --git a/modelopt/torch/puzzletron/export/mbridge/qwen3.py b/modelopt/torch/puzzletron/export/mbridge/qwen3.py new file mode 100644 index 000000000..664e87e08 --- /dev/null +++ b/modelopt/torch/puzzletron/export/mbridge/qwen3.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron Bridge for Puzzletron Qwen3-based AnyModel heterogeneous checkpoints.""" + +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.qwen.qwen3_bridge import Qwen3Bridge +from megatron.core.models.gpt.gpt_model import GPTModel +from transformers import Qwen3ForCausalLM + +from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin + + +@MegatronModelBridge.register_bridge(source=Qwen3ForCausalLM, target=GPTModel, model_type="qwen3") +class PuzzletronQwen3AnyModelBridge(HeterogeneousBridgeMixin, Qwen3Bridge): + """ + Megatron Bridge for Puzzletron Qwen3-based AnyModel checkpoints. + + Extends Qwen3Bridge with support for heterogeneous layer architectures (block_configs). + All Qwen3-specific settings are inherited from Qwen3Bridge. + """ + + # provider_bridge() is inherited from HeterogeneousBridgeMixin + # It automatically reuses Qwen3Bridge.provider_bridge() and adds heterogeneous config + # mapping_registry() is inherited from Qwen3Bridge