diff --git a/flagscale/backends/Megatron-LM/megatron/core/transformer/transformer_config.py.patch b/flagscale/backends/Megatron-LM/megatron/core/transformer/transformer_config.py.patch index f947218e3c..1a297b342a 100644 --- a/flagscale/backends/Megatron-LM/megatron/core/transformer/transformer_config.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/core/transformer/transformer_config.py.patch @@ -1,8 +1,18 @@ diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py -index 5ff62f74c..3c0571b2d 100644 +index 5ff62f74c..46bd78b24 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py -@@ -317,6 +317,15 @@ class TransformerConfig(ModelParallelConfig): +@@ -212,6 +212,9 @@ class TransformerConfig(ModelParallelConfig): + moe_deepep_num_sms: int = 20 + """Number of SMs to use for DeepEP.""" + ++ untie_embeddings_and_output_weights: bool = False ++ """The model's input word embedding matrix and the output layer's weight matrix are tied""" ++ + #################### + # initialization + #################### +@@ -317,6 +320,15 @@ class TransformerConfig(ModelParallelConfig): the number of transformer layers to recompute within each pipeline stage. Must be None for 'selective' activation checkpointing.""" @@ -18,7 +28,7 @@ index 5ff62f74c..3c0571b2d 100644 distribute_saved_activations: Optional[bool] = None """If True, distribute recomputed activations across the model parallel group.""" -@@ -417,6 +426,12 @@ class TransformerConfig(ModelParallelConfig): +@@ -417,6 +429,12 @@ class TransformerConfig(ModelParallelConfig): together with fp4 mode (i.e., TransformerConfig.fp4 is not None). Note that not all parameters will be converted to fp4; for example, biases will remain unchanged.""" @@ -31,7 +41,7 @@ index 5ff62f74c..3c0571b2d 100644 #################### # MoE related #################### -@@ -644,6 +659,9 @@ class TransformerConfig(ModelParallelConfig): +@@ -644,6 +662,9 @@ class TransformerConfig(ModelParallelConfig): config_logger_dir: str = "" """When non-empty, dumps entry-point configs to config_logger_dir""" @@ -41,7 +51,7 @@ index 5ff62f74c..3c0571b2d 100644 flash_decode: bool = False """ Use the optimized flash decoding kernel during inference. """ -@@ -705,6 +723,31 @@ class TransformerConfig(ModelParallelConfig): +@@ -705,6 +726,31 @@ class TransformerConfig(ModelParallelConfig): """Transformer implementation to use. Options are 'transformer_engine' for Transformer Engine and 'local' for MCore.""" @@ -73,7 +83,7 @@ index 5ff62f74c..3c0571b2d 100644 def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more -@@ -1481,6 +1524,9 @@ class TransformerConfig(ModelParallelConfig): +@@ -1481,6 +1527,9 @@ class TransformerConfig(ModelParallelConfig): f"the number of layers ({self.num_layers})" ) @@ -83,7 +93,7 @@ index 5ff62f74c..3c0571b2d 100644 @dataclass class MLATransformerConfig(TransformerConfig): -@@ -1569,3 +1615,4 @@ class MLATransformerConfig(TransformerConfig): +@@ -1569,3 +1618,4 @@ class MLATransformerConfig(TransformerConfig): assert ( self.apply_rope_fusion is False ), "Rope Fusion is not compatible with caching latents" diff --git a/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch b/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch index e1a001954c..a571a35b75 100644 --- a/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/training/arguments.py.patch @@ -1,5 +1,5 @@ diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index 1120c7529..ffb59dac4 100644 +index 1120c7529..3fbfe4d5c 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -67,6 +67,7 @@ def add_megatron_arguments(parser: argparse.ArgumentParser): @@ -22,7 +22,18 @@ index 1120c7529..ffb59dac4 100644 # Custom arguments. if extra_args_provider is not None: -@@ -368,63 +374,68 @@ def validate_args(args, defaults={}): +@@ -123,6 +129,10 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): + assert MultiStorageClientFeature.is_enabled() is False + print('WARNING: The MSC feature is disabled.') + ++ #checkout save hf config path ++ if args.save_hf: ++ assert args.hf_config_path is not None ++ + return args + + +@@ -368,63 +378,68 @@ def validate_args(args, defaults={}): "legacy model format only supports the 'torch' checkpoint format." update_use_dist_ckpt(args) @@ -142,7 +153,7 @@ index 1120c7529..ffb59dac4 100644 if args.hierarchical_context_parallel_sizes: from numpy import prod -@@ -433,8 +444,8 @@ def validate_args(args, defaults={}): +@@ -433,8 +448,8 @@ def validate_args(args, defaults={}): assert args.hierarchical_context_parallel_sizes is not None, \ "--hierarchical-context-parallel-sizes must be set when a2a+p2p is used in cp comm" @@ -153,7 +164,7 @@ index 1120c7529..ffb59dac4 100644 # Deprecated arguments. assert args.batch_size is None, '--batch-size argument is no longer ' \ -@@ -530,6 +541,7 @@ def validate_args(args, defaults={}): +@@ -530,6 +545,7 @@ def validate_args(args, defaults={}): if args.virtual_pipeline_model_parallel_size == 1: args.virtual_pipeline_model_parallel_size = None elif args.num_layers_per_virtual_pipeline_stage is not None or args.num_virtual_stages_per_pipeline_rank is not None: @@ -161,7 +172,7 @@ index 1120c7529..ffb59dac4 100644 if args.num_virtual_stages_per_pipeline_rank is None: assert args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None, \ 'please use --num-virtual-stages-per-pipeline-rank to specify virtual pipeline parallel degree when enable uneven pipeline parallelism' -@@ -571,8 +583,9 @@ def validate_args(args, defaults={}): +@@ -571,8 +587,9 @@ def validate_args(args, defaults={}): if args.account_for_loss_in_pipeline_split: num_layers += 1 @@ -173,7 +184,7 @@ index 1120c7529..ffb59dac4 100644 if args.virtual_pipeline_model_parallel_size is not None: if args.overlap_p2p_comm: -@@ -796,12 +809,22 @@ def validate_args(args, defaults={}): +@@ -796,12 +813,22 @@ def validate_args(args, defaults={}): # Checks. if args.ffn_hidden_size is None: if args.swiglu: @@ -202,7 +213,7 @@ index 1120c7529..ffb59dac4 100644 else: args.ffn_hidden_size = 4 * args.hidden_size -@@ -1175,6 +1198,147 @@ def validate_args(args, defaults={}): +@@ -1175,6 +1202,147 @@ def validate_args(args, defaults={}): args.recompute_granularity != 'full' ), 'recompute_granularity must not be full when CUDA Graphs are enabled.' @@ -350,7 +361,7 @@ index 1120c7529..ffb59dac4 100644 # Print arguments. _print_args("arguments", args) -@@ -1585,6 +1749,8 @@ def _add_network_size_args(parser): +@@ -1585,6 +1753,8 @@ def _add_network_size_args(parser): help='Which normalization technique to use.') group.add_argument('--norm-epsilon', type=float, default=1e-5, help='Epsilon for layer norm and RMS norm.') @@ -359,7 +370,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--apply-layernorm-1p', action='store_true', help='Adjust LayerNorm weights such that they are centered ' 'around zero. This improves numerical stability.') -@@ -1608,6 +1774,10 @@ def _add_network_size_args(parser): +@@ -1608,6 +1778,10 @@ def _add_network_size_args(parser): group.add_argument('--glu-linear-offset', type=float, default=0.0, help='Offset term in the GLU activation function: activation_func(x[0]) * (x[1] + offset). ' 'Only used when gated_linear_unit is True') @@ -370,7 +381,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--onnx-safe', type=bool, required=False, help='Use workarounds for known problems with ' 'Torch ONNX exporter') -@@ -1820,6 +1990,14 @@ def _add_logging_args(parser): +@@ -1820,6 +1994,14 @@ def _add_logging_args(parser): help='The wandb experiment name.') group.add_argument('--wandb-save-dir', type=str, default='', help='Path to save the wandb results locally.') @@ -385,7 +396,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--logging-level', type=int, default=None, help='Set default logging level') return parser -@@ -1854,6 +2032,15 @@ def _add_regularization_args(parser): +@@ -1854,6 +2036,15 @@ def _add_regularization_args(parser): 'numerical stability') group.add_argument('--sgd-momentum', type=float, default=0.9, help='Momentum factor for sgd') @@ -401,7 +412,7 @@ index 1120c7529..ffb59dac4 100644 return parser -@@ -2001,6 +2188,25 @@ def _add_training_args(parser): +@@ -2001,6 +2192,25 @@ def _add_training_args(parser): '"shared_experts": recompute the shared experts in the MoE layer.' '"moe_act", "layernorm", and "mla_up_proj" use output-discarding checkpointing, ' '"core_attn", "mlp", "moe", and "shared_experts" use normal checkpointing.') @@ -427,7 +438,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--no-clone-scatter-output-in-embedding', action='store_false', help='If not set, clone the output of the scatter in embedding layer to GC original tensor.', dest='clone_scatter_output_in_embedding') -@@ -2087,6 +2293,10 @@ def _add_training_args(parser): +@@ -2087,6 +2297,10 @@ def _add_training_args(parser): help='Total number of samples to train over all ' 'training runs. Note that either train-iters or ' 'train-samples should be provided.') @@ -438,7 +449,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--log-interval', type=int, default=100, help='Report loss and timing interval.') group.add_argument('--exit-interval', type=int, default=None, -@@ -2140,7 +2350,7 @@ def _add_training_args(parser): +@@ -2140,7 +2354,7 @@ def _add_training_args(parser): help='Enable bias only in the QKV linear layers', dest='add_qkv_bias') group.add_argument('--optimizer', type=str, default='adam', @@ -447,7 +458,7 @@ index 1120c7529..ffb59dac4 100644 help='Optimizer function') group.add_argument('--optimizer-cpu-offload', action='store_true', help='Offload optimizer state to CPU') -@@ -2210,6 +2420,10 @@ def _add_training_args(parser): +@@ -2210,6 +2424,10 @@ def _add_training_args(parser): help='The communicator group names to use high priority streams.') group.add_argument('--use-te-activation-func', action='store_true', help='Use activation function kernel from Transformer Engine in MLP module.') @@ -458,7 +469,7 @@ index 1120c7529..ffb59dac4 100644 return parser -@@ -2268,11 +2482,26 @@ def _add_learning_rate_args(parser): +@@ -2268,11 +2486,26 @@ def _add_learning_rate_args(parser): 'and initial warmup, the learning rate at each ' 'iteration would be different.') group.add_argument('--lr-decay-style', type=str, default='linear', @@ -486,7 +497,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--lr-decay-iters', type=int, default=None, help='number of iterations to decay learning rate over,' ' If None defaults to `--train-iters`') -@@ -2331,6 +2560,8 @@ def _add_checkpointing_args(parser): +@@ -2331,6 +2564,8 @@ def _add_checkpointing_args(parser): group.add_argument('--save-retain-interval', type=int, default=None, help='Number of iterations between retained checkpoints (other' 'checkpoints _except the last checkpoint_ are automatically deleted).') @@ -495,7 +506,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--no-save-optim', action='store_true', default=None, help='Do not save current optimizer.') group.add_argument('--no-save-rng', action='store_true', default=None, -@@ -2380,6 +2611,8 @@ def _add_checkpointing_args(parser): +@@ -2380,6 +2615,8 @@ def _add_checkpointing_args(parser): group.add_argument('--no-use-tokenizer-model-from-checkpoint-args', action='store_false', dest='use_tokenizer_model_from_checkpoint_args', help='If set, do not use tokenizer model path from checkpoint') @@ -504,7 +515,21 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--exit-on-missing-checkpoint', action='store_true', help="If '--load' is set, but checkpoint is not found " "(e.g., path typo), then exit instead of random " -@@ -2541,7 +2774,7 @@ def _add_distributed_args(parser): +@@ -2455,6 +2692,13 @@ def _add_checkpointing_args(parser): + group.add_argument('--load-model-opt-format', action='store_true', + help='Load a checkpoint for TensorRT model optimizer (nvidia-modelopt).' + 'This function can also be used to load NeMo .nemo sharded checkpoints.') ++ group.add_argument('--load-hf', action='store_true',default=None, ++ help='Use the HF format for warm start, and save it in the torch_dict' ++ 'format while also saving it in the HF format.') ++ group.add_argument('--save-hf', action='store_true',default=None, ++ help='Save as Hugging Face format checkpoint.') ++ group.add_argument('--hf-config-path', default=None, ++ help='Load the HF model from config.') + return parser + + +@@ -2541,7 +2785,7 @@ def _add_distributed_args(parser): default=False, help='if set, overlap pipeline parallel communication in warmup and flush', dest='overlap_p2p_comm_warmup_flush') group.add_argument('--distributed-backend', default='nccl', @@ -513,7 +538,7 @@ index 1120c7529..ffb59dac4 100644 help='Which backend to use for distributed training.') group.add_argument('--distributed-timeout-minutes', type=int, default=10, help='Timeout minutes for torch.distributed.') -@@ -2592,6 +2825,11 @@ def _add_distributed_args(parser): +@@ -2592,6 +2836,11 @@ def _add_distributed_args(parser): 'complete it instead. Also turns on ' '--use-cpu-initialization flag. This is for ' 'external DDP manager.' ) @@ -525,7 +550,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--account-for-embedding-in-pipeline-split', action='store_true', default=False, help='If set, *input* embedding layer will be treated as a standard transformer' 'layer in the context of partition and placement for pipeline parallelism.') -@@ -2636,6 +2874,10 @@ def _add_distributed_args(parser): +@@ -2636,6 +2885,10 @@ def _add_distributed_args(parser): help='If set, keep the fp8 transpose cache when using Megatron FSDP.') group.add_argument('--enable-full-sharding-in-hsdp', action='store_true', help='If set, enable full sharding in megatron-fsdp Hybrid Sharded Data Parallel (HSDP) mode.') @@ -536,7 +561,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--num-distributed-optimizer-instances', type=int, default=1, help='Number of Distributed Optimizer copies across Data Parallel domain.') group.add_argument('--use-torch-fsdp2', action='store_true', -@@ -2690,6 +2932,9 @@ def _add_validation_args(parser): +@@ -2690,6 +2943,9 @@ def _add_validation_args(parser): group.add_argument('--eval-interval', type=int, default=1000, help='Interval between running evaluation on ' 'validation set.') @@ -546,7 +571,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument("--test-mode", action="store_true", help='Run all real-time test alongside the experiment.') group.add_argument('--skip-train', action='store_true', default=False, help='If set, bypass the training loop, ' -@@ -2708,6 +2953,8 @@ def _add_tokenizer_args(parser): +@@ -2708,6 +2964,8 @@ def _add_tokenizer_args(parser): 'automatically calculated from vocab-size.') group.add_argument('--vocab-file', type=str, default=None, help='Path to the vocab file.') @@ -555,7 +580,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--merge-file', type=str, default=None, help='Path to the BPE merge file.') group.add_argument('--vocab-extra-ids', type=int, default=0, -@@ -2726,8 +2973,17 @@ def _add_tokenizer_args(parser): +@@ -2726,8 +2984,17 @@ def _add_tokenizer_args(parser): 'MultimodalTokenizer', 'NullTokenizer', 'NullMultimodalTokenizer', @@ -574,7 +599,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--tokenizer-model', type=str, default=None, help='Sentencepiece tokenizer model.') group.add_argument('--tokenizer-metadata', type=str, default=None, -@@ -2768,6 +3024,11 @@ def _add_data_args(parser): +@@ -2768,6 +3035,11 @@ def _add_data_args(parser): group.add_argument('--valid-data-path', nargs='*', default=None, help='The weight and prefix list for an independent validation dataset. ' 'Follows the same pattern rules as --data-path.') @@ -586,7 +611,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--test-data-path', nargs='*', default=None, help='The weight and prefix list for an independent test dataset. ' 'Follows the same pattern rules as --data-path.') -@@ -2816,11 +3077,18 @@ def _add_data_args(parser): +@@ -2816,11 +3088,18 @@ def _add_data_args(parser): 'end-of-document token.') group.add_argument('--eod-mask-loss', action='store_true', help='Mask loss for the end of document tokens.') @@ -605,7 +630,7 @@ index 1120c7529..ffb59dac4 100644 group.add_argument('--object-storage-cache-path', type=str, default=None, help='Path to cache index files when using s3 or msc dataloader') group.add_argument('--mid-level-dataset-surplus', type=float, default=0.005, -@@ -2897,6 +3165,19 @@ def _add_biencoder_args(parser): +@@ -2897,6 +3176,19 @@ def _add_biencoder_args(parser): return parser @@ -625,7 +650,7 @@ index 1120c7529..ffb59dac4 100644 def _add_vision_args(parser): group = parser.add_argument_group(title="vision") -@@ -2967,6 +3248,8 @@ def _add_vision_args(parser): +@@ -2967,6 +3259,8 @@ def _add_vision_args(parser): help='Whether to layer normalize the q and k attention embeddings.') group.add_argument('--qk-l2-norm', action='store_true', help='Use llama 4 qk l2 norm') @@ -634,7 +659,7 @@ index 1120c7529..ffb59dac4 100644 return parser -@@ -3275,3 +3558,94 @@ def _add_sft_args(parser): +@@ -3275,3 +3569,94 @@ def _add_sft_args(parser): group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned", help='SFT prompt format.') return parser diff --git a/flagscale/backends/Megatron-LM/megatron/training/checkpointing.py.patch b/flagscale/backends/Megatron-LM/megatron/training/checkpointing.py.patch index 8e1c68997d..cf10169ae2 100644 --- a/flagscale/backends/Megatron-LM/megatron/training/checkpointing.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/training/checkpointing.py.patch @@ -1,5 +1,5 @@ diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py -index 104fa6882..722859bf6 100644 +index 104fa6882..7326bea6f 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -286,12 +286,15 @@ def read_metadata(tracker_filename): @@ -20,7 +20,31 @@ index 104fa6882..722859bf6 100644 torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX) max_iter = iters_cuda[0].item() -@@ -692,6 +695,28 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati +@@ -473,6 +476,23 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati + # Save dataloader state if the dataloader supports it (currently only Megatron Energon). + maybe_save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None)) + ++ # save hf format model weight ++ hf_checkpoint_name = get_checkpoint_name(save_dir, iteration, release=False, pipeline_parallel=pipeline_parallel, ++ tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=True) ++ if args.save_hf and hasattr(args,'hf_config_path'): ++ assert args.hf_config_path is not None, "hf_config_path should not be None" ++ #use megatron bridge ++ from flagscale.train.bridge.models import AutoBridge ++ from flagscale.train.bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry ++ from transformers import AutoConfig ++ #Load the HF model from config ++ config_load=args.hf_config_path ++ config = safe_load_config_with_retry(config_load, trust_remote_code=False) ++ bridge=AutoBridge.from_hf_config(config) ++ #Save the HF model weights in the corresponding iteration's safetensor folder. ++ safe_save=os.path.join(hf_checkpoint_name, 'safetensor') ++ bridge.save_hf_pretrained(model=model,path=safe_save) ++ + # Save distributed optimizer's custom parameter state. + if ( + args.use_distributed_optimizer +@@ -692,6 +712,28 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati if not torch.distributed.is_initialized() \ or is_last_rank(): def wandb_finalize_fn(): @@ -49,7 +73,7 @@ index 104fa6882..722859bf6 100644 wandb_utils.on_save_checkpoint_success(checkpoint_name, get_checkpoint_tracker_filename(save_dir), save_dir, iteration) if args.async_save: assert async_save_request is not None -@@ -774,9 +799,7 @@ def maybe_save_dataloader_state(train_iterator, iteration, dataloader_save_path) +@@ -774,9 +816,7 @@ def maybe_save_dataloader_state(train_iterator, iteration, dataloader_save_path) torch.distributed.barrier(group=mpu.get_data_parallel_group()) @@ -60,7 +84,7 @@ index 104fa6882..722859bf6 100644 torch.distributed.barrier(group=mpu.get_data_parallel_group()) dataloader_save_dict = {} -@@ -1239,6 +1262,10 @@ def load_args_from_checkpoint( +@@ -1239,6 +1279,10 @@ def load_args_from_checkpoint( checkpoint_args, 'add_bias_linear', not getattr(checkpoint_args, 'disable_bias_linear') ) @@ -71,7 +95,7 @@ index 104fa6882..722859bf6 100644 def _set_arg(arg_name, old_arg_name=None, force=False): if not force and getattr(args, arg_name, None) is not None: return -@@ -1274,6 +1301,8 @@ def load_args_from_checkpoint( +@@ -1274,6 +1318,8 @@ def load_args_from_checkpoint( _set_arg('add_qkv_bias', force=True) _set_arg('squared_relu', force=True) _set_arg('swiglu', force=True) @@ -80,7 +104,25 @@ index 104fa6882..722859bf6 100644 _set_arg('untie_embeddings_and_output_weights', force=True) _set_arg('apply_layernorm_1p', force=True) _set_arg('normalization', force=True) -@@ -1432,6 +1461,14 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', +@@ -1347,6 +1393,17 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', + args = get_args() + load_dir = getattr(args, load_arg) + ++ # load hf format ++ if args.load_hf: ++ # use megatron bridge ++ from flagscale.train.bridge.models import AutoBridge ++ bridge=AutoBridge.from_hf_pretrained(load_dir) ++ bridge.load_hf_weights(ddp_model) ++ # no optimizer weight ++ iteration=0 ++ num_floating_point_operations_so_far=0 ++ return iteration, num_floating_point_operations_so_far ++ + # Check for model-opt format loading + if hasattr(args, 'load_model_opt_format') and args.load_model_opt_format: + print_rank_0(f'Loading checkpoint using ModelOpt format from {load_dir}') +@@ -1432,6 +1489,14 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', mismatch_msg = "(TP, PP) mismatch after resume ({} vs {} from checkpoint)".format( run_tp_pp, ckpt_tp_pp ) @@ -95,7 +137,7 @@ index 104fa6882..722859bf6 100644 # Determine if RNG state will be loaded if (ckpt_tp_pp == run_tp_pp and not release and not args.finetune and not args.no_load_rng -@@ -1468,6 +1505,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', +@@ -1468,6 +1533,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', ckpt_tp_pp != run_tp_pp and sharded_sd_metadata['distrib_optim_sharding_type'] not in DistributedOptimizer.checkpoint_fully_reshardable_formats @@ -103,7 +145,7 @@ index 104fa6882..722859bf6 100644 ): raise RuntimeError(f"{mismatch_msg}: not supported for DistributedOptimizer with sharding type" f" {sharded_sd_metadata['distrib_optim_sharding_type']}." -@@ -1481,7 +1519,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', +@@ -1481,7 +1547,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', gen_sd_optim = None gen_sd_opt_param_scheduler = None @@ -112,7 +154,7 @@ index 104fa6882..722859bf6 100644 model_sd_kwargs = dict(metadata=sharded_sd_metadata) # Determine if rerun state will be loaded -@@ -1829,3 +1867,4 @@ def load_biencoder_checkpoint(model, only_query_model=False, +@@ -1829,3 +1895,4 @@ def load_biencoder_checkpoint(model, only_query_model=False, print(' successfully loaded {}'.format(checkpoint_name)) return model diff --git a/flagscale/backends/Megatron-LM/megatron/training/yaml_arguments.py.patch b/flagscale/backends/Megatron-LM/megatron/training/yaml_arguments.py.patch new file mode 100644 index 0000000000..8fcd87a278 --- /dev/null +++ b/flagscale/backends/Megatron-LM/megatron/training/yaml_arguments.py.patch @@ -0,0 +1,14 @@ +diff --git a/megatron/training/yaml_arguments.py b/megatron/training/yaml_arguments.py +index 390d503ee..ae1f3069d 100644 +--- a/megatron/training/yaml_arguments.py ++++ b/megatron/training/yaml_arguments.py +@@ -409,7 +409,8 @@ def core_transformer_config_from_yaml(args, transfomer_key = "language_model"): + # Hardcoded + kw_args['deallocate_pipeline_outputs'] = True + kw_args['pipeline_dtype'] = kw_args['params_dtype'] +- kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm ++ kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm ++ kw_args['untie_embeddings_and_output_weights'] = args.untie_embeddings_and_output_weights + + assert args.activation_func in ["swiglu","squaredrelu","gelu"], f"{args.activation_func} is not a supported activation function" + if args.activation_func == "swiglu": diff --git a/flagscale/train/bridge/__init__.py b/flagscale/train/bridge/__init__.py new file mode 100644 index 0000000000..357be95b52 --- /dev/null +++ b/flagscale/train/bridge/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge +"""Megatron Bridge - A component of the Megatron ecosystem.""" + +from flagscale.train.bridge.models.conversion.auto_bridge import AutoBridge + +__all__ = ["AutoBridge"] diff --git a/flagscale/train/bridge/models/__init__.py b/flagscale/train/bridge/models/__init__.py new file mode 100644 index 0000000000..0456b6c99b --- /dev/null +++ b/flagscale/train/bridge/models/__init__.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from flagscale.train.bridge.models.conversion.auto_bridge import AutoBridge +from flagscale.train.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from flagscale.train.bridge.models.conversion.model_bridge import MegatronModelBridge +from flagscale.train.bridge.models.conversion.param_mapping import ( + AutoMapping, + ColumnParallelMapping, + GatedMLPMapping, + MegatronParamMapping, + QKVMapping, + ReplicatedMapping, + RowParallelMapping, +) +from flagscale.train.bridge.models.deepseek import ( + DeepSeekModelProvider, + DeepSeekProvider, + DeepSeekV2LiteModelProvider, + DeepSeekV2LiteProvider, + DeepSeekV2ModelProvider, + DeepSeekV2Provider, + DeepSeekV3ModelProvider, + DeepSeekV3Provider, + MoonlightModelProvider16B, + MoonlightProvider, +) +from flagscale.train.bridge.models.gpt_provider import GPTModelProvider +from flagscale.train.bridge.models.qwen import ( + Qwen2ModelProvider, + Qwen2ModelProvider1P5B, + Qwen2ModelProvider7B, + Qwen2ModelProvider72B, + Qwen2ModelProvider500M, + Qwen3ModelProvider, + Qwen3ModelProvider1P7B, + Qwen3ModelProvider4B, + Qwen3ModelProvider8B, + Qwen3ModelProvider14B, + Qwen3ModelProvider32B, + Qwen3ModelProvider600M, + Qwen3MoEModelProvider, + Qwen3MoEModelProvider30B_A3B, + Qwen3MoEModelProvider235B_A22B, + Qwen25ModelProvider1P5B, + Qwen25ModelProvider3B, + Qwen25ModelProvider7B, + Qwen25ModelProvider14B, + Qwen25ModelProvider32B, + Qwen25ModelProvider72B, + Qwen25ModelProvider500M, +) + +__all__ = [ + "AutoBridge", + "MegatronMappingRegistry", + "MegatronModelBridge", + "ColumnParallelMapping", + "GatedMLPMapping", + "MegatronParamMapping", + "QKVMapping", + "ReplicatedMapping", + "RowParallelMapping", + "AutoMapping", + "GPTModelProvider", + "Qwen2ModelProvider", + "Qwen2ModelProvider500M", + "Qwen2ModelProvider1P5B", + "Qwen2ModelProvider7B", + "Qwen2ModelProvider72B", + "Qwen25ModelProvider500M", + "Qwen25ModelProvider1P5B", + "Qwen25ModelProvider3B", + "Qwen25ModelProvider7B", + "Qwen25ModelProvider14B", + "Qwen25ModelProvider32B", + "Qwen25ModelProvider72B", + "Qwen3ModelProvider", + "Qwen3ModelProvider600M", + "Qwen3ModelProvider1P7B", + "Qwen3ModelProvider4B", + "Qwen3ModelProvider8B", + "Qwen3ModelProvider14B", + "Qwen3ModelProvider32B", + "Qwen3MoEModelProvider", + "Qwen3MoEModelProvider30B_A3B", + "Qwen3MoEModelProvider235B_A22B", + "DeepSeekModelProvider", + "DeepSeekProvider", + "DeepSeekV2LiteModelProvider", + "DeepSeekV2LiteProvider", + "DeepSeekV2ModelProvider", + "DeepSeekV2Provider", + "DeepSeekV3ModelProvider", + "DeepSeekV3Provider", + "MoonlightModelProvider16B", + "MoonlightProvider", +] diff --git a/flagscale/train/bridge/models/config.py b/flagscale/train/bridge/models/config.py new file mode 100644 index 0000000000..7ccffc8ccf --- /dev/null +++ b/flagscale/train/bridge/models/config.py @@ -0,0 +1,340 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import json + +from dataclasses import fields as dataclass_fields, is_dataclass +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Protocol, Type, TypeVar, Union, runtime_checkable + +import yaml + +from omegaconf import OmegaConf + +from flagscale.train.bridge.utils.instantiate_utils import InstantiationMode, instantiate +from flagscale.train.bridge.utils.yaml_utils import safe_yaml_representers + +# For TOML support +try: + import toml + + HAS_TOML = True +except ImportError: + HAS_TOML = False + + +T = TypeVar("T") +ConfigFormat = Literal["yaml", "json", "toml"] + + +@runtime_checkable +class ConfigProtocol(Protocol): + """Protocol defining the configuration interface for model providers.""" + + @classmethod + def from_hf_pretrained( + cls: Type[T], + pretrained_model_name_or_path: Union[str, Path], + trust_remote_code: bool = False, + mode: InstantiationMode = InstantiationMode.LENIENT, + **kwargs, + ) -> T: + """Load a pretrained model configuration from a directory or file.""" + ... + + def save_hf_pretrained( + self, + save_directory: Union[str, Path], + config_format: ConfigFormat | None = None, + config_name: Optional[str] = None, + **kwargs, + ) -> None: + """Save the model configuration to a directory.""" + ... + + +def from_hf_pretrained( + cls: Type[T], + pretrained_model_name_or_path: Union[str, Path], + trust_remote_code: bool = False, + mode: InstantiationMode = InstantiationMode.LENIENT, + config_name: str = "config", + **kwargs, +) -> T: + """ + Load a pretrained model configuration from a directory or file. + + Args: + cls: The class to instantiate + pretrained_model_name_or_path: Path to a directory containing a config file, + or direct path to a config file (yaml/json/toml) + trust_remote_code: Whether to trust and execute code references (classes/functions) + found in the configuration. Required to be True if the config + contains any class or function references. Default: False + mode: Instantiation mode (STRICT or LENIENT) for the instantiate function + config_name: Base name of the config file (without extension) + **kwargs: Additional keyword arguments to override loaded configuration + + Returns: + Instance of the class with loaded configuration + + Example: + ```python + # Load from directory (looks for config.yaml, config.json, or config.toml) + model = from_hf_pretrained(MyModel, "./saved_model/") + + # Load from specific file + model = from_hf_pretrained(MyModel, "./saved_model/config.yaml") + + # With code references + model = from_pretrained(MyModel, "./saved_model/", trust_remote_code=True) + + # Override configuration values + model = from_pretrained(MyModel, "./saved_model/", temperature=0.8) + ``` + """ + path = Path(pretrained_model_name_or_path) + + # Determine the config file path + if path.is_dir(): + # Look for config files in order of preference + config_file = None + for ext in [".yaml", ".yml", ".json", ".toml"]: + candidate = path / f"{config_name}{ext}" + if candidate.exists(): + config_file = candidate + break + + if config_file is None: + raise FileNotFoundError( + f"No configuration file found in {path}. " + f"Expected {config_name}.yaml, {config_name}.json, or {config_name}.toml" + ) + else: + config_file = path + + if not config_file.exists(): + raise FileNotFoundError(f"Configuration file not found at {config_file}") + + # Load the configuration based on file extension + file_ext = config_file.suffix.lower() + + if file_ext in [".yaml", ".yml"]: + with open(config_file, "r", encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + elif file_ext == ".json": + with open(config_file, "r", encoding="utf-8") as f: + config_dict = json.load(f) + elif file_ext == ".toml": + if not HAS_TOML: + raise ImportError( + "TOML support requires the 'toml' package. Install it with: pip install toml" + ) + with open(config_file, "r", encoding="utf-8") as f: + config_dict = toml.load(f) + else: + raise ValueError( + f"Unsupported file format: {file_ext}. Supported formats: .yaml, .yml, .json, .toml" + ) + + # Check for trust_remote_code requirement + if not trust_remote_code and _contains_code_references(config_dict): + raise ValueError( + "This configuration contains class or function references. " + "Loading it requires trust_remote_code=True to prevent arbitrary code execution." + ) + + # Convert to OmegaConf for compatibility with instantiate + omega_conf = OmegaConf.create(config_dict) + + # Merge with kwargs + if kwargs: + override_conf = OmegaConf.create(kwargs) + omega_conf = OmegaConf.merge(omega_conf, override_conf) + + # Add _target_ if not present + if "_target_" not in omega_conf: + omega_conf["_target_"] = f"{cls.__module__}.{cls.__qualname__}" + + # Convert back to container for instantiate + final_config = OmegaConf.to_container(omega_conf, resolve=True) + + # Use instantiate to create the object + return instantiate(final_config, mode=mode) + + +def save_hf_pretrained( + obj: Any, + save_directory: Union[str, Path], + config_format: ConfigFormat = "json", + config_name: str = "config", + **kwargs, +) -> None: + """ + Save the model configuration to a directory. + + Args: + obj: The object to save + save_directory: Directory where to save the configuration + config_format: Format to save in ("yaml", "json", or "toml"). Default: "json" + config_name: Name for the config file (without extension) + **kwargs: Additional metadata to save alongside the configuration + + Example: + ```python + # Save as JSON (default) + save_hf_pretrained(model, "./saved_model/") + + # Save as YAML + save_hf_pretrained(model, "./saved_model/", config_format="yaml") + + # Save with custom name + save_hf_pretrained(model, "./saved_model/", config_name="my_config") + ``` + """ + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + # Determine file extension + format_to_ext = {"yaml": ".yaml", "yml": ".yaml", "json": ".json", "toml": ".toml"} + + config_format = config_format.lower() + if config_format not in format_to_ext: + raise ValueError( + f"Unsupported format: {config_format}. Supported formats: {list(format_to_ext.keys())}" + ) + + if config_format == "toml" and not HAS_TOML: + raise ImportError( + "TOML support requires the 'toml' package. Install it with: pip install toml" + ) + + config_file = save_path / f"{config_name}{format_to_ext[config_format]}" + + # Get the configuration dictionary + config_dict = _to_dict(obj) + + # Add any additional metadata + if kwargs: + config_dict.update(kwargs) + + # Save based on format + if config_format in ["yaml", "yml"]: + with safe_yaml_representers(): + with open(config_file, "w", encoding="utf-8") as f: + yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False) + elif config_format == "json": + # First convert to YAML string to use the custom representers + with safe_yaml_representers(): + yaml_str = yaml.safe_dump(config_dict, default_flow_style=False) + # Then parse and save as JSON + yaml_dict = yaml.safe_load(yaml_str) + with open(config_file, "w", encoding="utf-8") as f: + json.dump(yaml_dict, f, indent=2, ensure_ascii=False) + elif config_format == "toml": + # First convert to YAML string to use the custom representers + with safe_yaml_representers(): + yaml_str = yaml.safe_dump(config_dict, default_flow_style=False) + # Then parse and save as TOML + yaml_dict = yaml.safe_load(yaml_str) + with open(config_file, "w", encoding="utf-8") as f: + toml.dump(yaml_dict, f) + + print(f"Configuration saved to {config_file}") + + +def _to_dict(obj: Any) -> Dict[str, Any]: + """ + Convert an object to a dictionary representation. + + Args: + obj: The object to convert + + Returns: + Dictionary representation of the object + """ + # Check if this is a ConfigContainer (has to_dict method) + if hasattr(obj, "to_dict") and callable(obj.to_dict): + return obj.to_dict() + + # Otherwise, build dict from dataclass fields or attributes + result = {} + result["_target_"] = f"{obj.__class__.__module__}.{obj.__class__.__qualname__}" + + if is_dataclass(obj): + # Handle dataclass + for field in dataclass_fields(obj): + if field.name.startswith("_"): + continue + value = getattr(obj, field.name) + result[field.name] = _convert_value_to_dict(value) + else: + # Handle regular class + for key, value in obj.__dict__.items(): + if not key.startswith("_"): + result[key] = _convert_value_to_dict(value) + + return result + + +def _convert_value_to_dict(value: Any) -> Any: + """ + Recursively convert a value to a dictionary representation. + + Args: + value: The value to convert + + Returns: + The converted value + """ + if hasattr(value, "_to_dict"): + return value._to_dict() + elif hasattr(value, "to_dict") and callable(value.to_dict): + return value.to_dict() + elif is_dataclass(value) and not isinstance(value, type): + # Handle regular dataclasses + result = {"_target_": f"{value.__class__.__module__}.{value.__class__.__qualname__}"} + for field in dataclass_fields(value): + if not field.name.startswith("_"): + result[field.name] = _convert_value_to_dict(getattr(value, field.name)) + return result + elif isinstance(value, (list, tuple)): + return [_convert_value_to_dict(item) for item in value] + elif isinstance(value, dict): + return {k: _convert_value_to_dict(v) for k, v in value.items()} + else: + return value + + +def _contains_code_references(config_dict: Dict[str, Any]) -> bool: + """ + Check if a configuration dictionary contains code references. + + Args: + config_dict: The configuration dictionary to check + + Returns: + True if code references are found, False otherwise + """ + if isinstance(config_dict, dict): + for key, value in config_dict.items(): + # Check for _target_ that's not a built-in type + if key == "_target_" and isinstance(value, str): + # Consider it a code reference if it's not a basic type + if not value.startswith( + ("builtins.", "str", "int", "float", "bool", "list", "dict", "tuple") + ): + return True + # Check for _call_ = False which indicates a code reference + if key == "_call_" and value is False: + return True + # Recursively check nested structures + if _contains_code_references(value): + return True + elif isinstance(config_dict, (list, tuple)): + for item in config_dict: + if _contains_code_references(item): + return True + + return False diff --git a/flagscale/train/bridge/models/conversion/__init__.py b/flagscale/train/bridge/models/conversion/__init__.py new file mode 100644 index 0000000000..9c36c332cd --- /dev/null +++ b/flagscale/train/bridge/models/conversion/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +# Import model providers for easy access +from flagscale.train.bridge.models.conversion.auto_bridge import AutoBridge +from flagscale.train.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from flagscale.train.bridge.models.conversion.model_bridge import MegatronModelBridge +from flagscale.train.bridge.models.conversion.param_mapping import ( + AutoMapping, + ColumnParallelMapping, + GatedMLPMapping, + MegatronParamMapping, + QKVMapping, + ReplicatedMapping, + RowParallelMapping, +) +from flagscale.train.bridge.models.conversion.utils import weights_verification_table + +__all__ = [ + "AutoBridge", + "MegatronMappingRegistry", + "MegatronModelBridge", + "ColumnParallelMapping", + "GatedMLPMapping", + "MegatronParamMapping", + "QKVMapping", + "ReplicatedMapping", + "RowParallelMapping", + "AutoMapping", + "weights_verification_table", +] diff --git a/flagscale/train/bridge/models/conversion/auto_bridge.py b/flagscale/train/bridge/models/conversion/auto_bridge.py new file mode 100644 index 0000000000..4507cafefa --- /dev/null +++ b/flagscale/train/bridge/models/conversion/auto_bridge.py @@ -0,0 +1,572 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import dataclasses + +from functools import cached_property, partial +from pathlib import Path +from typing import Any, Generic, Iterable, List, Optional, Type, TypeVar, Union + +import torch.distributed as dist +import transformers + +from transformers import AutoModelForCausalLM +from transformers.configuration_utils import PretrainedConfig +from typing_extensions import Unpack + +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import MLATransformerConfig, TransformerConfig + +from flagscale.train.bridge.models.conversion import model_bridge +from flagscale.train.bridge.models.conversion.model_bridge import ( + HFWeightTuple, + MegatronModelBridge, + WeightConversionTask, +) +from flagscale.train.bridge.models.conversion.utils import get_causal_lm_class_via_auto_map + +# from flagscale.train.bridge.models.gpt_provider import GPTModelProvider +from flagscale.train.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from flagscale.train.bridge.models.hf_pretrained.safe_config_loader import ( + safe_load_config_with_retry, +) +from flagscale.train.bridge.models.hf_pretrained.state import SafeTensorsStateSource + +# from flagscale.train.bridge.models.model_provider import GetModelKwargs, ModelParallelKwargs, ModelProviderMixin + + +MegatronModelT = TypeVar("MegatronModelT", bound=MegatronModule) +DataclassT = TypeVar("DataclassT") + + +class AutoBridge(Generic[MegatronModelT]): + """ + Automatically select and instantiate the appropriate bridge for a model. + + This unified bridge class combines automatic model detection with full bridge + functionality for converting models between HuggingFace and Megatron formats. + It handles the conversion of causal language models (e.g., GPT, Llama, Phi) + between HuggingFace's transformers library format and Megatron-Core's distributed + training format. It manages weight mapping, tensor parallelism distribution, and + configuration translation. + + The bridge supports both directions of conversion: + - HuggingFace → Megatron: For training or inference with Megatron + - Megatron → HuggingFace: For saving trained models in HF format + + Args: + hf_pretrained: Either a PreTrainedCausalLM instance with loaded model, + or a PretrainedConfig for configuration-only operations + + Example: + >>> # Load and convert a model to Megatron format + >>> bridge = AutoBridge.from_hf_pretrained("meta-llama/Meta-Llama-3-8B") + >>> provider = bridge.to_megatron_provider() + >>> megatron_model = provider.provide_distributed_model(wrap_with_ddp=False) + + >>> # Export a Megatron model back to HuggingFace format + >>> bridge.save_hf_pretrained(megatron_model, "./exported_model") + + >>> # Convert weights with custom settings + >>> for name, weight in bridge.export_hf_weights( + ... megatron_model, + ... cpu=True + ... ): + ... print(f"Exported {name}: {weight.shape}") + + >>> # Check if a model is supported before loading + >>> if AutoBridge.can_handle("microsoft/phi-2"): + ... bridge = AutoBridge.from_hf_pretrained("microsoft/phi-2") + + Note: + The bridge automatically detects the model architecture and applies + the appropriate weight mappings. Custom architectures require implementing + a MegatronModelBridge subclass. + """ + + def __init__(self, hf_pretrained: PreTrainedCausalLM | PretrainedConfig): + if not isinstance(hf_pretrained, (PreTrainedCausalLM, PretrainedConfig)): + raise ValueError( + "hf_pretrained must be a PreTrainedCausalLM or PretrainedConfig instance" + ) + self.hf_pretrained: PreTrainedCausalLM | PretrainedConfig = hf_pretrained + + @classmethod + def list_supported_models(cls) -> list[str]: + """ + List all model architectures currently supported by the bridge system. + + Returns: + List of supported HuggingFace model architecture names + """ + # Get all registered implementations from the dispatch system + supported = [] + + # Access the dispatch registry to find all registered types + + if hasattr(model_bridge.get_model_bridge, "_exact_types"): + for arch_type in model_bridge.get_model_bridge._exact_types.keys(): + # Support both type and string registrations + if isinstance(arch_type, str): + supported.append(arch_type) + elif hasattr(arch_type, "__name__"): + supported.append(arch_type.__name__) + + return sorted(supported) + + @classmethod + def supports(cls, config: Any) -> bool: + """ + Check if this bridge supports the given model configuration. + + A model is supported if it has at least one architecture ending with 'ForCausalLM' or 'ForConditionalGeneration' + or 'NemotronH_Nano_VL_V2'. + + Args: + config: HuggingFace model config object + + Returns: + True if this bridge can handle the model, False otherwise + """ + architectures = getattr(config, "architectures", []) + if not architectures: + return False + return any( + arch.endswith(("ForCausalLM", "ForConditionalGeneration", "NemotronH_Nano_VL_V2")) + for arch in architectures + ) + + @classmethod + def from_hf_config(cls, config: PretrainedConfig) -> "AutoBridge": + """ + Create an AutoBridge from a HuggingFace configuration. + + This method creates a bridge instance from just a model configuration, + without loading any weights. This is useful for: + - Creating Megatron models with random initialization + - Working with model architectures without downloading weights + - Testing and development scenarios + + Args: + config: HuggingFace PretrainedConfig instance containing model + architecture information + + Returns: + AutoBridge: Bridge instance configured for the architecture + + Raises: + ValueError: If the configuration is not for a supported CausalLM model + + Example: + >>> from transformers import AutoConfig + >>> + >>> # Load just the configuration + >>> config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B") + >>> + >>> # Create bridge from config (no weights) + >>> bridge = AutoBridge.from_hf_config(config) + >>> + >>> # Create Megatron model with random initialization + >>> provider = bridge.to_megatron_provider(load_weights=False) + >>> model = provider.provide_distributed_model(wrap_with_ddp=False) + + >>> # Or use for architecture exploration + >>> transformer_config = bridge.transformer_config + >>> print(f"Hidden size: {transformer_config.hidden_size}") + >>> print(f"Num layers: {transformer_config.num_layers}") + + See Also: + from_hf_pretrained: Create bridge with loaded weights + transformer_config: Access the Megatron TransformerConfig + """ + cls._validate_config(config) + model = PreTrainedCausalLM() + model.config = config + import torch + + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + + with init_empty_weights(): + hf_model = AutoModelForCausalLM.from_config(model.config) + + for name, param in hf_model.named_parameters(): + set_module_tensor_to_device( + hf_model, name, "cpu", torch.empty(*param.size(), dtype=model.config.torch_dtype) + ) + model.model = hf_model + return cls(model) + + @classmethod + def from_hf_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoBridge": + """ + Load an AutoBridge from a pretrained model, automatically detecting the model type. + + This method loads a model from HuggingFace Hub or a local directory and + creates a bridge instance ready for conversion operations. The model + architecture is validated to ensure compatibility. + + Args: + path: HuggingFace model ID or path to model directory + Examples: "meta-llama/Meta-Llama-3-8B", "./my_model" + **kwargs: Additional arguments passed to HuggingFace from_hf_pretrained + Common options include: + - torch_dtype: Model precision (torch.float16, torch.bfloat16) + - device_map: Device placement strategy ("auto", "cuda:0", etc.) + - trust_remote_code: Allow custom model code execution + - attn_implementation: Attention implementation ("flash_attention_2", etc.) + + Returns: + AutoBridge: Bridge instance with loaded model + + Raises: + ValueError: If the model architecture is not supported + + Example: + >>> # Basic loading + >>> bridge = AutoBridge.from_hf_pretrained("gpt2") + + >>> # Load with specific settings + >>> bridge = AutoBridge.from_hf_pretrained( + ... "meta-llama/Meta-Llama-3-8B", + ... torch_dtype=torch.float16, + ... device_map="auto" + ... ) + + >>> # Works with local paths too + >>> bridge = AutoBridge.from_hf_pretrained("/path/to/model") + """ + # First load just the config to check architecture support + # Use thread-safe config loading to prevent race conditions + config = safe_load_config_with_retry( + path, trust_remote_code=kwargs.get("trust_remote_code", False) + ) + + cls._validate_config(config, str(path)) + + try: + return cls(PreTrainedCausalLM.from_pretrained(path, **kwargs)) + except Exception as e: + raise ValueError(f"Failed to load model with AutoBridge: {e}") from e + + def load_hf_weights( + self, model: list[MegatronModelT], hf_path: str | Path | None = None + ) -> None: + """ + Load HuggingFace weights into a Megatron model. + + This method handles the conversion and distribution of weights from + HuggingFace format to Megatron's distributed format, including proper + tensor parallel and pipeline parallel distribution. + + Args: + model: List of Megatron model instances (one per virtual pipeline stage) + hf_path: Optional path to load weights from. If None, uses weights + from the bridge's hf_pretrained instance + + Returns: + The input model with loaded weights + + Raises: + ValueError: If hf_path is None and bridge was created without weights + + Example: + >>> # Load weights from bridge's pretrained model + >>> bridge = AutoBridge.from_hf_pretrained("gpt2") + >>> megatron_model = create_megatron_model() # Your model creation + >>> bridge.load_hf_weights(megatron_model) + + >>> # Load weights from a different checkpoint + >>> bridge.load_hf_weights(megatron_model, "./finetuned_model") + """ + if hf_path is None: + if not isinstance(self.hf_pretrained, PreTrainedCausalLM): + raise ValueError( + "hf_path is required when hf_pretrained is not a PreTrainedCausalLM instance" + ) + pre_trained = self.hf_pretrained + else: + pre_trained = PreTrainedCausalLM.from_pretrained(hf_path) + # Preserve trust_remote_code setting from the original bridge instance + trust_remote_code = getattr(self.hf_pretrained, 'trust_remote_code', False) + pre_trained = PreTrainedCausalLM.from_pretrained( + hf_path, trust_remote_code=trust_remote_code + ) + # self._model_bridge.load_weights_hf_to_megatron(model, pre_trained) + self._model_bridge.load_weights_hf_to_megatron(pre_trained, model) + + return model + + def save_hf_pretrained( + self, + model: list[MegatronModelT], + path: str | Path, + show_progress: bool = True, + strict: bool = True, + ) -> None: + """ + Save a Megatron model in HuggingFace format. + + This method exports the complete model including configuration, tokenizer, + and weights to a directory that can be loaded with HuggingFace's + from_pretrained methods. + + If the original model was loaded with trust_remote_code=True, any custom + modeling files (e.g., modeling_*.py, configuration_*.py) will be preserved + to ensure the saved model can be loaded properly. + + Args: + model: Megatron model instance or list of instances + path: Directory path to save the model + show_progress: Display progress bar during weight export + + Example: + >>> # Save model after training + >>> bridge.save_hf_pretrained(megatron_model, "./my_finetuned_model") + + >>> # Load the saved model with HuggingFace + >>> from transformers import AutoModelForCausalLM + >>> hf_model = AutoModelForCausalLM.from_pretrained("./my_finetuned_model") + + Note: + This method is collective - all ranks must call it. Only rank 0 + saves the configuration files, while weight saving is coordinated + across all ranks. + """ + if dist.is_available() and dist.is_initialized(): + # Distributed training, only rank 0 saves artifacts + if dist.get_rank() == 0: + self.hf_pretrained.save_artifacts(path) + else: + # No distributed training, save artifacts + self.hf_pretrained.save_artifacts(path) + self.save_hf_weights(model, path, show_progress, strict) + + def save_hf_weights( + self, + model: list[MegatronModelT], + path: str | Path, + show_progress: bool = True, + strict: bool = True, + ) -> None: + """ + Save Megatron model weights in HuggingFace safetensors format. + + This method exports only the model weights (not configuration or tokenizer) + to safetensors files compatible with HuggingFace. It uses streaming save + to handle large models efficiently without requiring all weights in memory + at once. + + The weights are gathered from distributed ranks and saved in the standard + HuggingFace sharded format when the model is large. + + Args: + model: Megatron model instance or list of instances + path: Directory path where weight files will be saved + show_progress: Display progress bar during export + + Raises: + ValueError: If the state source doesn't support streaming save + + Example: + >>> # Save just the weights + >>> bridge.save_hf_weights(megatron_model, "./model_weights") + + >>> # Save without progress bar (useful in scripts) + >>> bridge.save_hf_weights(megatron_model, "./weights", show_progress=False) + + Note: + - This method is collective and must be called by all ranks + - Uses safetensors format for efficient loading and security + - Automatically handles model sharding for large models + - The saved weights can be loaded with HuggingFace's from_pretrained + """ + if dist.is_available() and dist.is_initialized(): + dist.barrier() + dispatch_instance = (self._causal_lm_architecture, self._get_model_instance(model)) + generator = model_bridge.stream_weights_megatron_to_hf( + dispatch_instance, model, self.hf_pretrained, cpu=True, show_progress=show_progress + ) + source = SafeTensorsStateSource(path) + # Check if the state source is SafeTensorsStateSource for streaming save. + if ( + hasattr(self.hf_pretrained, "state") + and hasattr(self.hf_pretrained.state, "source") + # and isinstance(self.hf_pretrained.state.source, SafeTensorsStateSource) + ): + # self.hf_pretrained.state.source.save_generator(generator, path, strict=strict) + source.save_generator(generator, path, strict=strict) + else: + raise ValueError( + "The state source is not a SafeTensorsStateSource, cannot save in streaming mode." + ) + + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + @property + def _model_bridge(self) -> "MegatronModelBridge": + return model_bridge.get_model_bridge(self._causal_lm_architecture) + + @cached_property + def _causal_lm_architecture(self): + """Resolve the model's CausalLM architecture for dispatch. + + Behavior: + - If the model can be imported from transformers directly, return the actual transformers class object. + - Otherwise, if the model uses HuggingFace auto_map, return the architecture's class name as a string (e.g., + "DeepseekV2ForCausalLM"). + + Returns: + str | type: The transformers class for the CausalLM architecture or the architecture's class name as a + string for auto_map models. + + Raises: + ValueError: If no CausalLM architecture is found or cannot be resolved. + """ + if isinstance(self.hf_pretrained, PreTrainedCausalLM): + config = self.hf_pretrained.config + model_name_or_path = getattr(config, "_name_or_path", None) or getattr( + self.hf_pretrained, "model_name_or_path", None + ) + else: + config = self.hf_pretrained + model_name_or_path = getattr(config, "_name_or_path", None) + + architectures = getattr(config, "architectures", []) + + if not architectures: + raise ValueError( + "\nāœ— No architectures found in model config\n\n" + "The model configuration does not specify any architectures.\n" + "This is required for determining the model type." + ) + + causal_lm_arch = None + for architecture_name in architectures: + # TODO: Can we improve this? + if architecture_name.endswith( + ("ForCausalLM", "ForConditionalGeneration", "NemotronH_Nano_VL_V2") + ): + causal_lm_arch = architecture_name + break + + if not causal_lm_arch: + raise ValueError( + f"\nāœ— No CausalLM architecture found\n\n" + f"Model architectures: {architectures}\n\n" + f"None of the architectures end with 'ForCausalLM' or 'ForConditionalGeneration' or" + f"'NemotronH_Nano_VL_V2'.\n" + f"This bridge only supports causal language models.\n" + f"For other model types, use a different bridge class." + ) + + # Try auto_map first + cls = get_causal_lm_class_via_auto_map(model_name_or_path=model_name_or_path, config=config) + if cls is not None: + # For auto_map models, return the class name as a string + return getattr(cls, "__name__", str(cls)) + + try: + return getattr(transformers, causal_lm_arch) + except AttributeError: + raise ValueError( + f"\nāœ— Architecture class '{causal_lm_arch}' not found in transformers\n\n" + f"This could mean:\n" + f"1. The model requires a newer version of transformers\n" + f"2. The model uses a custom modeling file not in the standard library\n" + f"3. There's a typo in the architecture name\n\n" + f"Please verify your transformers installation and the model requirements." + ) + + @classmethod + def _validate_config(cls, config: PretrainedConfig, path: str | None = None) -> None: + # Check if this is a causal LM model + if not cls.supports(config): + architectures = getattr(config, "architectures", []) + raise ValueError( + f"\nāœ— Model architecture not supported by AutoBridge\n\n" + f"Model: {path}\n" + f"Architectures: {architectures}\n\n" + f"AutoBridge only supports models with architectures ending in 'ForCausalLM' or" + f"'ForConditionalGeneration' or 'NemotronH_Nano_VL_V2'.\n" + f"Found architectures that don't match this pattern.\n\n" + f"If this is a different model type (e.g., Vision, Sequence-to-Sequence),\n" + f"you may need to use a different bridge class." + ) + + # Check if we have an implementation for this specific architecture + architecture = None + for arch_name in config.architectures: + if arch_name.endswith( + ("ForCausalLM", "ForConditionalGeneration", "NemotronH_Nano_VL_V2") + ): + architecture = arch_name + break + + if architecture: + # Try auto_map first + arch_class = ( + get_causal_lm_class_via_auto_map(model_name_or_path=path, config=config) + if path + else None + ) + if arch_class is not None: + # For auto_map models, use class-name string + arch_key = getattr(arch_class, "__name__", str(arch_class)) + else: + try: + arch_class = getattr(transformers, architecture) + arch_key = arch_class + except AttributeError: + # Fall back to name-based registration + arch_key = architecture + + # Test if we have a registered implementation (type or class-name string) + has_implementation = False + if hasattr(model_bridge.get_model_bridge, "_exact_types"): + registry = model_bridge.get_model_bridge._exact_types + if isinstance(arch_key, str): + has_implementation = arch_key in registry + else: + has_implementation = (arch_key in registry) or ( + getattr(arch_key, "__name__", None) in registry + ) + + if not has_implementation: + # Get list of supported models + supported_models = cls.list_supported_models() + + raise ValueError( + f"\nāœ— Model architecture '{architecture}' is not yet supported\n\n" + f"Model: {path}\n" + f"Architecture: {architecture}\n\n" + f"Currently supported architectures:\n" + + "\n".join(f" • {model}" for model in supported_models) + + f"\n\nTo add support for {architecture}, you need to:\n" + f"1. Create a new bridge class that inherits from MegatronModelBridge\n" + f"2. Implement the required methods (provider_bridge, mapping_registry)\n" + f"3. Register it with @MegatronModelBridge.register_bridge decorator\n\n" + f"Example implementation:\n" + f" from flagscale.train.bridge.models.conversion.model_bridge import MegatronModelBridge\n" + f" from transformers import {architecture}\n" + f" from megatron.core.models.gpt import GPTModel\n\n" + f" @MegatronModelBridge.register_bridge(source={architecture}, target=GPTModel)\n" + f" class Megatron{architecture.replace('ForCausalLM', '')}Bridge(MegatronModelBridge):\n" + f" def provider_bridge(self, hf_pretrained):\n" + f" # Return a ModelProvider instance\n" + f" ...\n\n" + f" def mapping_registry(self):\n" + f" # Return a MegatronMappingRegistry with weight mappings\n" + f" ...\n\n" + f"For reference implementations, see:\n" + f" • src/megatron/bridge/models/llama/llama_bridge.py\n" + f" • src/megatron/bridge/models/qwen/qwen_2_causal_bridge.py" + ) from None + + def _get_model_instance(self, model: list[MegatronModelT]) -> MegatronModelT: + model_instance = model[0] + while hasattr(model_instance, "module"): + model_instance = model_instance.module + return model_instance diff --git a/flagscale/train/bridge/models/conversion/mapping_registry.py b/flagscale/train/bridge/models/conversion/mapping_registry.py new file mode 100644 index 0000000000..129ba5fc48 --- /dev/null +++ b/flagscale/train/bridge/models/conversion/mapping_registry.py @@ -0,0 +1,266 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import re + +from typing import List, Optional + +from flagscale.train.bridge.models.conversion.param_mapping import MegatronParamMapping + + +class MegatronMappingRegistry: + """ + Registry for weight mappings between model formats with pattern matching support. + + This class serves as a registry of weight mappings between Megatron and external + (typically HuggingFace) model formats. It provides efficient pattern matching + for parameter names using glob-like wildcards (*) and supports both forward + (Megatron → HF) and reverse (HF → Megatron) lookups. + + The registry pre-compiles regex patterns for efficient repeated lookups and + handles the resolution of wildcards in parameter names. + + Args: + *mappings: Variable number of MegatronParamMapping objects defining + the individual weight mappings + + Example: + >>> # Create a mapping registry with various mappings + >>> mapping_registry = MegatronMappingRegistry( + ... AutoMapping( + ... megatron_param="embedding.word_embeddings.weight", + ... hf_param="model.embed_tokens.weight", + ... ), + ... QKVMapping( + ... megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + ... q="model.layers.*.self_attn.q_proj.weight", + ... k="model.layers.*.self_attn.k_proj.weight", + ... v="model.layers.*.self_attn.v_proj.weight", + ... ), + ... ) + + >>> # Query for a specific layer (wildcards are resolved) + >>> mapping = mapping_registry.megatron_to_hf_lookup("decoder.layers.0.self_attention.linear_qkv.weight") + >>> print(mapping.hf_param) # Will show resolved HF names for layer 0 + + >>> # Reverse lookup from HF name + >>> mapping = mapping_registry.hf_to_megatron_lookup("model.layers.5.self_attn.q_proj.weight") + >>> print(mapping.megatron_param) # Shows corresponding Megatron name + + >>> # Build from a list + >>> mappings = [bridge1, bridge2, bridge3] + >>> mapping_registry = MegatronMappingRegistry(*mappings) + + Note: + Wildcard patterns support: + - '*' matches any sequence of digits (0-9) - designed for layer indices + - '**' matches any sequence of characters - designed for nested paths + """ + + def _convert_pattern_to_regex(self, pattern: str) -> str: + """Convert a pattern with wildcards to regex pattern. + + Args: + pattern: Pattern string with * and ** wildcards + + Returns: + Regex pattern string + + Note: + ** must be processed before * to avoid conflicts. + ** becomes (.*) - matches any characters including dots + * becomes (\\d+) - matches digits only for layer indices + """ + # Escape the pattern first + regex_pattern = re.escape(pattern) + + # Process ** before * to avoid conflicts + # Replace \*\* with (.*) + regex_pattern = regex_pattern.replace(r"\*\*", r"(.*)") + + # Replace remaining \* with (\d+) + regex_pattern = regex_pattern.replace(r"\*", r"(\d+)") + + return regex_pattern + + def __init__(self, *mappings: MegatronParamMapping): + """ + Initialize MegatronMappingRegistry with weight mappings. + + Args: + *mappings: MegatronParamMapping objects + """ + self.mappings = list(mappings) + + # Pre-compile patterns for efficiency + self._compiled_patterns = [] + self._reverse_patterns = [] # For hf_param -> megatron lookups + + for mapping in mappings: + # Compile source patterns + if "*" in mapping.megatron_param: + # Convert glob pattern to regex with support for * and ** + pattern = self._convert_pattern_to_regex(mapping.megatron_param) + self._compiled_patterns.append((re.compile(f"^{pattern}$"), mapping)) + else: + self._compiled_patterns.append((None, mapping)) + + # Compile destination patterns for reverse lookups + if isinstance(mapping.hf_param, str): + if "*" in mapping.hf_param: + pattern = self._convert_pattern_to_regex(mapping.hf_param) + self._reverse_patterns.append((re.compile(f"^{pattern}$"), mapping)) + else: + self._reverse_patterns.append((None, mapping)) + else: + # For dict destinations, compile patterns for each value + reverse_dict_patterns = {} + for key, hf_pattern in mapping.hf_param.items(): + if "*" in hf_pattern: + pattern = self._convert_pattern_to_regex(hf_pattern) + reverse_dict_patterns[key] = re.compile(f"^{pattern}$") + else: + reverse_dict_patterns[key] = None + self._reverse_patterns.append((reverse_dict_patterns, mapping)) + + def megatron_to_hf_lookup(self, megatron_param_name: str) -> Optional[MegatronParamMapping]: + """ + Get mapping for a Megatron parameter name. + + This method performs efficient lookups by first checking for exact matches, + then falling back to pattern matching using pre-compiled regex patterns. + When a pattern match is found, wildcards are automatically resolved. + + Args: + megatron_param_name: Megatron parameter name to look up + Example: "decoder.layers.0.self_attention.linear_qkv.weight" + + Returns: + MegatronParamMapping: Bridge instance with resolved wildcards, or None + if no matching mapping is found. The returned bridge will have + all wildcards replaced with actual values. + + Example: + >>> # Query with exact layer number + >>> bridge = state_map.megatron_to_hf_lookup("decoder.layers.5.mlp.linear_fc1.weight") + >>> if bridge: + ... print(f"Maps to: {bridge.hf_param}") # Shows HF name for layer 5 + """ + for pattern, mapping in self._compiled_patterns: + if pattern is None: + # Direct match + if mapping.megatron_param == megatron_param_name: + return mapping + else: + # Pattern match + match = pattern.match(megatron_param_name) + if match: + # Return resolved mapping with wildcards replaced + return mapping.resolve(match.groups()) + return None + + def hf_to_megatron_lookup(self, hf_param_name: str) -> Optional[MegatronParamMapping]: + """ + Get mapping for a destination parameter name (reverse lookup). + + This is useful when you have a destination name and want to find + the corresponding megatron name. + + Args: + hf_param_name: Destination parameter name to look up + + Returns: + MegatronParamMapping with resolved wildcards, or None if no match found + """ + for pattern_info, mapping in self._reverse_patterns: + if isinstance(mapping.hf_param, str): + # Simple string destination + pattern = pattern_info + if pattern is None: + # Direct match + if mapping.hf_param == hf_param_name: + return mapping + else: + # Pattern match + match = pattern.match(hf_param_name) + if match: + return mapping.resolve(match.groups()) + else: + # Dict destination - check each pattern + patterns_dict = pattern_info + for key, pattern in patterns_dict.items(): + if pattern is None: + # Direct match + if mapping.hf_param[key] == hf_param_name: + # Create a simplified mapping for this specific key + return mapping.resolve(()) + else: + # Pattern match + match = pattern.match(hf_param_name) + if match: + return mapping.resolve(match.groups()) + return None + + def get_all_mappings(self) -> List[MegatronParamMapping]: + """Get all mappings in this MegatronMappingRegistry.""" + return self.mappings.copy() + + def get_mappings_by_pattern(self, pattern: str) -> List[MegatronParamMapping]: + """ + Get all mappings that match a given pattern. + + Args: + pattern: Pattern to match (supports * and ** wildcards) + + Returns: + List of matching MegatronParamMapping objects + """ + # Convert pattern to regex using the same logic as _convert_pattern_to_regex + # but for this method we want both * and ** to match anything for search purposes + regex_pattern = re.escape(pattern) + regex_pattern = regex_pattern.replace(r"\*\*", r".*") + regex_pattern = regex_pattern.replace(r"\*", r".*") + compiled_pattern = re.compile(f"^{regex_pattern}$") + + matches = [] + for mapping in self.mappings: + if compiled_pattern.match(mapping.megatron_param): + matches.append(mapping) + + return matches + + def __len__(self) -> int: + """Return number of mappings.""" + return len(self.mappings) + + def __iter__(self): + """Iterate over mappings.""" + return iter(self.mappings) + + def __repr__(self) -> str: + """String representation of MegatronMappingRegistry.""" + return f"MegatronMappingRegistry({len(self.mappings)} mappings)" + + def describe(self) -> str: + """ + Get a human-readable description of all mappings. + + Returns: + Formatted string describing all weight mappings + """ + lines = [f"MegatronMappingRegistry with {len(self.mappings)} mappings:"] + for i, mapping in enumerate(self.mappings): + lines.append(f"\n{i + 1}. {mapping.megatron_param}") + if isinstance(mapping.hf_param, str): + lines.append(f" → {mapping.hf_param}") + else: + lines.append(" → {") + for key, value in mapping.hf_param.items(): + lines.append(f" {key}: {value}") + lines.append(" }") + + # Show bridge type + lines.append(f" bridge: {type(mapping).__name__}") + + return "\n".join(lines) diff --git a/flagscale/train/bridge/models/conversion/model_bridge.py b/flagscale/train/bridge/models/conversion/model_bridge.py new file mode 100644 index 0000000000..9059bcc6f9 --- /dev/null +++ b/flagscale/train/bridge/models/conversion/model_bridge.py @@ -0,0 +1,1036 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import abc +import itertools +import logging +import re + +from dataclasses import dataclass +from typing import ( + Callable, + Generic, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Type, + TypeVar, + Union, +) + +import torch + +from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn +from transformers.modeling_utils import PreTrainedModel + +from megatron.core import parallel_state +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import get_pg_size, unwrap_model + +from flagscale.train.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from flagscale.train.bridge.models.conversion.param_mapping import MegatronParamMapping +from flagscale.train.bridge.models.conversion.utils import ( + extract_sort_key, + get_module_and_param_from_name, + persistent_buffers, +) +from flagscale.train.bridge.models.decorators.dispatch import dispatch +from flagscale.train.bridge.utils.common_utils import print_rank_0 + +logger = logging.getLogger(__name__) + +MappingT = TypeVar("MappingT", bound=MegatronParamMapping) +HFPreTrained = TypeVar("HFPreTrained") +MegatronModel = TypeVar("MegatronModel", bound=MegatronModule) +_BridgeImplClass = TypeVar("_BridgeImplClass", bound="MegatronModelBridge") + + +def padding_embedd_size(mcore_weight: torch.Tensor, hf_vocab_size: int): + hf_size = hf_vocab_size + mcore_size = mcore_weight.shape[0] + full_word = {} + is_rank0 = torch.distributed.get_rank() == 0 + # Cut out extra padding we don't need + if mcore_size > hf_size: + full_word = mcore_weight[0:hf_size, :] + if is_rank0: + print(f"> padding embedding size mcore {mcore_size} to hf {hf_size}") + + # Expanding embedding to larger size by replicating final entry + elif mcore_size < hf_size: + padding_size = hf_size - mcore_size + + full_word = torch.cat( + (mcore_weight, mcore_weight[-1].unsqueeze(0).expand(padding_size, -1)) + ) + if is_rank0: + print(f"> padding embedding size mcore {mcore_size} to hf {hf_size}") + # Same size! + else: + full_word = mcore_weight + return full_word + + +class MegatronWeightTuple(NamedTuple): + """Tuple representing a Megatron model weight with its metadata.""" + + param_name: str + weight: torch.Tensor + vp_stage: int + + +class HFWeightTuple(NamedTuple): + """Tuple representing a HuggingFace model weight with its metadata.""" + + param_name: str + weight: torch.Tensor + + +@dataclass(frozen=True) +class WeightConversionTask(Generic[MappingT]): + """A unified task for converting weights between HuggingFace and Megatron formats. + + This class combines both HF->Megatron and Megatron->HF conversion tasks since they + have different method names (hf_to_megatron vs megatron_to_hf) and can coexist safely. + + The task encapsulates all information needed for weight conversion in either direction, + with different fields being relevant depending on the conversion type. + + Attributes: + param_name (str): *unwrapped, local* parameter name (no ``module.`` prefixes). + mapping (MappingT): Concrete :pyclass:`MegatronParamMapping` instance responsible + for weight transformation and distribution. + + pp_rank (Optional[int]): Pipeline-parallel rank that owns the parameter (required for saves). + vp_stage (Optional[int]): Virtual-pipeline stage index (required for loads). + megatron_module (Optional[torch.nn.Module]): Reference to the Megatron model or + sub-module that owns the parameter (required for loads). + param_weight (Optional[torch.Tensor]): The actual parameter tensor that will + receive the converted weight (required for loads). + + """ + + param_name: str + mapping: MappingT + pp_rank: Optional[int] = None + vp_stage: Optional[int] = None + megatron_module: Optional[torch.nn.Module] = None + param_weight: Optional[torch.Tensor] = None + + +def _megatron_local_name_to_global( + models: MegatronModule | List[MegatronModule], + config: TransformerConfig, + param_name: str, + vp_stage: Optional[int] = None, +) -> str: + """Adjust layer number and expert number from local to global numbering.""" + # PP + pp_group = parallel_state.get_pipeline_model_parallel_group() + if "layers." in param_name and get_pg_size(pp_group) > 1: + match = re.match(r"^(.+?\.layers\.\d+)", param_name) + assert match is not None + layer_prefix = match.group(1) + _, layer_module = get_module_and_param_from_name( + models=models, param_name=layer_prefix, vp_stage=vp_stage + ) + + local_layer_number = int(param_name.split("layers.")[1].split(".")[0]) + global_layer_number = layer_module.layer_number - 1 + param_name = param_name.replace( + f"layers.{local_layer_number}.", f"layers.{global_layer_number}." + ) + + # EP + ep_group = parallel_state.get_expert_model_parallel_group() + if ".mlp.experts.linear_fc" in param_name and get_pg_size(ep_group) > 1: + num_experts = config.num_moe_experts + num_experts_per_rank = num_experts // ep_group.size() + + def _update_expert_number(param_name: str, param_type: str) -> str: + """Update expert number from local to global for weight or bias parameters.""" + local_expert_number = int(param_name.split(f".{param_type}")[-1]) + global_expert_number = num_experts_per_rank * ep_group.rank() + local_expert_number + return param_name.replace( + f".{param_type}{local_expert_number}", f".{param_type}{global_expert_number}" + ) + + # Handle weight and bias parameters + if ".weight" in param_name: + param_name = _update_expert_number(param_name, "weight") + elif ".bias" in param_name: + param_name = _update_expert_number(param_name, "bias") + return param_name + + +# class MegatronModelBridge(Generic[HFPreTrained, ModelProviderTarget, MegatronModel]): +class MegatronModelBridge(Generic[HFPreTrained, MegatronModel]): + """ + High-level orchestrator for HuggingFace ↔ Megatron model conversions. + + This abstract base class provides the framework for converting models between + HuggingFace and Megatron formats. It acts as an orchestrator that coordinates + the conversion process without directly handling the complex details of + tensor parallelism or weight transformations. + + The bridge pattern separates concerns: + - MegatronModelBridge: Orchestrates the overall conversion process + - MegatronMappingRegistry: Manages parameter name mappings + - MegatronParamMapping: Handles actual weight transformations and distribution + + Key responsibilities: + 1. Build conversion tasks that map each parameter to its appropriate bridge + 2. Execute tasks with proper error handling and progress tracking + 3. Provide utilities for configuration translation + 4. Handle virtual pipeline parallelism (VP) complexities + + To implement a bridge for a new model architecture: + + 1. Create a subclass decorated with @MegatronModelBridge.register_bridge: + + .. code-block:: python + + @MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel) + class MegatronCausalLlamaBridge(MegatronModelBridge): + pass + + 2. Implement provider_bridge to create Megatron configurations: + + .. code-block:: python + + def provider_bridge(self, hf_pretrained) -> LlamaModelProvider: + return LlamaModelProvider( + num_layers=hf_pretrained.config.num_hidden_layers, + hidden_size=hf_pretrained.config.hidden_size, + ... + ) + + 3. Implement mapping_registry to define weight mappings: + + .. code-block:: python + + def mapping_registry(self) -> MegatronMappingRegistry: + return MegatronMappingRegistry( + AutoMapping( + megatron_param="embedding.word_embeddings.weight", + hf_param="model.embed_tokens.weight" + ), + ... + ) + + Example: + .. code-block:: python + + # The bridge is typically not instantiated directly + # Instead, use AutoBridge or AutoBridge which handle this + bridge = AutoBridge.from_hf_pretrained("meta-llama/Meta-Llama-3-8B") + provider = bridge.to_megatron_provider() + + Note: + This class uses generic type parameters to ensure type safety: + - HFPreTrained: The HuggingFace model type + - ModelProviderTarget: The Megatron model provider type + - MegatronModel: The Megatron model type + """ + + @abc.abstractmethod + def mapping_registry(self) -> MegatronMappingRegistry: + """Define weight mappings between HuggingFace and Megatron formats. + + This abstract method must be implemented by subclasses to specify how + parameters map between the two formats. The returned MegatronMappingRegistry + contains all param mappings needed for the model architecture. + + Returns: + MegatronMappingRegistry: MegatronMappingRegistry containing all weight + mapping definitions. + + Example: + .. code-block:: python + + def mapping_registry(self): + return MegatronMappingRegistry( + AutoMapping( + megatron_param="embedding.word_embeddings.weight", + hf_param="model.embed_tokens.weight" + ), + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight" + ), + # ... more param mappings + ) + """ + raise NotImplementedError("Subclass must implement mapping_registry method") + + def _megatron_global_param_names_all_pp_ranks( + self, megatron_model: Union[MegatronModel, List[MegatronModel]] + ) -> List[str]: + """Get all parameter names across all pipeline parallel ranks.""" + # Cache the result after first call + if hasattr(self, "_cached_param_names"): + return self._cached_param_names + + # Compute the result + pp_group = parallel_state.get_pipeline_model_parallel_group() + model_config = unwrap_model(megatron_model)[0].config + global_param_names = [] + + # Ensure megatron_model is a list for consistent handling + models_list = megatron_model if isinstance(megatron_model, list) else [megatron_model] + + for vp_stage, model in enumerate(models_list): + # persistent buffers are part of the model's state_dict, but not the named_parameters, so we must include them here separately + for local_param_name, _ in itertools.chain( + model.named_parameters(), persistent_buffers(model) + ): + if "_extra_state" in local_param_name: + continue + local_param_name = self._unwrap_name(local_param_name) + global_param_name = _megatron_local_name_to_global( + models_list, model_config, local_param_name, vp_stage + ) + global_param_names.append(global_param_name) + + gathered_global_param_names = [None] * pp_group.size() + torch.distributed.all_gather_object( + gathered_global_param_names, global_param_names, group=pp_group + ) + + # flatten the list, sort it and remove duplicates + # the order matters here, casually re-order will cause a hang. + # e.g. decoder.layers.0.mlp.experts.linear_fc1.weight100 + flattened_names = list(set(sum(gathered_global_param_names, []))) + + # the order cannot be changed, this sync for all ranks for conversion + # change this might cause a hang + gathered_global_param_names = sorted(flattened_names, key=extract_sort_key) + + # Cache the result + self._cached_param_names = gathered_global_param_names + + return self._cached_param_names + + def _with_progress_tracking(self, tasks, description: str, show_progress: bool = True): + """Helper method to wrap an iterable with progress tracking. + + Args: + tasks: Iterable of tasks to process + description: Description for the progress bar + show_progress: Whether to show progress (defaults to True) + + Yields: + Items from the tasks iterable while updating progress + """ + is_main_rank = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + bridge_name = self.__class__.__name__ + + if show_progress: + with Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeRemainingColumn(), + TextColumn("({task.completed}/{task.total})"), + TextColumn("{task.fields[bridge]}"), + disable=not is_main_rank, + ) as progress: + task_id = progress.add_task(description, total=len(tasks), bridge=bridge_name) + + for task in tasks: + yield task + progress.update(task_id, advance=1) + else: + # not using disable above because we notice it will dump some empty progress bar, + # even when disable is set to True + for task in tasks: + yield task + + def load_weights_hf_to_megatron( + self, hf_pretrained: HFPreTrained, megatron_model: Union[MegatronModel, List[MegatronModel]] + ) -> List[MegatronModel]: + """Load HuggingFace weights into Megatron models. + + This method orchestrates the complete weight loading process from HuggingFace + format to Megatron's distributed format. It builds a conversion task and + executes it with proper progress tracking and error handling. + + The actual weight transformations and distribution are delegated to the + appropriate MegatronParamMapping instances based on the state mappings. + + Args: + hf_pretrained (HFPreTrained): HuggingFace model or state source containing the + weights to load. + megatron_model (Union[MegatronModel, List[MegatronModel]]): Megatron model instance + or list of model instances (one per virtual pipeline stage). + + Returns: + List[MegatronModel]: The input megatron_model as a list with loaded weights. + + Process: + 1. Build a task mapping each Megatron parameter to its source + 2. For each parameter in the task: + - Fetch source weights from HuggingFace state + - Apply format transformation via the param mapping + - Distribute to appropriate TP/PP ranks + - Copy into the Megatron parameter + + Example: + .. code-block:: python + + hf_model = PreTrainedCausalLM.from_pretrained("gpt2") + megatron_model = create_megatron_model() # Single model or list + bridge.load_weights_hf_to_megatron(hf_model, megatron_model) + + Note: + Progress is shown only on rank 0 to avoid cluttered output in + distributed environments. + + Raises: + ValueError: If hf_pretrained doesn't have state attribute or if weight shapes don't match. + AttributeError: If required HF weights are missing. + """ + if not isinstance(megatron_model, list): + megatron_model = [megatron_model] + + hf_to_megatron_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) + hf_state_dict: Mapping[str, torch.Tensor] = ( + hf_pretrained.state if hasattr(hf_pretrained, "state") else {} + ) + + description = f"Loading from {hf_pretrained.model_name_or_path}" + for task in self._with_progress_tracking(hf_to_megatron_tasks, description): + # None means megatron module not on current rank, skip if this task is not going to happen + if task.megatron_module is None: + continue + # 1) Fetch source tensor(s) from HF state dict + if isinstance(task.mapping.hf_param, str): + hf_weights = hf_state_dict[task.mapping.hf_param] + else: + hf_weights = {k: hf_state_dict[v] for k, v in task.mapping.hf_param.items()} + + # 2) Delegate conversion & distribution to the bridge + converted_weights = task.mapping.hf_to_megatron(hf_weights, task.megatron_module) + + # 3) Copy into Megatron param if this rank received a shard + if converted_weights is not None: + # Assert that param_weight is not None for HF->Megatron tasks + assert ( + task.param_weight is not None + ), "param_weight is required for HF->Megatron conversion" + + # Check shape compatibility before copying + if converted_weights.shape != task.param_weight.shape: + raise ValueError( + f"Shape mismatch for megatron param {task.mapping.megatron_param}:\n" + f" Expected shape: {task.param_weight.shape}\n" + f" Got shape: {converted_weights.shape}\n" + f" Bridge type: {type(task.mapping).__name__}\n" + f" HF mapping: {task.mapping.hf_param}" + ) + task.param_weight.data.copy_(converted_weights) + + self._broadcast_shared_embeddings(megatron_model) + return megatron_model + + def stream_weights_hf_to_megatron( + self, + hf_pretrained: HFPreTrained, + megatron_model: Union[MegatronModel, List[MegatronModel]], + conversion_tasks: Optional[List[WeightConversionTask]] = None, + ) -> Iterable[MegatronWeightTuple]: + """Generator variant of load_weights_hf_to_megatron for streaming weight conversion. + + This method provides a memory-efficient way to convert weights by yielding + them one at a time instead of loading all at once. Useful for processing + very large models or when implementing custom weight handling logic. + + Args: + hf_pretrained (HFPreTrained): HuggingFace model or state source containing + the weights. + megatron_model (Union[MegatronModel, List[MegatronModel]]): Megatron model instance + or list of model instances to extract configuration from. + conversion_tasks (Optional[List[WeightConversionTask]]): Pre-built conversion tasks. + If not provided, tasks will be built automatically from the models. + + Yields: + MegatronWeightTuple: Named tuples containing: + - vp_stage: Index of the model in megatron_model list + - param_name: Name of the parameter + - weight: Transformed weight tensor for this rank + + Example: + .. code-block:: python + + # Process weights one by one + for weight_tuple in bridge.stream_weights_hf_to_megatron(hf_model, megatron_model): + print(f"Processing {weight_tuple.param_name}: {weight_tuple.weight.shape}") + # Custom processing logic here + + # Or use pre-built conversion tasks + tasks = bridge.build_conversion_tasks(hf_model, megatron_model) + for weight_tuple in bridge.stream_weights_hf_to_megatron(hf_model, megatron_model, tasks): + print(f"Processing {weight_tuple.param_name}: {weight_tuple.weight.shape}") + + Note: + Only yields weights that belong to the current rank after TP/PP distribution. + + Raises: + ValueError: If input parameters are invalid. + """ + + if not isinstance(megatron_model, list): + megatron_model = [megatron_model] + + # Use provided conversion tasks or build them + if conversion_tasks is None: + conversion_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) + + for task in conversion_tasks: + # None means megatron module not on current rank, skip if this task is not going to happen + if task.megatron_module is None: + continue + hf_state_dict: Mapping[str, torch.Tensor] = hf_pretrained.state + if isinstance(task.mapping.hf_param, str): + hf_weights = hf_state_dict[task.mapping.hf_param] + else: + hf_weights = {k: hf_state_dict[v] for k, v in task.mapping.hf_param.items()} + + converted_weights = task.mapping.hf_to_megatron(hf_weights, task.megatron_module) + if converted_weights is not None: + # Assert that vp_stage is not None for HF->Megatron tasks + yield MegatronWeightTuple(task.param_name, converted_weights, task.vp_stage) + + def stream_weights_megatron_to_hf( + self, + megatron_model: Union[MegatronModel, List[MegatronModel]], + hf_pretrained: HFPreTrained, + cpu: bool = True, + show_progress: bool = True, + conversion_tasks: Optional[List[WeightConversionTask]] = None, + ) -> Iterable[HFWeightTuple]: + """Export Megatron weights to HuggingFace format. + + This method orchestrates the conversion of weights from Megatron's distributed + format back to HuggingFace format. It handles gathering from tensor parallel + ranks, broadcasting across pipeline parallel ranks, and format conversions. + All ranks receive the full tensors. + + The export order is determined automatically: + - First tries safetensors order (if key_to_filename_map is available) + - Falls back to HuggingFace state dict order + + Args: + megatron_model (Union[MegatronModel, List[MegatronModel]]): Megatron model instance + or list of model instances (one per virtual pipeline stage). + hf_pretrained (HFPreTrained): HuggingFace model/config for metadata + and mapping info. + cpu (bool, optional): Whether to move tensors to CPU before yielding. + Defaults to True. + show_progress (bool, optional): Display progress bar during export. + Defaults to True. + conversion_tasks (Optional[List[WeightConversionTask]]): Pre-built conversion tasks. + If not provided, tasks will be built automatically from the models. + + Yields: + HFWeightTuple: Named tuples of (param_name, weight_tensor) in HF format. + + Example: + .. code-block:: python + + # Export weights + for name, weight in bridge.stream_weights_megatron_to_hf(megatron_model, hf_config): + print(f"Exported {name}: {weight.shape}") + + # Or use pre-built conversion tasks + tasks = bridge.build_conversion_tasks(hf_config, megatron_model) + for name, weight in bridge.stream_weights_megatron_to_hf( + megatron_model, hf_config, conversion_tasks=tasks + ): + print(f"Exported {name}: {weight.shape}") + + Raises: + ValueError: If input parameters are invalid. + + Note: + All ranks yield the full tensors after gathering from distributed format. + """ + + if not isinstance(megatron_model, list): + megatron_model = [megatron_model] + # Use provided conversion tasks or build them + if conversion_tasks is None: + conversion_tasks = self.build_conversion_tasks(hf_pretrained, megatron_model) + + megatron_to_hf_tasks = conversion_tasks + model_config = unwrap_model(megatron_model)[0].config + # embeddings_are_tied = model_config.share_embeddings_and_output_weights + embeddings_are_tied = not model_config.untie_embeddings_and_output_weights + for task in self._with_progress_tracking( + megatron_to_hf_tasks, "Converting to HuggingFace", show_progress + ): + converted_weights_dict = task.mapping.megatron_to_hf( + task.param_weight, task.megatron_module + ) + + # All ranks get the full tensor + for hf_name, tensor in converted_weights_dict.items(): + final_tensor = tensor.cpu() + + if hf_name == "model.embed_tokens.weight" or hf_name == "lm_head.weight": + final_tensor = padding_embedd_size( + final_tensor, hf_pretrained.config.vocab_size + ) + + # Handle tied embeddings case + # TODO(yuya): fix this hard coded naming + if embeddings_are_tied and hf_name == "model.embed_tokens.weight": + # Yield the embedding weight + yield HFWeightTuple(hf_name, final_tensor) + + # Also yield as lm_head.weight if it's expected + if hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source"): + expected_keys = hf_pretrained.state.source.get_all_keys() + if "lm_head.weight" in expected_keys: + final_tensor = final_tensor.detach().clone() + yield HFWeightTuple("lm_head.weight", final_tensor) + elif embeddings_are_tied and hf_name == "lm_head.weight": + # This should not happen when embeddings are tied - assert error + raise ValueError( + "Encountered lm_head.weight when embeddings are tied. This indicates a mapping error." + ) + else: + # Regular case - yield the tensor normally + yield HFWeightTuple(hf_name, final_tensor) + + def dtype_from_hf(self, config, default=None): + """Extract torch dtype from a HuggingFace config. + + This utility method handles the conversion of dtype specifications in + HuggingFace configs to PyTorch dtype objects. Supports both direct + torch.dtype objects and string representations. + + Args: + config: HuggingFace configuration object with a torch_dtype attribute. + default (Any, optional): Default value to return if torch_dtype is + not str or torch.dtype. Defaults to None. + + Returns: + torch.dtype: The corresponding PyTorch dtype. + + Raises: + AssertionError: If config doesn't have torch_dtype attribute. + ValueError: If torch_dtype is neither a string nor torch.dtype. + + Example: + .. code-block:: python + + dtype = bridge.dtype_from_hf(hf_config) + print(dtype) # torch.float16 + """ + assert hasattr(config, "torch_dtype"), "Expected config to have attr `torch_dtype`" + torch_dtype = config.torch_dtype + if isinstance(torch_dtype, torch.dtype): + return torch_dtype + elif isinstance(torch_dtype, str): + return self.dtype_from_str(torch_dtype) + elif default is not None: + return default + + raise ValueError("torch_dtype is not of type str/torch.dtype") + + def dtype_from_str(self, dtype: str) -> torch.dtype: + """Convert a string precision identifier to equivalent torch dtype. + + This utility method handles various string representations of PyTorch + data types, including common abbreviations and mixed precision formats. + + Args: + dtype (str): String representation of dtype (e.g., "float16", "fp16", + "bf16-mixed"). + + Returns: + torch.dtype: Corresponding PyTorch dtype (defaults to float32 if unknown). + + Supported formats: + - float16/fp16/16/16-mixed → torch.float16 + - bfloat16/bf16-mixed → torch.bfloat16 + - Others → torch.float32 (default) + + Example: + .. code-block:: python + + dtype = bridge.dtype_from_str("fp16") + print(dtype) # torch.float16 + + dtype = bridge.dtype_from_str("bf16-mixed") + print(dtype) # torch.bfloat16 + """ + assert isinstance(dtype, str) + if dtype in ["float16", "fp16", "16", "16-mixed"]: + return torch.float16 + elif dtype in ["bfloat16", "bf16-mixed"]: + return torch.bfloat16 + else: + return torch.float32 + + def make_vocab_size_divisible_by(self, vocab_size: int) -> int: + """Calculate an appropriate divisor for vocabulary size padding. + + Megatron requires vocabulary sizes to be divisible by certain values for + efficient tensor parallelism. This method finds the largest power of 2 + (up to 128) that evenly divides the vocabulary size. + + Args: + vocab_size (int): Original vocabulary size from the model. + + Returns: + int: Largest power of 2 (≤ 128) that divides vocab_size. + + Example: + .. code-block:: python + + # For vocab_size=50257 (GPT-2) + divisor = bridge.make_vocab_size_divisible_by(50257) + print(divisor) # 1 (50257 is prime) + + # For vocab_size=32000 (Llama) + divisor = bridge.make_vocab_size_divisible_by(32000) + print(divisor) # 128 + + Note: + The returned value is used by Megatron to potentially pad the + vocabulary to ensure efficient parallelization. + """ + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + # def _get_provider_from_model(self, model: MegatronModule) -> ModelProviderTarget: + # """Extract provider/config from model.""" + # model = unwrap_model(model) + # return model.config + + def _unwrap_name(self, name: str) -> str: + """Unwrap name from DDP or other wrappers. + + Args: + name: Parameter name that may have 'module.' prefixes + + Returns: + Unwrapped parameter name with 'module.' prefixes removed + + Example: + 'module.module.decoder.weight' -> 'decoder.weight' + """ + if not isinstance(name, str): + raise ValueError(f"name must be a string, got {type(name)}") + + while name.startswith("module."): + name = name[len("module.") :] + return name + + def _broadcast_shared_embeddings( + self, megatron_model: Union[MegatronModel, List[MegatronModel]] + ) -> None: + """Broadcast shared embeddings and output weights across embedding group. + + When embeddings and output weights are shared and pipeline parallelism is enabled, + this method ensures all ranks in the embedding group have the same weights by + broadcasting from rank 0. + + Args: + megatron_model: Megatron model instance or list of model instances. + """ + unwrapped_model = unwrap_model(megatron_model)[0] + # hack for vlm to work properly + if ( + hasattr(unwrapped_model, "language_model") + and unwrapped_model.language_model is not None + ): + unwrapped_model = unwrapped_model.language_model + model_config = unwrapped_model.config + if ( + not model_config.untie_embeddings_and_output_weights + and model_config.pipeline_model_parallel_size > 1 + ): + # Broadcast embeddings and output weights from rank 0 to embedding group + embd_group = parallel_state.get_embedding_group() + embd_group_ranks = torch.distributed.get_process_group_ranks(embd_group) + if embd_group is not None and torch.distributed.get_rank() in embd_group_ranks: + # Get embeddings and output weights from rank 0 + if hasattr(unwrapped_model, "embedding") and hasattr( + unwrapped_model.embedding, "word_embeddings" + ): + embd_weights = unwrapped_model.embedding.word_embeddings.weight.data + else: + assert hasattr(unwrapped_model, "output_layer"), "Output layer not found" + embd_weights = torch.empty_like(unwrapped_model.output_layer.weight.data) + torch.distributed.broadcast(embd_weights, src=embd_group_ranks[0], group=embd_group) + if hasattr(unwrapped_model, "output_layer"): + unwrapped_model.output_layer.weight.data.copy_(embd_weights) + + def build_conversion_tasks( + self, hf_pretrained: HFPreTrained, megatron_model: List[MegatronModel] + ) -> List[None | WeightConversionTask]: + """Construct the conversion tasks between HF and megatron. + + The algorithm walks over every parameter of every destination model, + asks the :class:`MegatronMappingRegistry` whether it has a mapping for that + parameter, and – if the corresponding HF weights actually exist – yields + an :class:`_HFLoadTask` describing exactly how that parameter will be + populated. + """ + + # Ensure hf_pretrained has the required state structure + if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): + raise ValueError("hf_pretrained.state.source is required for weight ordering") + + hf_keys: Iterable[str] = hf_pretrained.state.source.get_all_keys() + mapping_registry = self.mapping_registry() + model_config = unwrap_model(megatron_model)[0].config + # embeddings_are_tied = model_config.share_embeddings_and_output_weights + embeddings_are_tied = not model_config.untie_embeddings_and_output_weights + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + sorted_global_param_names_all_pp_ranks = self._megatron_global_param_names_all_pp_ranks( + megatron_model + ) + + # Filter out output_layer related parameters if embeddings are tied + if embeddings_are_tied: + sorted_global_param_names_all_pp_ranks = [ + name + for name in sorted_global_param_names_all_pp_ranks + if "output_layer" not in name + ] + + global_names_index_dict = { + name: idx for idx, name in enumerate(sorted_global_param_names_all_pp_ranks) + } + + tasks = [None] * len(sorted_global_param_names_all_pp_ranks) + for vp_stage, model in enumerate(megatron_model): + # persistent buffers are part of the model's state_dict, but not the named_parameters, so we must include them here separately + for local_name, _ in itertools.chain( + model.named_parameters(), persistent_buffers(model) + ): + if "_extra_state" in local_name: + continue + + local_name = self._unwrap_name(local_name) + global_name = _megatron_local_name_to_global( + megatron_model, model_config, local_name, vp_stage + ) + # if name removed due to some reason, continue. e.g. embeddings_are_tied + if global_name not in global_names_index_dict: + print_rank_0(f"WARNING: {global_name} not in global_names_index_dict") + continue + global_name_idx = global_names_index_dict[global_name] + mapping = mapping_registry.megatron_to_hf_lookup(global_name) + if not mapping: + logger.warning(f"WARNING: No mapping found for megatron_param: {global_name}") + continue + # ensure hf weights exist + if isinstance(mapping.hf_param, str): + if mapping.hf_param not in hf_keys: + prefix = '.'.join(mapping.hf_param.split('.')[:-2]) + is_rank0 = torch.distributed.get_rank() == 0 + if ('q_proj.weight' in mapping.hf_param) and ( + f'{prefix}.q_a_layernorm.weight' in hf_keys + and f'{prefix}.q_a_proj.weight' in hf_keys + and f'{prefix}.q_b_proj.weight' in hf_keys + ): + if is_rank0: + logger.warning(f"WARNING:mcore no-lora,but hf use lora") + else: + logger.warning(f"WARNING: Can't find {mapping.hf_param} in hf_keys") + continue + else: + missing_params = [ + hf_param + for hf_param in mapping.hf_param.values() + if hf_param not in hf_keys + ] + if missing_params: + logger.warning( + f"WARNING: Can't find the following HF parameters in hf_keys: {missing_params}" + ) + continue + + local_module, local_weights = get_module_and_param_from_name( + megatron_model, local_name, vp_stage + ) + tasks[global_name_idx] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=vp_stage, + param_name=local_name, + megatron_module=local_module, + param_weight=local_weights, + mapping=mapping, + ) + + # Fill the remaining ones for pp communications + for idx, global_name in enumerate(sorted_global_param_names_all_pp_ranks): + mapping = mapping_registry.megatron_to_hf_lookup(global_name) + if tasks[idx] is None: + # This is an exception here we pass in global name + # we are not using global_name to extract module and weights + # only use it for param mapping auto dispatch checks + tasks[idx] = WeightConversionTask( + pp_rank=pp_rank, + vp_stage=None, + param_name=global_name, + megatron_module=None, + param_weight=None, + mapping=mapping, + ) + + return tasks + + @classmethod + def register_bridge( + cls, *, source: Type[PreTrainedModel] | str, target: Type[MegatronModel] + ) -> Callable[[_BridgeImplClass], _BridgeImplClass]: + """Class decorator for registering bridge implementations. + + This decorator registers a MegatronModelBridge subclass with the dispatch + system, enabling automatic routing of conversions based on the source + HuggingFace model type and target Megatron model type. + + Args: + source (Type[PreTrainedModel] | str): HuggingFace PreTrainedModel class + (e.g., LlamaForCausalLM) or the class name as a string. Using a + string allows registering bridges for architectures that are only + available via auto_map. + target (Type[MegatronModel]): Megatron model class (e.g., GPTModel). + + Returns: + Callable[[_BridgeImplClass], _BridgeImplClass]: Decorator function + that registers the bridge implementation. + + Example: + .. code-block:: python + + @MegatronModelBridge.register_bridge(source=LlamaForCausalLM, target=GPTModel) + class MegatronCausalLlamaBridge(MegatronModelBridge): + def provider_bridge(self, hf_pretrained): + # Implementation + pass + + def mapping_registry(self): + # Implementation + pass + + String-based registration is also supported: + + .. code-block:: python + + @MegatronModelBridge.register_bridge(source="DeepseekV3ForCausalLM", target=GPTModel) + class MegatronDeepseekV3Bridge(MegatronModelBridge): + ... + + Note: + The decorated class is registered with multiple dispatchers to handle + different conversion scenarios. The registration is automatic when the + class is defined. + """ + + return create_bridge_decorator(source=source, target=target) + + +def is_tensor_parallel(param) -> bool: + """Check if a parameter is tensor parallel distributed.""" + return hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel + + +# Core dispatch functions +@dispatch +def get_model_bridge(hf_architecture) -> "MegatronModelBridge": + """Get the appropriate model bridge for a given HuggingFace architecture.""" + ... + + +@dispatch +def stream_weights_megatron_to_hf( + dispatch_instance: MegatronModel, + megatron_model: Union[MegatronModel, List[MegatronModel]], + hf_pretrained: HFPreTrained, + cpu: bool = True, + show_progress: bool = True, + conversion_tasks: Optional[List[WeightConversionTask]] = None, +) -> Iterable[HFWeightTuple]: + """Bridge Megatron model state to HuggingFace format.""" + ... + + +def register_bridge_implementation( + *, + source: Type["PreTrainedModel"] | str, + target: Type["MegatronModule"], + bridge_class: Type["MegatronModelBridge"], +) -> None: + """Register a bridge implementation with the dispatch system. + + Args: + source: HuggingFace PreTrainedModel class or the class name as a string. + Using a string allows registering bridges for architectures that are + available only via auto_map. + target: Megatron model class (e.g., GPTModel) + bridge_class: MegatronModelBridge implementation class + """ + bridge_class_name = bridge_class.__name__ + + @get_model_bridge.impl(source) + def _get_model_bridge_impl(_) -> "MegatronModelBridge": + bridge = bridge_class() + return bridge + + @stream_weights_megatron_to_hf.impl((source, target)) + def _megatron_to_hf_registered_impl( + _, + megatron_model: Union[MegatronModel, List[MegatronModel]], + hf_pretrained: HFPreTrained, + cpu: bool = True, + show_progress: bool = True, + conversion_tasks: Optional[List[WeightConversionTask]] = None, + ) -> Iterable[HFWeightTuple]: + bridge = bridge_class() + return bridge.stream_weights_megatron_to_hf( + megatron_model, + hf_pretrained, + cpu=cpu, + show_progress=show_progress, + conversion_tasks=conversion_tasks, + ) + + # Set meaningful names for debugging + _get_model_bridge_impl.__name__ = f"_bridge_with_{bridge_class_name}" + _megatron_to_hf_registered_impl.__name__ = f"_megatron_to_hf_with_{bridge_class_name}" + + +def create_bridge_decorator( + *, source: Type["PreTrainedModel"] | str, target: Type["MegatronModule"] +) -> Callable[[Type["MegatronModelBridge"]], Type["MegatronModelBridge"]]: + """Create a decorator for registering bridge implementations. + + Args: + source: HuggingFace PreTrainedModel class or the class name as a string + (useful for auto_map architectures) + target: Megatron model class + + Returns: + Decorator function that registers the bridge implementation + """ + + def decorator(bridge_class: Type["MegatronModelBridge"]) -> Type["MegatronModelBridge"]: + register_bridge_implementation(source=source, target=target, bridge_class=bridge_class) + return bridge_class + + return decorator diff --git a/flagscale/train/bridge/models/conversion/param_mapping.py b/flagscale/train/bridge/models/conversion/param_mapping.py new file mode 100644 index 0000000000..3f7b189521 --- /dev/null +++ b/flagscale/train/bridge/models/conversion/param_mapping.py @@ -0,0 +1,1785 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import json +import re + +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union + +import torch +import torch.distributed +import torch.nn as nn + +from megatron.core import mpu +from megatron.core.fp8_utils import FP8_TENSOR_CLASS, HAVE_TE_FP8_TENSOR_CLASS +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import get_pg_rank, get_pg_size + +from flagscale.train.bridge.models.conversion.utils import ( + get_module_and_param_from_name, + remove_non_pickleables, +) + +WeightType = TypeVar("WeightType", torch.Tensor, Dict[str, torch.Tensor]) + +import logging + +logger = logging.getLogger(__name__) + + +def col_padding_size(hf_weight: torch.Tensor, mcore_weight: torch.Tensor, tp_size: int): + hf_size = hf_weight.shape[0] + mcore_size = mcore_weight.shape[0] * tp_size + full_word = {} + is_rank0 = torch.distributed.get_rank() == 0 + # Cut out extra padding we don't need + if hf_size > mcore_size: + full_word = hf_weight[0:mcore_size, :] + if is_rank0: + print(f"> padding TP-ColumnParallelfrom {hf_size} to {mcore_size}") + + # Expanding embedding to larger size by replicating final entry + elif hf_size < mcore_size: + padding_size = mcore_size - hf_size + + full_word = torch.cat((hf_weight, hf_weight[-1].unsqueeze(0).expand(padding_size, -1))) + if is_rank0: + print(f"> padding TP-ColumnParallelfrom {hf_size} to {mcore_size}") + # Same size! + else: + full_word = hf_weight + return full_word + + +class MegatronParamMapping(ABC, Generic[WeightType]): + """ + Abstract base class for weight conversion between Megatron and external formats. + + This class provides the foundation for all weight mappings, handling the complex + conversions between Megatron-Core's distributed tensor formats and standard + (typically HuggingFace) formats. Each concrete mapping implements specific + transformation logic while inheriting common parallel communication patterns. + + Key responsibilities: + - Format transformation (e.g., QKV merging/splitting, gated MLP handling) + - Tensor parallel (TP) distribution and gathering across GPUs + - Pipeline parallel (PP) broadcasting between pipeline stages + - Wildcard pattern resolution for layer-wise mappings + + The mapping abstraction ensures that higher-level code doesn't need to know + about the parallel topology or format differences - it just requests a + conversion and the mapping handles all the complexity. + + Public helper methods for subclasses: + - broadcast_from_pp_rank: Broadcast tensors across pipeline stages + - broadcast_obj_from_pp_rank: Broadcast Python objects across PP ranks + - broadcast_tensor_to_tp_ranks: Broadcast within TP group + - scatter_to_tp_ranks: Distribute tensor shards to TP ranks + - gather_from_tp_ranks: Collect tensor shards from TP ranks + + Example: + .. code-block:: python + + class MyCustomMapping(MegatronParamMapping[torch.Tensor]): + def hf_to_megatron(self, hf_weights, megatron_module): + # Custom transformation logic + transformed = hf_weights.t() # Example: transpose + # Use helpers for distribution + return self.scatter_to_tp_ranks(...) + + def megatron_to_hf(self, megatron_weights, megatron_module): + # Broadcast from owning PP rank + weight = self.broadcast_from_pp_rank(megatron_weights) + # Gather from TP ranks and transform + gathered = self.gather_from_tp_ranks(weight) + return {"custom_weight": gathered[0].t()} + """ + + def __init__(self, megatron_param: str, hf_param: Union[str, Dict[str, str]]): + """Initialize the weight mapping. + + Args: + megatron_param (str): Megatron parameter name pattern (supports * + wildcards). + hf_param (Union[str, Dict[str, str]]): External format name pattern(s). + """ + self.megatron_param = megatron_param + self.hf_param = hf_param + self._validate_patterns() + + # Cache for metadata and tensor_spec_output + self._broadcast_obj_cache = {} + self._tensor_spec_output_cache = {} + + if mpu.is_initialized(): + self.pp_group = mpu.get_pipeline_model_parallel_group() + self.ep_group = mpu.get_expert_model_parallel_group() + self._tp_group = mpu.get_tensor_model_parallel_group() + self._etp_group = mpu.get_expert_tensor_parallel_group() + else: + self.pp_group = None + self.ep_group = None + self._tp_group = None + self._etp_group = None + + @property + def tp_group(self): + """Get the tensor model parallel group.""" + if self.is_expert: + return self._etp_group + return self._tp_group + + @property + def tp_rank(self) -> int: + """Get the tensor model parallel rank.""" + return get_pg_rank(self.tp_group) + + @property + def tp_size(self) -> int: + """Get the tensor model parallel size.""" + return get_pg_size(self.tp_group) + + @property + def pp_rank(self) -> int: + """Get the pipeline model parallel rank.""" + return get_pg_rank(self.pp_group) + + @property + def pp_size(self) -> int: + """Get the pipeline model parallel size.""" + return get_pg_size(self.pp_group) + + @property + def ep_rank(self) -> int: + """Get the expert model parallel rank.""" + return get_pg_rank(self.ep_group) + + @property + def ep_size(self) -> int: + """Get the expert model parallel size.""" + return get_pg_size(self.ep_group) + + @property + def etp_rank(self) -> int: + """Get the expert tensor parallel rank.""" + return get_pg_rank(self.etp_group) + + @property + def etp_size(self) -> int: + """Get the expert tensor parallel size.""" + return get_pg_size(self.etp_group) + + @property + def is_expert(self) -> bool: + """Check if this mapping is for an expert parameter.""" + return ".mlp.experts.linear_fc" in self.megatron_param + + def _resolve_names(self, captures: Tuple[str, ...]) -> Tuple[str, Union[str, Dict[str, str]]]: + """Resolve wildcard patterns with captured values. + + Handles both ** (any characters) and * (digits) wildcards in order. + ** patterns are processed before * patterns to avoid conflicts. + """ + resolved_megatron_param = self.megatron_param + capture_index = 0 + + # First pass: resolve ** wildcards + while "**" in resolved_megatron_param and capture_index < len(captures): + resolved_megatron_param = resolved_megatron_param.replace( + "**", captures[capture_index], 1 + ) + capture_index += 1 + + # Second pass: resolve * wildcards + while "*" in resolved_megatron_param and capture_index < len(captures): + resolved_megatron_param = resolved_megatron_param.replace( + "*", captures[capture_index], 1 + ) + capture_index += 1 + + if isinstance(self.hf_param, str): + resolved_hf_param = self.hf_param + capture_index = 0 + + # First pass: resolve ** wildcards + while "**" in resolved_hf_param and capture_index < len(captures): + resolved_hf_param = resolved_hf_param.replace("**", captures[capture_index], 1) + capture_index += 1 + + # Second pass: resolve * wildcards + while "*" in resolved_hf_param and capture_index < len(captures): + resolved_hf_param = resolved_hf_param.replace("*", captures[capture_index], 1) + capture_index += 1 + else: + resolved_hf_param = {} + for k, v in self.hf_param.items(): + resolved_v = v + capture_index = 0 + + # First pass: resolve ** wildcards + while "**" in resolved_v and capture_index < len(captures): + resolved_v = resolved_v.replace("**", captures[capture_index], 1) + capture_index += 1 + + # Second pass: resolve * wildcards + while "*" in resolved_v and capture_index < len(captures): + resolved_v = resolved_v.replace("*", captures[capture_index], 1) + capture_index += 1 + + resolved_hf_param[k] = resolved_v + + return resolved_megatron_param, resolved_hf_param + + def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": + """Create a new mapping with resolved wildcards. + + This default implementation works for mappings with a + (megatron_param, hf_param) constructor. + + Args: + captures (Tuple[str, ...]): Captured wildcard values. + + Returns: + MegatronParamMapping: A new mapping instance with resolved names. + """ + resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) + return type(self)(resolved_megatron_param, resolved_hf_param) + + @abstractmethod + def hf_to_megatron(self, hf_weights: WeightType, megatron_module: nn.Module) -> torch.Tensor: + """Convert hf_weights TO Megatron format. + + This method handles: + 1. Format transformation (if needed) + 2. Tensor parallel distribution (if self.tp_size > 1) + + Args: + hf_weights (WeightType): Source hf_weights in external format. + megatron_module (nn.Module): Target Megatron module (for config + access). + + Returns: + torch.Tensor: Weight tensor ready for the current TP rank. + """ + ... + + @abstractmethod + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Convert weights FROM Megatron format. + + This method handles: + 1. Pipeline parallel broadcasting (if weight is on different PP rank) + 2. Tensor parallel gathering (if needed) + 3. Format transformation + + Args: + megatron_weights (Optional[torch.Tensor]): Weight tensor from current + rank (None if on different PP rank). + megatron_module (Optional[nn.Module]): Module for config access + (None if on different PP rank). + + Returns: + Dict[str, torch.Tensor]: Converted weights (empty dict if not on + TP rank 0). + """ + ... + + def broadcast_from_pp_rank( + self, tensor: Optional[torch.Tensor], cache_key: Optional[str] = None + ) -> Optional[torch.Tensor]: + """Broadcast a tensor from the pipeline-parallel rank that owns it. + + Broadcasts to **all** PP ranks. This mirrors the behaviour of + `broadcast_from_megatron_pp` in the original MMapping implementation and + additionally keeps the tensor-parallel metadata (`tensor_model_parallel`, + `partition_dim`) consistent on every rank. + + Args: + tensor (Optional[torch.Tensor]): The local tensor if the current PP + rank owns it. ``None`` otherwise. + + Returns: + Optional[torch.Tensor]: The broadcasted tensor on every PP rank, or + ``None`` if *no* PP rank owned the tensor (which indicates a bug + in the calling code). + """ + + # Fast-path when we are not using pipeline parallelism. + if self.pp_size == 1: + return tensor + + # ------------------------------------------------------------------ + # 1. Gather (shape, dtype, tensor_parallel flag, partition_dim) from + # every PP rank so that we can find the source rank. + # ------------------------------------------------------------------ + if cache_key is not None and cache_key in self._tensor_spec_output_cache: + tensor_spec_output = self._tensor_spec_output_cache[cache_key] + else: + if tensor is not None: + shape = tensor.shape + dtype = tensor.dtype + tensor_parallel = getattr(tensor, "tensor_model_parallel", None) + partition_dim = getattr(tensor, "partition_dim", None) + tensor_spec = (shape, dtype, tensor_parallel, partition_dim) + else: + tensor_spec = None + + tensor_spec_output: list[Optional[tuple]] = [None] * self.pp_size + torch.distributed.all_gather_object( + tensor_spec_output, tensor_spec, group=self.pp_group + ) + self._tensor_spec_output_cache[cache_key] = tensor_spec_output + + # ------------------------------------------------------------------ + # 2. Identify the owning rank (the only rank with a non-None spec). + # ------------------------------------------------------------------ + target_tensor_spec = None + src_rank = None # Rank *inside* the PP group. + for rank, spec in enumerate(tensor_spec_output): + if spec is not None: + if target_tensor_spec is not None: + raise ValueError( + f"Tensor exists on more than one PP rank. Found on ranks {src_rank} and {rank}." + ) + target_tensor_spec = spec + src_rank = rank + + if target_tensor_spec is None: + # No rank had the tensor – this is an error in the caller. + raise ValueError("Object must exist on at least one PP rank") + + # ------------------------------------------------------------------ + # 3. Ensure every rank has an allocated tensor with the right shape + # and dtype before the broadcast. + # ------------------------------------------------------------------ + if tensor is None: + shape, dtype, tensor_parallel, partition_dim = target_tensor_spec + # Use CPU by default, unless CUDA is available + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tensor = torch.empty(shape, dtype=dtype, device=device) + if tensor_parallel is not None: + tensor.tensor_model_parallel = tensor_parallel + if partition_dim is not None: + tensor.partition_dim = partition_dim + + # ------------------------------------------------------------------ + # 4. Broadcast from the source PP rank to all other PP ranks. + # ------------------------------------------------------------------ + global_src = torch.distributed.get_global_rank(group=self.pp_group, group_rank=src_rank) + torch.distributed.broadcast(tensor, src=global_src, group=self.pp_group) + + return tensor + + def broadcast_obj_from_pp_rank( + self, obj: Optional[Any], cache_key: Optional[str] = None + ) -> Any: + """Broadcast any Python object from the PP rank that owns it. + + This method is useful for broadcasting configuration objects or + other metadata across pipeline parallel ranks. Results are cached + after the first call to avoid redundant broadcasts. + + Args: + obj (Optional[Any]): Object to broadcast (None on non-owning ranks). + cache_key (Optional[str]): Optional cache key. If not provided, + no caching will be performed. + + Returns: + Any: Broadcasted object on all ranks. + + Raises: + ValueError: If object exists on multiple ranks or no ranks. + """ + if self.pp_size == 1: + return obj + + # Check if we already have a cached result (only if cache_key is provided) + if cache_key is not None and cache_key in self._broadcast_obj_cache: + return self._broadcast_obj_cache[cache_key] + + # ------------------------------------------------------------------ + # 1. Gather presence flags from all PP ranks to find the source rank + # ------------------------------------------------------------------ + has_obj = obj is not None + obj_flags = [None] * self.pp_size + torch.distributed.all_gather_object(obj_flags, has_obj, group=self.pp_group) + + # ------------------------------------------------------------------ + # 2. Identify the owning rank (the only rank with True flag) + # ------------------------------------------------------------------ + src_rank = None # Rank *inside* the PP group + for rank, flag in enumerate(obj_flags): + if flag: + src_rank = rank + + if src_rank is None: + raise ValueError("Object must exist on at least one PP rank") + + # ------------------------------------------------------------------ + # 3. Broadcast the object from the source rank to all ranks + # ------------------------------------------------------------------ + if src_rank is None: + raise ValueError("Could not determine source rank") + + # Use broadcast_object_list which is more robust than all_gather_object + obj_list = [obj] + pp_ranks = torch.distributed.get_process_group_ranks(self.pp_group) + global_src = pp_ranks[src_rank] + torch.distributed.broadcast_object_list(obj_list, src=global_src, group=self.pp_group) + + result = obj_list[0] + + # Cache the result for future calls (only if cache_key is provided) + if cache_key is not None: + self._broadcast_obj_cache[cache_key] = result + + return result + + def clear_broadcast_cache(self): + """Clear the broadcast object cache. + + This can be useful for testing or if the objects being broadcast + might change during the lifetime of the mapping. + """ + self._broadcast_obj_cache.clear() + + def clear_tensor_spec_output_cache(self): + """Clear the tensor spec output cache. + + This can be useful for testing or if the tensor spec output + might change during the lifetime of the mapping. + """ + self._tensor_spec_output_cache.clear() + + def broadcast_tensor_to_tp_ranks(self, tensor: torch.Tensor, src_rank: int = 0) -> torch.Tensor: + """Broadcast a tensor to all TP ranks. + + Args: + tensor (torch.Tensor): The tensor to broadcast. + src_rank (int, optional): The source rank within the TP group. + Defaults to 0. + + Returns: + torch.Tensor: The broadcasted tensor. + """ + if self.tp_size == 1: + return tensor + + global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank) + torch.distributed.broadcast(tensor, src=global_src, group=self.tp_group) + return tensor + + def scatter_to_tp_ranks( + self, + splits: Optional[List[torch.Tensor]], + output_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + src_rank: int = 0, + ) -> torch.Tensor: + """Scatter tensor splits to TP ranks. + + Args: + splits (Optional[List[torch.Tensor]]): A list of tensor shards to + scatter. Only rank `src_rank` needs this. + output_shape (torch.Size): The shape of the output tensor on each rank. + dtype (torch.dtype): The data type of the output tensor. + device (torch.device): The device for the output tensor. + src_rank (int, optional): The source rank for the scatter operation. + Defaults to 0. + + Returns: + torch.Tensor: The scattered tensor shard on the current rank. + """ + if self.tp_size == 1: + return splits[0].to(device=device) if splits else None + + output = torch.empty(output_shape, dtype=dtype, device=device) + global_src = torch.distributed.get_global_rank(group=self.tp_group, group_rank=src_rank) + + scatter_list = None + if self.tp_rank == src_rank and splits: + scatter_list = [s.to(device=device) for s in splits] + + torch.distributed.scatter(output, scatter_list, src=global_src, group=self.tp_group) + return output + + def gather_from_tp_ranks(self, tensor: torch.Tensor) -> List[torch.Tensor]: + """Gather tensors from all TP ranks. + + Args: + tensor (torch.Tensor): The tensor shard to be gathered from the + current rank. + + Returns: + List[torch.Tensor]: A list of tensor shards from all TP ranks. + """ + if self.tp_size == 1: + return [tensor] + + gathered = [torch.empty_like(tensor) for _ in range(self.tp_size)] + torch.distributed.all_gather(gathered, tensor, group=self.tp_group) + return gathered + + def _count_wildcard_groups(self, pattern: str) -> int: + """Count the number of wildcard capture groups in a pattern. + + Args: + pattern: Pattern string with * and ** wildcards + + Returns: + Number of capture groups that will be generated + + Note: + ** counts as 1 group, * counts as 1 group + ** must be counted before * to avoid double-counting + """ + count = 0 + remaining = pattern + + # Count ** patterns first + while "**" in remaining: + count += 1 + remaining = remaining.replace("**", "", 1) + + # Count remaining * patterns + count += remaining.count("*") + + return count + + def _validate_patterns(self): + """Validate wildcard consistency between patterns.""" + megatron_param_wildcards = self._count_wildcard_groups(self.megatron_param) + if isinstance(self.hf_param, str): + hf_param_wildcards = self._count_wildcard_groups(self.hf_param) + if megatron_param_wildcards != hf_param_wildcards: + raise ValueError( + f"Wildcard count mismatch: megatron_param='{self.megatron_param}' has " + f"{megatron_param_wildcards} wildcards, hf_param='{self.hf_param}' has {hf_param_wildcards}" + ) + else: + for key, pattern in self.hf_param.items(): + hf_param_wildcards = self._count_wildcard_groups(pattern) + if megatron_param_wildcards != hf_param_wildcards: + raise ValueError( + f"Wildcard count mismatch: megatron_param='{self.megatron_param}' has " + f"{megatron_param_wildcards} wildcards, hf_param['{key}']='{pattern}' has {hf_param_wildcards}" + ) + + def _normalize_expert_param_name(self, param_name: str) -> str: + """Normalize expert parameter name by replacing trailing numbers with 0. + e.g. experts.weight15 -> experts.weight0, experts.bias15 -> experts.bias0 + + Args: + param_name (str): Parameter name that may end with a number. + + Returns: + str: Parameter name with trailing number replaced by 0. + """ + # Use regex to replace any trailing number with 0 + return re.sub(r"\d+$", "0", param_name) + + def _get_config(self, module: nn.Module) -> Any: + """Extract configuration from module hierarchy.""" + current = module + while current is not None: + if hasattr(current, "config"): + return current.config + # Try parent module + if hasattr(current, "_parent"): + current = current._parent + else: + # Walk up the module tree + for parent_module in module.modules(): + for child_name, child_module in parent_module.named_children(): + if child_module is current: + current = parent_module + break + else: + continue + break + else: + current = None + + raise ValueError( + f"Could not find config in module hierarchy for {module.__class__.__name__}. " + f"Ensure the module or its parent has a 'config' attribute." + ) + + def gather_from_ep_ranks( + self, + megatron_weights: Optional[torch.Tensor], + megatron_module: Optional[MegatronModule], + hf_param_name: Optional[str], + ) -> Dict[str, torch.Tensor]: + """Handle expert parallel weight gathering for MoE models. + + This method gathers expert weights across expert-parallel (EP) ranks and + returns a mapping from HF parameter names to the corresponding tensors + from each EP rank. Call this only for confirmed expert parameters + (self.is_expert is True), typically after TP gathering/concatenation in + the export path (Megatron → HF). + + Behavior and notation: + - Let E be the total number of experts (e.g., config.num_moe_experts) and + S be the expert-parallel size (ep_size). We assume E % S == 0. + - Each EP rank owns E/S experts. For a given parameter name, we infer a + local expert index L (0 ≤ L < E/S) on the current EP rank from the + global expert id embedded in the name (works for both .weight and .bias). + - The set of global expert ids that correspond to this local index L + across all EP ranks is: {L + k * (E/S) | k ∈ [0, S-1]}. + + Communication and outputs: + - We perform an all_gather over the EP group to collect the tensor from + every EP rank into a list ordered by EP rank id. + - For each EP rank k, we construct the HF parameter name by replacing the + expert id in `hf_param_name` with (L + k * (E/S)), preserving the rest + of the path, and map that name to the gathered tensor from rank k. + + Example: + - E = 8, S = 2 → E/S = 4. Experts are distributed as: + Rank 0: [0, 1, 2, 3], Rank 1: [4, 5, 6, 7]. + If the local index L = 0 (derived from the param name), this returns: + {"...experts.0.weight": tensor_from_rank0, "...experts.4.weight": tensor_from_rank1} + + Args: + megatron_weights (Optional[torch.Tensor]): The local expert weight tensor + (after any TP handling) on this EP rank. + megatron_module (Optional[MegatronModule]): The Megatron module containing + configuration (used to determine E and E/S). Can be None on non-owning PP + ranks; values will be broadcast across PP. + hf_param_name (Optional[str]): HF parameter name template for the current + (local) expert on this rank. The expert id within this string is replaced + with the appropriate global expert ids for each EP rank. + + Returns: + Dict[str, torch.Tensor]: Mapping from HF parameter names (one per EP rank) + to the corresponding expert tensors gathered from each EP rank. + """ + if megatron_module is None: + num_experts_per_rank = self.broadcast_obj_from_pp_rank(None, "num_experts_per_rank") + else: + model_config = self._get_config(megatron_module) + num_experts = model_config.num_moe_experts + num_experts_per_rank = num_experts // self.ep_size + num_experts_per_rank = self.broadcast_obj_from_pp_rank( + num_experts_per_rank, "num_experts_per_rank" + ) + + # Extract local expert number from parameter name + # Handle both .weight and .bias suffixes + local_expert_number = None + for key in (".weight", ".bias"): + if key in self.megatron_param: + global_expert_number = int(self.megatron_param.split(key)[-1]) + local_expert_number = global_expert_number % num_experts_per_rank + + # Compute global expert numbers for all EP ranks + # use regex to replace the local expert number with the global expert number + gathered_expert_param_names = [ + re.sub( + r"experts\.(\d+)", + f"experts.{int(local_expert_number) + num_experts_per_rank * i}", + str(hf_param_name), + ) + for i in range(self.ep_size) + ] + assert ( + hf_param_name in gathered_expert_param_names + ), f"hf_param_name {hf_param_name} not in gathered_expert_param_names {gathered_expert_param_names}" + + # Gather weights from all EP ranks + gathered_weights = [torch.empty_like(megatron_weights) for _ in range(self.ep_size)] + torch.distributed.all_gather(gathered_weights, megatron_weights, group=self.ep_group) + + # Return dictionary mapping HF parameter names to weights + return { + param_name: gathered_weights[i] + for i, param_name in enumerate(gathered_expert_param_names) + } + + def maybe_dequantize(self, tensor: torch.Tensor) -> torch.Tensor: + """Dequantize FP8 tensor if needed.""" + if HAVE_TE_FP8_TENSOR_CLASS and isinstance(tensor, FP8_TENSOR_CLASS): + return tensor.dequantize(dtype=tensor.dtype) + return tensor + + +class DirectMapping(MegatronParamMapping[torch.Tensor]): + """Direct 1:1 weight mapping with no transformation or tensor parallelism.""" + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + """Direct copy - no transformation or distribution.""" + return hf_weights + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Direct copy with PP broadcast.""" + # Handle cross-PP broadcast + megatron_weights = self.broadcast_from_pp_rank( + megatron_weights, cache_key=str(self.hf_param) + ) + + if megatron_weights is None: + return {} + + # Dequantize if needed + megatron_weights = self.maybe_dequantize(megatron_weights) + + return {str(self.hf_param): megatron_weights} + + +class ColumnParallelMapping(MegatronParamMapping[torch.Tensor]): + """ + Mapping for column-parallel linear and embedding weights. + + Column-parallel layers in Megatron split the output dimension across tensor + parallel ranks. This is used for layers where each rank computes a portion + of the output features independently, such as: + - Embedding layers (split vocabulary) + - Linear layers producing hidden states (e.g., QKV projections, MLP up projections) + + The weight matrix is partitioned along dimension 0 (rows), so each TP rank + holds a subset of output features while maintaining all input features. + + **Sharding pattern** + - Original weight: `[output_features, input_features]` + - Rank 0: `[output_features/tp_size, input_features]` + - Rank 1: `[output_features/tp_size, input_features]` + - ... + + **Forward path (HuggingFace → Megatron)** + 1. Validate divisibility: output dimension must be divisible by tp_size + 2. Split: Chunk tensor along dim 0 into tp_size equal parts + 3. Scatter: Distribute chunks to respective TP ranks + + **Reverse path (Megatron → HuggingFace)** + 1. Broadcast: Ensure all PP ranks have the tensor + 2. Gather: Collect chunks from all TP ranks + 3. Concatenate: Reassemble along dim 0 on rank 0 + + Example: + .. code-block:: python + + # For a weight of shape [4096, 1024] with tp_size=4: + # Each rank gets [1024, 1024] after column-parallel split + mapping = ColumnParallelMapping("linear.weight", "transformer.linear.weight") + megatron_weights = mapping.hf_to_megatron(hf_weight, megatron_module) + # megatron_weights.shape = [1024, 1024] on each rank + + Note: + This mapping also handles bias terms, which are 1D tensors split + along their only dimension following the same pattern. + """ + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + """Split weight along dim 0 and distribute to TP ranks.""" + # if self.tp_size == 1: + # return hf_weights + + # Some parameters are named with global expert number, e.g. experts.weight15, + # normalize it to experts.weight0, note we are only use the shape, dtype, device info, + # not the actual value, so it is safe to do this. + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + + if self.tp_size == 1: + full_weight = col_padding_size(hf_weights, target_param, self.tp_size) + return full_weight + + # On rank 0, check for divisibility and split + if self.tp_rank == 0: + if hf_weights is None: + raise ValueError("hf_weights should not be None on rank 0") + + # For MCore MambaMixer, A_log is initialized in FP32 but cast to BF16 when + # saving ckpts, including the ckpt uploaded to HF. Without this cast, + # self.scatter_to_tp_ranks will try to scatter the HF A_log weights in BF16 to + # the Megatron tensor which is in FP32. This will error. So we cast before the scatter. + if hf_weights.dtype != target_param.dtype: + logger.warning( + f"WARNING: Dtype mismatch between HuggingFace weights and Megatron module. " + f"HF dtype: {hf_weights.dtype}. Megatron dtype: {target_param.dtype}. " + f"Casting HF weights to Megatron dtype. THIS MAY RESULT IN A LOSS OF PRECISION. " + ) + hf_weights = hf_weights.to(target_param.dtype) + + # For bias (1D), we still split along dim 0 + # For weight (2D), we split along dim 0 (output dimension) + # full_size = hf_weights.shape[0] + # if full_size % self.tp_size != 0: + # raise ValueError( + # f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks" + # ) + # splits = torch.chunk(hf_weights, self.tp_size, dim=0) + full_weight = col_padding_size(hf_weights, target_param, self.tp_size) + full_size = full_weight.shape[0] + if full_size % self.tp_size != 0: + raise ValueError( + f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks" + ) + splits = torch.chunk(full_weight, self.tp_size, dim=0) + else: + splits = None + + # Scatter to all ranks. Each rank gets its sharded shape from its module. + return self.scatter_to_tp_ranks( + splits, target_param.shape, target_param.dtype, target_param.device + ) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Gather from all TP ranks and concatenate.""" + # Handle cross-PP broadcast + megatron_weights = self.broadcast_from_pp_rank( + megatron_weights, cache_key=str(self.hf_param) + ) + + if megatron_weights is None: + return {} + + # Dequantize if needed + megatron_weights = self.maybe_dequantize(megatron_weights) + + if self.tp_size == 1: + full_weights = megatron_weights + else: + # Gather from all TP ranks + gathered = self.gather_from_tp_ranks(megatron_weights) + full_weights = torch.cat(gathered, dim=0) + + if self.is_expert: + return self.gather_from_ep_ranks(full_weights, megatron_module, self.hf_param) + + return {str(self.hf_param): full_weights} + + +class RowParallelMapping(MegatronParamMapping[torch.Tensor]): + """Mapping for **row-parallel** linear weights. + + Megatron shards row-parallel tensors along **dimension 1** (the *input* + dimension of a linear layer). + + **Forward path (external → Megatron)** + 1. Rank 0 validates that the *second* dimension is divisible by `tp_size`. + 2. Rank 0 splits the tensor with `torch.chunk(..., dim=1)` producing + `tp_size` equally-sized shards. + 3. The shards are **scattered** so that every TP rank receives exactly one + shard matching the shape of its local Megatron parameter. + + **Reverse path (Megatron → external)** + 1. The local Megatron parameter (which may live on any PP rank) is + broadcast to all PP ranks so that the gather step can be collective. + 2. All TP ranks **gather** their shard. + 3. Rank 0 concatenates the gathered list along dim 1 to reconstruct the + original unsharded weight and emits it under the external (HF) name. + """ + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + """Split weight along dim 1 and distribute to TP ranks.""" + if self.tp_size == 1: + return hf_weights + + # Some parameters are named with global expert number, e.g. experts.weight15, + # normalize it to experts.weight0, note we are only use the shape, dtype, device info, + # not the actual value, so it is safe to do this. + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + + # On rank 0, check for divisibility and split + if self.tp_rank == 0: + if hf_weights is None: + raise ValueError("hf_weights should not be None on rank 0") + + # For bias (1D), we still split along dim 0 + # For weight (2D), we split along dim 1 + if hf_weights.ndim == 1: + full_size = hf_weights.shape[0] + if full_size % self.tp_size != 0: + raise ValueError( + f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks" + ) + splits = torch.chunk(hf_weights, self.tp_size, dim=0) + else: + assert hf_weights.ndim == 2 + full_size = hf_weights.shape[1] + if full_size % self.tp_size != 0: + raise ValueError( + f"Cannot evenly split dimension 0 size {full_size} across {self.tp_size} TP ranks" + ) + splits = torch.chunk(hf_weights, self.tp_size, dim=1) + + else: + splits = None + + # Scatter to all ranks. Each rank gets its sharded shape from its module. + return self.scatter_to_tp_ranks( + splits, target_param.shape, target_param.dtype, target_param.device + ) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Gather from all TP ranks and concatenate.""" + # Handle cross-PP broadcast + megatron_weights = self.broadcast_from_pp_rank( + megatron_weights, cache_key=str(self.hf_param) + ) + + if megatron_weights is None: + return {} + + # Dequantize if needed + megatron_weights = self.maybe_dequantize(megatron_weights) + + if self.tp_size == 1: + full_weights = megatron_weights + else: + gathered = self.gather_from_tp_ranks(megatron_weights) + full_weights = torch.cat(gathered, dim=1) + + if self.is_expert: + return self.gather_from_ep_ranks(full_weights, megatron_module, self.hf_param) + + return {str(self.hf_param): full_weights} + + +class ReplicatedMapping(MegatronParamMapping[torch.Tensor]): + """Mapping for weights that are **fully replicated** across TP ranks. + + Examples: layer-norm scales, biases, router weights in MoE, etc. + + These tensors exist in exactly the same form on *every* TP rank, so the + mapping logic is trivial – but we still need to broadcast across TP ranks + during *load* (HF → Megatron) and ensure we do **not** emit duplicates + during *export* (Megatron → HF). + """ + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + """Replicate weight to all TP ranks.""" + try: + target_device = megatron_module.weight.device + except AttributeError: + # the parameter may not be called "weight" + target_device = next(megatron_module.parameters()).device + hf_weights = hf_weights.to(device=target_device) + if self.tp_size == 1: + return hf_weights + + # TODO(yuya): router.weight is on device cpu, need to check. + if target_device.index != torch.cuda.current_device(): + hf_weights = hf_weights.to(torch.cuda.current_device()) + + # All ranks need the full weight + if self.tp_rank > 0: + # Create empty tensor of correct shape + hf_weights = torch.empty_like(hf_weights) + + # Broadcast from rank 0 to all TP ranks + return self.broadcast_tensor_to_tp_ranks(hf_weights, src_rank=0) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Return weight only from rank 0 to avoid duplication.""" + # Handle cross-PP broadcast + megatron_weights = self.broadcast_from_pp_rank( + megatron_weights, cache_key=str(self.hf_param) + ) + + if megatron_weights is None: + return {} + + # Dequantize if needed + megatron_weights = self.maybe_dequantize(megatron_weights) + + if self.is_expert: + return self.gather_from_ep_ranks(megatron_weights, megatron_module, self.hf_param) + + return {str(self.hf_param): megatron_weights} + + +class AutoMapping(MegatronParamMapping[torch.Tensor]): + """ + Smart mapping that automatically detects and applies the correct parallelism strategy. + + This mapping eliminates the need to manually specify whether a layer is + column-parallel, row-parallel, or replicated. It examines the Megatron + module at runtime and delegates to the appropriate specialized mapping. + + **Detection strategy** + 1. Check module class name against a registry of known types + 2. If unknown, examine module attributes (tensor_model_parallel, partition_dim) + 3. Delegate to appropriate mapping: ColumnParallel, RowParallel, or Replicated + + This abstraction is particularly useful for model-agnostic code where you + don't know the parallelism type ahead of time, or when working with models + that mix different parallelism strategies. + + **Built-in module recognition** + - Column-parallel: `ColumnParallelLinear`, `VocabParallelEmbedding`, etc. + - Row-parallel: `RowParallelLinear`, `TERowParallelLinear` + - Replicated: `LayerNorm`, `RMSNorm`, and other normalization layers + + Example: + .. code-block:: python + + # Automatically handles any weight type + mapping = AutoMapping( + megatron_param="decoder.layers.*.mlp.linear_fc1.weight", + hf_param="model.layers.*.mlp.gate_proj.weight" + ) + + # Works with column-parallel layers + megatron_weights = mapping.hf_to_megatron(hf_weight, column_parallel_module) + + # Also works with normalization layers + norm_weight = mapping.hf_to_megatron(hf_norm, layer_norm_module) + + # Register custom module types + AutoMapping.register_module_type("MyCustomLinear", "column") + + Note: + If the parallelism type cannot be determined, the mapping will raise + a descriptive error suggesting how to fix the issue. + """ + + # Module type registry + _MODULE_TYPE_REGISTRY: Dict[str, set] = { + "column": { + "ColumnParallelLinear", + "TEColumnParallelLinear", + "TELayerNormColumnParallelLinear", + "TEColumnParallelGroupedLinear", + "VocabParallelEmbedding", + }, + "row": {"RowParallelLinear", "TERowParallelLinear", "TERowParallelGroupedLinear"}, + "replicated": { + # Normalization layers + "TENorm", + "FusedLayerNorm", + "WrappedTorchNorm", + "LayerNorm", + "RMSNorm", + "L2Norm", + # Other non-parallel modules + "IdentityOp", + "DotProductAttention", + "TEDotProductAttention", + "TopKRouter", + }, + } + + @classmethod + def register_module_type(cls, module_name: str, parallelism_type: str): + """Register a new module type for automatic parallelism detection. + + Args: + module_name (str): The name of the module class (e.g., + 'MyColumnLinear'). + parallelism_type (str): One of 'column', 'row', or 'replicated'. + """ + if parallelism_type not in cls._MODULE_TYPE_REGISTRY: + raise ValueError( + f"Invalid parallelism_type '{parallelism_type}'. " + f"Must be one of {list(cls._MODULE_TYPE_REGISTRY.keys())}" + ) + cls._MODULE_TYPE_REGISTRY[parallelism_type].add(module_name) + + def __init__(self, megatron_param: str, hf_param: str): + """Initialize TP-aware mapping.""" + super().__init__(megatron_param, hf_param) + + # Cache for detected parallelism type and delegate mapping + self._detected_type: Optional[str] = None + self._mapping: Optional[MegatronParamMapping[torch.Tensor]] = None + + def _get_or_create_mapping(self, parallelism_type: str) -> MegatronParamMapping[torch.Tensor]: + """Get or create the appropriate mapping for the given type.""" + if parallelism_type == "column": + return ColumnParallelMapping(self.megatron_param, self.hf_param) + elif parallelism_type == "row": + return RowParallelMapping(self.megatron_param, self.hf_param) + elif parallelism_type == "replicated": + return ReplicatedMapping(self.megatron_param, self.hf_param) + else: + raise ValueError(f"Unknown parallelism type: {parallelism_type}") + + def _detect_parallelism_type(self, module: nn.Module) -> str: + """Detect parallelism type from module.""" + module_type = type(module).__name__ + + # Handle fused modules like TELayerNormColumnParallelLinear + # These modules have both column-parallel weights (weight, bias) + # and replicated layer norm weights (layer_norm_weight, layer_norm_bias) + if module_type == "TELayerNormColumnParallelLinear": + # Check the actual parameter name to determine the correct parallelism type + if self.megatron_param and ( + self.megatron_param.endswith("layer_norm_weight") + or self.megatron_param.endswith("layer_norm_bias") + ): + return "replicated" + # All other parameters (weight, bias) are column-parallel + return "column" + + # Check registry first + for parallelism, types in self._MODULE_TYPE_REGISTRY.items(): + if module_type in types: + return parallelism + + # Fallback to inspecting module attributes + if hasattr(module, "tensor_model_parallel"): + if not module.tensor_model_parallel: + return "replicated" + + # Check partition dimension + partition_dim = getattr(module, "partition_dim", None) + if partition_dim == 0: + return "column" + elif partition_dim == 1: + return "row" + + # Fallback for normalization layers + if any(norm in module_type for norm in ["Norm", "Normalization"]): + return "replicated" + + # Check parallel_mode for TELinear + if module_type == "TELinear": + if module.parallel_mode == "column": + return "column" + elif module.parallel_mode == "row": + return "row" + else: + return "replicated" + + # Cannot determine - raise informative error + known_types = {p: sorted(list(t)) for p, t in self._MODULE_TYPE_REGISTRY.items()} + + raise ValueError( + f"Cannot determine parallelism type for module '{module_type}' " + f"at weight '{self.megatron_param}'.\n" + f"Please use an explicit mapping type (e.g., ColumnParallelMapping) " + f"or register the module type using:\n" + f" AutoMapping.register_module_type('{module_type}', 'column|row|replicated')\n\n" + f"Currently known module types:\n{json.dumps(known_types, indent=2)}" + ) + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + """Delegate to appropriate mapping based on module type.""" + # Detect type and create delegate on first use + if self._mapping is None: + self._detected_type = self._detect_parallelism_type(megatron_module) + self._mapping = self._get_or_create_mapping(self._detected_type) + + return self._mapping.hf_to_megatron(hf_weights, megatron_module) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Delegate to appropriate mapping based on module type.""" + # Need to determine type even if module is None (different PP rank) + assert self.megatron_param is not None, "`megatron_param` is required for AutoMapping." + + if self._mapping is None: + if megatron_module is not None: + self._detected_type = self._detect_parallelism_type(megatron_module) + # Broadcast to other ranks + self._detected_type = self.broadcast_obj_from_pp_rank( + self._detected_type, "detected_type" + ) + else: + # Receive from owning rank + self._detected_type = self.broadcast_obj_from_pp_rank(None, "detected_type") + self._mapping = self._get_or_create_mapping(self._detected_type) + + return self._mapping.megatron_to_hf(megatron_weights, megatron_module) + + +class QKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): + """ + Mapping for interleaved Query/Key/Value attention projection weights. + + This mapping handles the conversion between separate Q, K, V matrices used in + standard transformers and Megatron's optimized interleaved format. The + interleaving pattern groups queries with their corresponding key-value pairs + to maximize GEMM efficiency during attention computation. + + **External format (HuggingFace)** + - Separate tensors: `q_proj`, `k_proj`, `v_proj` + - Each of shape `[hidden_size, hidden_size]` or `[hidden_size, head_dim * num_heads]` + + **Megatron format** + - Single interleaved tensor following grouped query attention (GQA) pattern + - Interleaving order: `[q1...qn, k1, v1, q1...qn, k2, v2, ...]` + - Where `n = num_attention_heads / num_query_groups` + + **Key features** + 1. Format conversion: Handles merging/splitting with proper interleaving + 2. Grouped Query Attention: Supports different numbers of Q and KV heads + 3. Tensor parallelism: Delegates to AutoMapping for distribution + + Example: + .. code-block:: python + + # Create mapping for attention weights + mapping = QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight" + ) + + # Convert from HuggingFace to Megatron + qkv_weights = {"q": q_tensor, "k": k_tensor, "v": v_tensor} + megatron_qkv = mapping.hf_to_megatron(qkv_weights, megatron_module) + + # Convert from Megatron to HuggingFace + hf_weights = mapping.megatron_to_hf(megatron_qkv, megatron_module) + # Returns: {"q_proj.weight": ..., "k_proj.weight": ..., "v_proj.weight": ...} + + Note: + This mapping automatically handles both regular multi-head attention + (same number of Q, K, V heads) and grouped query attention (fewer + KV heads than Q heads) based on the model configuration. + """ + + def __init__(self, megatron_param: str, q: str, k: str, v: str): + """Initialize QKV mapping. + + Args: + megatron_param (str): Megatron QKV parameter name pattern. + q (str): Query weight name pattern. + k (str): Key weight name pattern. + v (str): Value weight name pattern. + """ + super().__init__(megatron_param, {"q": q, "k": k, "v": v}) + # Delegate all tensor-parallel logic to the smart TP-aware mapping so we + # do not hard-code the assumption that QKV projections are column-parallel. + # This keeps the format-handling (merge/split) concerns separate from + # TP/PP distribution mechanics. + self._tp_mapping = AutoMapping(megatron_param, megatron_param) + + def hf_to_megatron( + self, hf_weights: Dict[str, torch.Tensor], megatron_module: nn.Module + ) -> torch.Tensor: + """Merge Q, K, V into interleaved format and distribute.""" + if self.tp_rank == 0: + config = self._get_config(megatron_module) + + # Check if we're dealing with biases (1D tensors) or hf_weights (2D tensors) + if hf_weights["q"].ndim == 1: + # For biases, use the bias-specific merge function + merged = merge_qkv_biases(config, hf_weights["q"], hf_weights["k"], hf_weights["v"]) + else: + # For hf_weights, use the standard merge function + merged = merge_qkv_weights( + config, hf_weights["q"], hf_weights["k"], hf_weights["v"] + ) + else: + merged = None + + # Delegate the actual sharding/broadcasting to the TP-aware mapping. + return self._tp_mapping.hf_to_megatron(merged, megatron_module) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Gather QKV shards and split into Q, K, V.""" + # Dequantize if needed + if megatron_weights is not None: + megatron_weights = self.maybe_dequantize(megatron_weights) + + # ------------------------------------------------------------------ + # Broadcast / retrieve the transformer configuration so that every PP + # rank (also the ones that will early-return) participates in the + # collective communication. + # ------------------------------------------------------------------ + if megatron_module is None: + config = self.broadcast_obj_from_pp_rank(None, "qkv_config") + else: + config = self._get_config(megatron_module) + # create shallow copy and remove non-picklable objects with max depth=2 + config = remove_non_pickleables(config, max_depth=2) + config = self.broadcast_obj_from_pp_rank(config, "qkv_config") + + # Delegate TP/PP gathering. + packed_dict = self._tp_mapping.megatron_to_hf(megatron_weights, megatron_module) + + if not packed_dict: + return {} + + packed_qkv = next(iter(packed_dict.values())) + + # Check if we're dealing with biases (1D) or weights (2D) + if packed_qkv.ndim == 1: + # Split biases + q, k, v = split_qkv_biases(config, packed_qkv) + else: + # Split weights + q, k, v = split_qkv_weights(config, packed_qkv) + + return {self.hf_param["q"]: q, self.hf_param["k"]: k, self.hf_param["v"]: v} + + def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": + """Return a new *resolved* QKVMapping instance.""" + resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) + + return type(self)( + resolved_megatron_param, + resolved_hf_param["q"], + resolved_hf_param["k"], + resolved_hf_param["v"], + ) + + +class ConcatenatedQKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): + """ + Mapping for interleaved Query/Key/Value attention projection weights. + + This mapping handles the conversion between Concatenated Q, K, V matrices used in + some transformers models and Megatron's optimized interleaved format. The + interleaving pattern groups queries with their corresponding key-value pairs + to maximize GEMM efficiency during attention computation. + + **External format (HuggingFace)** + - One tensor with concatenated query, key, value: `qkv`, with shape + `[hidden_size, head_dim * num_heads + 2 * head_dim * num_query_groups]` + + **Megatron format** + - Single interleaved tensor following grouped query attention (GQA) pattern + - Interleaving order: `[q1...qn, k1, v1, q1...qn, k2, v2, ...]` + - Where `n = num_attention_heads / num_query_groups` + + **Key features** + 1. Format conversion: Handles merging/splitting with proper interleaving + 2. Grouped Query Attention: Supports different numbers of Q and KV heads + 3. Tensor parallelism: Delegates to AutoMapping for distribution + + Example: + .. code-block:: python + + # Create mapping for attention weights + mapping = QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + qkv="model.layers.*.self_attn.qkv.weight", + ) + + # Convert from HuggingFace to Megatron + megatron_qkv = mapping.hf_to_megatron(qkv_weights, megatron_module) + + # Convert from Megatron to HuggingFace + hf_weights = mapping.megatron_to_hf(megatron_qkv, megatron_module) + + Note: + This mapping automatically handles both regular multi-head attention + (same number of Q, K, V heads) and grouped query attention (fewer + KV heads than Q heads) based on the model configuration. + """ + + def __init__(self, megatron_param: str, hf_param: str): + """Initialize QKV mapping. + + Args: + megatron_param (str): Megatron interleaved QKV parameter name pattern. + hf_param (str): HF concatenated QKV parameter name pattern. + """ + super().__init__(megatron_param, hf_param) + # Delegate all tensor-parallel logic to the smart TP-aware mapping so we + # do not hard-code the assumption that QKV projections are column-parallel. + # This keeps the format-handling (merge/split) concerns separate from + # TP/PP distribution mechanics. + self._tp_mapping = AutoMapping(megatron_param, megatron_param) + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + """Merge Q, K, V into interleaved format and distribute.""" + if self.tp_rank == 0: + config = self._get_config(megatron_module) + head_num = config.num_attention_heads + head_size = config.kv_channels + num_query_groups = config.num_query_groups + q, k, v = hf_weights.split( + [head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], + dim=0, + ) + # Check if we're dealing with biases (1D tensors) or hf_weights (2D tensors) + if q.ndim == 1: + # For biases, use the bias-specific merge function + merged = merge_qkv_biases(config, q, k, v) + else: + # For hf_weights, use the standard merge function + merged = merge_qkv_weights(config, q, k, v) + else: + merged = None + + # Delegate the actual sharding/broadcasting to the TP-aware mapping. + return self._tp_mapping.hf_to_megatron(merged, megatron_module) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Gather QKV shards and split into Q, K, V.""" + # Dequantize if needed + if megatron_weights is not None: + megatron_weights = self.maybe_dequantize(megatron_weights) + + # ------------------------------------------------------------------ + # Broadcast / retrieve the transformer configuration so that every PP + # rank (also the ones that will early-return) participates in the + # collective communication. + # ------------------------------------------------------------------ + if megatron_module is None: + config = self.broadcast_obj_from_pp_rank(None, "qkv_config") + else: + config = self._get_config(megatron_module) + # create shallow copy and remove non-picklable objects with max depth=2 + config = remove_non_pickleables(config, max_depth=2) + config = self.broadcast_obj_from_pp_rank(config, "qkv_config") + + # Delegate TP/PP gathering. + packed_dict = self._tp_mapping.megatron_to_hf(megatron_weights, megatron_module) + + if not packed_dict: + return {} + + packed_qkv = next(iter(packed_dict.values())) + + # Check if we're dealing with biases (1D) or weights (2D) + if packed_qkv.ndim == 1: + # Split biases + q, k, v = split_qkv_biases(config, packed_qkv) + else: + # Split weights + q, k, v = split_qkv_weights(config, packed_qkv) + + return {str(self.hf_param): torch.cat((q, k, v), dim=0)} + + def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": + """Return a new *resolved* QKVMapping instance.""" + resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) + + return type(self)(resolved_megatron_param, resolved_hf_param) + + +class GatedMLPMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): + r"""Mapping for **gated-MLP** projection weights (SwiGLU / GeGLU). + + Checkpoint formats expose two independent matrices: + + - **G** – gate projection + - **U** – up projection + + Megatron concatenates them row-wise (`[G; U]`) so that a single GEMM can + produce both activations. + + **Responsibilities handled by this mapping** + 1. **Concatenate / split** – convert between `[G; U]` (Megatron) and the + separate `{G, U}` matrices (external). + 2. **Tensor-parallel distribution** – correctly splits gate and up + projections separately before concatenating corresponding shards, + ensuring each TP rank gets the proper [gate_shard; up_shard] format. + + **TP Distribution Strategy** + For tensor parallelism, this mapping: + - Splits gate and up matrices separately along output dimension (dim 0) + - Concatenates corresponding shards: [gate_shard_i; up_shard_i] for rank i + - This ensures each rank's concatenated tensor matches the expected shape + """ + + def __init__(self, megatron_param: str, gate: str, up: str): + """Initialize gated MLP mapping. + + Args: + megatron_param (str): Megatron MLP parameter name pattern. + gate (str): Gate projection weight name pattern. + up (str): Up projection weight name pattern. + """ + super().__init__(megatron_param, {"gate": gate, "up": up}) + + def hf_to_megatron( + self, hf_weights: Dict[str, torch.Tensor], megatron_module: nn.Module + ) -> torch.Tensor: + """Split gate and up separately, then concatenate corresponding shards.""" + # For single TP, just concatenate and return + if self.tp_size == 1: + return torch.cat([hf_weights["gate"], hf_weights["up"]], dim=0) + + # Get target parameter info from megatron module + # Some parameters are named with global expert number, e.g. experts.weight15, + # normalize it to experts.weight0, note we are only use the shape, dtype, device info, + # not the actual value, so it is safe to do this. + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + + # On rank 0, split gate and up separately, then concatenate corresponding pieces + if self.tp_rank == 0: + gate = hf_weights["gate"] + up = hf_weights["up"] + + # Verify shapes match + assert gate.shape == up.shape, "Gate and up weights must have the same shape" + + # Check divisibility for TP splitting + gate_output_size = gate.shape[0] + if gate_output_size % self.tp_size != 0: + raise ValueError( + f"Cannot evenly split gate dimension 0 size {gate_output_size} across {self.tp_size} TP ranks" + ) + + # Split gate and up separately along output dimension (dim 0) + # This works for both bias (1D) and weight (2D) tensors + gate_splits = torch.chunk(gate, self.tp_size, dim=0) + up_splits = torch.chunk(up, self.tp_size, dim=0) + + # Concatenate corresponding pieces: [gate_shard_i; up_shard_i] for each rank i + splits = [torch.cat([gate_splits[i], up_splits[i]], dim=0) for i in range(self.tp_size)] + else: + splits = None + + # Scatter the concatenated shards to each rank + return self.scatter_to_tp_ranks( + splits, target_param.shape, target_param.dtype, target_param.device + ) + + def megatron_to_hf( + self, megatron_weights: Optional[torch.Tensor], megatron_module: Optional[nn.Module] + ) -> Dict[str, torch.Tensor]: + """Gather concatenated shards and split into gate and up.""" + # Handle cross-PP broadcast first + megatron_weights = self.broadcast_from_pp_rank( + megatron_weights, cache_key=str(self.hf_param) + ) + + if megatron_weights is None: + return {} + + # Dequantize if needed + megatron_weights = self.maybe_dequantize(megatron_weights) + + # Handle TP gathering + if self.tp_size == 1: + # No TP, just split the concatenated tensor + fused_mlp = megatron_weights + gate, up = torch.chunk(fused_mlp, 2, dim=0) + + else: + # Gather shards from all TP ranks + gathered_shards = self.gather_from_tp_ranks(megatron_weights) + + # Split each shard back into gate and up parts + gate_parts = [] + up_parts = [] + for shard in gathered_shards: + # Each shard is [gate_shard; up_shard] concatenated along dim 0 + # This works for both bias (1D) and weight (2D) tensors + gate_shard, up_shard = torch.chunk(shard, 2, dim=0) + gate_parts.append(gate_shard) + up_parts.append(up_shard) + + # Concatenate all gate parts and all up parts separately + gate = torch.cat(gate_parts, dim=0) + up = torch.cat(up_parts, dim=0) + + if self.is_expert: + gathered_gate_weights_dict = self.gather_from_ep_ranks( + gate, megatron_module, self.hf_param["gate"] + ) + gathered_up_weights_dict = self.gather_from_ep_ranks( + up, megatron_module, self.hf_param["up"] + ) + return {**gathered_gate_weights_dict, **gathered_up_weights_dict} + + return {self.hf_param["gate"]: gate, self.hf_param["up"]: up} + + def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": + """Return a new *resolved* GatedMLPMapping instance.""" + resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) + + return type(self)( + resolved_megatron_param, resolved_hf_param["gate"], resolved_hf_param["up"] + ) + + +def merge_qkv_biases( + config: TransformerConfig, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor +) -> torch.Tensor: + """Merge separate Q, K, V bias vectors into Megatron's interleaved QKV format. + + Args: + config (TransformerConfig): Transformer configuration. + q (torch.Tensor): Query projection biases [hidden_size]. + k (torch.Tensor): Key projection biases [kv_hidden_size]. + v (torch.Tensor): Value projection biases [kv_hidden_size]. + + Returns: + torch.Tensor: Interleaved QKV biases in Megatron format as 1D tensor. + """ + head_num = config.num_attention_heads + num_query_groups = config.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = config.kv_channels or (config.hidden_size // head_num) + + # Reshape biases to expose head dimension + q = q.view(head_num, head_size) + k = k.view(num_query_groups, head_size) + v = v.view(num_query_groups, head_size) + + # Interleave in Megatron pattern: [q1...qn, k1, v1, q1...qn, k2, v2, ...] + qkv_biases = [] + for i in range(num_query_groups): + qkv_biases.append(q[i * heads_per_group : (i + 1) * heads_per_group, :]) + qkv_biases.append(k[i : i + 1, :]) + qkv_biases.append(v[i : i + 1, :]) + + # Concatenate and flatten back to 1D + qkv = torch.cat(qkv_biases) + return qkv.flatten() + + +def split_qkv_biases( + config: TransformerConfig, qkv: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Split Megatron's interleaved QKV bias into separate Q, K, V biases. + + Args: + config (TransformerConfig): Transformer configuration. + qkv (torch.Tensor): Interleaved QKV biases in Megatron format (1D + tensor). + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of (Q, K, V) bias vectors. + """ + head_num = config.num_attention_heads + num_query_groups = config.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = config.kv_channels or (config.hidden_size // head_num) + qkv_total_dim = head_num + 2 * num_query_groups + + # Reshape to expose interleaved structure + qkv = qkv.reshape(qkv_total_dim, head_size) + + # Extract Q, K, V from interleaved pattern + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, heads_per_group + 2) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, heads_per_group + 2) + + q = qkv[q_slice].flatten() + k = qkv[k_slice].flatten() + v = qkv[v_slice].flatten() + + return q, k, v + + +def merge_qkv_weights( + provider: TransformerConfig, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor +) -> torch.Tensor: + """Merge separate Q, K, V weight matrices into Megatron's interleaved QKV format. + + Args: + provider (TransformerConfig): Model configuration provider. + q (torch.Tensor): Query projection weights [hidden_size, hidden_size] or + bias [hidden_size]. + k (torch.Tensor): Key projection weights [kv_hidden_size, hidden_size] + or bias [kv_hidden_size]. + v (torch.Tensor): Value projection weights [kv_hidden_size, + hidden_size] or bias [kv_hidden_size]. + + Returns: + torch.Tensor: Interleaved QKV weights in Megatron format. + """ + head_num = provider.num_attention_heads + num_query_groups = provider.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = provider.kv_channels or (provider.hidden_size // head_num) + hidden_size = provider.hidden_size + is_bias = q.ndim == 1 + + # Reshape to expose head dimension + if is_bias: + q_reshaped = q.view(head_num, head_size) + k_reshaped = k.view(num_query_groups, head_size) + v_reshaped = v.view(num_query_groups, head_size) + else: + q_reshaped = q.view(head_num, head_size, hidden_size) + k_reshaped = k.view(num_query_groups, head_size, hidden_size) + v_reshaped = v.view(num_query_groups, head_size, hidden_size) + + # Interleave in Megatron pattern: [q1...qn, k1, v1, q1...qn, k2, v2, ...] + qkv_weights = [] + for i in range(num_query_groups): + q_group = q_reshaped[i * heads_per_group : (i + 1) * heads_per_group] + k_group = k_reshaped[i : i + 1] + v_group = v_reshaped[i : i + 1] + qkv_weights.extend([q_group, k_group, v_group]) + + qkv = torch.cat(qkv_weights, dim=0) + + # Final reshape + if is_bias: + return qkv.reshape(-1) + else: + return qkv.reshape([-1, hidden_size]) + + +def split_qkv_weights( + provider: TransformerConfig, qkv: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Split Megatron's interleaved QKV tensor into separate Q, K, V matrices. + + Args: + provider (TransformerConfig): Model configuration provider. + qkv (torch.Tensor): Interleaved QKV weights in Megatron format. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of (Q, K, V) + weight matrices. + """ + head_num = provider.num_attention_heads + num_query_groups = provider.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = provider.kv_channels or (provider.hidden_size // head_num) + qkv_total_dim = head_num + 2 * num_query_groups + is_bias = qkv.ndim == 1 + + if is_bias: + hidden_size = 1 + qkv_reshaped = qkv.view(qkv_total_dim, head_size) + else: + hidden_size = qkv.shape[-1] + qkv_reshaped = qkv.view(qkv_total_dim, head_size, hidden_size) + + # Extract Q, K, V from interleaved pattern + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, heads_per_group + 2) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, heads_per_group + 2) + + q = qkv_reshaped[q_slice] + k = qkv_reshaped[k_slice] + v = qkv_reshaped[v_slice] + + if is_bias: + q = q.reshape(-1) + k = k.reshape(-1) + v = v.reshape(-1) + else: + q = q.reshape(-1, hidden_size) + k = k.reshape(-1, hidden_size) + v = v.reshape(-1, hidden_size) + + return q, k, v diff --git a/flagscale/train/bridge/models/conversion/utils.py b/flagscale/train/bridge/models/conversion/utils.py new file mode 100644 index 0000000000..66d68aee66 --- /dev/null +++ b/flagscale/train/bridge/models/conversion/utils.py @@ -0,0 +1,287 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import copy +import functools +import re +import types + +from typing import Iterable, List, Optional, Tuple + +import torch + +from rich.table import Table +from transformers.configuration_utils import PretrainedConfig + +from megatron.core.transformer.module import MegatronModule +from megatron.core.utils import unwrap_model + + +def weights_verification_table(bridge, megatron_model) -> Table: + """ + Returns a table comparing weights between a Hugging Face model and a Megatron-LM model. + + Args: + bridge (AutoBridge): The bridge object containing model information. + megatron_model: The Megatron-LM model instance. + + Returns: + Table: A rich Table object with the comparison. + """ + table = Table(title="Hugging Face Weights Verification") + table.add_column("Weight Name", style="cyan") + table.add_column("Shape") + table.add_column("DType") + table.add_column("Device") + table.add_column("Matches Original", justify="center") + + # Check each weight against the original HF-model + for name, param in bridge.export_hf_weights(megatron_model, show_progress=True): + original_param = bridge.hf_pretrained.state[name] + table.add_row( + name, + str(tuple(param.shape)), + str(param.dtype).replace("torch.", ""), + str(param.device), + "āœ…" if torch.allclose(param, original_param.to(param.device), atol=1e-6) else "āŒ", + ) + + return table + + +def get_module_and_param_from_name( + models: MegatronModule | List[MegatronModule], param_name: str, vp_stage: Optional[int] = None +) -> Tuple[torch.nn.Module, torch.Tensor] | Tuple[torch.nn.Module, torch.Tensor, Tuple]: + """ + Get parameter from specific VP stage, ensuring that parameter + attributes are preserved. Supports both absolute and relative parameter names. + + Args: + models: List of Megatron model instances or a submodule + param_name: Dot-separated parameter name (can be absolute or relative to models) + vp_stage: Virtual pipeline stage index (None for single stage) + + Returns: + Tuple of (module, parameter) where module owns the parameter + + Raises: + ValueError: If vp_stage is out of range or parameter doesn't exist + + Examples: + Basic usage with full model: + >>> module, param = get_module_and_param_from_name( + ... models=full_model, + ... param_name="transformer.layers.0.attention.query.weight" + ... ) + + Usage with model list and VP stage: + >>> module, param = get_module_and_param_from_name( + ... models=[model1, model2, model3], + ... param_name="layers.0.mlp.dense.bias", + ... vp_stage=1 + ... ) + + Usage with submodule and relative path: + >>> linear_module = model.transformer.layers[0].mlp.dense + >>> module, param = get_module_and_param_from_name( + ... models=linear_module, + ... param_name="weight" + ... ) + + Usage with submodule and absolute path (automatic suffix matching): + >>> linear_module = model.transformer.layers[0].mlp.dense + >>> module, param = get_module_and_param_from_name( + ... models=linear_module, + ... param_name="transformer.layers.0.mlp.dense.weight" + ... ) + # Automatically matches "weight" suffix and returns the parameter + + Edge case with partial path matching: + >>> attention_module = model.transformer.layers[0].attention + >>> module, param = get_module_and_param_from_name( + ... models=attention_module, + ... param_name="layers.0.attention.query.weight" + ... ) + # Matches "query.weight" suffix within the attention module + """ + + if isinstance(models, list): + if vp_stage is None: + model = models[0] + else: + if vp_stage >= len(models): + raise ValueError(f"VP stage {vp_stage} out of range (max: {len(models) - 1})") + model = models[vp_stage] + else: + model = models + + module = unwrap_model(model) + splitted_name = param_name.split(".") + + # Try to find the parameter using the given parts + def try_get_param(parts): + param = module + temp_module = module + + for i, part in enumerate(parts): + if not hasattr(param, part): + return None + param = getattr(param, part) + if i < len(parts) - 1: + temp_module = getattr(temp_module, part) + + return temp_module, param + + # First try the full parameter name (current behavior) + result = try_get_param(splitted_name) + if result is not None: + return result + + # If full name doesn't work, try suffixes of the parameter name + # This handles cases where models is a submodule but param_name is absolute + for start_idx in range(1, len(splitted_name)): + suffix_parts = splitted_name[start_idx:] + result = try_get_param(suffix_parts) + if result is not None: + return result + + # If no approach works, raise an error + raise ValueError(f"Parameter '{param_name}' not found in model at VP stage {vp_stage}") + + +def remove_non_pickleables(obj, max_depth: int = 2, current_depth: int = 0): + """Remove non-pickleable objects from a configuration object recursively. + + This utility function identifies and removes objects that cannot be pickled for + inter-process communication, including functions, bound methods, partial + functions, and other problematic callables. + + Args: + obj: The object to clean + max_depth: Maximum recursion depth (default: 2) + current_depth: Current recursion depth (internal use) + + Returns: + The cleaned object with non-pickleables removed + """ + + # Stop recursion if max depth reached + if current_depth >= max_depth: + return obj + + # Handle None + if obj is None: + return obj + + # Check if object is a problematic callable + if callable(obj): + # Allow classes/types but remove function objects, methods, partials + if isinstance(obj, type): + return obj + elif hasattr(obj, "__call__") and ( + isinstance(obj, (types.FunctionType, types.MethodType, functools.partial)) + or hasattr(obj, "__self__") + ): # bound methods + return None + + # Handle dataclass/object with attributes + if hasattr(obj, "__dict__"): + # Create a copy to avoid modifying the original + cleaned_obj = copy.copy(obj) + + for attr_name in list(vars(cleaned_obj).keys()): + attr_value = getattr(cleaned_obj, attr_name) + + # Recursively clean attribute + cleaned_value = remove_non_pickleables(attr_value, max_depth, current_depth + 1) + + # Set the cleaned value (or None if it was removed) + setattr(cleaned_obj, attr_name, cleaned_value) + + return cleaned_obj + + # Handle lists + elif isinstance(obj, list): + return [remove_non_pickleables(item, max_depth, current_depth + 1) for item in obj] + + # Handle tuples + elif isinstance(obj, tuple): + return tuple(remove_non_pickleables(item, max_depth, current_depth + 1) for item in obj) + + # Handle dictionaries + elif isinstance(obj, dict): + return { + key: remove_non_pickleables(value, max_depth, current_depth + 1) + for key, value in obj.items() + } + + # For primitive types and other safe objects, return as-is + return obj + + +def extract_sort_key(param_name: str): + """Extract sorting key based on layer and expert numbers.""" + + # Extract at most 2 numbers: layer number and expert number + # Pattern: *layers.d+.*d+ (layer number and potentially expert number) + numbers = [] + # Find layer number + layer_match = re.search(r"layers\.(\d+)", param_name) + if layer_match: + numbers.append(int(layer_match.group(1))) + # Find expert number after bias or weight + expert_match = re.search(r"(?:bias|weight)(\d+)", param_name) + if expert_match: + numbers.append(int(expert_match.group(1))) + # Pad to ensure consistent comparison (max 2 numbers) + while len(numbers) < 2: + numbers.append(-1) + numbers = numbers[:2] # Keep at most 2 numbers + return numbers, param_name + + +def get_causal_lm_class_via_auto_map( + model_name_or_path: str, config: PretrainedConfig +) -> type | None: + """Return CausalLM class via config.auto_map if available; otherwise None. + + If auto_map["AutoModelForCausalLM"] is present in the config, returns the dynamically loaded class. + Returns None when auto_map is absent or loading fails. Does not download weights. + """ + auto_map = getattr(config, "auto_map", None) + if auto_map and "AutoModelForCausalLM" in auto_map: + auto_map_class = auto_map["AutoModelForCausalLM"] + repo_id = model_name_or_path or getattr(config, "_name_or_path", None) + if not repo_id: + return None + try: + from transformers.dynamic_module_utils import get_class_from_dynamic_module + + return get_class_from_dynamic_module( + class_reference=auto_map_class, + pretrained_model_name_or_path=repo_id, + cache_dir=None, + force_download=False, + resume_download=True, + proxies=None, + use_auth_token=None, + revision=None, + local_files_only=False, + repo_id=repo_id, + ) + except Exception: + return None + + return None + + +def persistent_buffers(model: torch.nn.Module) -> Iterable[Tuple[str, torch.Tensor]]: + """Return an iterator over persistent module buffers, yielding both the name of the buffer as well as the buffer itself.""" + + for mod_prefix, mod in model.named_modules(): + # only local buffers; we'll add the prefix ourselves + for local_name, buffer in mod.named_buffers(recurse=False): + if local_name not in getattr(mod, "_non_persistent_buffers_set", set()): + full_name = f"{mod_prefix + '.' if mod_prefix else ''}{local_name}" + yield full_name, buffer diff --git a/flagscale/train/bridge/models/decorators/__init__.py b/flagscale/train/bridge/models/decorators/__init__.py new file mode 100644 index 0000000000..5d6d602f55 --- /dev/null +++ b/flagscale/train/bridge/models/decorators/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from flagscale.train.bridge.models.decorators.dispatch import dispatch +from flagscale.train.bridge.models.decorators.torchrun import torchrun_main + +__all__ = ["dispatch", "torchrun_main"] diff --git a/flagscale/train/bridge/models/decorators/dispatch.py b/flagscale/train/bridge/models/decorators/dispatch.py new file mode 100644 index 0000000000..7e02855d66 --- /dev/null +++ b/flagscale/train/bridge/models/decorators/dispatch.py @@ -0,0 +1,348 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +"""Simplified dispatch system for Python, based on classes' typeclass implementation. + +This module provides a dispatch-based polymorphism system allowing extensible +behavior for different types using the `impl` decorator. +""" + +from functools import _find_impl # type: ignore +from typing import Any, Callable, Dict, Optional, TypeVar + +_SignatureType = TypeVar("_SignatureType", bound=Callable) + + +class _Dispatch: + """Internal dispatch representation with type-based routing logic.""" + + __slots__ = ("_signature", "_name", "_exact_types", "_dispatch_cache", "_doc", "_module") + + def __init__(self, signature: Callable) -> None: + self._signature = signature + self._name = signature.__name__ + self._exact_types: Dict[Any, Callable] = {} + self._dispatch_cache: Dict[Any, Callable] = {} + + # Extract docstring and module info for rich repr + self._doc = signature.__doc__ + self._module = signature.__module__ + + def __call__(self, instance: Any, *args, **kwargs) -> Any: + """Dispatch to the appropriate implementation based on instance type.""" + # Special case for tuple-based keys. + if isinstance(instance, tuple): + key = tuple(v if isinstance(v, (type, str)) else type(v) for v in instance) + + # Direct match + impl = self._exact_types.get(key) + if impl is not None: + # NOTE: This path is not cached for simplicity + return impl(instance, *args, **kwargs) + + # Subclass match for tuples of types + for registered_key, callback in self._exact_types.items(): + if ( + not isinstance(registered_key, tuple) + or len(registered_key) != len(key) + or not all(isinstance(t, type) for t in registered_key) + ): + continue + + try: + # For subclass checks, operate on the instance types only + key_types = tuple(v if isinstance(v, type) else type(v) for v in instance) + if all(issubclass(k, rk) for k, rk in zip(key_types, registered_key)): + # NOTE: not caching tuple subclass matches for simplicity + return callback(instance, *args, **kwargs) + except TypeError: + continue # issubclass can fail + + # Normalize both sides to names so tuples of types and/or strings can match. + def _name(obj): + return obj if isinstance(obj, str) else getattr(obj, "__name__", None) or str(obj) + + key_names = tuple(_name(v) for v in key) + for registered_key, callback in self._exact_types.items(): + if not isinstance(registered_key, tuple) or len(registered_key) != len(key): + continue + reg_names = tuple(_name(rk) for rk in registered_key) + if reg_names == key_names: + return callback(instance, *args, **kwargs) + + # No implementation found for this tuple, raise a specific error. + error_msg = self._format_no_implementation_error(instance) + raise NotImplementedError(error_msg) + + # For class dispatch, we use the class (or string of class name) itself as the key + if isinstance(instance, type): + cache_key = instance + instance_type = instance + elif isinstance(instance, str): + cache_key = instance + instance_type = str + else: + cache_key = type(instance) + instance_type = cache_key + + # Try cache + impl = self._dispatch_cache.get(cache_key) + if impl is None: + impl = self._dispatch(instance, instance_type) + if impl is None: + error_msg = self._format_no_implementation_error(instance) + raise NotImplementedError(error_msg) + self._dispatch_cache[cache_key] = impl + + return impl(instance, *args, **kwargs) + + def impl(self, *target_types: Any) -> Callable[[Callable], Callable]: + """Register an implementation for one or more types. + + Usage: + @mydispatch.impl(int) # Register for a single type + @mydispatch.impl(int, str) # Register for multiple types + @mydispatch.impl((list, str)) # Register for a tuple of types as a key + """ + if not target_types: + raise ValueError( + "\nāœ— Missing argument to .impl()\n\n" + "You must specify at least one target type.\n\n" + "Examples:\n" + f" @{self._name}.impl(str) # Single type\n" + f" @{self._name}.impl(int, float) # Multiple types\n" + f" @{self._name}.impl((list, str)) # Tuple key\n" + ) + + def decorator(func: Callable) -> Callable: + if len(target_types) == 1: + # This handles both `@impl(int)` and `@impl((int, str))` + self._exact_types[target_types[0]] = func + else: + # This handles `@impl(int, str)` + for typ in target_types: + self._exact_types[typ] = func + + self._dispatch_cache.clear() + return func + + return decorator + + def __repr__(self) -> str: + """Rich representation showing all implementations.""" + # Build signature string + import inspect + + sig = inspect.signature(self._signature) + sig_str = f"{self._name}{sig}" + + lines = [f"Dispatch({sig_str})("] + + # Add regular implementations + for typ, impl in self._exact_types.items(): + if isinstance(typ, tuple): + type_name = ( + f"({', '.join(t.__name__ if hasattr(t, '__name__') else str(t) for t in typ)})" + ) + else: + type_name = typ.__name__ if hasattr(typ, "__name__") else str(typ) + impl_loc = self._format_location(impl) + lines.append(f" ({type_name}): {impl.__name__} at {impl_loc}") + + lines.append(")") + return "\n".join(lines) + + def _dispatch(self, instance: Any, instance_type: type) -> Optional[Callable]: + """Find the implementation for a given type. + + Fallback order: + 1) Exact type match + 2) issubclass match (when instance is a type) + 3) MRO-based match via functools._find_impl + 4) Name-based fallback: match by class __name__ for dynamically generated + classes (e.g., HF transformers auto_map dynamic modules) + """ + # Direct type match + impl = self._exact_types.get(instance_type, None) + if impl is not None: + return impl + + # For class dispatch, check issubclass relationships + if isinstance(instance, type): + for registered_type, callback in self._exact_types.items(): + if not isinstance(registered_type, type): + continue + try: + if issubclass(instance, registered_type): + return callback + except TypeError: + # issubclass can fail for some types + pass + + # Use functools._find_impl for MRO-based dispatch, only for single types + single_type_impls = {k: v for k, v in self._exact_types.items() if isinstance(k, type)} + impl = _find_impl(instance_type, single_type_impls) + if impl is not None: + return impl + + # Name-based fallback for dynamic HF classes and string registrations. + def _name(obj): + return obj if isinstance(obj, str) else getattr(obj, "__name__", None) + + if isinstance(instance, str): + inst_name = instance + elif isinstance(instance, type): + inst_name = _name(instance) + else: + inst_name = _name(type(instance)) + + if inst_name: + for registered_type, callback in self._exact_types.items(): + reg_name = _name(registered_type) + if reg_name and str(reg_name) == inst_name: + return callback + + return None + + def _format_location(self, func: Callable) -> str: + """Format the location of a function for display.""" + try: + import inspect + + filename = inspect.getfile(func) + _, lineno = inspect.getsourcelines(func) + # Shorten the path to be more readable + import os + + filename = os.path.relpath(filename) + return f"{filename}:{lineno}" + except Exception: + return "" + + def _format_no_implementation_error(self, instance: Any) -> str: + """Format a helpful error message when no implementation is found.""" + type_name_for_header: str + type_name_for_suggestion: str + type_name_for_func: str + instance_type_hint: str + + if isinstance(instance, tuple): + instance_types = tuple(v if isinstance(v, type) else type(v) for v in instance) + type_names_str = ", ".join( + t.__qualname__ if hasattr(t, "__qualname__") else str(t) for t in instance_types + ) + type_name_for_header = f"tuple of types ({type_names_str})" + + suggestion_names = ", ".join( + t.__name__ if hasattr(t, "__name__") else str(t) for t in instance_types + ) + type_name_for_suggestion = f"({suggestion_names})" + type_name_for_func = "tuple" + instance_type_hint = f"Tuple[{', '.join(t.__name__ for t in instance_types)}]" + else: + instance_type = instance if isinstance(instance, type) else type(instance) + qualname = ( + instance_type.__qualname__ + if hasattr(instance_type, "__qualname__") + else str(instance_type) + ) + type_name_for_header = f"type '{qualname}'" + type_name_for_suggestion = ( + instance_type.__name__ if hasattr(instance_type, "__name__") else str(instance_type) + ) + type_name_for_func = type_name_for_suggestion.lower().replace(".", "_") + instance_type_hint = type_name_for_suggestion + + # Build error message + lines = [ + f"\nāœ— No implementation found for {type_name_for_header}", + "", + f"The dispatch function '{self._name}' has no implementation for this type.", + "", + ] + + # Add available implementations + if self._exact_types: + lines.append("Available implementations:") + + # Add registered types + sorted_keys = sorted(self._exact_types.keys(), key=str) + for typ in sorted_keys: + if isinstance(typ, tuple): + type_display = f"({', '.join(t.__name__ if hasattr(t, '__name__') else str(t) for t in typ)})" + else: + type_display = typ.__name__ if hasattr(typ, "__name__") else str(typ) + lines.append(f" • {type_display}") + else: + lines.append("No implementations registered yet.") + + # Generate help based on existing implementations + if self._exact_types: + # Get a sample implementation to show the pattern + _, sample_impl = next(iter(self._exact_types.items())) + + lines.extend( + [ + "", + "To add support for this type, register an implementation:", + f" @{self._name}.impl({type_name_for_suggestion})", + f" def _{self._name}_{type_name_for_func}(instance: {instance_type_hint}) -> ...:", + " # Your implementation here", + ] + ) + + # Try to extract parameter info from the sample implementation + import inspect + + try: + sig = inspect.signature(sample_impl) + params = list(sig.parameters.keys())[1:] # Skip first param (instance) + if params: + param_hints = ", ".join(params) + lines.append(f" # Expected parameters: {param_hints}") + except Exception: + pass + else: + lines.extend( + [ + "", + "To add support for this type:", + f" @{self._name}.impl({type_name_for_suggestion})", + f" def _{self._name}_{type_name_for_func}(instance: {instance_type_hint}, ...) -> ...:", + " # Your implementation here", + ] + ) + + return "\n".join(lines) + + +def dispatch(func: _SignatureType) -> _Dispatch: + """Create a new dispatch function from a signature. + + Args: + func: Function defining the dispatch signature and default behavior + + Returns: + A dispatch object that can be extended with implementations + + Example: + >>> @dispatch + ... def to_string(instance) -> str: + ... '''Convert instance to string representation.''' + ... + >>> @to_string.impl(int) + ... def _to_string_int(instance: int) -> str: + ... return str(instance) + ... + >>> @to_string.impl(list, tuple) + ... def _to_string_sequence(instance) -> str: + ... return ', '.join(map(str, instance)) + ... + >>> assert to_string(42) == "42" + >>> assert to_string([1, 2, 3]) == "1, 2, 3" + """ + return _Dispatch(func) + + +__all__ = ["dispatch"] diff --git a/flagscale/train/bridge/models/decorators/torchrun.py b/flagscale/train/bridge/models/decorators/torchrun.py new file mode 100644 index 0000000000..80fa77dcc5 --- /dev/null +++ b/flagscale/train/bridge/models/decorators/torchrun.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import os +import traceback + +from functools import wraps + +import torch + +from torch.distributed.elastic.multiprocessing.errors import record + + +def torchrun_main(fn): + """ + A decorator that wraps the main function of a torchrun script. It uses + the `torch.distributed.elastic.multiprocessing.errors.record` decorator + to record any exceptions and ensures that the distributed process group + is properly destroyed on successful completion. In case of an exception, + it prints the traceback and performs a hard exit, allowing torchrun to + terminate all other processes. + """ + recorded_fn = record(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return_value = recorded_fn(*args, **kwargs) + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + return return_value + except Exception: + # The 'record' decorator might only log the exception to a file. + # Print it to stderr as well to make sure it's visible. + traceback.print_exc() + # Use os._exit(1) for a hard exit. A regular sys.exit(1) might + # not be enough to terminate a process stuck in a bad C++ state + # (e.g., after a NCCL error), which can cause the job to hang. + os._exit(1) + + return wrapper diff --git a/flagscale/train/bridge/models/deepseek/__init__.py b/flagscale/train/bridge/models/deepseek/__init__.py new file mode 100644 index 0000000000..2e21ff79b5 --- /dev/null +++ b/flagscale/train/bridge/models/deepseek/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from flagscale.train.bridge.models.deepseek.deepseek_provider import ( + DeepSeekModelProvider, + DeepSeekProvider, + DeepSeekV2LiteModelProvider, + DeepSeekV2LiteProvider, + DeepSeekV2ModelProvider, + DeepSeekV2Provider, + DeepSeekV3ModelProvider, + DeepSeekV3Provider, + MoonlightModelProvider16B, + MoonlightProvider, +) +from flagscale.train.bridge.models.deepseek.deepseek_v2_bridge import DeepSeekV2Bridge # noqa: F401 +from flagscale.train.bridge.models.deepseek.deepseek_v3_bridge import DeepSeekV3Bridge # noqa: F401 + +__all__ = [ + "DeepSeekModelProvider", + "DeepSeekV2LiteModelProvider", + "DeepSeekV2ModelProvider", + "DeepSeekV3ModelProvider", + "MoonlightModelProvider16B", + "DeepSeekProvider", + "DeepSeekV2LiteProvider", + "DeepSeekV2Provider", + "DeepSeekV3Provider", + "MoonlightProvider", +] diff --git a/flagscale/train/bridge/models/deepseek/common.py b/flagscale/train/bridge/models/deepseek/common.py new file mode 100644 index 0000000000..37945fc907 --- /dev/null +++ b/flagscale/train/bridge/models/deepseek/common.py @@ -0,0 +1,137 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from flagscale.train.bridge.models.conversion.param_mapping import AutoMapping, GatedMLPMapping +from flagscale.train.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM + +try: + import apex # noqa: F401 + + HAVE_APEX = True +except ImportError: + HAVE_APEX = False + + +def get_common_configs(hf_pretrained: PreTrainedCausalLM) -> dict: + """ + Returns a dictionary of common configurations for the DeepSeek family of models. + """ + hf_config = hf_pretrained.config + + configs = {} + + if not HAVE_APEX: + configs["gradient_accumulation_fusion"] = False + + if hasattr(hf_config, "rope_scaling") and hf_config.rope_scaling is not None: + configs["rotary_scaling_factor"] = hf_config.rope_scaling["factor"] + configs["mscale"] = hf_config.rope_scaling["mscale"] + configs["mscale_all_dim"] = hf_config.rope_scaling["mscale_all_dim"] + else: + configs["rotary_scaling_factor"] = 1.0 + configs["mscale"] = 1.0 + configs["mscale_all_dim"] = 1.0 + + configs["num_layers"] = hf_config.num_hidden_layers + configs["hidden_size"] = hf_config.hidden_size + configs["ffn_hidden_size"] = hf_config.intermediate_size + configs["num_attention_heads"] = hf_config.num_attention_heads + configs["kv_channels"] = hf_config.num_key_value_heads + configs["q_lora_rank"] = hf_config.q_lora_rank + configs["num_moe_experts"] = hf_config.n_routed_experts + configs["moe_ffn_hidden_size"] = hf_config.moe_intermediate_size + configs["moe_shared_expert_intermediate_size"] = ( + hf_config.moe_intermediate_size * hf_config.n_shared_experts + ) + configs["moe_layer_freq"] = [0] * hf_config.first_k_dense_replace + [1] * ( + hf_config.num_hidden_layers - hf_config.first_k_dense_replace + ) + configs["moe_router_topk"] = hf_config.num_experts_per_tok + configs["moe_router_num_groups"] = hf_config.n_group + configs["moe_router_group_topk"] = hf_config.topk_group + configs["moe_router_topk_scaling_factor"] = hf_config.routed_scaling_factor + configs["kv_lora_rank"] = hf_config.kv_lora_rank + configs["qk_head_dim"] = hf_config.qk_nope_head_dim + configs["qk_pos_emb_head_dim"] = hf_config.qk_rope_head_dim + configs["v_head_dim"] = hf_config.v_head_dim + + # Ensure MLA is enabled + configs["multi_latent_attention"] = True + configs["generation_config"] = hf_pretrained.generation_config + configs["vocab_size"] = hf_config.vocab_size + configs["rotary_base"] = hf_config.rope_theta + configs["init_method_std"] = hf_config.initializer_range + configs["layernorm_epsilon"] = hf_config.rms_norm_eps + + return configs + + +def get_common_mapping_list() -> list: + """ + Returns a list of common parameter mappings for the DeepSeek family of models. + """ + param_mappings = { + # Embed + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + # Attention + "decoder.layers.*.input_layernorm.weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + # Reference: https://github.com/NVIDIA/NeMo/blob/50cceb9c90ea1f440d1e14074fa13bd45f60a1c4/nemo/collections/llm/gpt/model/deepseek.py#L637-L650 + # In deepseek, HF weight `model.layers.*.post_attention_layernorm.weight` is mapped to the following mcore weights depending on the layer type: + # (a) `decoder.layers.*.pre_mlp_layernorm.weight`, if the layer is MoE + # (b) `decoder.layers.*.mlp.linear_fc1.layer_norm_weight`, if the layer is dense + "decoder.layers.*.pre_mlp_layernorm.weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.self_attention.linear_kv_down_proj.weight": "model.layers.*.self_attn.kv_a_proj_with_mqa.weight", + "decoder.layers.*.self_attention.linear_kv_up_proj.weight": "model.layers.*.self_attn.kv_b_proj.weight", + "decoder.layers.*.self_attention.linear_kv_up_proj.layer_norm_weight": "model.layers.*.self_attn.kv_a_layernorm.weight", + # Mcore local spec + "decoder.layers.*.self_attention.kv_layernorm.weight": "model.layers.*.self_attn.kv_a_layernorm.weight", + # Dense MLP + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + # MoE + "decoder.layers.*.mlp.router.weight": "model.layers.*.mlp.gate.weight", + "decoder.layers.*.mlp.experts.linear_fc2.weight*": "model.layers.*.mlp.experts.*.down_proj.weight", + "decoder.layers.*.mlp.shared_experts.linear_fc2.weight": "model.layers.*.mlp.shared_experts.down_proj.weight", + # LM Head + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + # MLA + "decoder.layers.*.self_attention.linear_q_down_proj.weight": "model.layers.*.self_attn.q_a_proj.weight", + "decoder.layers.*.self_attention.linear_q_up_proj.weight": "model.layers.*.self_attn.q_b_proj.weight", + "decoder.layers.*.self_attention.linear_q_up_proj.layer_norm_weight": "model.layers.*.self_attn.q_a_layernorm.weight", + # Mcore local spec + "decoder.layers.*.self_attention.q_layernorm.weight": "model.layers.*.self_attn.q_a_layernorm.weight", + # For models without MLA + "decoder.layers.*.self_attention.linear_q_proj.weight": "model.layers.*.self_attn.q_proj.weight", + } + + # TODO: mtp layers + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(hf_param, megatron_param) + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + mapping_list.extend( + [ + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.linear_fc1.weight", + gate="model.layers.*.mlp.gate_proj.weight", + up="model.layers.*.mlp.up_proj.weight", + ), + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.experts.linear_fc1.weight*", + gate="model.layers.*.mlp.experts.*.gate_proj.weight", + up="model.layers.*.mlp.experts.*.up_proj.weight", + ), + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.shared_experts.linear_fc1.weight", + gate="model.layers.*.mlp.shared_experts.gate_proj.weight", + up="model.layers.*.mlp.shared_experts.up_proj.weight", + ), + ] + ) + + return mapping_list diff --git a/flagscale/train/bridge/models/deepseek/deepseek_provider.py b/flagscale/train/bridge/models/deepseek/deepseek_provider.py new file mode 100644 index 0000000000..bb38580423 --- /dev/null +++ b/flagscale/train/bridge/models/deepseek/deepseek_provider.py @@ -0,0 +1,309 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import warnings + +from dataclasses import dataclass, field +from functools import partial +from typing import TYPE_CHECKING, Callable, List, Optional, Union + +import torch +import torch.nn.functional as F + +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec + +from flagscale.train.bridge.models.gpt_provider import GPTModelProvider +from flagscale.train.bridge.models.transformer_config import MLATransformerConfig +from flagscale.train.bridge.utils.common_utils import get_rank_safe + +try: + import transformer_engine # type: ignore # noqa: F401 + + HAVE_TE = True +except (ImportError, ModuleNotFoundError): + HAVE_TE = False + +if TYPE_CHECKING: + from megatron.core.transformer import ModuleSpec + +if HAVE_TE: + from megatron.core.utils import is_te_min_version + + +@dataclass +class DeepSeekModelProvider(MLATransformerConfig, GPTModelProvider): + """ + Base config for DeepSeek V2 and V3 models. + """ + + transformer_layer_spec: Union["ModuleSpec", Callable[["GPTModelProvider"], "ModuleSpec"]] = ( + partial(get_gpt_decoder_block_spec, use_transformer_engine=HAVE_TE) + ) + + # Model + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True # swiglu + position_embedding_type: str = "rope" + add_bias_linear: bool = False + share_embeddings_and_output_weights: bool = False + num_attention_heads: int = 128 + kv_channels: int = 128 + max_position_embeddings: int = 4096 + seq_length: int = 4096 + rotary_base: float = 10000.0 + make_vocab_size_divisible_by: int = 3200 + mtp_num_layers: Optional[int] = None + mtp_loss_scaling_factor: Optional[float] = None + + # Regularization + attention_dropout: float = 0.0 + hidden_dropout: float = 0.0 + qk_layernorm: bool = True + + # MoE + moe_grouped_gemm: bool = True + moe_router_pre_softmax: bool = True + moe_token_dispatcher_type: str = "alltoall" + moe_router_load_balancing_type: str = "seq_aux_loss" + moe_shared_expert_overlap: bool = True + moe_router_dtype: Optional[str] = "fp32" + + # MLA + q_lora_rank: int = 1536 + kv_lora_rank: int = 512 + qk_head_dim: int = 128 + qk_pos_emb_head_dim: int = 64 + v_head_dim: int = 128 + rotary_scaling_factor: float = 40 + mscale: float = 1.0 + mscale_all_dim: float = 1.0 + + # Miscellaneous + init_method_std: float = 0.006 + layernorm_epsilon: float = 1e-6 + bf16: bool = True + params_dtype: torch.dtype = torch.bfloat16 + async_tensor_model_parallel_allreduce: bool = True + attention_softmax_in_fp32: bool = False + persist_layer_norm: bool = True + num_layers_in_first_pipeline_stage: Optional[int] = None + num_layers_in_last_pipeline_stage: Optional[int] = None + account_for_embedding_in_pipeline_split: bool = False + account_for_loss_in_pipeline_split: bool = False + + # MLA specific + multi_latent_attention: bool = True + + # fusions + apply_rope_fusion: bool = False + bias_activation_fusion: bool = True + bias_dropout_fusion: bool = True + masked_softmax_fusion: bool = True + cross_entropy_loss_fusion: bool = True + cross_entropy_fusion_impl: str = "te" + moe_permute_fusion: bool = is_te_min_version("2.1.0") if HAVE_TE else False + + +@dataclass +class DeepSeekV2ModelProvider(DeepSeekModelProvider): + """ + DeepSeek-V2 Model: https://github.com/deepseek-ai/DeepSeek-V2 + """ + + num_layers: int = 60 + hidden_size: int = 5120 + ffn_hidden_size: int = 12288 + num_moe_experts: int = 160 + moe_ffn_hidden_size: int = 1536 + moe_shared_expert_intermediate_size: int = 3072 # 1536 * 2 shared experts + moe_layer_freq: Union[int, List[int]] = field( + default_factory=lambda: [0] + [1] * 59 + ) # first layer is dense + moe_router_topk: int = 6 + moe_router_num_groups: int = 8 + moe_router_group_topk: int = 3 + moe_router_topk_scaling_factor: float = 16.0 + moe_aux_loss_coeff: float = 1e-3 + mscale: float = 0.707 + mscale_all_dim: float = 0.707 + vocab_size: int = 102400 + + +@dataclass +class DeepSeekV2LiteModelProvider(DeepSeekV2ModelProvider): + """ + DeepSeek-V2-Lite Model: https://github.com/deepseek-ai/DeepSeek-V2 + HuggingFace: https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite + """ + + num_layers: int = 27 + hidden_size: int = 2048 + ffn_hidden_size: int = 10944 + num_attention_heads: int = 16 + kv_channels: int = 16 + q_lora_rank: int = None + num_moe_experts: int = 64 + moe_ffn_hidden_size: int = 1408 + moe_shared_expert_intermediate_size: int = 2816 # 1408 * 2 shared experts + moe_layer_freq: Union[int, List[int]] = field( + default_factory=lambda: [0] + [1] * 26 + ) # first layer is dense + moe_router_topk: int = 6 + moe_router_num_groups: int = 1 + moe_router_group_topk: int = 1 + moe_router_topk_scaling_factor: float = 1.0 + vocab_size: int = 102400 + + +@dataclass +class DeepSeekV3ModelProvider(DeepSeekModelProvider): + """ + DeepSeek-V3 Model: https://github.com/deepseek-ai/DeepSeek-V3 + """ + + num_layers: int = 61 + hidden_size: int = 7168 + ffn_hidden_size: int = 18432 + num_moe_experts: int = 256 + moe_ffn_hidden_size: int = 2048 + moe_shared_expert_intermediate_size: int = 2048 # 2048 * 1 shared expert + moe_layer_freq: Union[int, List[int]] = field( + default_factory=lambda: [0] * 3 + [1] * 58 + ) # first three layers are dense + moe_router_topk: int = 8 + moe_router_num_groups: int = 8 + moe_router_group_topk: int = 4 + moe_router_topk_scaling_factor: float = 2.5 + make_vocab_size_divisible_by: int = 1280 + moe_router_score_function: str = "sigmoid" + moe_router_enable_expert_bias: bool = True + moe_router_bias_update_rate: float = 1e-3 + mscale: float = 1.0 + mscale_all_dim: float = 1.0 + vocab_size: int = 129280 + + +@dataclass +class MoonlightModelProvider16B(DeepSeekModelProvider): + """ + Moonlight-16B-A3B Model: https://github.com/moonshotai/Moonlight-16B-A3B + + Moonlight is based on DeepSeek-V3. + """ + + max_position_embeddings: int = 4096 + num_layers: int = 27 + hidden_size: int = 2048 + ffn_hidden_size: int = 11264 + num_attention_heads: int = 16 + kv_channels: int = 16 + num_moe_experts: int = 64 + moe_ffn_hidden_size: int = 1408 + moe_shared_expert_intermediate_size: int = 2816 # 1408 * 2 shared expert + moe_layer_freq: Union[int, List[int]] = field( + default_factory=lambda: [0] * 1 + [1] * 26 + ) # first layer is dense + moe_router_topk: int = 6 + moe_router_num_groups: int = 1 + moe_router_group_topk: int = 1 + moe_router_topk_scaling_factor: float = 2.446 + moe_aux_loss_coeff: float = 0.001 + make_vocab_size_divisible_by: int = 1280 + moe_router_score_function: str = "sigmoid" + moe_router_enable_expert_bias: bool = True + rotary_scaling_factor: float = 1.0 + mscale: float = 1.0 + mscale_all_dim: float = 1.0 + rotary_base: float = 50000 + layernorm_epsilon: float = 1e-5 + q_lora_rank: int = None + init_method_std: float = 0.02 + moe_router_bias_update_rate: float = 1e-3 + rotary_percent: float = 1.0 + vocab_size: int = 163840 + + +# ----------------------------------------------------------------------------- +# Deprecated aliases (to be removed in a future release) +# ----------------------------------------------------------------------------- + + +def _warn_deprecated(old_cls: str, new_cls: str) -> None: + if get_rank_safe() == 0: + warnings.warn( + f"{old_cls} is deprecated and will be removed in a future release. Use {new_cls} instead.", + DeprecationWarning, + stacklevel=2, + ) + + +@dataclass +class DeepSeekProvider(DeepSeekModelProvider): + """Deprecated alias for ``DeepSeekModelProvider``. + + Deprecated: + This alias remains for backward compatibility and will be removed in a + future release. Import and use ``DeepSeekModelProvider`` instead. + """ + + def __post_init__(self) -> None: + _warn_deprecated("DeepSeekProvider", "DeepSeekModelProvider") + super().__post_init__() + + +@dataclass +class DeepSeekV2Provider(DeepSeekV2ModelProvider): + """Deprecated alias for ``DeepSeekV2ModelProvider``. + + Deprecated: + This alias remains for backward compatibility and will be removed in a + future release. Import and use ``DeepSeekV2ModelProvider`` instead. + """ + + def __post_init__(self) -> None: + _warn_deprecated("DeepSeekV2Provider", "DeepSeekV2ModelProvider") + super().__post_init__() + + +@dataclass +class DeepSeekV2LiteProvider(DeepSeekV2LiteModelProvider): + """Deprecated alias for ``DeepSeekV2LiteModelProvider``. + + Deprecated: + This alias remains for backward compatibility and will be removed in a + future release. Import and use ``DeepSeekV2LiteModelProvider`` instead. + """ + + def __post_init__(self) -> None: + _warn_deprecated("DeepSeekV2LiteProvider", "DeepSeekV2LiteModelProvider") + super().__post_init__() + + +@dataclass +class DeepSeekV3Provider(DeepSeekV3ModelProvider): + """Deprecated alias for ``DeepSeekV3ModelProvider``. + + Deprecated: + This alias remains for backward compatibility and will be removed in a + future release. Import and use ``DeepSeekV3ModelProvider`` instead. + """ + + def __post_init__(self) -> None: + _warn_deprecated("DeepSeekV3Provider", "DeepSeekV3ModelProvider") + super().__post_init__() + + +@dataclass +class MoonlightProvider(MoonlightModelProvider16B): + """Deprecated alias for ``MoonlightModelProvider16B``. + + Deprecated: + This alias remains for backward compatibility and will be removed in a + future release. Import and use ``MoonlightModelProvider16B`` instead. + """ + + def __post_init__(self) -> None: + _warn_deprecated("MoonlightProvider", "MoonlightModelProvider16B") + super().__post_init__() diff --git a/flagscale/train/bridge/models/deepseek/deepseek_v2_bridge.py b/flagscale/train/bridge/models/deepseek/deepseek_v2_bridge.py new file mode 100644 index 0000000000..01edfbb493 --- /dev/null +++ b/flagscale/train/bridge/models/deepseek/deepseek_v2_bridge.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import torch + +from megatron.core.models.gpt.gpt_model import GPTModel + +from flagscale.train.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from flagscale.train.bridge.models.conversion.model_bridge import MegatronModelBridge +from flagscale.train.bridge.models.deepseek.common import ( + get_common_configs, + get_common_mapping_list, +) +from flagscale.train.bridge.models.deepseek.deepseek_provider import DeepSeekV2ModelProvider +from flagscale.train.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM + + +@MegatronModelBridge.register_bridge(source="DeepseekV2ForCausalLM", target=GPTModel) +class DeepSeekV2Bridge(MegatronModelBridge): + """ + Megatron Bridge for DeepSeek-V2. + + As a user you would not use this bridge directly, but through `AutoBridge`. + + Example: + >>> from flagscale.train.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("deepseek-ai/DeepSeek-V2", trust_remote_code=True) + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> DeepSeekV2ModelProvider: + hf_config = hf_pretrained.config + configs = get_common_configs(hf_pretrained) + + configs["fp16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16 + configs["bf16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16 + configs["params_dtype"] = self.dtype_from_hf(hf_config, default=torch.float32) + + configs["make_vocab_size_divisible_by"] = 3200 + configs["moe_aux_loss_coeff"] = hf_config.aux_loss_alpha + + provider = DeepSeekV2ModelProvider(**configs) + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + mapping_list = get_common_mapping_list() + return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/bridge/models/deepseek/deepseek_v3_bridge.py b/flagscale/train/bridge/models/deepseek/deepseek_v3_bridge.py new file mode 100644 index 0000000000..0a171981be --- /dev/null +++ b/flagscale/train/bridge/models/deepseek/deepseek_v3_bridge.py @@ -0,0 +1,64 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import torch + +from megatron.core.models.gpt.gpt_model import GPTModel + +from flagscale.train.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from flagscale.train.bridge.models.conversion.model_bridge import MegatronModelBridge +from flagscale.train.bridge.models.conversion.param_mapping import AutoMapping +from flagscale.train.bridge.models.deepseek.common import ( + get_common_configs, + get_common_mapping_list, +) +from flagscale.train.bridge.models.deepseek.deepseek_provider import DeepSeekV3ModelProvider +from flagscale.train.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM + + +@MegatronModelBridge.register_bridge(source="DeepseekV3ForCausalLM", target=GPTModel) +class DeepSeekV3Bridge(MegatronModelBridge): + """ + Megatron Bridge for DeepSeek-V3. + + As a user you would not use this bridge directly, but through `AutoBridge`. + + Example: + >>> from flagscale.train.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("deepseek-ai/DeepSeek-V3-Base", trust_remote_code=True) + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> DeepSeekV3ModelProvider: + hf_config = hf_pretrained.config + configs = get_common_configs(hf_pretrained) + + configs["fp16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16 + configs["bf16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16 + configs["params_dtype"] = self.dtype_from_hf(hf_config, default=torch.float32) + + configs["make_vocab_size_divisible_by"] = 1280 + configs["moe_router_score_function"] = "sigmoid" + configs["moe_router_enable_expert_bias"] = True + # aux_loss_alpha is not set in all DSv3 HF configs + if hasattr(hf_config, "aux_loss_alpha"): + configs["moe_aux_loss_coeff"] = hf_config.aux_loss_alpha + + # TODO: mtp + + provider = DeepSeekV3ModelProvider(**configs) + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + mapping_list = get_common_mapping_list() + + param_mappings = { + # expert bias + "decoder.layers.*.mlp.router.expert_bias": "model.layers.*.mlp.gate.e_score_correction_bias" + } + + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/bridge/models/gpt_full_te_layer_autocast_spec.py b/flagscale/train/bridge/models/gpt_full_te_layer_autocast_spec.py new file mode 100644 index 0000000000..7409349ce5 --- /dev/null +++ b/flagscale/train/bridge/models/gpt_full_te_layer_autocast_spec.py @@ -0,0 +1,347 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from importlib.metadata import version +from typing import Any, Callable, Optional, Union + +import packaging +import torch + +from transformer_engine.pytorch import TransformerLayer + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.fusions.fused_layer_norm import FusedLayerNorm +from megatron.core.transformer.cuda_graphs import CudaGraphManager +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import ( + TransformerBlockSubmodules, + get_num_layers_to_build, +) +from megatron.core.transformer.transformer_layer import BaseTransformerLayer +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint + + +# Copied from nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py +class AutocastTransformerLayer(TransformerLayer): + """ + Wrapper of te.pytorch.TransformerLayer: a single transformerlayer + that takes input with size [s, b, h] and returns an output of + the same size. + """ + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + layernorm_epsilon: float, + num_attention_heads: int, + init_method: Callable, + output_layer_init_method: Callable, + hidden_dropout: float, + attention_dropout: float, + layer_number: Optional[int] = None, + kv_channels: Optional[int] = None, + self_attn_mask_type: str = "causal", + tp_group: Optional[Any] = None, + tp_size: int = 1, + params_dtype: torch.dtype = torch.float32, + get_rng_state_tracker: Optional[Callable] = None, + fuse_wgrad_accumulation: bool = False, + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, + sequence_parallel: bool = False, + apply_residual_connection_post_layernorm: bool = False, + output_layernorm: bool = False, + layer_type: str = "encoder", + drop_path_rate: float = 0, + use_emha: bool = False, + ub_tp_comm_overlap: bool = False, + ub_bulk_wgrad: bool = True, + ub_bulk_dgrad: bool = True, + autocast_dtype: Any = 16, + zero_centered_gamma: bool = False, + device: str = "cuda", + **kwargs, + ) -> None: + transformer_layer_args = { + "hidden_size": hidden_size, + "ffn_hidden_size": ffn_hidden_size, + "layernorm_epsilon": layernorm_epsilon, + "num_attention_heads": num_attention_heads, + "init_method": init_method, + "output_layer_init_method": output_layer_init_method, + "hidden_dropout": hidden_dropout, + "attention_dropout": attention_dropout, + "layer_number": layer_number, + "kv_channels": kv_channels, + "self_attn_mask_type": self_attn_mask_type, + "tp_group": tp_group, + "tp_size": tp_size, + "params_dtype": params_dtype, + "get_rng_state_tracker": get_rng_state_tracker, + "fuse_wgrad_accumulation": fuse_wgrad_accumulation, + "seq_length": seq_length, + "micro_batch_size": micro_batch_size, + "sequence_parallel": sequence_parallel, + "apply_residual_connection_post_layernorm": apply_residual_connection_post_layernorm, + "output_layernorm": output_layernorm, + "layer_type": layer_type, + "drop_path_rate": drop_path_rate, + "set_parallel_mode": tp_size > 1, + "fuse_qkv_params": True, + "zero_centered_gamma": zero_centered_gamma, + "ub_tp_comm_overlap": ub_tp_comm_overlap, + "ub_bulk_wgrad": ub_bulk_wgrad, + "ub_bulk_dgrad": ub_bulk_dgrad, + "device": device, + } + te_version = packaging.version.Version(version("transformer-engine")) + if te_version > packaging.version.Version("1.5.0"): + for comm in ["ag", "rs"]: + ub_overlap_flag = "ub_overlap_" + comm + split_gemm_flag = "ub_split_" + comm + atomic_gemm_flag = "ub_atomic_gemm_" + comm + # Use old overlap flags if they were supplied instead + if ub_overlap_flag in kwargs: + transformer_layer_args[ub_overlap_flag] = kwargs[ub_overlap_flag] + else: + transformer_layer_args[ub_overlap_flag] = kwargs.get( + split_gemm_flag, True + ) or kwargs.get(atomic_gemm_flag, False) + if te_version > packaging.version.Version("1.6.0.dev0"): + transformer_layer_args["ub_overlap_rs_dgrad"] = kwargs.get( + "ub_overlap_rs_dgrad", False + ) + else: + transformer_layer_args["ub_split_ag"] = kwargs.get("ub_split_ag", True) + transformer_layer_args["ub_split_rs"] = kwargs.get("ub_split_rs", True) + transformer_layer_args["ub_atomic_gemm_ag"] = kwargs.get("ub_atomic_gemm_ag", False) + transformer_layer_args["ub_atomic_gemm_rs"] = kwargs.get("ub_atomic_gemm_rs", False) + super().__init__(**transformer_layer_args) + + # Dtype for forward pass + self.dtype = torch_dtype_from_precision(autocast_dtype) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor = None, + encoder_output: Optional[torch.Tensor] = None, + enc_dec_attn_mask: Optional[torch.Tensor] = None, + inference_params: Optional[Any] = None, + is_first_microbatch: Optional[bool] = None, + checkpoint_core_attention: Optional[bool] = False, + ) -> torch.Tensor: + """ + Perform a forward pass through the transformer layer. + """ + if self.dtype == torch.float32: + return super().forward( + hidden_states, + attention_mask, + encoder_output=encoder_output, + enc_dec_attn_mask=enc_dec_attn_mask, + inference_params=inference_params, + is_first_microbatch=is_first_microbatch, + checkpoint_core_attention=checkpoint_core_attention, + ) + with torch.autocast(device_type="cuda", dtype=self.dtype): + return super().forward( + hidden_states, + attention_mask=attention_mask, + encoder_output=encoder_output, + enc_dec_attn_mask=enc_dec_attn_mask, + inference_params=inference_params, + is_first_microbatch=is_first_microbatch, + checkpoint_core_attention=checkpoint_core_attention, + ) + + +class TETransformerLayerAutocast(MegatronModule, BaseTransformerLayer): # type: ignore + """ + A MegatronModule that wraps the AutocastTransformerLayer. + """ + + def __init__(self, config, layer_number=1, hidden_dropout=None, **kwargs): + super().__init__(config=config) + self.layer_number = layer_number + self._get_layer_offset() + + self.config = config + self.is_first_microbatch = True + precision = "bf16" if config.bf16 else 16 + + transformer_layer_args = { + "hidden_size": config.hidden_size, + "ffn_hidden_size": config.ffn_hidden_size, + "layernorm_epsilon": config.layernorm_epsilon, + "num_attention_heads": config.num_attention_heads, + "init_method": config.init_method, + "output_layer_init_method": config.output_layer_init_method, + "hidden_dropout": config.hidden_dropout, + "attention_dropout": config.attention_dropout, + "layer_number": layer_number + self._get_layer_offset(), + "kv_channels": config.kv_channels, + "tp_size": parallel_state.get_tensor_model_parallel_world_size(), + "params_dtype": config.params_dtype, + "get_rng_state_tracker": tensor_parallel.random.get_cuda_rng_tracker, + "fuse_wgrad_accumulation": config.gradient_accumulation_fusion, + "seq_length": None, # used for jit warmup + "micro_batch_size": None, # used for jit warmup + "sequence_parallel": config.sequence_parallel, + "apply_residual_connection_post_layernorm": config.apply_residual_connection_post_layernorm, + "autocast_dtype": precision, + "ub_tp_comm_overlap": config.tp_comm_overlap, + "ub_bulk_wgrad": config.tp_comm_bulk_wgrad, + "ub_bulk_dgrad": config.tp_comm_bulk_dgrad, + "zero_centered_gamma": config.layernorm_zero_centered_gamma, + "device": "cpu" if config.use_cpu_initialization else "cuda", + } + te_version = packaging.version.Version(version("transformer-engine")) + if te_version > packaging.version.Version("1.5.0"): + # Use old overlap flags if they were supplied instead + transformer_layer_args["ub_overlap_ag"] = ( + config.tp_comm_overlap_ag + if hasattr(config, "tp_comm_overlap_ag") + else config.tp_comm_split_ag or config.tp_comm_atomic_ag + ) + transformer_layer_args["ub_overlap_rs"] = ( + config.tp_comm_overlap_rs + if hasattr(config, "tp_comm_overlap_rs") + else config.tp_comm_split_rs or config.tp_comm_atomic_rs + ) + if te_version > packaging.version.Version("1.6.0.dev0"): + transformer_layer_args["ub_overlap_rs_dgrad"] = ( + config.tp_comm_overlap_rs_dgrad + if hasattr(config, "tp_comm_overlap_rs_dgrad") + else False + ) + else: + transformer_layer_args["ub_split_ag"] = config.tp_comm_split_ag + transformer_layer_args["ub_split_rs"] = config.tp_comm_split_rs + transformer_layer_args["ub_atomic_gemm_ag"] = config.tp_comm_atomic_ag + transformer_layer_args["ub_atomic_gemm_rs"] = config.tp_comm_atomic_rs + self.transformer_layer = AutocastTransformerLayer(**transformer_layer_args) + + if self.config.enable_cuda_graph and self.training: + assert ( + not config.cpu_offloading and config.recompute_granularity is None + ), "Cudagraphs not supported" + self.add_module("cudagraph_manager", CudaGraphManager(config)) + + # Called by MCore's TransformerBlock.forward + # megatron/core/transformer/transformer_block.py + def forward( + self, + hidden_states, + is_first_microbatch=None, + attention_mask=None, + context=None, + context_mask=None, + inference_params=None, + **kwargs, + ): + """Forward function of TETransformerLayerAutocast. Called by MCore's TransformerBlock.forward.""" + # Use is_first_microbatch argument during CUDA graph capture. Use self.is_first_microbatch otherwise. + hidden_states = self.transformer_layer.forward( + hidden_states, + attention_mask=attention_mask, + encoder_output=context, + enc_dec_attn_mask=context_mask, + inference_params=inference_params, + is_first_microbatch=( + is_first_microbatch if is_first_microbatch is not None else self.is_first_microbatch + ), + # checkpoint_core_attention, + ) + self.is_first_microbatch = False + context = None + + # External CUDA graph requires returned values to be Tensors + if ( + hasattr(self.config, "external_cuda_graph") + and self.config.external_cuda_graph + and self.training + ): + return hidden_states + return hidden_states, context + + def _get_layer_offset(self): + pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() + + num_layers_per_pipeline_rank = ( + self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size() + ) + + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() + vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + + total_num_layers = self.config.num_layers + num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size + total_virtual_chunks = total_num_layers // vp_size + offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank) + + else: + # Each stage gets a contiguous set of layers. + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + offset = pipeline_rank * num_layers_per_pipeline_rank + else: + offset = 0 + + return offset + + def sharded_state_dict(self, prefix: str = "", sharded_offsets: tuple = (), metadata=None): + """Get the sharded state dict for the transformer layer.""" + TENSOR_PARALLEL_LAYERS_AXIS_MAP = { + "self_attention.layernorm_qkv.weight": 0, + "self_attention.layernorm_qkv.bias": 0, + "self_attention.proj.weight": 1, + "layernorm_mlp.fc1_weight": 0, + "layernorm_mlp.fc1_bias": 0, + "layernorm_mlp.fc2_weight": 1, + } + + state_dict = self.state_dict(prefix="", keep_vars=True) + sharded_state_dict = make_sharded_tensors_for_checkpoint( + state_dict, prefix, TENSOR_PARALLEL_LAYERS_AXIS_MAP, sharded_offsets + ) + + # TODO: we need to add sharded_state_dict_keys_map to the config. Like in TransformerLayer submodules config + # prefixed_map = { + # f'{prefix}{k}': f'{prefix}{v}' + # for k, v in self.config.sharded_state_dict_keys_map.items() + # } + + # if prefixed_map: + # apply_prefix_mapping(sharded_state_dict, prefixed_map) + + return sharded_state_dict + + def __call__(self, *args, **kwargs): + if hasattr(self, "cudagraph_manager"): + return self.cudagraph_manager(self, args, kwargs) + return super().__call__(*args, **kwargs) + + +# Use this spec to use the full Transformer layer from Transformer Engine +def get_gpt_full_te_layer_autocast_spec(transformer_config) -> ModuleSpec: + """Get the ModuleSpec for full Transformer layer from Transformer Engine.""" + num_layers = get_num_layers_to_build(transformer_config) + return TransformerBlockSubmodules( + layer_specs=[ModuleSpec(module=TETransformerLayerAutocast)] * num_layers, + layer_norm=FusedLayerNorm, + ) + + +def torch_dtype_from_precision(precision: Union[int, str]) -> torch.dtype: + """Mapping from precision types to corresponding PyTorch parameter datatype.""" + if precision in ("bf16", "bf16-mixed"): + return torch.bfloat16 + elif precision in (16, "16", "16-mixed"): + return torch.float16 + elif precision in (32, "32", "32-true"): + return torch.float32 + else: + raise ValueError(f"Could not parse the precision of `{precision}` to a valid torch.dtype") diff --git a/flagscale/train/bridge/models/gpt_provider.py b/flagscale/train/bridge/models/gpt_provider.py new file mode 100644 index 0000000000..9988c9fb36 --- /dev/null +++ b/flagscale/train/bridge/models/gpt_provider.py @@ -0,0 +1,430 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import contextlib +import inspect +import logging + +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Callable, Literal, Optional, Union + +import torch + +from megatron.core import parallel_state +from megatron.core.models.gpt import GPTModel as MCoreGPTModel +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) +from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec +from megatron.core.transformer import ModuleSpec + +from flagscale.train.bridge.models.model_provider import ModelProviderMixin +from flagscale.train.bridge.models.transformer_config import TransformerConfig +from flagscale.train.bridge.utils import fusions +from flagscale.train.bridge.utils.vocab_utils import calculate_padded_vocab_size + +logger = logging.getLogger(__name__) + + +def transformer_engine_layer_spec(config: "GPTModelProvider") -> ModuleSpec: + """Create a Transformer Engine layer specification based on the provided config.""" + if ( + "use_te_op_fuser" + in inspect.signature(get_gpt_layer_with_transformer_engine_spec).parameters + ): + kwargs = {"use_te_op_fuser": config.use_transformer_engine_op_fuser} + else: + kwargs = {} + return get_gpt_layer_with_transformer_engine_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + qk_layernorm=config.qk_layernorm, + fp8=bool(config.num_moe_experts and (config.fp8 is not None)), + **kwargs, + ) + + +def transformer_engine_full_layer_spec(config: "GPTModelProvider") -> ModuleSpec: + """Create a full Transformer Engine layer specification with autocast support. + + Args: + config: GPT configuration object + + Returns: + ModuleSpec: Module specification for full TE layers + """ + from flagscale.train.bridge.models.gpt_full_te_layer_autocast_spec import ( + get_gpt_full_te_layer_autocast_spec, + ) + + return get_gpt_full_te_layer_autocast_spec(transformer_config=config) + + +def local_layer_spec(config: "GPTModelProvider") -> ModuleSpec: + """Create a local layer specification without Transformer Engine. + + Args: + config: GPT configuration object + + Returns: + ModuleSpec: Module specification for local implementation layers + """ + return get_gpt_layer_local_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + qk_layernorm=config.qk_layernorm, + normalization=config.normalization, + ) + + +def quantization_layer_spec(config: "GPTModelProvider") -> ModuleSpec: + """Layer specification for quantization with ModelOpt.""" + return get_gpt_modelopt_spec( + config=config, + local_core_attention=False, + remap_te_layernorm=True, + real_quant_cfg="None", + use_arbitrary_attention_mask=True, + ) + + +def default_layer_spec(config: "GPTModelProvider") -> ModuleSpec: + """Determine the most appropriate layer specification based on availability.""" + if config.restore_modelopt_state: + return quantization_layer_spec(config) + elif config.use_transformer_engine_full_layer_spec: + return transformer_engine_full_layer_spec(config) + else: + return transformer_engine_layer_spec(config) + + +@dataclass +class GPTModelProvider(TransformerConfig, ModelProviderMixin[MCoreGPTModel]): + """Configuration and provider for Megatron Core GPT models. + + This class extends TransformerConfig with GPT-specific parameters and + provides a method to instantiate configured GPT models. + """ + + # Model configuration + fp16_lm_cross_entropy: bool = False + parallel_output: bool = True + share_embeddings_and_output_weights: bool = True + make_vocab_size_divisible_by: int = 128 + position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute" + rotary_base: int = 10000 + rotary_percent: float = 1.0 + seq_len_interpolation_factor: Optional[float] = None + seq_length: int = 1024 + attention_softmax_in_fp32: bool = False + deallocate_pipeline_outputs: bool = True + scatter_embedding_sequence_parallel: bool = True + tp_only_amax_red: bool = False + tp_comm_overlap_cfg: Optional[Union[str, dict[str, Any]]] = None + """Config file when tp_comm_overlap is enabled.""" + + use_transformer_engine_full_layer_spec: bool = False + use_transformer_engine_op_fuser: bool = False + transformer_layer_spec: Union[ModuleSpec, Callable[["GPTModelProvider"], ModuleSpec]] = ( + default_layer_spec + ) + + generation_config: Optional[Any] = None + + # This represents the unpadded vocab size + # The padded vocab size is automatically calculated in the provide() method. + vocab_size: Optional[int] = None + # Set if the tokenizer provides the vocab size. In this case, the vocab size will be padded + # Controls whether vocab size should be padded for tensor parallelism + should_pad_vocab: bool = False + + # MoE / FP8 + num_moe_experts: Optional[int] = None + moe_grouped_gemm: bool = False + qk_layernorm: bool = False + fp8: Optional[str] = None + normalization: str = "LayerNorm" + + # Multi-token prediction + mtp_enabled: bool = False + + # Additional parameters that might be needed + init_model_with_meta_device: bool = False + use_te_rng_tracker: bool = False + enable_cuda_graph: bool = False + virtual_pipeline_model_parallel_size: Optional[int] = None + account_for_embedding_in_pipeline_split: bool = False + account_for_loss_in_pipeline_split: bool = False + + # Fusions + masked_softmax_fusion: bool = field(default_factory=fusions.can_enable_masked_softmax_fusion) + cross_entropy_loss_fusion: bool = True # Generally beneficial, no specific dependencies + gradient_accumulation_fusion: bool = field( + default_factory=fusions.can_enable_gradient_accumulation_fusion + ) + bias_activation_fusion: bool = ( + False # Disabled by default as it can interfere with certain architectures + ) + persist_layer_norm: bool = False + bias_dropout_fusion: bool = field(default_factory=fusions.can_enable_bias_dropout_fusion) + apply_rope_fusion: bool = field(default_factory=fusions.can_enable_apply_rope_fusion) + + # If True, restore the modelopt_state that contains quantization, sparsity, speculative decoding transformation state. + # When resuming modelopt_state, we also change the transformer_layer_spec to `megatron.core.post_training.modelopt.gpt.model_specs` which is a combination of local spec + TEDotProductAttention. + + restore_modelopt_state: bool = False + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: + """Configure and instantiate a Megatron Core GPT model based on this configuration. + + Args: + pre_process: Whether to include pre-processing in the model, defaults to first pipeline stage + post_process: Whether to include post-processing in the model, defaults to last pipeline stage + vp_stage: Virtual pipeline stage + + Returns: + MCoreGPTModel: Configured Megatron Core GPT model instance + """ + # Validate fusion configurations + if not fusions.validate_rope_fusion_compatibility(self): + self.apply_rope_fusion = False + + if self.enable_cuda_graph: + assert getattr(self, "use_te_rng_tracker", False), ( + "Transformer engine's RNG tracker is required for cudagraphs, it can be " + "enabled with use_te_rng_tracker=True'." + ) + + vp_size = self.virtual_pipeline_model_parallel_size + is_pipeline_asymmetric = getattr( + self, "account_for_embedding_in_pipeline_split", False + ) or getattr(self, "account_for_loss_in_pipeline_split", False) + is_pipeline_asymmetric |= ( + getattr(self, "num_layers_in_first_pipeline_stage", None) + or getattr(self, "num_layers_in_last_pipeline_stage", None) + ) is not None + is_flexible_pp_layout = is_pipeline_asymmetric or ( + getattr(self, "pipeline_model_parallel_layout", None) is not None + ) + if vp_size and not is_flexible_pp_layout: + p_size = self.pipeline_model_parallel_size + assert ( + self.num_layers // p_size + ) % vp_size == 0, ( + "Make sure the number of model chunks is the same across all pipeline stages." + ) + + transformer_layer_spec = self.transformer_layer_spec + if not isinstance(transformer_layer_spec, ModuleSpec): + # Check if the transformer_layer_spec function accepts vp_stage parameter + if "vp_stage" in inspect.signature(transformer_layer_spec).parameters: + transformer_layer_spec = transformer_layer_spec(self, vp_stage=vp_stage) + else: + transformer_layer_spec = transformer_layer_spec(self) + + assert self.vocab_size is not None, "vocab_size must be configured before calling provide()" + if self.should_pad_vocab: + padded_vocab_size = calculate_padded_vocab_size( + self.vocab_size, self.make_vocab_size_divisible_by, self.tensor_model_parallel_size + ) + else: + padded_vocab_size = self.vocab_size + + # Initialize model as meta data instead of allocating data on a device + model_init_device_context = contextlib.nullcontext + if self.init_model_with_meta_device: + model_init_device_context = partial(torch.device, device="meta") + + # Check if mtp_block_spec parameter is supported + kwargs = {} + if "mtp_block_spec" in inspect.signature(MCoreGPTModel.__init__).parameters: + kwargs["mtp_block_spec"] = mtp_block_spec(self, vp_stage=vp_stage) + + with model_init_device_context(): + model = MCoreGPTModel( + self, + transformer_layer_spec=transformer_layer_spec, + vocab_size=padded_vocab_size, + max_sequence_length=self.seq_length, + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + position_embedding_type=self.position_embedding_type, + rotary_percent=self.rotary_percent, + rotary_base=self.rotary_base, + seq_len_interpolation_factor=self.seq_len_interpolation_factor, + pre_process=pre_process + or parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage), + post_process=post_process + or parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage), + scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel, + vp_stage=vp_stage, + **kwargs, + ) + + # If using full TE layer, need to set TP, CP group since the module call + # is not routed through megatron core, which normally handles passing the + # TP, CP group to the TE modules. + # Deep iterate but skip self to avoid infinite recursion. + if self.use_transformer_engine_full_layer_spec: + # Copied from: + # https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py + if parallel_state.get_tensor_model_parallel_world_size() > 1: + for index, child in enumerate(model.modules()): + if index == 0: + continue + if hasattr(child, "set_tensor_parallel_group"): + tp_group = parallel_state.get_tensor_model_parallel_group() + child.set_tensor_parallel_group(tp_group) + + if parallel_state.get_context_parallel_world_size() > 1: + cp_stream = torch.cuda.Stream() + for index, child in enumerate(model.modules()): + if index == 0: + continue + if hasattr(child, "set_context_parallel_group"): + child.set_context_parallel_group( + parallel_state.get_context_parallel_group(), + parallel_state.get_context_parallel_global_ranks(), + cp_stream, + ) + + return model + + +def mtp_block_spec( + config: "GPTModelProvider", vp_stage: Optional[int] = None +) -> Optional[ModuleSpec]: + """Pass in the MTP block spec if model has MTP layers. + + Args: + config: GPT configuration object + + Returns: + ModuleSpec: The MTP module specification + """ + if getattr(config, "mtp_num_layers", None): + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec + + if isinstance(config.transformer_layer_spec, Callable): + if "vp_stage" in inspect.signature(config.transformer_layer_spec).parameters: + spec = config.transformer_layer_spec(config, vp_stage=vp_stage) + else: + spec = config.transformer_layer_spec(config) + else: + spec = config.transformer_layer_spec + if hasattr(spec, "layer_specs") and len(spec.layer_specs) == 0: + # Get the decoder layer spec explicitly if no decoder layer in the last stage, + # Only happens with block spec (TransformerBlockSubmodules) when using MoE. + spec = default_layer_spec(config) + return get_gpt_mtp_block_spec(config, spec, use_transformer_engine=True, vp_stage=vp_stage) + else: + return None + + +@dataclass +class GPTProvider126M(GPTModelProvider): + """Configuration for a 126M parameter GPT model. + + Predefined configuration for a small GPT model with 12 layers, + 768 hidden size, and 12 attention heads. + """ + + seq_length: int = 2048 + num_layers: int = 12 + hidden_size: int = 768 + ffn_hidden_size: int = 3072 + num_attention_heads: int = 12 + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + + +@dataclass +class GPTProvider5B(GPTModelProvider): + """Configuration for a 5B parameter GPT model. + + Predefined configuration for a medium-sized GPT model with 24 layers, + 4096 hidden size, and 32 attention heads. + """ + + seq_length: int = 2048 + num_layers: int = 24 + hidden_size: int = 4096 + ffn_hidden_size: int = 16384 + num_attention_heads: int = 32 + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + + +@dataclass +class GPTProvider7B(GPTModelProvider): + """Configuration for a 7B parameter GPT model. + + Predefined configuration for a medium-sized GPT model with 32 layers, + 4096 hidden size, and 32 attention heads. + """ + + seq_length: int = 2048 + num_layers: int = 32 + hidden_size: int = 4096 + ffn_hidden_size: int = 10880 + num_attention_heads: int = 32 + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + + +@dataclass +class GPTProvider20B(GPTModelProvider): + """Configuration for a 20B parameter GPT model. + + Predefined configuration for a large GPT model with 44 layers, + 6144 hidden size, and 48 attention heads. + """ + + seq_length: int = 2048 + num_layers: int = 44 + hidden_size: int = 6144 + ffn_hidden_size: int = 24576 + num_attention_heads: int = 48 + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + + +@dataclass +class GPTProvider40B(GPTModelProvider): + """Configuration for a 40B parameter GPT model. + + Predefined configuration for a large GPT model with 48 layers, + 8192 hidden size, and 64 attention heads. + """ + + seq_length: int = 2048 + num_layers: int = 48 + hidden_size: int = 8192 + ffn_hidden_size: int = 32768 + num_attention_heads: int = 64 + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + + +@dataclass +class GPTProvider175B(GPTModelProvider): + """Configuration for a 175B parameter GPT model. + + Predefined configuration for a massive GPT model with 96 layers, + 12288 hidden size, and 96 attention heads. + """ + + seq_length: int = 2048 + num_layers: int = 96 + hidden_size: int = 12288 + ffn_hidden_size: int = 49152 + num_attention_heads: int = 96 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + bias_activation_fusion: bool = True + bias_dropout_add_fusion: bool = True + layernorm_zero_centered_gamma: bool = True diff --git a/flagscale/train/bridge/models/hf_pretrained/README.md b/flagscale/train/bridge/models/hf_pretrained/README.md new file mode 100644 index 0000000000..3cb199ffd7 --- /dev/null +++ b/flagscale/train/bridge/models/hf_pretrained/README.md @@ -0,0 +1,111 @@ +# HuggingFace pre-trained checkpoints + +Structured, type-safe classes for working with Hugging Face checkpoints. Each checkpoint type has its own structure - `flagscale.train.bridge.models.hf_pretrained` makes these contracts explicit and provides a clean interface. + +## Quick Start + +```python +from flagscale.train.bridge.models.hf_pretrained import PreTrainedCausalLM + +# Load any Hugging Face model with proper structure +model = PreTrainedCausalLM.from_hf_pretrained("gpt2") + +# See exactly what the checkpoint contains +print(model) +# PreTrainedCausalLM( +# (model): GPT2LMHeadModel [layers=12, hidden_size=768] +# (tokenizer): GPT2TokenizerFast [vocab_size=50257] +# (config): GPT2Config [model_type=gpt2] +# (generation_config): GenerationConfig [loaded] +# (parameters): 124,439,808 +# (device): cpu +# (dtype): torch.float32 +# ) + +# Use it naturally +text = model.encode("Hello world") +output = model.generate(text.input_ids, max_length=50) +result = model.decode(output[0]) +``` + +## Available Classes + +### PreTrainedCausalLM +For text generation models (GPT, LLaMA, etc.) + +```python +from flagscale.train.bridge.models.hf_pretrained import PreTrainedCausalLM + +# Type-safe loading with lazy evaluation +llama = PreTrainedCausalLM.from_hf_pretrained( + "meta-llama/Llama-2-7b-hf", + torch_dtype=torch.float16, + device="cuda" +) + +# Components load on demand +config = llama.config # Loads just config +tokenizer = llama.tokenizer # Loads just tokenizer +model = llama.model # Loads model weights +``` + +### PreTrainedVLM +For vision-language models (CLIP, LLaVA, etc.) + +```python +from flagscale.train.bridge.models.hf_pretrained import PreTrainedVLM + +vlm = PreTrainedVLM.from_hf_pretrained("llava-hf/llava-1.5-7b-hf") + +# Unified processing for images and text +inputs = vlm.process_images_and_text( + images=my_image, + text="What's in this image?" +) + +output = vlm.generate(**inputs) +``` + +## Key Features + +### šŸ” Transparent Inspection +See exactly what's in a checkpoint without loading everything: + +```python +model = PreTrainedCausalLM.from_hf_pretrained("microsoft/phi-2") +print(model) # Shows architecture, parameters, device, dtype +``` + +### šŸ’¾ Lazy Loading +Components load only when accessed, saving memory: + +```python +# Nothing loaded yet +model = PreTrainedCausalLM.from_hf_pretrained("gpt2") + +# Still nothing loaded - just returns the config +config = model.config + +# Now the model weights are loaded +outputs = model.generate(...) +``` + +### šŸŽÆ Type Safety +Full type hints for better IDE support: + +```python +from transformers import GPT2LMHeadModel + +gpt2: PreTrainedCausalLM[GPT2LMHeadModel] = PreTrainedCausalLM.from_hf_pretrained("gpt2") +# IDE knows exact model type for autocomplete +``` + +### šŸ”§ Unified State Dict Access +Access model weights consistently: + +```python +# Works for any model type +model.state["*.attention.*.weight"] # Get attention weights +model.state.regex(r".*\.bias$") # Find all biases +model.state.glob("*.layer.*.weight") # Pattern matching +``` diff --git a/flagscale/train/bridge/models/hf_pretrained/__init__.py b/flagscale/train/bridge/models/hf_pretrained/__init__.py new file mode 100644 index 0000000000..2846e28a23 --- /dev/null +++ b/flagscale/train/bridge/models/hf_pretrained/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from flagscale.train.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from flagscale.train.bridge.models.hf_pretrained.vlm import PreTrainedVLM + +__all__ = ["PreTrainedCausalLM", "PreTrainedVLM"] diff --git a/flagscale/train/bridge/models/hf_pretrained/base.py b/flagscale/train/bridge/models/hf_pretrained/base.py new file mode 100644 index 0000000000..f6bd9d56a8 --- /dev/null +++ b/flagscale/train/bridge/models/hf_pretrained/base.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import shutil + +from abc import ABC, abstractmethod +from fnmatch import fnmatch +from pathlib import Path +from typing import ClassVar, Dict, List, Optional, Union + +import torch + +from transformers import AutoConfig, PreTrainedModel + +from flagscale.train.bridge.models.hf_pretrained.state import ( + SafeTensorsStateSource, + StateDict, + StateSource, +) + + +class PreTrainedBase(ABC): + """ + Abstract base class for all pretrained models. + + This class provides a generic mechanism for managing model artifacts + (e.g., config, tokenizer) with lazy loading. Subclasses that are + decorated with `@dataclass` can define artifacts as fields with metadata + specifying a loader method. The `model` itself is handled via a + dedicated property that relies on the abstract `_load_model` method. + + Example: + @dataclass + class MyModel(PreTrainedBase): + config: AutoConfig = field( + init=False, + metadata=artifact(loader="_load_config") + ) + + def _load_model(self) -> "PreTrainedModel": + # Implementation for the loading logic + ... + """ + + model_name_or_path: Union[str, Path] + ARTIFACTS: ClassVar[List[str]] = [] + OPTIONAL_ARTIFACTS: ClassVar[List[str]] = [] + + def __init__(self, **kwargs): + self._state_dict_accessor: Optional[StateDict] = None + self.init_kwargs = kwargs + # Store the original source path for custom modeling file preservation + self._original_source_path: Optional[Union[str, Path]] = None + + def get_artifacts(self) -> Dict[str, str]: + """Get the artifacts dictionary mapping artifact names to their attribute names.""" + return {artifact: f"_{artifact}" for artifact in self.ARTIFACTS} + + def _copy_custom_modeling_files( + self, source_path: Union[str, Path], target_path: Union[str, Path] + ) -> None: + """Copy custom modeling files from source to target directory. + + This preserves custom modeling files that were used during model loading + with trust_remote_code=True, ensuring the saved model can be loaded properly. + + Args: + source_path: Source directory containing custom modeling files + target_path: Target directory to copy files to + """ + source_path = Path(source_path) + target_path = Path(target_path) + + # Common custom modeling file patterns + custom_file_patterns = ["*.py", "*.json", "*.jpeg", "*.png", "*.jpg", "*.mp4"] + copied_files = [] + + # First, try to copy from local directory if it exists + if source_path.exists() and source_path.is_dir(): + for pattern in custom_file_patterns: + for file_path in source_path.glob(pattern): + if file_path.is_file(): + target_file = target_path / file_path.name + try: + shutil.copy2(file_path, target_file) + copied_files.append(file_path.name) + except (OSError, IOError): + # Silently skip files that can't be copied + pass + + # If no files were copied and source_path looks like a HuggingFace Hub ID, + # try to download the custom modeling files directly from the Hub + if not copied_files and "/" in str(source_path) and not source_path.exists(): + try: + from huggingface_hub import hf_hub_download, list_repo_files + + # Get list of Python files in the repository + repo_files = list_repo_files(str(source_path)) + print("repo_files: ", repo_files) + for file in repo_files: + # Check if it matches our custom file patterns + if any(fnmatch(file, pattern) for pattern in custom_file_patterns): + try: + downloaded_file = hf_hub_download( + repo_id=str(source_path), + filename=file, + local_dir=target_path, + local_dir_use_symlinks=False, + ) + copied_files.append(file) + except Exception as e: + print("Error downloading file: ", e, "Skipping file...") + # Silently skip files that can't be downloaded + pass + + except Exception as e: + print( + "Error downloading custom modeling files: ", + e, + "Skipping custom modeling files...", + ) + # If HuggingFace Hub operations fail, silently continue + pass + + return copied_files + + def save_artifacts(self, save_directory: Union[str, Path]): + """ + Saves all loaded, generic artifacts that have a `save_pretrained` method + to the specified directory. Note: This does not save the `model` attribute. + + If the model was loaded with trust_remote_code=True, this method will also + attempt to preserve any custom modeling files to ensure the saved model + can be loaded properly. + """ + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + _ = getattr(self, "config") # trigger lazy loading of config + if hasattr(self, "_config") and self._config is not None: + self._config.save_pretrained(save_path) + + # Iterate over required artifacts to save them in a predictable order + # for name in self.ARTIFACTS: + # # Access the public property to trigger lazy loading if needed + # artifact = getattr(self, name) + # attr_name = f"_{name}" + # if hasattr(self, attr_name): + # if artifact is not None and hasattr(artifact, "save_pretrained"): + # artifact.save_pretrained(save_path) + + # Iterate over optional artifacts - only save if they exist and have save_pretrained + for name in self.OPTIONAL_ARTIFACTS: + artifact = getattr(self, name, None) + if artifact is not None and hasattr(artifact, "save_pretrained"): + artifact.save_pretrained(save_path) + + # Preserve custom modeling files if trust_remote_code was used + if hasattr(self, 'trust_remote_code') and self.trust_remote_code: + # Try original source path first, then fallback to model_name_or_path + source_paths = [] + if hasattr(self, '_original_source_path') and self._original_source_path: + source_paths.append(self._original_source_path) + if hasattr(self, 'model_name_or_path') and self.model_name_or_path: + source_paths.append(self.model_name_or_path) + + for source_path in source_paths: + copied_files = self._copy_custom_modeling_files(source_path, save_path) + if copied_files: + # Successfully copied files, no need to try other paths + break + + @abstractmethod + def _load_model(self) -> PreTrainedModel: + """Subclasses must implement this to load the main model.""" + pass + + @abstractmethod + def _load_config(self) -> AutoConfig: + """Subclasses must implement this to load the model config.""" + pass + + @property + def model(self) -> PreTrainedModel: + """Lazily loads and returns the underlying model.""" + if not hasattr(self, "_model"): + self._model = self._load_model() + return self._model + + @model.setter + def model(self, value: PreTrainedModel): + """Manually set the model.""" + self._model = value + + @property + def config(self) -> AutoConfig: + """Lazy load and return the model config.""" + if not hasattr(self, "_config"): + self._config = self._load_config() + return self._config + + @config.setter + def config(self, value: AutoConfig): + """Set the config manually.""" + self._config = value + + @property + def state(self) -> StateDict: + """ + Get the state dict accessor for pandas-like querying. + + This accessor can be backed by either a fully loaded model in memory + or a ".safetensors" checkpoint on disk, enabling lazy loading of tensors. + + Examples: + model.state() # Get full state dict + model.state["key"] # Get single entry + model.state[["key1", "key2"]] # Get multiple entries + model.state["*.weight"] # Glob pattern + model.state.regex(r".*\\.bias$") # Regex pattern + """ + if self._state_dict_accessor is None: + source: Optional[Union[Dict[str, torch.Tensor], StateSource]] = None + # Prioritize the loaded model's state_dict if available + if hasattr(self, "_model") and self._model is not None: + source = self.model.state_dict() + elif hasattr(self, "model_name_or_path") and self.model_name_or_path: + source = SafeTensorsStateSource(self.model_name_or_path) + + if source is None: + raise ValueError( + "Cannot create StateDict accessor: model is not loaded and model_name_or_path is not set." + ) + self._state_dict_accessor = StateDict(source) + return self._state_dict_accessor diff --git a/flagscale/train/bridge/models/hf_pretrained/causal_lm.py b/flagscale/train/bridge/models/hf_pretrained/causal_lm.py new file mode 100644 index 0000000000..b106530f40 --- /dev/null +++ b/flagscale/train/bridge/models/hf_pretrained/causal_lm.py @@ -0,0 +1,657 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import sys + +from pathlib import Path +from typing import Dict, Generic, List, Optional, TypeVar, Union + +import torch + +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + GenerationConfig, + PreTrainedTokenizer, +) +from transformers.generation.utils import GenerateOutput + +from flagscale.train.bridge.models.hf_pretrained.base import PreTrainedBase +from flagscale.train.bridge.models.hf_pretrained.safe_config_loader import ( + safe_load_config_with_retry, +) + +# Python 3.12+ supports PEP 692 (TypedDict Unpack) +if sys.version_info >= (3, 12): + from typing import TypedDict, Unpack +else: + from typing_extensions import TypedDict, Unpack + + +CausalLMType = TypeVar("CausalLMType", bound=AutoModelForCausalLM) + + +class PreTrainedCausalLM(PreTrainedBase, Generic[CausalLMType]): + """ + A generic class for Pretrained Causal Language Models with lazy loading. + + Allows type-safe access to specific model implementations like LlamaForCausalLM. + + Examples: + Basic usage with lazy loading: + >>> from mbridge.pretrained import PreTrainedCausalLM + >>> # Create instance - no model loading happens yet + >>> model = PreTrainedCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> # Components are loaded on first access + >>> config = model.config # Loads config + >>> tokenizer = model.tokenizer # Loads tokenizer + >>> # Generate text - model is loaded here + >>> inputs = model.encode("Hello, how are you?") + >>> outputs = model.generate(**inputs, max_length=50) + >>> print(model.decode(outputs[0], skip_special_tokens=True)) + + Using specific model types with type hints: + >>> from transformers import LlamaForCausalLM + >>> from mbridge.pretrained import PreTrainedCausalLM + >>> # Type-safe access to Llama-specific features + >>> llama_model: PreTrainedCausalLM[LlamaForCausalLM] = PreTrainedCausalLM.from_pretrained( + ... "meta-llama/Llama-2-7b-chat-hf", + ... torch_dtype=torch.float16, + ... device="cuda" + ... ) + >>> # Access Llama-specific attributes + >>> model_instance = llama_model.model # Type is LlamaForCausalLM + + Loading with custom configurations: + >>> # Load model with specific settings + >>> model = PreTrainedCausalLM.from_pretrained( + ... "gpt2", + ... device="cuda:0", + ... torch_dtype=torch.bfloat16, + ... attn_implementation="flash_attention_2", + ... load_in_8bit=True + ... ) + >>> # Override generation config + >>> from transformers import GenerationConfig + >>> model.generation_config = GenerationConfig( + ... max_length=100, + ... temperature=0.7, + ... top_p=0.9, + ... do_sample=True + ... ) + + Manual component management: + >>> # Create empty instance + >>> model = PreTrainedCausalLM() + >>> # Manually set components + >>> from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM + >>> model.config = AutoConfig.from_pretrained("microsoft/phi-2") + >>> model.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") + >>> model.model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2") + >>> # Save all components + >>> model.save_artifacts("./my_model") + + Batch processing example: + >>> # Process multiple prompts + >>> prompts = [ + ... "The capital of France is", + ... "Machine learning is", + ... "Python programming language was created by" + ... ] + >>> # Encode all prompts + >>> inputs = model.encode(prompts, padding=True, truncation=True) + >>> # Generate completions + >>> outputs = model.generate(**inputs, max_new_tokens=20) + >>> # Decode results + >>> for i, output in enumerate(outputs): + ... print(f"Prompt {i+1}: {model.decode(output, skip_special_tokens=True)}") + """ + + ARTIFACTS = ["tokenizer"] + OPTIONAL_ARTIFACTS = ["generation_config"] + + def __init__( + self, + model_name_or_path: Optional[Union[str, Path]] = None, + device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + **kwargs, + ): + """ + Initialize a Pretrained Causal LM with lazy loading. + + Args: + model_name_or_path: HuggingFace model identifier or local path + device: Device to load model on (e.g., 'cuda', 'cpu') + torch_dtype: Data type to load model in (e.g., torch.float16) + trust_remote_code: Whether to trust remote code when loading + **kwargs: Additional arguments passed to from_pretrained methods + """ + self._model_name_or_path = model_name_or_path + # self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.device = "cpu" + self.torch_dtype = torch_dtype + self.trust_remote_code = trust_remote_code + super().__init__(**kwargs) + # Store the original source path for custom modeling file preservation + if model_name_or_path and trust_remote_code: + self._original_source_path = model_name_or_path + + def _load_model(self) -> CausalLMType: + """Load the model.""" + if self.model_name_or_path is None: + raise ValueError("model_name_or_path must be provided to load model") + + model_kwargs = {"trust_remote_code": self.trust_remote_code, **self.init_kwargs} + if self.torch_dtype is not None: + model_kwargs["torch_dtype"] = self.torch_dtype + config = getattr(self, "_config", None) + if config is not None: + model_kwargs["config"] = config + + model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, **model_kwargs) + model = model.to(self.device) + + generation_config = getattr(self, "_generation_config", None) + if generation_config is not None and hasattr(model, "generation_config"): + model.generation_config = generation_config + return model + + def _load_config(self) -> AutoConfig: + """Load the model config with thread-safety protection.""" + if self.model_name_or_path is None: + raise ValueError("model_name_or_path must be provided to load config") + return safe_load_config_with_retry( + self.model_name_or_path, trust_remote_code=self.trust_remote_code, **self.init_kwargs + ) + + def _load_tokenizer(self) -> PreTrainedTokenizer: + """Load the tokenizer.""" + if self.model_name_or_path is None: + raise ValueError("model_name_or_path must be provided to load tokenizer") + tokenizer = AutoTokenizer.from_pretrained( + self.model_name_or_path, trust_remote_code=self.trust_remote_code, **self.init_kwargs + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + def _load_generation_config(self) -> Optional[GenerationConfig]: + """Load the generation config.""" + if self.model_name_or_path is not None: + try: + return GenerationConfig.from_pretrained( + self.model_name_or_path, + trust_remote_code=self.trust_remote_code, + **self.init_kwargs, + ) + except Exception: + # Not all models have generation configs + pass + return None + + @property + def generation_config(self) -> Optional[GenerationConfig]: + """Lazy load and return the generation config.""" + if not hasattr(self, "_generation_config"): + self._generation_config = self._load_generation_config() + return self._generation_config + + @generation_config.setter + def generation_config(self, value: GenerationConfig): + """Set the generation config manually.""" + self._generation_config = value + # Update model's generation config if model is already loaded + model = getattr(self, "_model", None) + if model is not None and hasattr(model, "generation_config"): + model.generation_config = value + + @property + def tokenizer(self) -> PreTrainedTokenizer: + """Lazy load and return the tokenizer.""" + if not hasattr(self, "_tokenizer"): + self._tokenizer = self._load_tokenizer() + return self._tokenizer + + @tokenizer.setter + def tokenizer(self, value: PreTrainedTokenizer): + """Set the tokenizer manually.""" + self._tokenizer = value + + @property + def model_name_or_path(self) -> Optional[Union[str, Path]]: + """Return the model name or path.""" + return self._model_name_or_path + + @property + def has_model(self) -> bool: + """Check if model has been loaded.""" + return hasattr(self, "_model") and self._model is not None + + @property + def model(self) -> CausalLMType: + """Lazy load and return the underlying model.""" + return super().model + + @model.setter + def model(self, value: CausalLMType): + """Set the model manually and move it to the appropriate device.""" + self._model = value + if self._model is not None: + self._model = self._model.to(self.device) + + @classmethod + def from_pretrained( + cls, + model_name_or_path: Union[str, Path], + device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + **kwargs, + ) -> "PreTrainedCausalLM[CausalLMType]": + """ + Create a PreTrainedCausalLM instance for lazy loading. + + Args: + model_name_or_path: HuggingFace model identifier or local path + device: Device to load model on + torch_dtype: Data type to load model in + trust_remote_code: Whether to trust remote code + **kwargs: Additional arguments for from_pretrained methods + + Returns: + PreTrainedCausalLM instance configured for lazy loading + """ + return cls( + model_name_or_path=model_name_or_path, + device=device, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + def generate( + self, input_ids: Optional[torch.LongTensor] = None, **kwargs: Unpack["GenerateKwargs"] + ) -> Union[torch.LongTensor, GenerateOutput]: + """ + Generate text using the underlying language model. + + This method forwards all arguments to the model's generate method, + supporting all generation strategies provided by the transformers library. + + Common parameters include: + inputs (torch.LongTensor, optional): Input token IDs. If not provided, + will generate from the beginning of sequence token. + max_length (int, optional): Maximum length of generated sequence. + Defaults to model's max_length configuration. + min_length (int, optional): Minimum length of generated sequence. + max_new_tokens (int, optional): Maximum number of tokens to generate, + ignoring the number of tokens in the prompt. + do_sample (bool, optional): Whether to use sampling. Defaults to False + (greedy decoding). + temperature (float, optional): Temperature for sampling. Higher values + produce more random outputs. Typical range: 0.1-2.0. + top_p (float, optional): Nucleus sampling threshold. Only tokens with + cumulative probability up to top_p are considered. Range: 0.0-1.0. + top_k (int, optional): Only consider the top k tokens for sampling. + num_beams (int, optional): Number of beams for beam search. 1 means + no beam search. + repetition_penalty (float, optional): Penalty for repeating tokens. + Values > 1.0 discourage repetition. + pad_token_id (int, optional): ID of padding token. + eos_token_id (int or List[int], optional): ID(s) of end-of-sequence token(s). + use_cache (bool, optional): Whether to use past key values to speed up + generation. Defaults to True. + + Returns: + torch.LongTensor or transformers.generation.utils.GenerateOutput: + Generated token IDs. If return_dict_in_generate=True, returns a + GenerateOutput object containing generated sequences and additional + information like scores. + + Examples: + >>> # Basic generation + >>> model = PreTrainedCausalLM.from_pretrained("gpt2") + >>> inputs = model.encode("Hello, how are") + >>> outputs = model.generate(inputs["input_ids"], max_length=20) + >>> print(model.decode(outputs[0])) + + >>> # Generation with sampling + >>> outputs = model.generate( + ... inputs["input_ids"], + ... max_length=50, + ... do_sample=True, + ... temperature=0.8, + ... top_p=0.9 + ... ) + + >>> # Beam search + >>> outputs = model.generate( + ... inputs["input_ids"], + ... max_length=50, + ... num_beams=5, + ... early_stopping=True + ... ) + + Note: + For detailed documentation of all parameters, see the transformers + library documentation for generation methods. + """ + model = self.model # Ensures model is loaded + # Sync generation config if it has been set on the wrapper + generation_config = getattr(self, "_generation_config", None) + if generation_config is not None and hasattr(model, "generation_config"): + model.generation_config = generation_config + return model.generate(input_ids, **kwargs) + + def __call__(self, *args, **kwargs): + """Forward call to model.""" + return self.model(*args, **kwargs) + + def encode( + self, text: Union[str, List[str]], **kwargs: Unpack["EncodeKwargs"] + ) -> Dict[str, torch.Tensor]: + """ + Encode text into token IDs using the model's tokenizer. + + This method tokenizes input text and returns tensors ready for model input. + The output is automatically moved to the same device as the model. + + Args: + text (str or List[str]): Input text to encode. Can be a single string + or a list of strings for batch encoding. + **kwargs: Additional arguments passed to the tokenizer. Common options: + padding (bool or str, optional): Padding strategy. + - True or 'longest': Pad to longest sequence in batch + - 'max_length': Pad to max_length + - False or 'do_not_pad': No padding (default) + truncation (bool or str, optional): Truncation strategy. + - True or 'longest_first': Truncate to max_length + - 'only_first': Truncate only first sequence (for pairs) + - False: No truncation + max_length (int, optional): Maximum length of returned sequences. + Defaults to model's max_length. + add_special_tokens (bool, optional): Whether to add special tokens + (e.g., [CLS], [SEP]). Defaults to True. + return_attention_mask (bool, optional): Whether to return attention + mask. Defaults to True. + return_token_type_ids (bool, optional): Whether to return token + type IDs (for models like BERT). Defaults to True if model + expects them. + + Returns: + Dict[str, torch.Tensor]: Dictionary containing: + - input_ids: Token IDs tensor of shape (batch_size, sequence_length) + - attention_mask: Attention mask tensor of same shape (if applicable) + - token_type_ids: Token type IDs tensor (if applicable) + Additional keys may be present depending on the tokenizer. + + Examples: + >>> model = PreTrainedCausalLM.from_pretrained("gpt2") + >>> # Single text encoding + >>> tokens = model.encode("Hello world!") + >>> print(tokens["input_ids"].shape) # torch.Size([1, 3]) + + >>> # Batch encoding with padding + >>> texts = ["Hello!", "How are you doing today?"] + >>> tokens = model.encode(texts, padding=True) + >>> print(tokens["input_ids"].shape) # torch.Size([2, 6]) + + >>> # Encoding with truncation + >>> tokens = model.encode( + ... "This is a very long text that might exceed the maximum length", + ... truncation=True, + ... max_length=10 + ... ) + + Note: + The returned tensors are on the same device as the model, ready + for immediate use in forward passes or generation. + """ + # Only set return_tensors default if not provided + if "return_tensors" not in kwargs: + kwargs["return_tensors"] = "pt" + + return self.tokenizer(text, **kwargs).to(self.device) + + def decode( + self, token_ids: Union[int, List[int], torch.Tensor], **kwargs: Unpack["DecodeKwargs"] + ) -> str: + """ + Decode token IDs back into text using the model's tokenizer. + + This method converts token IDs (from model output or encode method) + back into human-readable text. + + Args: + token_ids (int, List[int], or torch.Tensor): Token IDs to decode. + Can be: + - Single token ID (int) + - List of token IDs + - 1D tensor of token IDs + - 2D tensor (will decode the first sequence) + **kwargs: Additional arguments passed to the tokenizer's decode method: + skip_special_tokens (bool, optional): Whether to remove special + tokens (e.g., [PAD], [CLS], [SEP]) from output. Defaults to True. + clean_up_tokenization_spaces (bool, optional): Whether to clean up + tokenization artifacts (extra spaces, etc.). Defaults to True. + + Returns: + str: Decoded text string. + + Examples: + >>> model = PreTrainedCausalLM.from_pretrained("gpt2") + >>> # Encode and decode round-trip + >>> text = "Hello, world!" + >>> tokens = model.encode(text) + >>> decoded = model.decode(tokens["input_ids"][0]) + >>> print(decoded) # "Hello, world!" + + >>> # Decode generated tokens + >>> inputs = model.encode("The weather is") + >>> outputs = model.generate(inputs["input_ids"], max_length=10) + >>> decoded = model.decode(outputs[0]) + >>> print(decoded) # "The weather is nice today..." + + >>> # Decode without special tokens + >>> token_ids = [101, 7592, 1010, 2088, 999, 102] # BERT-style tokens + >>> decoded = model.decode(token_ids, skip_special_tokens=True) + >>> print(decoded) # "Hello, world!" + + >>> # Decode keeping special tokens + >>> decoded = model.decode(token_ids, skip_special_tokens=False) + >>> print(decoded) # "[CLS] Hello, world! [SEP]" + + Note: + If a 2D tensor is provided (batch of sequences), only the first + sequence is decoded. For batch decoding, use tokenizer.batch_decode() + directly or iterate over the sequences. + """ + return self.tokenizer.decode(token_ids, **kwargs) + + def to(self, device: Union[str, torch.device]): + """Move model to specified device.""" + self.device = device + if self.has_model: + self._model = self._model.to(device) + return self + + def half(self): + """Convert model to half precision (float16).""" + if self.has_model: + self._model = self._model.half() + return self + + def float(self): + """Convert model to full precision (float32).""" + if self.has_model: + self._model = self._model.float() + return self + + def save_pretrained(self, save_directory: Union[str, Path]): + """ + Save all components (model, tokenizer, config, generation_config) to a directory. + + This method saves: + - Model weights and config + - Tokenizer files + - Generation config (if available) + + Args: + save_directory: Path to directory where components will be saved + """ + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + # Save model if loaded + if hasattr(self, "_model") and self._model is not None: + self._model.save_pretrained(save_path) + + # Use the base class save_artifacts to save config and all artifacts + self.save_artifacts(save_path) + + @property + def dtype(self) -> Optional[torch.dtype]: + """Get model's dtype if loaded.""" + if self.has_model: + try: + return next(self.model.parameters()).dtype + except StopIteration: + return None + return None + + @property + def num_parameters(self) -> Optional[int]: + """Get total number of parameters if model is loaded.""" + if self.has_model: + return sum(p.numel() for p in self.model.parameters()) + return None + + def __repr__(self) -> str: + """Return a string representation of the PreTrainedCausalLM instance.""" + try: + # Access config to trigger lazy loading for a richer repr + _ = self.config + except Exception: + # If loading fails, repr shouldn't crash. + pass + + lines = [f"{self.__class__.__name__}("] + for name, attr_name in sorted(self.get_artifacts().items()): + is_loaded = hasattr(self, attr_name) + artifact_instance = getattr(self, attr_name, None) if is_loaded else None + + type_name = "N/A" + details = "not loaded" + if is_loaded and artifact_instance is not None: + type_name = artifact_instance.__class__.__name__ + if name == "tokenizer": + vocab = getattr(artifact_instance, "vocab_size", "N/A") + details = f"vocab_size={vocab}" + elif name == "config": + m_type = getattr(artifact_instance, "model_type", "N/A") + details = f"model_type={m_type}" + else: + details = "loaded" + lines.append(f" ({name}): {type_name} [{details}]") + + # Manually add model repr + model_repr_content: str + if self.has_model: + model_class_name = self.model.__class__.__name__ + # Assuming self.config is loaded or available here due to earlier attempt + config = self.config + layers = getattr(config, "num_hidden_layers", "N/A") + hidden_size = getattr(config, "hidden_size", "N/A") + model_repr_content = ( + f"{model_class_name} [layers={layers}, hidden_size={hidden_size}, loaded]" + ) + elif "config" in self.__dict__: # Model not loaded, but config is + config = self.config + model_class_name_from_hf_config = "CausalLM" # Default + if hasattr(config, "architectures") and config.architectures: + model_class_name_from_hf_config = config.architectures[0] + elif getattr(config, "model_type", None): + mt = config.model_type + model_class_name_from_hf_config = f"{mt.capitalize()}Model" if mt else "CausalLM" + + details_parts = [] + if getattr(config, "num_hidden_layers", None) is not None: + details_parts.append(f"layers={config.num_hidden_layers}") + if getattr(config, "hidden_size", None) is not None: + details_parts.append(f"hidden_size={config.hidden_size}") + + details_str = ", ".join(details_parts) + status_suffix = "not loaded" + if details_str: + model_repr_content = ( + f"{model_class_name_from_hf_config}({details_str}) [{status_suffix}]" + ) + else: + model_repr_content = f"{model_class_name_from_hf_config} [{status_suffix}]" + else: # Model and Config also not loaded + model_repr_content = "AutoModelForCausalLM [not loaded]" + + lines.append(f" (model): {model_repr_content}") + + lines.sort() + + params_str = f"{self.num_parameters:,}" if self.num_parameters is not None else "N/A" + dtype_str = str(self.dtype).replace("torch.", "") if self.dtype is not None else "N/A" + lines.extend( + [ + f" (parameters): {params_str}", + f" (device): {str(self.device)}", + f" (dtype): {dtype_str}", + ")", + ] + ) + return "\n".join(lines) + + +# TypedDict definitions for method parameters +class GenerateKwargs(TypedDict, total=False): + """TypedDict for generate method parameters.""" + + attention_mask: Optional[torch.Tensor] + max_length: Optional[int] + max_new_tokens: Optional[int] + min_length: Optional[int] + do_sample: Optional[bool] + temperature: Optional[float] + top_k: Optional[int] + top_p: Optional[float] + repetition_penalty: Optional[float] + pad_token_id: Optional[int] + eos_token_id: Optional[Union[int, List[int]]] + bos_token_id: Optional[int] + num_beams: Optional[int] + num_return_sequences: Optional[int] + early_stopping: Optional[bool] + use_cache: Optional[bool] + return_dict_in_generate: Optional[bool] + output_scores: Optional[bool] + output_attentions: Optional[bool] + + +class EncodeKwargs(TypedDict, total=False): + """TypedDict for encode method parameters.""" + + padding: Union[bool, str] + truncation: Union[bool, str] + max_length: Optional[int] + add_special_tokens: bool + return_attention_mask: bool + return_token_type_ids: Optional[bool] + return_tensors: str + + +class DecodeKwargs(TypedDict, total=False): + """TypedDict for decode method parameters.""" + + skip_special_tokens: bool + clean_up_tokenization_spaces: bool diff --git a/flagscale/train/bridge/models/hf_pretrained/safe_config_loader.py b/flagscale/train/bridge/models/hf_pretrained/safe_config_loader.py new file mode 100644 index 0000000000..9d5e9490aa --- /dev/null +++ b/flagscale/train/bridge/models/hf_pretrained/safe_config_loader.py @@ -0,0 +1,136 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +""" +Thread-safe configuration loading utilities. + +This module provides utilities for safely loading HuggingFace model configurations +in multi-threaded environments, preventing race conditions that can occur when +multiple threads try to download and cache the same model simultaneously. +""" + +import hashlib +import os +import time + +from pathlib import Path +from typing import Union + +import filelock + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig + + +def safe_load_config_with_retry( + path: Union[str, Path], + trust_remote_code: bool = False, + max_retries: int = 3, + base_delay: float = 1.0, + **kwargs, +) -> PretrainedConfig: + """ + Thread-safe and process-safe configuration loading with retry logic. + + This function prevents race conditions when multiple threads/processes + try to download and cache the same model configuration simultaneously. + Uses file locking (if filelock is available) to coordinate access across + processes. + + Args: + path: HuggingFace model ID or path to model directory + trust_remote_code: Whether to trust remote code when loading config + max_retries: Maximum number of retry attempts (default: 3) + base_delay: Base delay in seconds for exponential backoff (default: 1.0) + **kwargs: Additional arguments passed to AutoConfig.from_pretrained + + Returns: + PretrainedConfig: The loaded model configuration + + Raises: + ValueError: If config loading fails after all retries + + Environment Variables: + MEGATRON_CONFIG_LOCK_DIR: Override the directory where lock files are created. + Default: ~/.cache/huggingface/ + Useful for multi-node setups where a shared lock directory is needed. + + Example: + >>> config = safe_load_config_with_retry("meta-llama/Meta-Llama-3-8B") + >>> print(config.model_type) + + >>> # With custom retry settings + >>> config = safe_load_config_with_retry( + ... "gpt2", + ... max_retries=5, + ... base_delay=0.5, + ... trust_remote_code=True + ... ) + + >>> # Multi-node setup with shared lock directory + >>> import os + >>> os.environ["MEGATRON_CONFIG_LOCK_DIR"] = "/shared/locks" + >>> config = safe_load_config_with_retry("meta-llama/Meta-Llama-3-8B") + """ + last_exception = None + + for attempt in range(max_retries + 1): + try: + # Use file locking for process-safe access + # Create a lock file based on the path hash to avoid conflicts + path_hash = hashlib.md5(str(path).encode()).hexdigest() + + # Allow override of lock directory via environment variable + # This is useful for multi-node setups where a shared lock directory is needed + lock_dir = os.getenv("MEGATRON_CONFIG_LOCK_DIR") + if lock_dir: + lock_file = Path(lock_dir) / f".megatron_config_lock_{path_hash}" + else: + lock_file = ( + Path.home() / ".cache" / "huggingface" / f".megatron_config_lock_{path_hash}" + ) + + lock_file.parent.mkdir(parents=True, exist_ok=True) + + with filelock.FileLock(str(lock_file) + ".lock", timeout=60): + return AutoConfig.from_pretrained( + path, trust_remote_code=trust_remote_code, **kwargs + ) + + except Exception as e: + last_exception = e + + # Don't retry on certain types of errors + error_msg = str(e).lower() + if any( + phrase in error_msg + for phrase in [ + "does not appear to have a file named config.json", + "repository not found", + "entry not found", + "401 client error", + "403 client error", + ] + ): + # Model doesn't exist or access denied, no point retrying + raise ValueError( + f"Failed to load configuration from {path}. " + f"Ensure the path is valid and contains a config.json file. " + f"Error: {e}" + ) from e + + if attempt < max_retries: + # Exponential backoff with jitter + delay = base_delay * (2**attempt) + (time.time() % 1) * 0.1 + time.sleep(delay) + else: + # Final attempt failed + break + + # All retries exhausted + raise ValueError( + f"Failed to load configuration from {path} after {max_retries + 1} attempts. " + f"This might be due to network issues or concurrent access conflicts. " + f"Last error: {last_exception}" + ) from last_exception diff --git a/flagscale/train/bridge/models/hf_pretrained/state.py b/flagscale/train/bridge/models/hf_pretrained/state.py new file mode 100644 index 0000000000..6de010b14d --- /dev/null +++ b/flagscale/train/bridge/models/hf_pretrained/state.py @@ -0,0 +1,850 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import fnmatch +import json +import re + +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Mapping +from functools import lru_cache +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Pattern, Tuple, Union, overload + +import torch + + +class StateDict(Mapping[str, torch.Tensor]): + """ + A state dict accessor that provides a unified interface for querying model + checkpoints. + + `StateDict` allows for efficient and flexible access to tensor data from + various sources, such as in-memory dictionaries or directories of + `.safetensors` files. A key feature is its ability to query and load only + the required tensors without loading the entire checkpoint into memory, + making it highly memory-efficient for large models. + + It supports a flexible, pandas-like querying interface that allows for + accessing tensors by exact name, a list of names, glob patterns, or regular + expressions. This makes it easy to inspect and manipulate model + checkpoints. + + Examples: + >>> # Setup an example StateDict from an in-memory dictionary + >>> import torch + >>> import re + >>> d = { + ... "model.layer.0.weight": torch.randn(10, 10), + ... "model.layer.0.bias": torch.randn(10), + ... "model.layer.1.weight": torch.randn(10, 10), + ... "model.layer.1.bias": torch.randn(10), + ... } + >>> state = StateDict(d) + >>> + >>> # 1. Access a single tensor by exact key + >>> state["model.layer.0.weight"].shape + torch.Size([10, 10]) + >>> + >>> # 2. Access multiple tensors with a list of strings + >>> list(state[["model.layer.0.weight", "model.layer.1.weight"]].keys()) + ['model.layer.0.weight', 'model.layer.1.weight'] + >>> + >>> # 3. Access with a glob pattern + >>> sorted(list(state.glob("model.layer.*.bias").keys())) + ['model.layer.0.bias', 'model.layer.1.bias'] + >>> + >>> # 4. Access with a compiled regex pattern + >>> regex = re.compile(r"model\\\\.layer\\\\.0\\\\..*") + >>> sorted(list(state[regex].keys())) + ['model.layer.0.bias', 'model.layer.0.weight'] + + The same querying flexibility applies to checkpoints on disk. The following + is a conceptual example of using `StateDict` with a `SafetensorsStateSource` + to query a sharded checkpoint without loading all of it into memory. + + .. code-block:: python + + # Assume SafetensorsStateSource is available + # from flagscale.train.bridge.models.state import SafetensorsStateSource + + # Imagine a directory 'my_model_checkpoint/' with sharded weights. + state_from_disk = StateDict(SafetensorsStateSource('my_model_checkpoint/')) + + # You can query it just like the in-memory dictionary. Only the required + # tensors (e.g., all weight tensors) will be loaded from disk. + weights = state_from_disk.glob("model.layer.*.weight") + """ + + source: "StateSource" + + def __init__(self, source: Dict[str, torch.Tensor] | "StateSource"): + """ + Initializes the StateDict query accessor. + + Args: + source: The source of the tensor data. This can be a standard + Python dictionary mapping tensor names to `torch.Tensor` objects, + or an instance of a `StateSource` subclass (e.g., + `SafetensorsStateSource`) for more advanced, out-of-memory + access. + """ + if isinstance(source, dict): + source = DictStateSource(source) + + if not isinstance(source, StateSource): + raise TypeError(f"StateDict source must be a dict or a StateSource, got {type(source)}") + + self.source = source + + def _get_all_keys(self) -> List[str]: + """ + Get all available tensor keys from the underlying source. + """ + return self.source.get_all_keys() + + def _load_tensors(self, keys_to_load: List[str]) -> Dict[str, torch.Tensor]: + """ + Load specified tensors from the underlying source. + """ + return self.source.load_tensors(keys_to_load) + + def _match_keys(self, pattern: Union[str, Pattern]) -> List[str]: + """Match keys against a glob pattern or regex.""" + all_keys = self._get_all_keys() + + if isinstance(pattern, Pattern): + # Regex pattern + return [k for k in all_keys if pattern.search(k)] + elif "*" in pattern or "?" in pattern or "[" in pattern: + # Glob pattern + return [k for k in all_keys if fnmatch.fnmatch(k, pattern)] + else: + # Exact match + return [pattern] if pattern in all_keys else [] + + @overload + def __getitem__(self, key: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ... + + @overload + def __getitem__(self, key: List[str]) -> Dict[str, torch.Tensor]: ... + + @overload + def __getitem__(self, key: Pattern) -> Dict[str, torch.Tensor]: ... + + def __getitem__( + self, key: Union[str, List[str], Pattern] + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Accesses state dict entries using various key types. + + This method allows for retrieving tensors using: + - A single string for an exact key match. + - A list of strings for multiple exact key matches. + - A string with glob-style wildcards (`*`, `?`, `[]`). + - A compiled regular expression object. + + Args: + key: A single key string, a list of keys, a glob pattern string, or a + compiled regular expression. + + Returns: + - A single `torch.Tensor` if `key` is a string that matches exactly one key + and does not contain wildcards. + - A `Dict[str, torch.Tensor]` for all other cases (list of keys, glob + pattern, or regex), mapping the matched keys to their corresponding + tensors. + + Raises: + KeyError: If the key (or any key in a list) is not found, or if a + pattern matches no keys. + + Examples: + >>> d = { + ... "model.embed_tokens.weight": torch.randn(10, 1), + ... "model.layers.0.mlp.weight": torch.randn(10, 1), + ... "model.layers.0.self_attn.q_proj.weight": torch.randn(10, 1), + ... "lm_head.weight": torch.randn(10, 1), + ... } + >>> state = StateDict(d) + >>> + >>> # Exact match (returns a single tensor) + >>> tensor = state["model.embed_tokens.weight"] + >>> isinstance(tensor, torch.Tensor) + True + >>> + >>> # List of keys (returns a dict of tensors) + >>> tensors = state[["model.embed_tokens.weight", "lm_head.weight"]] + >>> sorted(tensors.keys()) + ['lm_head.weight', 'model.embed_tokens.weight'] + >>> + >>> # Glob pattern (returns a dict of tensors) + >>> layer_0_weights = state["model.layers.0.*.weight"] + >>> sorted(layer_0_weights.keys()) + ['model.layers.0.mlp.weight', 'model.layers.0.self_attn.q_proj.weight'] + >>> + >>> # Regex pattern (returns a dict of tensors) + >>> import re + >>> attn_weights = state[re.compile(r".*self_attn.*")] + >>> list(attn_weights.keys()) + ['model.layers.0.self_attn.q_proj.weight'] + """ + if isinstance(key, Pattern): + matched_keys = self._match_keys(key) + if not matched_keys: + raise KeyError(f"No keys match regex pattern: {key.pattern}") + return self._load_tensors(matched_keys) + elif isinstance(key, str): + if "*" in key or "?" in key or "[" in key: + matched_keys = self._match_keys(key) + if not matched_keys: + raise KeyError(f"No keys match pattern: {key}") + return self._load_tensors(matched_keys) + else: + if key not in self._get_all_keys(): + raise KeyError(f"Key not found: {key}") + return self._load_tensors([key])[key] + elif isinstance(key, list): + all_keys_set = set(self._get_all_keys()) + missing_keys = [k for k in key if k not in all_keys_set] + if missing_keys: + raise KeyError(f"Keys not found: {missing_keys}") + return self._load_tensors(key) + else: + raise TypeError(f"Key must be str, list of str, or compiled regex, got {type(key)}") + + def regex(self, pattern: str) -> Dict[str, torch.Tensor]: + """ + Queries the state dict with a regular expression pattern. + + This is a convenience method that compiles the pattern string and uses it + to retrieve all matching tensors. + + Args: + pattern: The regular expression string to match against tensor keys. + + Returns: + A dictionary mapping matching tensor names to their `torch.Tensor` objects. + + Examples: + >>> d = { + ... "model.layers.0.self_attn.weight": torch.randn(1, 1), + ... "model.layers.1.self_attn.weight": torch.randn(1, 1), + ... "model.layers.1.mlp.weight": torch.randn(1, 1) + ... } + >>> state = StateDict(d) + >>> # Get all attention-related weights + >>> attention_weights = state.regex(r"model\\.layers\\.\\d+\\.self_attn.*") + >>> sorted(attention_weights.keys()) + ['model.layers.0.self_attn.weight', 'model.layers.1.self_attn.weight'] + """ + return self[re.compile(pattern)] + + def glob(self, pattern: str) -> Dict[str, torch.Tensor]: + """ + Queries the state dict with a glob pattern. + + This is a convenience method for pattern matching using Unix shell-style + wildcards. + + Args: + pattern: The glob pattern string to match against tensor keys. + + Returns: + A dictionary mapping matching tensor names to their `torch.Tensor` objects. + + Examples: + >>> d = { + ... "model.layers.0.mlp.weight": torch.randn(1, 1), + ... "model.layers.0.mlp.bias": torch.randn(1, 1), + ... "model.layers.1.mlp.weight": torch.randn(1, 1) + ... } + >>> state = StateDict(d) + >>> # Get all mlp weights and biases from the first layer + >>> layer_0_mlp = state.glob("model.layers.0.mlp.*") + >>> sorted(layer_0_mlp.keys()) + ['model.layers.0.mlp.bias', 'model.layers.0.mlp.weight'] + """ + return self[pattern] + + def __call__(self) -> Dict[str, torch.Tensor]: + """ + Loads and returns the entire state dict as a dictionary. + + Note: + This method loads all tensors from the source into memory. For large + models, this can be memory-intensive. Prefer using pattern-based + or single-key lookups for more efficient access if you only need a + subset of the state dict. + + Returns: + A dictionary containing all tensor names and their corresponding + `torch.Tensor` objects. + """ + all_keys = self._get_all_keys() + return self._load_tensors(all_keys) + + def keys(self) -> List[str]: + """Get all state dict keys.""" + return self._get_all_keys() + + def items(self) -> List[tuple]: + """Get all state dict items.""" + return list(self().items()) + + def __contains__(self, key: str) -> bool: + """Check if a key exists in the state dict.""" + return key in self._get_all_keys() + + def __repr__(self) -> str: + """String representation.""" + try: + num_params = len(self) + return f"" + except Exception: + return "" + + def get(self, key: str, default=None) -> Optional[torch.Tensor]: + """ + Gets a tensor from the state dict. + Returns `default` if the key is not found. + Note: This method is for single key lookup and does not support patterns. + """ + if key in self._get_all_keys(): + return self._load_tensors([key])[key] + return default + + def __iter__(self) -> Iterable[str]: + """Iterate over state dict keys.""" + return iter(self.keys()) + + def __len__(self) -> int: + """Get number of entries in the state dict.""" + return len(self.keys()) + + def has_glob(self, pattern: str) -> bool: + """ + Efficiently checks if any tensor key matches the given glob pattern. + This is forwarded to the underlying StateSource which may have an + optimized implementation that avoids iterating over all keys. + + Args: + pattern: The glob pattern to match against tensor keys. + + Returns: + True if a matching key is found, False otherwise. + """ + return self.source.has_glob(pattern) + + +class StateSource(ABC, Mapping[str, torch.Tensor]): + """ + Abstract base class for a source of model state. + + This class defines a standard interface for `StateDict` to access tensor + data, abstracting away the details of how and where the data is stored. + Subclasses can implement loading from different storage backends, such as + in-memory dictionaries or files on disk. This allows `StateDict` to handle + various checkpoint formats in a uniform way. + """ + + @abstractmethod + def get_all_keys(self) -> List[str]: + """Returns a list of all available tensor keys in the source.""" + pass + + @abstractmethod + def load_tensors(self, keys: List[str]) -> Dict[str, torch.Tensor]: + """Loads the specified tensors from the source.""" + pass + + def __getitem__(self, key: str) -> torch.Tensor: + """Loads a single tensor by key.""" + tensors = self.load_tensors([key]) + if key not in tensors: + raise KeyError(f"Key not found in source: {key}") + return tensors[key] + + def __iter__(self) -> Iterable[str]: + """Iterates over all tensor keys.""" + return iter(self.get_all_keys()) + + def __len__(self) -> int: + """Returns the total number of tensors in the source.""" + return len(self.get_all_keys()) + + def has_glob(self, pattern: str) -> bool: + """ + Checks if any tensor key matches the given glob pattern. + This default implementation is not efficient for all sources, as it may + load all keys. Subclasses should override this method if a more + performant implementation is available. + """ + import fnmatch + + for key in self.get_all_keys(): + if fnmatch.fnmatch(key, pattern): + return True + return False + + +class DictStateSource(StateSource): + """ + A state source backed by an in-memory Python dictionary. + + This is the simplest `StateSource` implementation. It's used when the entire + model state dict is already loaded into a dictionary in memory. + + Args: + state_dict: A dictionary mapping tensor names (str) to `torch.Tensor` objects. + """ + + def __init__(self, state_dict: Dict[str, torch.Tensor]): + self._dict = state_dict + self._keys_cache: Optional[List[str]] = None + + def get_all_keys(self) -> List[str]: + if self._keys_cache is None: + self._keys_cache = sorted(list(self._dict.keys())) + return self._keys_cache + + def load_tensors(self, keys: List[str]) -> Dict[str, torch.Tensor]: + return {key: self._dict[key] for key in keys if key in self._dict} + + +class SafeTensorsStateSource(StateSource): + """ + A state source backed by a directory of .safetensors files. + + This source is designed for efficiently loading tensors from checkpoints saved + in the Safetensors format, which is common for large models that are often + "sharded" into multiple files. + + It can handle two common scenarios: + 1. A directory containing multiple `.safetensors` files. + 2. A directory containing a `model.safetensors.index.json` file, which maps + tensor names to the specific `.safetensors` file they reside in. This is + the standard format used by Hugging Face Transformers. + + Using this source allows `StateDict` to query for tensor keys and load only + the necessary files and tensors from disk, avoiding high memory usage. + + Args: + path: The path to the directory containing the `.safetensors` files + and/or the index file. Can also be a Hugging Face Hub model ID. + """ + + def __init__(self, path: Union[str, Path]): + self.model_name_or_path = path + self._resolved_path_cache: Optional[Path] = None + self._keys_cache: Optional[List[str]] = None + self._key_to_filename_map_cache: Optional[Dict[str, str]] = None + + @property + def path(self) -> Path: + """ + The local path to the checkpoint files. + If the initial path is a Hugging Face Hub model ID, this property + will handle downloading the necessary files and return the local + cache path. + """ + if self._resolved_path_cache is None: + self._resolved_path_cache = self._resolve_path(self.model_name_or_path) + return self._resolved_path_cache + + @property + def key_to_filename_map(self) -> Dict[str, str]: + """ + Provides a mapping from tensor keys to the safetensor filename they + are stored in. + + This map is constructed either from `model.safetensors.index.json` if + it exists, or by scanning all `.safetensors` files in the directory. + The result is cached for efficiency. + """ + if self._key_to_filename_map_cache is not None: + return self._key_to_filename_map_cache + + # First, try to load from the index file. + key_map = self._cached_get_key_to_filename_map(self.path) + if key_map: + self._key_to_filename_map_cache = key_map + return key_map + + # If no index, scan the directory. + import os + + from glob import glob as file_glob + + from safetensors import safe_open + + key_map = {} + safetensor_files = file_glob(str(self.path / "*.safetensors")) + for file_path in safetensor_files: + filename = os.path.basename(file_path) + try: + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in f.keys(): + if key in key_map: + # This is an issue. Same key in multiple files, and no index. + # How to resolve ambiguity? Let's just warn and overwrite. Last one wins. + print( + f"Warning: duplicate key '{key}' found in '{filename}' and '{key_map[key]}'. Using '{filename}'." + ) + key_map[key] = filename + except Exception as e: + # Can be not a safetensor file, etc. + print(f"Warning: could not open {filename} as a safetensors file: {e}") + + self._key_to_filename_map_cache = key_map + return key_map + + @staticmethod + def _resolve_path(model_name_or_path: Union[str, Path]) -> Path: + """ + Resolves a model name or path to a local directory. + If the path is not a local directory, it is treated as a Hugging + Face Hub model ID, and the corresponding files are downloaded. + """ + local_path = Path(model_name_or_path) + if local_path.is_dir(): + return local_path + + try: + from huggingface_hub import snapshot_download + from huggingface_hub.utils import HfHubHTTPError + + # Not a local directory, so we assume it's a model ID + # on the Hugging Face Hub. + return Path( + snapshot_download( + repo_id=str(model_name_or_path), + allow_patterns=["*.safetensors", "model.safetensors.index.json"], + # Ignore other large files. + ignore_patterns=["*.bin", "*.pt", "*.pth"], + ) + ) + except (ImportError, HfHubHTTPError, ValueError): + # If huggingface_hub is not installed, or if it's not a + # valid model ID, we return the original path and let the + # subsequent logic handle the file not found error. + return local_path + + def get_all_keys(self) -> List[str]: + if self._keys_cache is not None: + return self._keys_cache + + from glob import glob as file_glob + + from safetensors import safe_open + + all_keys = set() + key_to_filename_map = self.key_to_filename_map + if key_to_filename_map: + all_keys.update(key_to_filename_map.keys()) + + if not all_keys: + safetensor_files = file_glob(str(self.path / "*.safetensors")) + if not safetensor_files and not key_to_filename_map: + raise FileNotFoundError( + f"No .safetensors files or index found in {self.model_name_or_path}" + ) + for safetensor_file in safetensor_files: + with safe_open(safetensor_file, framework="pt", device="cpu") as f: + all_keys.update(f.keys()) + + self._keys_cache = sorted(list(all_keys)) + return self._keys_cache + + def load_tensors(self, keys_to_load: List[str]) -> Dict[str, torch.Tensor]: + if not keys_to_load: + return {} + + from glob import glob as file_glob + + from safetensors import safe_open + + loaded_tensors = {} + remaining_keys = set(keys_to_load) + key_to_filename_map = self.key_to_filename_map + + if key_to_filename_map: + file_to_keys_map = defaultdict(list) + for key in list(remaining_keys): + if key in key_to_filename_map: + filename = key_to_filename_map[key] + file_to_keys_map[filename].append(key) + + for filename, keys_in_file in file_to_keys_map.items(): + file_path = self.path / filename + if file_path.exists(): + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in keys_in_file: + if key in f.keys(): + loaded_tensors[key] = f.get_tensor(key) + remaining_keys.discard(key) + + if remaining_keys: + safetensor_files = file_glob(str(self.path / "*.safetensors")) + if not safetensor_files and not key_to_filename_map and not loaded_tensors: + raise FileNotFoundError( + f"No .safetensors files found in {self.model_name_or_path} to load keys: {remaining_keys}" + ) + for safetensor_file_path in safetensor_files: + if not remaining_keys: + break + with safe_open(safetensor_file_path, framework="pt", device="cpu") as f: + current_file_keys = f.keys() + for key in list(remaining_keys): + if key in current_file_keys: + loaded_tensors[key] = f.get_tensor(key) + remaining_keys.remove(key) + + if remaining_keys: + raise KeyError( + f"Keys not found in safetensors from {self.model_name_or_path}: {remaining_keys}" + ) + + return loaded_tensors + + def has_glob(self, pattern: str) -> bool: + """ + Efficiently checks if any tensor key matches the given glob pattern. + + This method avoids loading all tensor keys into memory at once. It scans + the checkpoint index or file headers and returns as soon as a match is + found. + + Args: + pattern: The glob pattern to match against tensor keys. + + Returns: + True if a matching key is found, False otherwise. + """ + import fnmatch + + from glob import glob as file_glob + + from safetensors import safe_open + + key_to_filename_map = self.key_to_filename_map + if key_to_filename_map: + for key in key_to_filename_map.keys(): + if fnmatch.fnmatch(key, pattern): + return True + return False + + # If no index map, scan the files directly. + safetensor_files = file_glob(str(self.path / "*.safetensors")) + if not safetensor_files: + return False + + for safetensor_file in safetensor_files: + try: + with safe_open(safetensor_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if fnmatch.fnmatch(key, pattern): + return True + except Exception: + # Ignore files that are not valid safetensors + continue + + return False + + def save_generator( + self, + generator: Iterable[Tuple[str, torch.Tensor]], + output_path: Union[str, Path], + strict: bool = True, + ): + """ + Saves tensors from a generator to `.safetensors` files, preserving the + original sharding structure in a memory-efficient, streaming fashion. + + This method reads the sharding information (which tensor belongs to which + file) from the source checkpoint. It then consumes a generator of tensors, + buffering them in memory only until a complete file shard can be written to + disk. This approach minimizes peak memory usage compared to collecting all + tensors first. + + If the original checkpoint had a `model.safetensors.index.json` file, a new + one will be created for the saved tensors. + + Args: + generator: An iterable of (tensor_name, tensor) tuples. + output_path: The directory where the new safetensor files and index + will be saved. + strict: If True (default), raises a KeyError if the generator + yields a tensor name not found in the original model's + sharding structure. If False, it prints a warning and + skips the tensor. + """ + # In a distributed environment, only rank 0 should write to disk. + # Other ranks must still exhaust the generator to participate in collectives. + is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + rank = torch.distributed.get_rank() if is_distributed else 0 + + if rank != 0: + # Other ranks must exhaust the generator to avoid hangs in collectives. + for _ in generator: + pass + return + + # Rank 0 proceeds with saving. + from safetensors.torch import save_file + + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + key_to_filename_map = self.key_to_filename_map + all_expected_keys = set(key_to_filename_map.keys()) + + if not key_to_filename_map: + buffered_tensors = dict(generator) + if buffered_tensors: + save_file(buffered_tensors, output_path / "model.safetensors") + return + + filename_to_keys_map = defaultdict(set) + for key, filename in key_to_filename_map.items(): + filename_to_keys_map[filename].add(key) + + files_to_save = dict(filename_to_keys_map) + buffered_tensors = {} + all_yielded_keys = set() + all_saved_keys = set() + + for name, tensor in generator: + all_yielded_keys.add(name) + if name not in all_expected_keys: + if strict: + raise KeyError( + f"Tensor '{name}' from generator not found in the original model structure. " + "To ignore, set strict=False." + ) + else: + print( + f"Warning: tensor '{name}' from generator not found in original model structure. Skipping." + ) + continue + + buffered_tensors[name] = tensor + + # Check if any file is complete and can be saved. + # Iterate over a copy of keys since we might modify the dict. + for filename in list(files_to_save.keys()): + keys_for_file = files_to_save[filename] + if keys_for_file.issubset(buffered_tensors.keys()): + # This shard is complete, save it. + tensors_to_save = {key: buffered_tensors[key] for key in keys_for_file} + + output_file_path = output_path / filename + save_file(tensors_to_save, output_file_path) + + # Free memory by removing saved tensors from the buffer. + for key in keys_for_file: + del buffered_tensors[key] + + all_saved_keys.update(keys_for_file) + del files_to_save[filename] + + # --- Final Reporting --- + if files_to_save: + if strict: + print( + "Warning: The following files could not be saved because the generator did not yield all of their tensors:" + ) + else: + print( + "Warning: The following files are different from the source because the generator did not yield all " + "of their tensors. However they are still saved because strict=False." + ) + for filename, keys_for_file in files_to_save.items(): + missing_for_file = keys_for_file - all_yielded_keys + if missing_for_file: + print(f" - {filename}: missing {len(missing_for_file)} tensors:") + for key in sorted(list(missing_for_file)): + print(f" - {key}") + if not strict: + for filename in list(files_to_save.keys()): + keys_for_file = files_to_save[filename] + tensors_to_save = { + key: buffered_tensors[key] + for key in keys_for_file + if key in buffered_tensors + } + # missing_keys = set(keys_for_file) - tensors_to_save.keys() + # if missing_keys: + # print(f" - {filename}: missing {len(missing_keys)} tensors:") + # for key in sorted(list(missing_keys)): + # print(f" - {key}") + output_file_path = output_path / filename + save_file(tensors_to_save, output_file_path) + + # Free memory by removing saved tensors from the buffer. + for key in tensors_to_save.keys(): + del buffered_tensors[key] + + all_saved_keys.update(keys_for_file) + del files_to_save[filename] + + if buffered_tensors: + print( + f"Warning: {len(buffered_tensors)} tensors were yielded but not saved because their corresponding file shards were incomplete." + ) + + # Final check on whether all original tensors were written. + unsaved_keys = all_expected_keys - all_saved_keys + if not unsaved_keys: + extra_keys = all_yielded_keys - all_expected_keys + if extra_keys: + print( + f"\nSuccess: All tensors from the original checkpoint were written. " + f"({len(extra_keys)} extra tensors from generator were ignored as per strict=False)." + ) + else: + print("\nSuccess: All tensors from the original checkpoint were written.") + else: + print( + f"\nError: {len(unsaved_keys)} tensors from the original checkpoint were not written. See warnings above for details." + ) + + # Create index file for the saved shards. + original_index_file = self.path / "model.safetensors.index.json" + if original_index_file.exists(): + with open(original_index_file, "r") as f: + original_index_data = json.load(f) + + new_weight_map = {key: key_to_filename_map[key] for key in all_saved_keys} + + new_index_data = { + "metadata": original_index_data.get("metadata", {}), + "weight_map": new_weight_map, + } + + output_index_file = output_path / "model.safetensors.index.json" + if new_weight_map: + with open(output_index_file, "w") as f: + json.dump(new_index_data, f, indent=4) + + def _get_key_to_filename_map(self) -> Optional[Dict[str, str]]: + return self._cached_get_key_to_filename_map(self.path) + + @staticmethod + @lru_cache(maxsize=None) + def _cached_get_key_to_filename_map( + model_name_or_path: Union[str, Path] + ) -> Optional[Dict[str, str]]: + """Static, cached method to get the key-to-filename map.""" + index_file = Path(model_name_or_path) / "model.safetensors.index.json" + if index_file.exists(): + with open(index_file, "r") as f: + try: + index_data = json.load(f) + if "weight_map" in index_data and isinstance(index_data["weight_map"], dict): + return index_data["weight_map"] + except json.JSONDecodeError: + return None + return None diff --git a/flagscale/train/bridge/models/hf_pretrained/vlm.py b/flagscale/train/bridge/models/hf_pretrained/vlm.py new file mode 100644 index 0000000000..76cc99c5aa --- /dev/null +++ b/flagscale/train/bridge/models/hf_pretrained/vlm.py @@ -0,0 +1,603 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from pathlib import Path +from typing import Any, Dict, Generic, List, Optional, TypeVar, Union + +import torch + +from transformers import ( + AutoConfig, + AutoImageProcessor, + AutoModel, + AutoProcessor, + AutoTokenizer, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizer, + ProcessorMixin, +) +from transformers.generation.utils import GenerateOutput + +from flagscale.train.bridge.models.hf_pretrained.base import PreTrainedBase +from flagscale.train.bridge.models.hf_pretrained.safe_config_loader import ( + safe_load_config_with_retry, +) + +# Type variable for generic model type +VLMType = TypeVar("VLMType", bound=PreTrainedModel) + + +class PreTrainedVLM(PreTrainedBase, Generic[VLMType]): + """ + A generic class for Pretrained Vision-Language Models with lazy loading. + + Allows type-safe access to specific VLM implementations like LlavaForConditionalGeneration. + + Examples: + Basic usage with image and text: + >>> from flagscale.train.bridge.models.hf_pretrained.vlm import PreTrainedVLM + >>> from PIL import Image + >>> + >>> # Create instance - no model loading happens yet + >>> vlm = PreTrainedVLM.from_pretrained("llava-hf/llava-1.5-7b-hf") + >>> + >>> # Load an image + >>> image = Image.open("cat.jpg") + >>> + >>> # Process image and text together - processor and model load here + >>> inputs = vlm.process_images_and_text( + ... images=image, + ... text="What do you see in this image?" + ... ) + >>> + >>> # Generate response + >>> outputs = vlm.generate(**inputs, max_new_tokens=100) + >>> print(vlm.decode(outputs[0], skip_special_tokens=True)) + + Batch processing with multiple images: + >>> # Process multiple images with questions + >>> images = [Image.open(f"image_{i}.jpg") for i in range(3)] + >>> questions = [ + ... "What is the main object in this image?", + ... "Describe the scene", + ... "What colors do you see?" + ... ] + >>> + >>> # Process batch + >>> inputs = vlm.process_images_and_text( + ... images=images, + ... text=questions, + ... padding=True + ... ) + >>> + >>> # Generate responses + >>> outputs = vlm.generate(**inputs, max_new_tokens=50) + >>> for i, output in enumerate(outputs): + ... print(f"Image {i+1}: {vlm.decode(output, skip_special_tokens=True)}") + + Using specific VLM types with type hints: + >>> from transformers import LlavaForConditionalGeneration + >>> from flagscale.train.bridge.models.hf_pretrained.vlm import PreTrainedVLM + >>> + >>> # Type-safe access to Llava-specific features + >>> llava: PreTrainedVLM[LlavaForConditionalGeneration] = PreTrainedVLM.from_pretrained( + ... "llava-hf/llava-1.5-7b-hf", + ... torch_dtype=torch.float16, + ... device="cuda" + ... ) + >>> + >>> # Access model-specific attributes + >>> vision_tower = llava.model.vision_tower # Type-safe access + + Text-only generation (for multimodal models that support it): + >>> # Some VLMs can also work with text-only inputs + >>> text_inputs = vlm.encode_text("Explain what a neural network is.") + >>> outputs = vlm.generate(**text_inputs, max_length=100) + >>> print(vlm.decode(outputs[0], skip_special_tokens=True)) + + Custom preprocessing and generation: + >>> # Load with custom settings + >>> vlm = PreTrainedVLM.from_pretrained( + ... "Qwen/Qwen-VL-Chat", + ... trust_remote_code=True, + ... device_map="auto", + ... load_in_4bit=True + ... ) + >>> + >>> # Custom generation config + >>> from transformers import GenerationConfig + >>> vlm.generation_config = GenerationConfig( + ... max_new_tokens=200, + ... temperature=0.8, + ... top_p=0.95, + ... do_sample=True + ... ) + >>> + >>> # Process with custom parameters + >>> inputs = vlm.process_images_and_text( + ... images=image, + ... text="\\nDescribe this image in detail.", + ... max_length=512 + ... ) + + Manual component setup: + >>> # Create empty instance + >>> vlm = PreTrainedVLM() + >>> + >>> # Load components separately + >>> from transformers import AutoProcessor, AutoModel + >>> vlm.processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base") + >>> vlm.model = AutoModel.from_pretrained("microsoft/Florence-2-base") + >>> + >>> # Use for various vision tasks + >>> task_prompt = "" # Object detection task + >>> inputs = vlm.process_images_and_text(images=image, text=task_prompt) + >>> outputs = vlm.generate(**inputs) + + Conversational VLM usage: + >>> # Multi-turn conversation with images + >>> conversation = [] + >>> + >>> # First turn + >>> image1 = Image.open("chart.png") + >>> inputs = vlm.process_images_and_text( + ... images=image1, + ... text="What type of chart is this?" + ... ) + >>> response = vlm.generate(**inputs) + >>> conversation.append(("user", "What type of chart is this?")) + >>> conversation.append(("assistant", vlm.decode(response[0]))) + >>> + >>> # Follow-up question + >>> follow_up = "What is the highest value shown?" + >>> # Format conversation history + new question + >>> full_prompt = format_conversation(conversation) + f"\\nUser: {follow_up}" + >>> inputs = vlm.process_images_and_text(images=image1, text=full_prompt) + >>> response = vlm.generate(**inputs) + """ + + ARTIFACTS = ["processor", "tokenizer", "image_processor"] + OPTIONAL_ARTIFACTS = ["generation_config"] + + def __init__( + self, + model_name_or_path: Optional[Union[str, Path]] = None, + device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + **kwargs, + ): + """ + Initialize a Pretrained VLM with lazy loading. + + Args: + model_name_or_path: HuggingFace model identifier or local path + device: Device to load model on (e.g., 'cuda', 'cpu') + torch_dtype: Data type to load model in (e.g., torch.float16) + trust_remote_code: Whether to trust remote code when loading + **kwargs: Additional arguments passed to component loaders + """ + self._model_name_or_path = model_name_or_path + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.torch_dtype = torch_dtype + self.trust_remote_code = trust_remote_code + super().__init__(**kwargs) + + def _load_model(self) -> VLMType: + """Lazy load and return the model.""" + if self.model_name_or_path is None: + raise ValueError("model_name_or_path must be provided to load model") + + model_kwargs = {"trust_remote_code": self.trust_remote_code, **self.init_kwargs} + + if self.torch_dtype is not None: + model_kwargs["torch_dtype"] = self.torch_dtype + + # Use provided config if already loaded + config = getattr(self, "_config", None) + if config is not None: + model_kwargs["config"] = config + + # Try AutoModel first for VLMs + model = AutoModel.from_pretrained(self.model_name_or_path, **model_kwargs) + + # Move to device + model = model.to(self.device) + + # Set generation config if available + generation_config = getattr(self, "_generation_config", None) + if generation_config is not None and hasattr(model, "generation_config"): + model.generation_config = generation_config + return model + + def _load_config(self) -> AutoConfig: + """Lazy load and return the model config with thread-safety protection.""" + if self.model_name_or_path is None: + raise ValueError("model_name_or_path must be provided to load config") + + return safe_load_config_with_retry( + self.model_name_or_path, trust_remote_code=self.trust_remote_code, **self.init_kwargs + ) + + def _load_processor(self) -> ProcessorMixin: + """Lazy load and return the processor.""" + if self.model_name_or_path is None: + raise ValueError("model_name_or_path must be provided to load processor") + + try: + return AutoProcessor.from_pretrained( + self.model_name_or_path, + trust_remote_code=self.trust_remote_code, + **self.init_kwargs, + ) + except Exception: + # Some VLMs might not have a processor, fall back to manual loading + raise ValueError( + f"Could not load processor for {self.model_name_or_path}. " + "This model might require manual processor setup." + ) + + def _load_tokenizer(self) -> Optional[PreTrainedTokenizer]: + """ + Lazy load and return the tokenizer. + For VLMs, the tokenizer might be included in the processor. + """ + # Check if tokenizer is available through processor first + processor = getattr(self, "_processor", None) + if processor is not None and hasattr(processor, "tokenizer"): + return processor.tokenizer + + # Try to load tokenizer separately + if self.model_name_or_path is not None: + try: + tokenizer = AutoTokenizer.from_pretrained( + self.model_name_or_path, + trust_remote_code=self.trust_remote_code, + **self.init_kwargs, + ) + + # Set padding token if not present + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + except Exception: + # Some VLMs include tokenizer only in processor + pass + return None + + def _load_image_processor(self) -> Optional[Any]: + """ + Lazy load and return the image processor. + For VLMs, the image processor might be included in the processor. + """ + # Check if image processor is available through processor first + processor = getattr(self, "_processor", None) + if processor is not None and hasattr(processor, "image_processor"): + return processor.image_processor + + # Try to load image processor separately + if self.model_name_or_path is not None: + try: + return AutoImageProcessor.from_pretrained( + self.model_name_or_path, + trust_remote_code=self.trust_remote_code, + **self.init_kwargs, + ) + except Exception: + # Some VLMs include image processor only in processor + pass + return None + + def _load_generation_config(self) -> Optional[GenerationConfig]: + """Lazy load and return the generation config.""" + if self.model_name_or_path is not None: + try: + return GenerationConfig.from_pretrained( + self.model_name_or_path, + trust_remote_code=self.trust_remote_code, + **self.init_kwargs, + ) + except Exception: + # Not all models have generation configs + pass + return None + + @property + def model_name_or_path(self) -> Optional[Union[str, Path]]: + """Return the model name or path.""" + return self._model_name_or_path + + @property + def model(self) -> VLMType: + """Lazy load and return the underlying model.""" + if not hasattr(self, "_model"): + self._model = self._load_model() + else: + # Ensure model is on the right device when accessed + if hasattr(self._model, "device") and hasattr(self._model.device, "type"): + current_device = str(self._model.device) + target_device = str(self.device) + if current_device != target_device: + self._model = self._model.to(self.device) + return self._model + + @model.setter + def model(self, value: VLMType): + """Set the model manually.""" + self._model = value + + @property + def processor(self) -> ProcessorMixin: + """Lazy load and return the processor.""" + if not hasattr(self, "_processor"): + self._processor = self._load_processor() + return self._processor + + @processor.setter + def processor(self, value: ProcessorMixin): + """Set the processor manually.""" + self._processor = value + + @property + def tokenizer(self) -> Optional[PreTrainedTokenizer]: + """Lazy load and return the tokenizer.""" + if not hasattr(self, "_tokenizer"): + self._tokenizer = self._load_tokenizer() + return self._tokenizer + + @tokenizer.setter + def tokenizer(self, value: PreTrainedTokenizer): + """Set the tokenizer manually.""" + self._tokenizer = value + + @property + def image_processor(self) -> Optional[Any]: + """Lazy load and return the image processor.""" + if not hasattr(self, "_image_processor"): + self._image_processor = self._load_image_processor() + return self._image_processor + + @image_processor.setter + def image_processor(self, value: Any): + """Set the image processor manually.""" + self._image_processor = value + + @property + def generation_config(self) -> Optional[GenerationConfig]: + """Lazy load and return the generation config.""" + if not hasattr(self, "_generation_config"): + self._generation_config = self._load_generation_config() + return self._generation_config + + @generation_config.setter + def generation_config(self, value: GenerationConfig): + """Set the generation config manually.""" + self._generation_config = value + # Update model's generation config if model is loaded + if ( + hasattr(self, "_model") + and self._model is not None + and hasattr(self._model, "generation_config") + ): + self._model.generation_config = value + + @property + def kwargs(self) -> Dict[str, Any]: + """Additional initialization kwargs.""" + return self.init_kwargs + + @classmethod + def from_pretrained( + cls, + model_name_or_path: Union[str, Path], + device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + **kwargs, + ) -> "PreTrainedVLM[VLMType]": + """ + Create a PreTrainedVLM instance for lazy loading. + + Args: + model_name_or_path: HuggingFace model identifier or local path + device: Device to load model on + torch_dtype: Data type to load model in + trust_remote_code: Whether to trust remote code + **kwargs: Additional arguments for from_pretrained methods + + Returns: + PreTrainedVLM instance configured for lazy loading + """ + return cls( + model_name_or_path=model_name_or_path, + device=device, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + def generate(self, **kwargs) -> Union[torch.LongTensor, GenerateOutput]: + """ + Generate sequences using the model. + + Args: + **kwargs: Arguments for the generate method + + Returns: + Generated sequences + """ + return self.model.generate(**kwargs) + + def __call__(self, *args, **kwargs): + """Forward pass through the model.""" + return self.model(*args, **kwargs) + + def encode_text(self, text: Union[str, List[str]], **kwargs) -> Dict[str, torch.Tensor]: + """ + Encode text input using the tokenizer. + + Args: + text: Input text or list of texts + **kwargs: Additional tokenizer arguments + + Returns: + Encoded inputs ready for the model + """ + if self.tokenizer is None: + raise ValueError( + "No tokenizer available. Set tokenizer manually or ensure model has one." + ) + return self.tokenizer(text, return_tensors="pt", **kwargs).to(self.device) + + def decode(self, token_ids: torch.Tensor, **kwargs) -> str: + """ + Decode token IDs to text. + + Args: + token_ids: Token IDs to decode + **kwargs: Additional decoding arguments + + Returns: + Decoded text + """ + if self.tokenizer is None: + raise ValueError( + "No tokenizer available. Set tokenizer manually or ensure model has one." + ) + return self.tokenizer.decode(token_ids, **kwargs) + + def process_images_and_text( + self, images: Optional[Any] = None, text: Optional[Union[str, List[str]]] = None, **kwargs + ) -> Dict[str, torch.Tensor]: + """ + Process images and text together using the processor. + + Args: + images: Input images + text: Input text + **kwargs: Additional processor arguments + + Returns: + Processed inputs ready for the model + """ + inputs = self.processor(images=images, text=text, return_tensors="pt", **kwargs) + # Move all tensors in the dict to the device + if isinstance(inputs, dict): + for key, value in inputs.items(): + if hasattr(value, "to"): + inputs[key] = value.to(self.device) + return inputs + + def save_pretrained(self, save_directory: Union[str, Path]): + """ + Save the model and all components to a directory. + + Args: + save_directory: Directory to save to + """ + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + # Save model + if hasattr(self, "_model") and self._model is not None: + self._model.save_pretrained(save_path) + + # Save artifacts through base class + self.save_artifacts(save_path) + + def to(self, device: Union[str, torch.device]) -> "PreTrainedVLM[VLMType]": + """ + Move model to a device. + + Args: + device: Target device + + Returns: + Self for chaining + """ + self.device = device + if hasattr(self, "_model") and self._model is not None: + self._model = self._model.to(device) + return self + + def half(self) -> "PreTrainedVLM[VLMType]": + """ + Convert model to half precision. + + Returns: + Self for chaining + """ + if hasattr(self, "_model") and self._model is not None: + self._model = self._model.half() + self.torch_dtype = torch.float16 + return self + + def float(self) -> "PreTrainedVLM[VLMType]": + """ + Convert model to full precision. + + Returns: + Self for chaining + """ + if hasattr(self, "_model") and self._model is not None: + self._model = self._model.float() + self.torch_dtype = torch.float32 + return self + + @property + def dtype(self) -> Optional[torch.dtype]: + """Return the dtype of the model.""" + if hasattr(self, "_model") and self._model is not None: + return next(self._model.parameters()).dtype + return self.torch_dtype + + def num_parameters(self, only_trainable: bool = False) -> int: + """ + Get the number of parameters in the model. + + Args: + only_trainable: Whether to count only trainable parameters + + Returns: + Number of parameters + """ + if not hasattr(self, "_model") or self._model is None: + return 0 + + if only_trainable: + return sum(p.numel() for p in self._model.parameters() if p.requires_grad) + return sum(p.numel() for p in self._model.parameters()) + + def __repr__(self) -> str: + """String representation.""" + parts = [f"{self.__class__.__name__}("] + + if self._model_name_or_path: + parts.append(f" model_name_or_path='{self._model_name_or_path}',") + + parts.append(f" device='{self.device}',") + + if self.torch_dtype: + parts.append(f" torch_dtype={self.torch_dtype},") + + if self.trust_remote_code: + parts.append(f" trust_remote_code={self.trust_remote_code},") + + # Show loaded components + loaded = [] + if hasattr(self, "_model") and self._model is not None: + loaded.append("model") + if hasattr(self, "_processor") and self._processor is not None: + loaded.append("processor") + if hasattr(self, "_tokenizer") and self._tokenizer is not None: + loaded.append("tokenizer") + if hasattr(self, "_config") and self._config is not None: + loaded.append("config") + + if loaded: + parts.append(f" loaded_components={loaded},") + + parts.append(")") + return "\n".join(parts) diff --git a/flagscale/train/bridge/models/model_provider.py b/flagscale/train/bridge/models/model_provider.py new file mode 100644 index 0000000000..7df1c60b04 --- /dev/null +++ b/flagscale/train/bridge/models/model_provider.py @@ -0,0 +1,710 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import abc +import os + +from pathlib import Path +from typing import Callable, Generic, TypedDict, TypeVar, Union + +try: + from typing import Unpack +except ImportError: + try: + from typing_extensions import Unpack + except ImportError: + from unittest.mock import MagicMock + + Unpack = MagicMock() + + +from typing import Callable + +import torch + +from megatron.core import parallel_state, tensor_parallel +from megatron.core.distributed import ( + DistributedDataParallel, + DistributedDataParallelConfig, + FullyShardedDataParallel, + TorchFullyShardedDataParallel, +) +from megatron.core.enums import ModelType +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.module import Float16Module, MegatronModule +from megatron.core.utils import get_model_config + +from flagscale.train.bridge.models.config import from_hf_pretrained, save_hf_pretrained +from flagscale.train.bridge.utils.common_utils import get_local_rank_preinit +from flagscale.train.bridge.utils.instantiate_utils import InstantiationMode + +try: + from megatron.core.fp8_utils import correct_amax_history_if_needed +except ImportError: + correct_amax_history_if_needed = None + + +ModelT = TypeVar("ModelT", bound=MegatronModule) + + +class ModelProviderMixin(abc.ABC, Generic[ModelT]): + """A mixin that implements the ModelProvider pattern for Megatron Bridge. + + The ModelProvider pattern solves ecosystem fragmentation by providing a standardized + way to instantiate models. This mixin provides a consistent `provide_distributed_model()` method + that handles the complexity of distributed training setup, along with HuggingFace-inspired + `.from_hf_pretrained()` and `.save_hf_pretrained()` for interoperability. + + For advanced customization, multiple hooks can be registered via `register_pre_wrap_hook` + and `register_post_wrap_hook`. These hooks allow modifying the model before and after + it's wrapped for distributed training (e.g., freezing layers, logging). The composed + hooks can be accessed via the `pre_wrap_hook` and `post_wrap_hook` properties. + + Subclasses must implement the `provide` method to define the model architecture. + """ + + CONFIG_NAME = "mhub_model.json" + DEFAULT_CONFIG_FORMAT = "json" + + @abc.abstractmethod + def provide( + self, + pre_process: bool | None = None, + post_process: bool | None = None, + vp_stage: int | None = None, + ) -> ModelT: + """Abstract method to provide the model instance. + + Subclasses must implement this method to return the specific Megatron model + (e.g., `GPTModel`) with its configuration. This method is called by `get_model` + to obtain the base model before it is wrapped for distributed training. + + Args: + pre_process (bool, optional): Whether to include the embedding layer (used with pipeline parallelism). + post_process (bool, optional): Whether to include the output layer (used with pipeline parallelism). + vp_stage (int, optional): The virtual pipeline stage of the model. + + Returns: + ModelT: The Megatron model instance. + """ + pass + + def provide_distributed_model( + self, + ddp_config: DistributedDataParallelConfig | None = None, + model_type=ModelType.encoder_or_decoder, + overlap_param_gather_with_optimizer_step: bool = False, + fp16: bool | None = None, + bf16: bool | None = None, + use_megatron_fsdp: bool = False, + use_torch_fsdp2: bool = False, + wrap_with_ddp: bool = True, + data_parallel_random_init: bool = True, + use_cpu_initialization: None | bool = False, + init_model_with_meta_device: bool | None = None, + pre_wrap_hook: ( + Union[ + Callable[[list[MegatronModule]], list[MegatronModule]], + list[Callable[[list[MegatronModule]], list[MegatronModule]]], + ] + | None + ) = None, + post_wrap_hook: Callable[[list[MegatronModule]], list[MegatronModule]] | None = None, + ) -> list[ModelT]: + """Instantiate and wrap the model for distributed training. + + This method retrieves the model from `provide` and sets up the distributed + environment, including data-parallel and model-parallel configurations. + It's the primary entry point for creating a model that's ready for use + in the Megatron ecosystem. + + Args: + ddp_config: Configuration for distributed data parallel. + model_type: Type of model (encoder, decoder, or both). + overlap_param_gather_with_optimizer_step: Whether to overlap param gathering. + fp16: Override FP16 setting. + bf16: Override BF16 setting. + use_megatron_fsdp: Use Megatron's Fully Sharded Data Parallel + use_torch_fsdp2: Use PyTorch FSDP2 instead of custom DDP. + wrap_with_ddp: Whether to wrap model with DDP. + data_parallel_random_init: Initialize parameters randomly across data parallel ranks. + use_cpu_initialization: Initialize model on CPU. + init_model_with_meta_device: Initialize model on meta device. + pre_wrap_hook: A single callable or list of callables to modify the model before it's wrapped. + If provided, this will override all hooks registered via `register_pre_wrap_hook`. + If a list is provided, hooks will be executed in order. + post_wrap_hook: A single callable to modify the model after it's wrapped. If provided, + this will override all hooks registered via `register_post_wrap_hook`. + + Returns: + A list containing the wrapped model instance. + """ + if wrap_with_ddp and not ddp_config: + raise ValueError("ddp_config is required when wrap_with_ddp is True") + + if not torch.distributed.is_initialized(): + os.environ["RANK"] = os.environ.get("RANK", "0") + os.environ["WORLD_SIZE"] = os.environ.get("WORLD_SIZE", "1") + os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost") + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12355") + torch.cuda.set_device(get_local_rank_preinit()) + torch.distributed.init_process_group("nccl") + + if not parallel_state.is_initialized(): + print("Model parallel not initialized, initializing...") + self.initialize_model_parallel(seed=0) + + # Convert list of hooks to a single composed callable + if isinstance(pre_wrap_hook, list): + + def composed_pre_wrap_hook(model: list[MegatronModule]) -> list[MegatronModule]: + for hook in pre_wrap_hook: + model = hook(model) + return model + + final_pre_wrap_hook = composed_pre_wrap_hook + else: + final_pre_wrap_hook = pre_wrap_hook or self.pre_wrap_hook + final_post_wrap_hook = post_wrap_hook or self.post_wrap_hook + + model = get_model( + self, + ddp_config=ddp_config, + model_type=model_type, + overlap_param_gather_with_optimizer_step=overlap_param_gather_with_optimizer_step, + fp16=fp16, + bf16=bf16, + use_megatron_fsdp=use_megatron_fsdp, + use_torch_fsdp2=use_torch_fsdp2, + wrap_with_ddp=wrap_with_ddp, + data_parallel_random_init=data_parallel_random_init, + use_cpu_initialization=use_cpu_initialization, + init_model_with_meta_device=init_model_with_meta_device, + pre_wrap_hook=final_pre_wrap_hook, + ) + + if final_post_wrap_hook: + _model = final_post_wrap_hook(model) + if _model is not None: + model = _model + + return model + + def initialize_model_parallel( + self, seed: int | None = None, seed_kwargs: dict | None = None, **model_parallel_kwargs + ) -> None: + """Initializes model parallelism and sets the random seed. + + This is a convenience method that sets up tensor, pipeline, and other + forms of model parallelism based on the attributes of the provider instance. + + Args: + seed: The random seed for model parallel RNG. + seed_kwargs: Additional arguments for `model_parallel_cuda_manual_seed`. + **model_parallel_kwargs: Additional arguments for `parallel_state.initialize_model_parallel`. + """ + if not torch.distributed.is_initialized(): + torch.cuda.set_device(get_local_rank_preinit()) + torch.distributed.init_process_group("nccl") + + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=getattr(self, "tensor_model_parallel_size", 1), + pipeline_model_parallel_size=getattr(self, "pipeline_model_parallel_size", 1), + virtual_pipeline_model_parallel_size=getattr( + self, "virtual_pipeline_model_parallel_size", None + ), + context_parallel_size=getattr(self, "context_parallel_size", 1) or 1, + expert_model_parallel_size=getattr(self, "expert_model_parallel_size", 1) or 1, + expert_tensor_parallel_size=getattr(self, "expert_tensor_parallel_size", None), + **model_parallel_kwargs, + ) + if seed is not None: + model_parallel_cuda_manual_seed(seed, **(seed_kwargs or {})) + + @property + def meta_model(self) -> list[ModelT]: + """Returns the model instantiated on the meta device for inspection. + + This is useful for examining the model architecture without allocating + GPU memory. + """ + return self(wrap_with_ddp=False, init_model_with_meta_device=True) + + @property + def pre_wrap_hook(self) -> Callable[[list[MegatronModule]], list[MegatronModule]] | None: + """A composed callable of all registered pre-wrap hooks. + + This read-only property returns a single function that executes all registered + pre-wrap hooks in order. The hook is applied before the model is passed to the DDP + wrapper and can be used for tasks like freezing layers or altering model structure. + + Use `register_pre_wrap_hook` to add a hook to the execution chain. + + Returns: + A callable that executes all registered pre-wrap hooks in order, or None if no + hooks are registered. + """ + if not hasattr(self, "_pre_wrap_hooks") or not self._pre_wrap_hooks: + return None + + def composed_hook(model: list[MegatronModule]) -> list[MegatronModule]: + for hook in self._pre_wrap_hooks: + model = hook(model) + return model + + return composed_hook + + def register_pre_wrap_hook( + self, hook: Callable[[list[MegatronModule]], list[MegatronModule]], prepend: bool = False + ) -> None: + """Registers a hook to be executed before the model is wrapped. + + The hook should be a callable that accepts a list of `MegatronModule` instances + and returns a (potentially modified) list of `MegatronModule` instances. + + Args: + hook: The hook to register. + prepend: If True, the hook is inserted at the beginning of the execution + chain. Otherwise, it is appended to the end. + """ + if not hasattr(self, "_pre_wrap_hooks"): + self._pre_wrap_hooks = [] + if prepend: + self._pre_wrap_hooks.insert(0, hook) + else: + self._pre_wrap_hooks.append(hook) + + @property + def post_wrap_hook(self) -> Callable[[list[MegatronModule]], list[MegatronModule]] | None: + """A composed callable of all registered post-wrap hooks. + + This read-only property returns a single function that executes all registered + post-wrap hooks in order. The hook is applied after the model has been wrapped by + DDP and is useful for tasks like logging or attaching custom attributes. + + Use `register_post_wrap_hook` to add a hook to the execution chain. + + Returns: + A callable that executes all registered post-wrap hooks in order, or None if no + hooks are registered. + """ + if not hasattr(self, "_post_wrap_hooks") or not self._post_wrap_hooks: + return None + + def composed_hook(model: list[MegatronModule]) -> list[MegatronModule]: + for hook in self._post_wrap_hooks: + model = hook(model) + return model + + return composed_hook + + def register_post_wrap_hook( + self, hook: Callable[[list[MegatronModule]], list[MegatronModule]], prepend: bool = False + ) -> None: + """Registers a hook to be executed after the model is wrapped. + + The hook should be a callable that accepts a list of `MegatronModule` instances + and returns a (potentially modified) list of `MegatronModule` instances. + + Args: + hook: The hook to register. + prepend: If True, the hook is inserted at the beginning of the execution + chain. Otherwise, it is appended to the end. + """ + if not hasattr(self, "_post_wrap_hooks"): + self._post_wrap_hooks = [] + if prepend: + self._post_wrap_hooks.insert(0, hook) + else: + self._post_wrap_hooks.append(hook) + + @classmethod + def from_hf_pretrained( + cls, + pretrained_model_name_or_path: str | Path, + trust_remote_code: bool = False, + mode: InstantiationMode | None = None, + config_name: str | None = None, + **kwargs, + ): + """Load a pretrained model configuration from a directory or HuggingFace Hub. + + This method provides a HuggingFace-inspired interface for loading model + configurations, enabling interoperability. + + Args: + pretrained_model_name_or_path: The path to a local directory or a + HuggingFace model identifier. + trust_remote_code: Whether to trust remote code when loading. + mode: The instantiation mode (e.g., `LENIENT`). + config_name: The name of the configuration file (without extension). + **kwargs: Additional keyword arguments for `from_hf_pretrained`. + + Returns: + An instance of the model provider with the loaded configuration. + """ + if config_name is None: + config_name = cls.CONFIG_NAME.rsplit(".", 1)[0] + if mode is None: + mode = InstantiationMode.LENIENT + return from_hf_pretrained( + cls, + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + mode=mode, + config_name=config_name, + **kwargs, + ) + + def save_hf_pretrained( + self, + save_directory: str | Path, + config_format: str | None = None, + config_name: str | None = None, + **kwargs, + ): + """Save the model configuration to a directory. + + This method provides a HuggingFace-inspired interface for saving model + configurations, enabling interoperability. + + Args: + save_directory: The directory where the configuration will be saved. + config_format: The format for the configuration file (e.g., `json` or `yaml`). + config_name: The name of the configuration file (without extension). + **kwargs: Additional keyword arguments for `save_hf_pretrained`. + """ + if config_name is None: + config_name = self.CONFIG_NAME.rsplit(".", 1)[0] + if config_format is None: + config_format = self.DEFAULT_CONFIG_FORMAT + return save_hf_pretrained( + self, save_directory, config_format=config_format, config_name=config_name, **kwargs + ) + + +class GetModelKwargs(TypedDict, total=False): + """Keyword arguments for the `provide_distributed_model` method. + + Attributes: + ddp_config: Configuration for distributed data parallel. + model_type: Type of model (encoder, decoder, or both). + overlap_param_gather_with_optimizer_step: Whether to overlap param gathering. + fp16: Override FP16 setting. + bf16: Override BF16 setting. + use_megatron_fsdp: Use Megatron's Fully Sharded Data Parallel + use_torch_fsdp2: Use PyTorch FSDP2 instead of custom DDP. + wrap_with_ddp: Whether to wrap model with DDP. + data_parallel_random_init: Initialize parameters randomly across data parallel ranks. + use_cpu_initialization: Initialize model on CPU. + init_model_with_meta_device: Initialize model on meta device. + pre_wrap_hook: A single callable or list of callables that overrides all registered pre-wrap hooks. + post_wrap_hook: A single callable that overrides all registered post-wrap hooks. + """ + + ddp_config: DistributedDataParallelConfig | None + model_type: ModelType + overlap_param_gather_with_optimizer_step: bool + fp16: bool | None + bf16: bool | None + use_megatron_fsdp: bool + use_torch_fsdp2: bool + wrap_with_ddp: bool + data_parallel_random_init: bool + use_cpu_initialization: bool | None + init_model_with_meta_device: bool | None + pre_wrap_hook: ( + Union[ + Callable[[list[MegatronModule]], list[MegatronModule]], + list[Callable[[list[MegatronModule]], list[MegatronModule]]], + ] + | None + ) + post_wrap_hook: Callable[[list[MegatronModule]], list[MegatronModule]] | None + + +class ModelParallelKwargs(TypedDict, total=False): + """Model-parallel override kwargs. + + Attributes map to `TransformerConfig`/provider fields that control parallelism. + Only provided values are applied as overrides. + """ + + tensor_model_parallel_size: int + pipeline_model_parallel_size: int + context_parallel_size: int + expert_model_parallel_size: int + expert_tensor_parallel_size: int + moe_extended_tp: bool + sequence_parallel: bool + virtual_pipeline_model_parallel_size: int | None + hierarchical_context_parallel_sizes: list[int] | None + + +def get_model( + model_provider: ModelProviderMixin, + ddp_config: DistributedDataParallelConfig, + model_type=ModelType.encoder_or_decoder, + overlap_param_gather_with_optimizer_step: bool = False, + fp16: bool | None = None, + bf16: bool | None = None, + use_megatron_fsdp: bool = False, + use_torch_fsdp2: bool = False, + wrap_with_ddp: bool = True, + data_parallel_random_init: bool = True, + use_cpu_initialization: None | bool = False, + init_model_with_meta_device: bool | None = None, + pre_wrap_hook: ( + Union[ + Callable[[list[MegatronModule]], list[MegatronModule]], + list[Callable[[list[MegatronModule]], list[MegatronModule]]], + ] + | None + ) = None, +) -> list[MegatronModule]: + """Create and configure a model for distributed training. + + This function handles the complete model creation pipeline including: + - Model instantiation with proper pipeline parallel configuration + - GPU memory allocation + - Mixed precision (FP16/BF16) wrapping + - Float8 tensor correction + - Distributed Data Parallel (DDP) wrapping + + Args: + model_provider: ModelProviderMixin instance that creates the model. + Uses the provide() method with optional pre_process(bool), post_process(bool), + vp_stage(int) arguments for pipeline parallelism + ddp_config: Configuration for distributed data parallel training + model_type: Type of model (encoder, decoder, or encoder_and_decoder) + overlap_param_gather_with_optimizer_step: Whether to overlap parameter + gathering with optimizer step for performance optimization + fp16: Enable FP16 mixed precision training. If None, uses model config + bf16: Enable BF16 mixed precision training. If None, uses model config + use_megatron_fsdp: Use Megatron's Fully Sharded Data Parallel + use_torch_fsdp2: Use PyTorch's Fully Sharded Data Parallel v2 + wrap_with_ddp: Whether to wrap the model with DDP + data_parallel_random_init: Whether to use random initialization for + data parallel ranks (vs broadcasting from rank 0) + use_cpu_initialization: Whether to initialize model on CPU to save GPU memory + init_model_with_meta_device: Whether to initialize the model on the meta device + pre_wrap_hook: A callable or list of callables that takes a list of `MegatronModule` + and returns a modified list, or `None` to clear the hook. If a list is provided, + hooks will be executed in order. + + Returns: + list[MegatronModule]: List of model modules. Contains multiple modules + when using virtual pipeline parallelism, otherwise a single module + """ + if fp16: + model_provider.fp16 = fp16 + if bf16: + model_provider.bf16 = bf16 + + model_provider.use_cpu_initialization = ( + use_cpu_initialization if use_cpu_initialization else False + ) + if init_model_with_meta_device: + model_provider.init_model_with_meta_device = True + with torch.device("meta"): + model = _create_model(model_provider, model_type) + else: + model = _create_model(model_provider, model_type) + + if pre_wrap_hook: + if isinstance(pre_wrap_hook, list): + # Execute hooks in order + for hook in pre_wrap_hook: + if not callable(hook): + raise RuntimeError("All elements in pre_wrap_hook list must be callable") + _model = hook(model) + if _model is not None: + model = _model + else: + if not callable(pre_wrap_hook): + raise RuntimeError("pre_wrap_hook must be a callable or a list of callables") + _model = pre_wrap_hook(model) + if _model is not None: + model = _model + + # Set tensor model parallel attributes if not set + # In case pre_wrap_hook augmented the model (e.g. adding PEFT adapters) + for model_module in model: + for param in model_module.parameters(): + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + _print_num_params(model) + + model_config = get_model_config(model[0]) + + # GPU allocation. + # For FSDP2, we don't allocate GPU memory here. We allocate GPU memory + # in the fully_shard function of FSDP2 instead. + if ( + not use_torch_fsdp2 + and not model_config.use_cpu_initialization + and not model_config.init_model_with_meta_device + ): + for model_module in model: + model_module.cuda(torch.cuda.current_device()) + + if model_config.fp16 or model_config.bf16: + model = [Float16Module(model_config, model_module) for model_module in model] + + if correct_amax_history_if_needed is not None: + correct_amax_history_if_needed(model) + + if wrap_with_ddp: + model = _ddp_wrap( + model, + data_parallel_random_init, + ddp_config, + overlap_param_gather_with_optimizer_step, + use_megatron_fsdp=use_megatron_fsdp, + use_torch_fsdp2=use_torch_fsdp2, + ) + + return model + + +def _create_model( + model_provider: ModelProviderMixin, model_type: ModelType +) -> list[MegatronModule]: + """Create model instances with appropriate pipeline parallel configuration. + + Handles virtual pipeline parallelism (VPP) by creating multiple model + instances when needed. Sets pre_process and post_process flags based on + pipeline parallel rank. + + Args: + model_provider: ModelProviderMixin instance that creates the model + model_type: ModelType enum indicating encoder, decoder, or both + + Returns: + list: List of model instances. Multiple instances for VPP, otherwise single + """ + + if ( + parallel_state.get_pipeline_model_parallel_world_size() > 1 + and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None + ): + assert ( + model_type != ModelType.encoder_and_decoder + ), "Interleaved schedule not supported for model with both encoder and decoder" + model = [] + for i in range(parallel_state.get_virtual_pipeline_model_parallel_world_size()): + pre_process = parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) + post_process = parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) + this_model = model_provider.provide( + pre_process=pre_process, post_process=post_process, vp_stage=i + ) + this_model.model_type = model_type + model.append(this_model) + else: + pre_process = parallel_state.is_pipeline_first_stage() + post_process = parallel_state.is_pipeline_last_stage() + if model_type == ModelType.encoder_and_decoder: + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + rank = parallel_state.get_pipeline_model_parallel_rank() + first_decoder_rank = parallel_state.get_pipeline_model_parallel_decoder_start() + world_size = parallel_state.get_pipeline_model_parallel_world_size() + pre_process = rank == 0 or rank == first_decoder_rank + post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size - 1)) + model = model_provider.provide() + else: + model = model_provider.provide(pre_process=pre_process, post_process=post_process) + model.model_type = model_type + + if not isinstance(model, list): + model = [model] + + # Set tensor model parallel attributes if not set. + # Only parameters that are already tensor model parallel have these + # attributes set for them. We should make sure the default attributes + # are set for all params so the optimizer can use them. + for model_module in model: + for param in model_module.parameters(): + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + return model + + +def _ddp_wrap( + model: list[MegatronModule], + data_parallel_random_init: bool, + ddp_config: DistributedDataParallelConfig, + overlap_param_gather_with_optimizer_step: bool, + use_megatron_fsdp: bool = False, + use_torch_fsdp2: bool = False, +) -> list[MegatronModule]: + """Wrap model with Distributed Data Parallel (DDP) or Fully Sharded Data Parallel (FSDP). + + Args: + model: List of model modules to wrap + use_torch_fsdp2: Whether to use PyTorch FSDP v2 instead of DDP + data_parallel_random_init: Whether to broadcast parameters from rank 0 + ddp_config: Configuration for distributed data parallel + overlap_param_gather_with_optimizer_step: Whether to disable bucketing + for overlapping parameter gathering with optimizer step + + Returns: + list[MegatronModule]: List of DDP/FSDP wrapped model modules + """ + if use_megatron_fsdp: + DP = FullyShardedDataParallel + if use_torch_fsdp2: + raise ValueError( + "Using use_megatron_fsdp and use_torch_fsdp2 at the same time is not supported." + ) + elif use_torch_fsdp2: + DP = TorchFullyShardedDataParallel + else: + DP = DistributedDataParallel + + model = [ + DP( + config=get_model_config(model_chunk), + ddp_config=ddp_config, + module=model_chunk, + # Turn off bucketing for model_chunk 2 onwards, since communication for these + # model chunks is overlapped with compute anyway. + disable_bucketing=(model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step, + ) + for (model_chunk_idx, model_chunk) in enumerate(model) + ] + + # Broadcast params from data parallel src rank to other data parallel ranks. + if data_parallel_random_init: + for model_module in model: + model_module.broadcast_params() + + return model + + +def _print_num_params(model: list[MegatronModule]) -> None: + """Print the number of parameters in the model on rank 0. + + Only prints on data parallel rank 0 to avoid duplicate output. + Shows parameter count per (tensor parallel, pipeline parallel) rank. + + Args: + model: List of model modules to count parameters from + """ + if ( + parallel_state.get_data_parallel_rank() == 0 + and parallel_state.get_context_parallel_rank() == 0 + ): + print( + " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( + parallel_state.get_tensor_model_parallel_rank(), + parallel_state.get_pipeline_model_parallel_rank(), + sum( + [ + sum([p.nelement() for p in model_module.parameters()]) + for model_module in model + ] + ), + ), + flush=True, + ) diff --git a/flagscale/train/bridge/models/qwen/__init__.py b/flagscale/train/bridge/models/qwen/__init__.py new file mode 100644 index 0000000000..b1391af905 --- /dev/null +++ b/flagscale/train/bridge/models/qwen/__init__.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from flagscale.train.bridge.models.qwen.qwen2_bridge import Qwen2Bridge # noqa: F401 +from flagscale.train.bridge.models.qwen.qwen3_bridge import Qwen3Bridge # noqa: F401 +from flagscale.train.bridge.models.qwen.qwen3_moe_bridge import Qwen3MoEBridge # noqa: F401 +from flagscale.train.bridge.models.qwen.qwen_provider import ( + Qwen2ModelProvider, + Qwen2ModelProvider1P5B, + Qwen2ModelProvider7B, + Qwen2ModelProvider72B, + Qwen2ModelProvider500M, + Qwen3ModelProvider, + Qwen3ModelProvider1P7B, + Qwen3ModelProvider4B, + Qwen3ModelProvider8B, + Qwen3ModelProvider14B, + Qwen3ModelProvider32B, + Qwen3ModelProvider600M, + Qwen3MoEModelProvider, + Qwen3MoEModelProvider30B_A3B, + Qwen3MoEModelProvider235B_A22B, + Qwen25ModelProvider1P5B, + Qwen25ModelProvider3B, + Qwen25ModelProvider7B, + Qwen25ModelProvider14B, + Qwen25ModelProvider32B, + Qwen25ModelProvider72B, + Qwen25ModelProvider500M, +) + +__all__ = [ + "Qwen2ModelProvider", + "Qwen2ModelProvider500M", + "Qwen2ModelProvider1P5B", + "Qwen2ModelProvider7B", + "Qwen2ModelProvider72B", + "Qwen25ModelProvider500M", + "Qwen25ModelProvider1P5B", + "Qwen25ModelProvider3B", + "Qwen25ModelProvider7B", + "Qwen25ModelProvider14B", + "Qwen25ModelProvider32B", + "Qwen25ModelProvider72B", + "Qwen3ModelProvider", + "Qwen3ModelProvider600M", + "Qwen3ModelProvider1P7B", + "Qwen3ModelProvider4B", + "Qwen3ModelProvider8B", + "Qwen3ModelProvider14B", + "Qwen3ModelProvider32B", + "Qwen3MoEModelProvider", + "Qwen3MoEModelProvider30B_A3B", + "Qwen3MoEModelProvider235B_A22B", +] diff --git a/flagscale/train/bridge/models/qwen/qwen2_bridge.py b/flagscale/train/bridge/models/qwen/qwen2_bridge.py new file mode 100644 index 0000000000..4e1e8b1161 --- /dev/null +++ b/flagscale/train/bridge/models/qwen/qwen2_bridge.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import torch + +from transformers import Qwen2ForCausalLM + +from megatron.core.models.gpt.gpt_model import GPTModel + +from flagscale.train.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from flagscale.train.bridge.models.conversion.model_bridge import MegatronModelBridge +from flagscale.train.bridge.models.conversion.param_mapping import ( + AutoMapping, + GatedMLPMapping, + QKVMapping, +) +from flagscale.train.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from flagscale.train.bridge.models.qwen.qwen_provider import Qwen2ModelProvider + + +@MegatronModelBridge.register_bridge(source=Qwen2ForCausalLM, target=GPTModel) +class Qwen2Bridge(MegatronModelBridge): + """ + Megatron Bridge for Qwen2 Causal LM. + + This bridge handles the conversion between HuggingFace Qwen2ForCausalLM + and Megatron-Core GPTModel formats, including weight mappings and + configuration translation. + + Example: + >>> from flagscale.train.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen2-7B") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> Qwen2ModelProvider: + hf_config = hf_pretrained.config + + provider = Qwen2ModelProvider( + num_layers=hf_config.num_hidden_layers, + hidden_size=hf_config.hidden_size, + ffn_hidden_size=hf_config.intermediate_size, + num_attention_heads=hf_config.num_attention_heads, + num_query_groups=hf_config.num_key_value_heads, + init_method_std=hf_config.initializer_range, + layernorm_epsilon=hf_config.rms_norm_eps, + gated_linear_unit=True, + make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size), + rotary_base=hf_config.rope_theta, + share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False), + vocab_size=hf_config.vocab_size, + seq_length=hf_config.max_position_embeddings, + fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), + bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), + params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), + generation_config=hf_pretrained.generation_config, + add_qkv_bias=True, # Qwen2 has bias in QKV projections + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format + # First create simple 1:1 parameter mappings using a dictionary for readability + + # Dictionary maps Megatron parameter names -> HF parameter names + # Supports wildcard (*) patterns for layer-specific parameters + param_mappings = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "output_layer.weight": "lm_head.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + } + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(megatron_param, hf_param) + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + # Add special mappings that require parameter concatenation/transformation + mapping_list.extend( + [ + # QKV: Combine separate Q, K, V matrices into single QKV matrix + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight", + ), + # QKV bias: Combine separate Q, K, V biases into single QKV bias (Qwen2 specific) + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.bias", + q="model.layers.*.self_attn.q_proj.bias", + k="model.layers.*.self_attn.k_proj.bias", + v="model.layers.*.self_attn.v_proj.bias", + ), + # Gated MLP: Combine gate and up projection matrices into single FC1 matrix + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.linear_fc1.weight", + gate="model.layers.*.mlp.gate_proj.weight", + up="model.layers.*.mlp.up_proj.weight", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/bridge/models/qwen/qwen3_bridge.py b/flagscale/train/bridge/models/qwen/qwen3_bridge.py new file mode 100644 index 0000000000..d0b6685ad9 --- /dev/null +++ b/flagscale/train/bridge/models/qwen/qwen3_bridge.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import torch + +from transformers import Qwen3ForCausalLM + +from megatron.core.models.gpt.gpt_model import GPTModel + +from flagscale.train.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from flagscale.train.bridge.models.conversion.model_bridge import MegatronModelBridge +from flagscale.train.bridge.models.conversion.param_mapping import ( + AutoMapping, + GatedMLPMapping, + QKVMapping, +) +from flagscale.train.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from flagscale.train.bridge.models.qwen.qwen_provider import Qwen3ModelProvider + + +@MegatronModelBridge.register_bridge(source=Qwen3ForCausalLM, target=GPTModel) +class Qwen3Bridge(MegatronModelBridge): + """ + Megatron Bridge for Qwen3 Causal LM. + + This bridge handles the conversion between HuggingFace Qwen2ForCausalLM + (used for Qwen3 models) and Megatron-Core GPTModel formats. Qwen3 differs + from Qwen2 by using QK layernorm. + + Example: + >>> from flagscale.train.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-1.7B") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> Qwen3ModelProvider: + hf_config = hf_pretrained.config + + provider = Qwen3ModelProvider( + num_layers=hf_config.num_hidden_layers, + hidden_size=hf_config.hidden_size, + ffn_hidden_size=hf_config.intermediate_size, + num_attention_heads=hf_config.num_attention_heads, + num_query_groups=hf_config.num_key_value_heads, + init_method_std=hf_config.initializer_range, + layernorm_epsilon=hf_config.rms_norm_eps, + gated_linear_unit=True, + make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size), + rotary_base=hf_config.rope_theta, + share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False), + vocab_size=hf_config.vocab_size, + seq_length=hf_config.max_position_embeddings, + fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), + bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), + params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), + generation_config=hf_pretrained.generation_config, + qk_layernorm=True, # Qwen3 uses QK layernorm + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format + # First create simple 1:1 parameter mappings using a dictionary for readability + + # Dictionary maps Megatron parameter names -> HF parameter names + # Supports wildcard (*) patterns for layer-specific parameters + param_mappings = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "output_layer.weight": "lm_head.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.self_attention.q_layernorm.weight": "model.layers.*.self_attn.q_norm.weight", # Qwen3 specific + "decoder.layers.*.self_attention.k_layernorm.weight": "model.layers.*.self_attn.k_norm.weight", # Qwen3 specific + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + } + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(megatron_param, hf_param) + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + # Add special mappings that require parameter concatenation/transformation + mapping_list.extend( + [ + # QKV: Combine separate Q, K, V matrices into single QKV matrix + # Note: Qwen3 does NOT have bias in QKV projections (unlike Qwen2) + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight", + ), + # Gated MLP: Combine gate and up projection matrices into single FC1 matrix + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.linear_fc1.weight", + gate="model.layers.*.mlp.gate_proj.weight", + up="model.layers.*.mlp.up_proj.weight", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/bridge/models/qwen/qwen3_moe_bridge.py b/flagscale/train/bridge/models/qwen/qwen3_moe_bridge.py new file mode 100755 index 0000000000..9a4f2f2976 --- /dev/null +++ b/flagscale/train/bridge/models/qwen/qwen3_moe_bridge.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import torch + +from transformers import Qwen3MoeForCausalLM + +from megatron.core.models.gpt.gpt_model import GPTModel + +from flagscale.train.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from flagscale.train.bridge.models.conversion.model_bridge import MegatronModelBridge +from flagscale.train.bridge.models.conversion.param_mapping import ( + AutoMapping, + GatedMLPMapping, + QKVMapping, +) +from flagscale.train.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from flagscale.train.bridge.models.qwen.qwen_provider import Qwen3MoEModelProvider + + +@MegatronModelBridge.register_bridge(source=Qwen3MoeForCausalLM, target=GPTModel) +class Qwen3MoEBridge(MegatronModelBridge): + """ + Megatron Bridge for Qwen3 MoE Causal LM. + + This bridge handles the conversion between HuggingFace Qwen3MoeForCausalLM + and Megatron-Core GPTModel formats. Qwen3 MoE models use mixture of experts + architecture with QK layernorm. + + Example: + >>> from flagscale.train.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-235B-A22B") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> Qwen3MoEModelProvider: + hf_config = hf_pretrained.config + + provider = Qwen3MoEModelProvider( + num_layers=hf_config.num_hidden_layers, + hidden_size=hf_config.hidden_size, + ffn_hidden_size=hf_config.intermediate_size, + moe_ffn_hidden_size=hf_config.moe_intermediate_size, # Maps to moe_intermediate_size in HF + num_attention_heads=hf_config.num_attention_heads, + num_query_groups=hf_config.num_key_value_heads, + num_moe_experts=hf_config.num_experts, + moe_router_topk=hf_config.num_experts_per_tok, # Maps to num_experts_per_tok in HF + init_method_std=hf_config.initializer_range, + layernorm_epsilon=hf_config.rms_norm_eps, + gated_linear_unit=True, + make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size), + rotary_base=hf_config.rope_theta, + share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False), + vocab_size=hf_config.vocab_size, + seq_length=hf_config.max_position_embeddings, + fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), + bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), + params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), + generation_config=hf_pretrained.generation_config, + qk_layernorm=True, # Qwen3 MoE uses QK layernorm + moe_grouped_gemm=True, + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format + # First create simple 1:1 parameter mappings using a dictionary for readability + + # Dictionary maps Megatron parameter names -> HF parameter names + # Supports wildcard (*) patterns for layer-specific parameters + param_mappings = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "output_layer.weight": "lm_head.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.router.weight": "model.layers.*.mlp.gate.weight", + "decoder.layers.*.pre_mlp_layernorm.weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.self_attention.q_layernorm.weight": "model.layers.*.self_attn.q_norm.weight", + "decoder.layers.*.self_attention.k_layernorm.weight": "model.layers.*.self_attn.k_norm.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + } + + mapping_list = [] + # Convert each dictionary entry to AutoMapping(megatron_param, hf_param) + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + # Add special mappings that require parameter concatenation/transformation + mapping_list.extend( + [ + # QKV: Combine separate Q, K, V matrices into single QKV matrix + # Note: Qwen3 MoE does NOT have bias in QKV projections + QKVMapping( + megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", + q="model.layers.*.self_attn.q_proj.weight", + k="model.layers.*.self_attn.k_proj.weight", + v="model.layers.*.self_attn.v_proj.weight", + ), + GatedMLPMapping( + megatron_param="decoder.layers.*.mlp.experts.linear_fc1.weight*", + gate="model.layers.*.mlp.experts.*.gate_proj.weight", + up="model.layers.*.mlp.experts.*.up_proj.weight", + ), + AutoMapping( + megatron_param="decoder.layers.*.mlp.experts.linear_fc2.weight*", + hf_param="model.layers.*.mlp.experts.*.down_proj.weight", + ), + ] + ) + + return MegatronMappingRegistry(*mapping_list) diff --git a/flagscale/train/bridge/models/qwen/qwen_provider.py b/flagscale/train/bridge/models/qwen/qwen_provider.py new file mode 100644 index 0000000000..bb7211b81c --- /dev/null +++ b/flagscale/train/bridge/models/qwen/qwen_provider.py @@ -0,0 +1,393 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import logging + +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +import torch.nn.functional as F + +from flagscale.train.bridge.models.gpt_provider import GPTModelProvider + +logger = logging.getLogger(__name__) + + +@dataclass +class Qwen2ModelProvider(GPTModelProvider): + """Base model provider for Qwen 2 Models.""" + + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True + add_bias_linear: bool = False + add_qkv_bias: bool = True + seq_length: int = 4096 + init_method_std: int = 0.02 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + vocab_size: int = 151936 + share_embeddings_and_output_weights: Optional[bool] = False + layernorm_epsilon: float = 1e-6 + rotary_base: float = 1000000.0 + position_embedding_type: str = "rope" + autocast_dtype: torch.dtype = torch.bfloat16 + params_dtype: torch.dtype = torch.bfloat16 + bf16: bool = True + + +# ============================================================================= +# Qwen 2 Model Providers +# ============================================================================= + + +@dataclass +class Qwen2ModelProvider500M(Qwen2ModelProvider): + """ + Config for Qwen 2 0.5B: https://huggingface.co/Qwen/Qwen2-0.5B + """ + + num_layers: int = 24 + hidden_size: int = 896 + num_attention_heads: int = 14 + num_query_groups: int = 2 + ffn_hidden_size: int = 4864 + share_embeddings_and_output_weights: bool = True + seq_length: int = 32768 + + +@dataclass +class Qwen2ModelProvider1P5B(Qwen2ModelProvider): + """ + Config for Qwen 2 1.5B: https://huggingface.co/Qwen/Qwen2-1.5B + """ + + num_layers: int = 28 + hidden_size: int = 1536 + num_attention_heads: int = 12 + num_query_groups: int = 2 + ffn_hidden_size: int = 8960 + seq_length: int = 32768 + share_embeddings_and_output_weights: bool = True + + +@dataclass +class Qwen2ModelProvider7B(Qwen2ModelProvider): + """ + Config for Qwen 2 7B: https://huggingface.co/Qwen/Qwen2-7B + """ + + num_layers: int = 28 + hidden_size: int = 3584 + num_attention_heads: int = 28 + num_query_groups: int = 4 + ffn_hidden_size: int = 18944 + vocab_size: int = 152064 + seq_length: int = 32768 + + +@dataclass +class Qwen2ModelProvider72B(Qwen2ModelProvider): + """ + Config for Qwen 2 72B: https://huggingface.co/Qwen/Qwen2-72B + """ + + num_layers: int = 80 + hidden_size: int = 8192 + num_attention_heads: int = 64 + num_query_groups: int = 8 + ffn_hidden_size: int = 29568 + vocab_size: int = 152064 + layernorm_epsilon: float = 1e-6 + seq_length: int = 32768 + + +# ============================================================================= +# Qwen 2.5 Model Providers +# ============================================================================= + + +@dataclass +class Qwen25ModelProvider500M(Qwen2ModelProvider): + """ + Config for Qwen 2.5 0.5B: https://huggingface.co/Qwen/Qwen2.5-0.5B + """ + + num_layers: int = 24 + hidden_size: int = 896 + num_attention_heads: int = 14 + num_query_groups: int = 2 + ffn_hidden_size: int = 4864 + share_embeddings_and_output_weights: bool = True + seq_length: int = 32768 + + +@dataclass +class Qwen25ModelProvider1P5B(Qwen2ModelProvider): + """ + Config for Qwen 2.5 1.5B: https://huggingface.co/Qwen/Qwen2.5-1.5B + """ + + num_layers: int = 28 + hidden_size: int = 1536 + num_attention_heads: int = 12 + num_query_groups: int = 2 + ffn_hidden_size: int = 8960 + seq_length: int = 32768 + share_embeddings_and_output_weights: bool = True + + +@dataclass +class Qwen25ModelProvider3B(Qwen2ModelProvider): + """ + Config for Qwen 2.5 3B: https://huggingface.co/Qwen/Qwen2.5-3B + """ + + num_layers: int = 36 + hidden_size: int = 2048 + num_attention_heads: int = 16 + num_query_groups: int = 2 + ffn_hidden_size: int = 11008 + vocab_size: int = 151936 + share_embeddings_and_output_weights: bool = True + seq_length: int = 32768 + + +@dataclass +class Qwen25ModelProvider7B(Qwen2ModelProvider): + """ + Config for Qwen 2.5 7B: https://huggingface.co/Qwen/Qwen2.5-7B + """ + + num_layers: int = 28 + hidden_size: int = 3584 + num_attention_heads: int = 28 + num_query_groups: int = 4 + ffn_hidden_size: int = 18944 + vocab_size: int = 152064 + seq_length: int = 32768 + + +@dataclass +class Qwen25ModelProvider14B(Qwen2ModelProvider): + """ + Config for Qwen 2.5 14B: https://huggingface.co/Qwen/Qwen2.5-14B + """ + + num_layers: int = 48 + hidden_size: int = 5120 + num_attention_heads: int = 40 + num_query_groups: int = 8 + ffn_hidden_size: int = 13824 + vocab_size: int = 152064 + layernorm_epsilon: float = 1e-6 + seq_length: int = 32768 + + +@dataclass +class Qwen25ModelProvider32B(Qwen2ModelProvider): + """ + Config for Qwen 2.5 32B: https://huggingface.co/Qwen/Qwen2.5-32B + """ + + num_layers: int = 64 + hidden_size: int = 5120 + num_attention_heads: int = 40 + num_query_groups: int = 8 + ffn_hidden_size: int = 27648 + vocab_size: int = 152064 + layernorm_epsilon: float = 1e-6 + seq_length: int = 32768 + + +@dataclass +class Qwen25ModelProvider72B(Qwen2ModelProvider): + """ + Config for Qwen 2.5 72B: https://huggingface.co/Qwen/Qwen2.5-72B + """ + + num_layers: int = 80 + hidden_size: int = 8192 + num_attention_heads: int = 64 + num_query_groups: int = 8 + ffn_hidden_size: int = 29568 + vocab_size: int = 152064 + layernorm_epsilon: float = 1e-6 + seq_length: int = 32768 + + +# ============================================================================= +# Qwen 3 Model Provider (based on GPTProvider) +# ============================================================================= + + +@dataclass +class Qwen3ModelProvider(GPTModelProvider): + """Base model provider for Qwen 3 Models.""" + + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True + add_bias_linear: bool = False + add_qkv_bias: bool = False + qk_layernorm: bool = True + kv_channels: Optional[int] = 128 + num_query_groups: int = 8 + seq_length: int = 40960 + init_method_std: int = 0.02 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + vocab_size: int = 151936 + share_embeddings_and_output_weights: Optional[bool] = False + layernorm_epsilon: float = 1e-6 + rotary_base: float = 1000000.0 + position_embedding_type: str = "rope" + autocast_dtype: torch.dtype = torch.bfloat16 + params_dtype: torch.dtype = torch.bfloat16 + bf16: bool = True + + +@dataclass +class Qwen3ModelProvider600M(Qwen3ModelProvider): + """ + Config for Qwen 3 0.6B: https://huggingface.co/Qwen/Qwen3-0.6B + """ + + num_layers: int = 28 + hidden_size: int = 1024 + num_attention_heads: int = 16 + ffn_hidden_size: int = 3072 + share_embeddings_and_output_weights: bool = True + + +@dataclass +class Qwen3ModelProvider1P7B(Qwen3ModelProvider): + """ + Config for Qwen 3 1.7B: https://huggingface.co/Qwen/Qwen3-1.7B + """ + + num_layers: int = 28 + hidden_size: int = 2048 + num_attention_heads: int = 16 + ffn_hidden_size: int = 6144 + share_embeddings_and_output_weights: bool = True + + +@dataclass +class Qwen3ModelProvider4B(Qwen3ModelProvider): + """ + Config for Qwen 3 4B: https://huggingface.co/Qwen/Qwen3-4B + """ + + num_layers: int = 36 + hidden_size: int = 2560 + num_attention_heads: int = 32 + ffn_hidden_size: int = 9728 + share_embeddings_and_output_weights: bool = True + + +@dataclass +class Qwen3ModelProvider8B(Qwen3ModelProvider): + """ + Config for Qwen 3 8B: https://huggingface.co/Qwen/Qwen3-8B + """ + + num_layers: int = 36 + hidden_size: int = 4096 + num_attention_heads: int = 32 + ffn_hidden_size: int = 12288 + + +@dataclass +class Qwen3ModelProvider14B(Qwen3ModelProvider): + """ + Config for Qwen 3 14B: https://huggingface.co/Qwen/Qwen3-14B + """ + + num_layers: int = 40 + hidden_size: int = 5120 + num_attention_heads: int = 40 + ffn_hidden_size: int = 17408 + + +@dataclass +class Qwen3ModelProvider32B(Qwen3ModelProvider): + """ + Config for Qwen 3 32B: https://huggingface.co/Qwen/Qwen3-32B + """ + + num_layers: int = 64 + hidden_size: int = 5120 + num_attention_heads: int = 64 + ffn_hidden_size: int = 25600 + + +# ============================================================================= +# Qwen 3 MoE Model Provider (based on GPTProvider) +# ============================================================================= + + +@dataclass +class Qwen3MoEModelProvider(GPTModelProvider): + """Base provider for Qwen 3 MoE Models.""" + + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True + add_bias_linear: bool = False + add_qkv_bias: bool = False + qk_layernorm: bool = True + kv_channels: Optional[int] = 128 + num_query_groups: int = 8 + seq_length: int = 40960 + init_method_std: int = 0.02 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + vocab_size: int = 151936 + share_embeddings_and_output_weights: Optional[bool] = False + layernorm_epsilon: float = 1e-6 + rotary_base: float = 1000000.0 + position_embedding_type: str = "rope" + autocast_dtype: torch.dtype = torch.bfloat16 + params_dtype: torch.dtype = torch.bfloat16 + bf16: bool = True + + # MoE specific parameters + num_moe_experts: int = 128 + moe_router_load_balancing_type: str = "aux_loss" + moe_aux_loss_coeff: float = 1e-3 + moe_router_topk: int = 8 + moe_router_pre_softmax: bool = False + moe_grouped_gemm: bool = True + moe_token_dispatcher_type: str = "alltoall" + moe_permute_fusion: bool = True + + +@dataclass +class Qwen3MoEModelProvider30B_A3B(Qwen3MoEModelProvider): + """ + Provider for Qwen 3 30B-A3B: https://huggingface.co/Qwen/Qwen3-30B-A3B + """ + + num_layers: int = 48 + hidden_size: int = 2048 + num_attention_heads: int = 32 + num_query_groups: int = 4 + ffn_hidden_size: int = 6144 + moe_ffn_hidden_size: int = 768 + + +@dataclass +class Qwen3MoEModelProvider235B_A22B(Qwen3MoEModelProvider): + """ + Provider for Qwen 3 235B-A22B: https://huggingface.co/Qwen/Qwen3-235B-A22B + """ + + num_layers: int = 94 + hidden_size: int = 4096 + num_attention_heads: int = 64 + num_query_groups: int = 4 + ffn_hidden_size: int = 12288 + moe_ffn_hidden_size: int = 1536 diff --git a/flagscale/train/bridge/models/transformer_config.py b/flagscale/train/bridge/models/transformer_config.py new file mode 100644 index 0000000000..4a3daf77fc --- /dev/null +++ b/flagscale/train/bridge/models/transformer_config.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +"""Bridge wrapper classes for Megatron Core transformer configurations. + +These classes provide deferred post-initialization to support the Bridge configuration +override system while maintaining compatibility with Megatron Core's post_init behavior. +""" + +from dataclasses import dataclass + +from megatron.core.transformer.transformer_config import ( + MLATransformerConfig as MCoreMLATransformerConfig, + TransformerConfig as MCoreTransformerConfig, +) + + +@dataclass +class TransformerConfig(MCoreTransformerConfig): + """Megatron Core TransformerConfig with deferred post-init. + + This class inherits from Megatron Core's TransformerConfig but defers the + execution of post_init() until finalize() is explicitly called. This allows + for field modifications after construction but before computed fields are + calculated. + + Usage: + # Create config with deferred post-init + config = TransformerConfig(num_layers=32, hidden_size=4096) + + # Modify fields as needed + config.seq_length = 8192 + config.tensor_model_parallel_size = 2 + + # Finalize to compute derived fields + config.finalize() + """ + + def __post_init__(self) -> None: + """Skip MCore post_init during initial construction. + + The original post_init logic is deferred until finalize() is called. + This allows for field modifications after construction without + invalidating computed fields. + """ + pass + + def finalize(self) -> None: + """Execute the deferred MCore post-init logic. + + This method calls the original Megatron Core TransformerConfig.__post_init__() + to compute derived fields based on the current field values. It can be + called multiple times safely. + """ + MCoreTransformerConfig.__post_init__(self) + + +@dataclass +class MLATransformerConfig(TransformerConfig, MCoreMLATransformerConfig): + """Megatron Core MLATransformerConfig with deferred post-init. + + This class inherits from Megatron Core's MLATransformerConfig but defers the + execution of post_init() until finalize() is explicitly called. This allows + for field modifications after construction but before computed fields are + calculated. + + Usage: + # Create config with deferred post-init + config = MLATransformerConfig(num_layers=32, hidden_size=4096) + + # Modify fields as needed + config.q_lora_rank = 1536 + config.kv_lora_rank = 512 + + # Finalize to compute derived fields + config.finalize() + """ + + def __post_init__(self) -> None: + """Skip MCore post_init during initial construction. + + The original post_init logic is deferred until finalize() is called. + This allows for field modifications after construction without + invalidating computed fields. + """ + pass + + def finalize(self) -> None: + """Execute the deferred MCore post-init logic. + + This method calls the original Megatron Core MLATransformerConfig.__post_init__() + to compute derived fields based on the current field values. It can be + called multiple times safely. + """ + MCoreMLATransformerConfig.__post_init__(self) diff --git a/flagscale/train/bridge/utils/__init__.py b/flagscale/train/bridge/utils/__init__.py new file mode 100644 index 0000000000..3bfe2ab7d3 --- /dev/null +++ b/flagscale/train/bridge/utils/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge diff --git a/flagscale/train/bridge/utils/common_utils.py b/flagscale/train/bridge/utils/common_utils.py new file mode 100644 index 0000000000..de4e4e17e4 --- /dev/null +++ b/flagscale/train/bridge/utils/common_utils.py @@ -0,0 +1,147 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import os +import types +import warnings + +import torch +import torch.distributed + +from megatron.core import DistributedDataParallel as DDP +from megatron.core.transformer.module import Float16Module + +try: + from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP + + ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, torch_FSDP, Float16Module) +except ImportError: + ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) + + +def get_rank_safe() -> int: + """Get the distributed rank safely, even if torch.distributed is not initialized. + + Returns: + The current process rank. + """ + # In megatron init, args.rank comes from the torchrun env var. + # Once init has been done, args.rank is updated to value of torch get_rank() + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return int(os.getenv("RANK", "0")) + + +def get_world_size_safe() -> int: + """Get the distributed world size safely, even if torch.distributed is not initialized. + + Returns: + The total number of processes in the distributed job. + """ + # In megatron init, args.world_size comes from the torchrun env var. + # Once init has been done, args.world_size is updated to value of torch get_world_size() + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return int(os.getenv("WORLD_SIZE", "1")) + + +def get_last_rank() -> int: + """Get the last rank in the distributed group""" + if not torch.distributed.is_initialized(): + return 0 + return torch.distributed.get_world_size() - 1 + + +def get_local_rank_preinit() -> int: + """Get the local rank from the environment variable, intended for use before full init. + + Returns: + The local rank of the current process. + """ + return int(os.getenv("LOCAL_RANK", "0")) + + +def print_rank_0(message: str) -> None: + """Print a message only on global rank 0. + + Args: + message: The message string to print. + """ + rank = get_rank_safe() + if rank == 0: + print(message, flush=True) + + +def warn_rank_0(message): + """Warn only on rank 0.""" + rank = get_rank_safe() + if rank == 0: + warnings.warn(message) + + +def is_last_rank() -> bool: + """Check if the current rank is the last rank in the default process group. + + Returns: + True if the current rank is the last one, False otherwise. + """ + return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1) + + +def print_rank_last(message: str) -> None: + """Print a message only on the last rank of the default process group. + + Args: + message: The message string to print. + """ + if torch.distributed.is_initialized(): + if is_last_rank(): + print(message, flush=True) + else: + print(message, flush=True) + + +def hook_hf_module_setattr_for_tp_grad_sync(module: torch.nn.Module) -> torch.nn.Module: + """Mark params for TP grad sync and hook __setattr__ on a module and its children. + + This ensures that all existing parameters under the provided module have the + attribute ``average_gradients_across_tp_domain=True`` and that any future + submodules assigned onto this module (or any of its current children) will + also have their parameters marked automatically. + + Args: + module: The root module (typically a Hugging Face module instance). + + Returns: + The same module instance for convenience. + """ + if module is None: + return module + + # Mark all existing parameters recursively + for param in module.parameters(recurse=True): + setattr(param, "average_gradients_across_tp_domain", True) + + def _wrap_setattr(original_setattr): + def _wrapped(self, name, value): + original_setattr(name, value) + if isinstance(value, torch.nn.Module): + for p in value.parameters(recurse=True): + setattr(p, "average_gradients_across_tp_domain", True) + + return _wrapped + + # Hook __setattr__ on the module and all existing submodules to catch + # future dynamic assignments anywhere in the hierarchy. + for submodule in module.modules(): + if getattr(submodule, "_tp_grad_sync_setattr_wrapped", False): + continue + original_setattr = submodule.__setattr__ + wrapped = _wrap_setattr(original_setattr) + submodule.__setattr__ = types.MethodType(wrapped, submodule) + setattr(submodule, "_tp_grad_sync_setattr_wrapped", True) + + return module diff --git a/flagscale/train/bridge/utils/decorators.py b/flagscale/train/bridge/utils/decorators.py new file mode 100644 index 0000000000..437db3b4f6 --- /dev/null +++ b/flagscale/train/bridge/utils/decorators.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import functools +import logging +import warnings + +from typing import Any, Callable, TypeVar + +logger = logging.getLogger(__name__) + +# Define a TypeVar for generic return types +R = TypeVar("R") + + +def experimental_fn(func: Callable[..., R]) -> Callable[..., R]: + """Decorator to mark a function as experimental and issue a warning upon its call.""" + warning_message = f"Function '{func.__name__}' is experimental. APIs in this module are subject to change without notice." + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> R: + warnings.warn(warning_message, stacklevel=2) + return func(*args, **kwargs) + + return wrapper diff --git a/flagscale/train/bridge/utils/fusions.py b/flagscale/train/bridge/utils/fusions.py new file mode 100644 index 0000000000..1f7d6f52a6 --- /dev/null +++ b/flagscale/train/bridge/utils/fusions.py @@ -0,0 +1,175 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +"""Fusion capability checks for Megatron models. + +This module provides functions to check if various fusion optimizations +can be enabled based on the current environment and dependencies. +""" + +import logging +import os + +from megatron.core.transformer.transformer_config import TransformerConfig + +logger = logging.getLogger(__name__) + +# Control whether to log warnings when fusions are disabled +# Set environment variable MEGATRON_SUPPRESS_FUSION_WARNINGS=1 to disable warnings +LOG_FUSION_DISABLE = os.environ.get("MEGATRON_SUPPRESS_FUSION_WARNINGS", "0") != "1" + + +def can_enable_apply_rope_fusion() -> bool: + """Check if RoPE (Rotary Position Embedding) fusion can be enabled. + + Returns: + bool: True if RoPE fusion is available and compatible. + """ + # Check for Transformer Engine availability + try: + import transformer_engine # noqa: F401 + + from megatron.core.utils import get_te_version, is_te_min_version + + if not is_te_min_version("2.2.0.dev0"): + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion requires Transformer Engine >= 2.2.0.dev0. " + f"Current version: {get_te_version()}. Fusion disabled." + ) + return False + except ImportError: + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion requires Transformer Engine but it is not installed. Fusion disabled." + ) + return False + + # Check for RoPE fusion kernel availability + try: + from megatron.core.models.common.embeddings.rope_utils import ( + fused_apply_rotary_pos_emb, + fused_apply_rotary_pos_emb_thd, + ) + + if fused_apply_rotary_pos_emb is None and fused_apply_rotary_pos_emb_thd is None: + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion kernels are not available in megatron.core. Fusion disabled." + ) + return False + return True + except ImportError: + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion requires RoPE fusion kernels from megatron.core but they are not available. " + "Fusion disabled." + ) + return False + + +def can_enable_gradient_accumulation_fusion() -> bool: + """Check if gradient accumulation fusion can be enabled. + + Returns: + bool: True if gradient accumulation fusion is available. + """ + try: + import fused_weight_gradient_mlp_cuda # noqa: F401 + + return True + except ImportError: + if LOG_FUSION_DISABLE: + logger.warning( + "gradient_accumulation_fusion requires FusedLayerNorm from megatron.core.fusions " + "but it is not available. Fusion disabled." + ) + return False + + +def can_enable_bias_dropout_fusion() -> bool: + """Check if bias dropout fusion can be enabled. + + Returns: + bool: True if bias dropout fusion is available. + """ + try: + from megatron.core.fusions.fused_bias_dropout import ( # noqa: F401 + bias_dropout_add_fused_train, + ) + + return True + except ImportError: + if LOG_FUSION_DISABLE: + logger.warning( + "bias_dropout_fusion requires fused_bias_dropout from megatron.core.fusions " + "but it is not available. Fusion disabled." + ) + return False + + +def can_enable_masked_softmax_fusion() -> bool: + """Check if masked softmax fusion can be enabled. + + Returns: + bool: True if masked softmax fusion kernels are available. + """ + try: + # Try to import the CUDA kernels that are required for masked softmax fusion + import scaled_masked_softmax_cuda # noqa: F401 + import scaled_upper_triang_masked_softmax_cuda # noqa: F401 + + return True + except ImportError: + if LOG_FUSION_DISABLE: + logger.warning( + "masked_softmax_fusion requires CUDA kernels (scaled_masked_softmax_cuda) " + "but they are not available. This typically happens when Megatron-Core is not " + "built with CUDA extensions. Fusion disabled." + ) + return False + + +def validate_rope_fusion_compatibility(config: TransformerConfig) -> bool: + """Validate if RoPE fusion is compatible with the current model configuration. + + Args: + model_provider: The GPTModelProvider instance to validate. + + Returns: + bool: True if RoPE fusion is compatible, False otherwise. + """ + if not config.apply_rope_fusion: + return True + + # Check for multi_latent_attention incompatibility + if getattr(config, "multi_latent_attention", False): + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion for multi-latent attention only supports training. " + "It is experimental and may change in future versions." + ) + return True + + # Check TE version for rotary_interleaved + if getattr(config, "rotary_interleaved", False): + try: + from megatron.core.utils import get_te_version, is_te_min_version + + if not is_te_min_version("2.2.0.dev0"): + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion with rotary_interleaved requires TE >= 2.2.0.dev0. " + f"Current TE version: {get_te_version()}. Consider disabling apply_rope_fusion." + ) + return False + except ImportError: + if LOG_FUSION_DISABLE: + logger.warning( + "apply_rope_fusion with rotary_interleaved requires Transformer Engine. " + "Consider disabling apply_rope_fusion." + ) + return False + + return True diff --git a/flagscale/train/bridge/utils/import_utils.py b/flagscale/train/bridge/utils/import_utils.py new file mode 100644 index 0000000000..33d1dd4edf --- /dev/null +++ b/flagscale/train/bridge/utils/import_utils.py @@ -0,0 +1,409 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import importlib +import logging +import traceback + +from contextlib import contextmanager +from typing import Tuple + +import torch + +from packaging.version import Version as PkgVersion + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler()) + +GPU_INSTALL_STRING = ( + """Install GPU packages via `pip install --extra-index-url """ + """https://pypi.nvidia.com nemo-curator[cuda12x]` +or use `pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]"` if installing from source""" +) +MISSING_NEMO_EXPORT_DEPLOY_MSG = ( + "nemo-export-deploy is not available. Please install it with `pip install nemo-export-deploy`." +) +MISSING_NVRX_MSG = "nvidia-resiliency-ext is not available. Please install it with `pip install nvidia-resiliency-ext`." +MISSING_NEMO_RUN_MSG = "nemo-run is not available. Please install it with `pip install nemo-run`." + + +class UnavailableError(Exception): + """Error thrown if a symbol is unavailable due to an issue importing it""" + + +@contextmanager +def null_decorator(*args, **kwargs): + """null_decorator""" + if len(kwargs) == 0 and len(args) == 1 and callable(args[0]): + return args[0] + else: + + def inner(func): + return func + + return inner + + +class UnavailableMeta(type): + """A metaclass for generating placeholder objects for unavailable symbols + + This metaclass allows errors to be deferred from import time to the time + that a symbol is actually used in order to streamline the usage of optional + dependencies. This is particularly useful for attempted imports of GPU-only + modules which will only be invoked if GPU-only functionality is + specifically used. + + If an attempt to import a symbol fails, this metaclass is used to generate + a class which stands in for that symbol. Any attempt to call the symbol + (instantiate the class) or access its attributes will throw an + UnavailableError exception. Furthermore, this class can be used in + e.g. isinstance checks, since it will (correctly) fail to match any + instance it is compared against. + + In addition to calls and attribute access, a number of dunder methods are + implemented so that other common usages of imported symbols (e.g. + arithmetic) throw an UnavailableError, but this is not guaranteed for + all possible uses. In such cases, other exception types (typically + TypeErrors) will be thrown instead. + """ + + def __new__(meta, name, bases, dct): + if dct.get("_msg", None) is None: + dct["_msg"] = f"{name} could not be imported" + name = f"MISSING{name}" + return super(UnavailableMeta, meta).__new__(meta, name, bases, dct) + + def __call__(cls, *args, **kwargs): + raise UnavailableError(cls._msg) + + def __getattr__(cls, name): + # Special handling for unittest.mock which tries to access __func_ + # and other attributes during its operations + if name in ("__func__", "__wrapped__", "__name__", "__qualname__"): + raise AttributeError(f"'{cls.__name__}' has no attribute '{name}'") + raise UnavailableError(cls._msg) + + def __eq__(cls, other): + raise UnavailableError(cls._msg) + + def __lt__(cls, other): + raise UnavailableError(cls._msg) + + def __gt__(cls, other): + raise UnavailableError(cls._msg) + + def __le__(cls, other): + raise UnavailableError(cls._msg) + + def __ge__(cls, other): + raise UnavailableError(cls._msg) + + def __ne__(cls, other): + raise UnavailableError(cls._msg) + + def __abs__(cls): + raise UnavailableError(cls._msg) + + def __add__(cls, other): + raise UnavailableError(cls._msg) + + def __radd__(cls, other): + raise UnavailableError(cls._msg) + + def __iadd__(cls, other): + raise UnavailableError(cls._msg) + + def __floordiv__(cls, other): + raise UnavailableError(cls._msg) + + def __rfloordiv__(cls, other): + raise UnavailableError(cls._msg) + + def __ifloordiv__(cls, other): + raise UnavailableError(cls._msg) + + def __lshift__(cls, other): + raise UnavailableError(cls._msg) + + def __rlshift__(cls, other): + raise UnavailableError(cls._msg) + + def __mul__(cls, other): + raise UnavailableError(cls._msg) + + def __rmul__(cls, other): + raise UnavailableError(cls._msg) + + def __imul__(cls, other): + raise UnavailableError(cls._msg) + + def __ilshift__(cls, other): + raise UnavailableError(cls._msg) + + def __pow__(cls, other): + raise UnavailableError(cls._msg) + + def __rpow__(cls, other): + raise UnavailableError(cls._msg) + + def __ipow__(cls, other): + raise UnavailableError(cls._msg) + + def __rshift__(cls, other): + raise UnavailableError(cls._msg) + + def __rrshift__(cls, other): + raise UnavailableError(cls._msg) + + def __irshift__(cls, other): + raise UnavailableError(cls._msg) + + def __sub__(cls, other): + raise UnavailableError(cls._msg) + + def __rsub__(cls, other): + raise UnavailableError(cls._msg) + + def __isub__(cls, other): + raise UnavailableError(cls._msg) + + def __truediv__(cls, other): + raise UnavailableError(cls._msg) + + def __rtruediv__(cls, other): + raise UnavailableError(cls._msg) + + def __itruediv__(cls, other): + raise UnavailableError(cls._msg) + + def __divmod__(cls, other): + raise UnavailableError(cls._msg) + + def __rdivmod__(cls, other): + raise UnavailableError(cls._msg) + + def __neg__(cls): + raise UnavailableError(cls._msg) + + def __invert__(cls): + raise UnavailableError(cls._msg) + + def __hash__(cls): + raise UnavailableError(cls._msg) + + def __index__(cls): + raise UnavailableError(cls._msg) + + def __iter__(cls): + raise UnavailableError(cls._msg) + + def __delitem__(cls, name): + raise UnavailableError(cls._msg) + + def __setitem__(cls, name, value): + raise UnavailableError(cls._msg) + + def __enter__(cls, *args, **kwargs): + raise UnavailableError(cls._msg) + + def __get__(cls, *args, **kwargs): + raise UnavailableError(cls._msg) + + def __delete__(cls, *args, **kwargs): + raise UnavailableError(cls._msg) + + def __len__(cls): + raise UnavailableError(cls._msg) + + +def is_unavailable(obj): + """Helper to check if given symbol is actually a placeholder""" + return type(obj) is UnavailableMeta + + +class UnavailableNullContext: + """A placeholder class for unavailable context managers + + This context manager will return a value which will throw an + UnavailableError if used in any way, but the context manager itself can be + safely invoked. + """ + + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return UnavailableMeta( + "MissingContextValue", + (), + {"_msg": "Attempted to make use of placeholder context return value."}, + ) + + def __exit__(self, *args, **kwargs): + pass + + +def safe_import(module, *, msg=None, alt=None) -> Tuple[object, bool]: + """A function used to import modules that may not be available. + + This function will attempt to import a module with the given name, but it + will not throw an ImportError if the module is not found. Instead, it will + return a placeholder object which will raise an exception only if used. + + Args: + module (str): The name of the module to import. + msg (str, optional): An error message to be displayed if this module is used + after a failed import. Defaults to None. + alt (object, optional): A module to be used in place of the given module if it + fails to import. Defaults to None. + + Returns: + tuple: A tuple containing two elements. The first element is the imported module, + the given alternate, or a class derived from UnavailableMeta. The second element + is a boolean indicating whether the intended import was successful. + """ + try: + return importlib.import_module(module), True + except ImportError: + exception_text = traceback.format_exc() + logger.debug(f"Import of {module} failed with: {exception_text}") + except Exception: + exception_text = traceback.format_exc() + raise + if msg is None: + msg = f"{module} could not be imported" + if alt is None: + return UnavailableMeta(module.rsplit(".")[-1], (), {"_msg": msg}), False + else: + return alt, False + + +def safe_import_from( + module, symbol, *, msg=None, alt=None, fallback_module=None +) -> Tuple[object, bool]: + """A function used to import symbols from modules that may not be available. + + This function will attempt to import a symbol with the given name from + the given module, but it will not throw an ImportError if the symbol is not + found. Instead, it will return a placeholder object which will raise an + exception only if used. + + Args: + module (str): The name of the module in which the symbol is defined. + symbol (str): The name of the symbol to import. + msg (str, optional): An error message to be displayed if this symbol is used + after a failed import. Defaults to None. + alt (object, optional): An object to be used in place of the given symbol if it fails + to import. Defaults to None. + fallback_module (str, optional): Alternative name of the model in which the symbol is defined. + The function will first try to import using the `module` value and if that fails + will also try the `fallback_module`. Defaults to None. + + Returns: + tuple: A tuple containing two elements. The first element is the imported symbol, + the given alternate, or a class derived from UnavailableMeta. The second element + is a boolean indicating whether the intended import was successful. + """ + try: + imported_module = importlib.import_module(module) + return getattr(imported_module, symbol), True + except ImportError: + exception_text = traceback.format_exc() + logger.debug(f"Import of {module} failed with: {exception_text}") + except AttributeError: + # if there is a fallback module try it. + if fallback_module is not None: + return safe_import_from(fallback_module, symbol, msg=msg, alt=alt, fallback_module=None) + exception_text = traceback.format_exc() + logger.info(f"Import of {symbol} from {module} failed with: {exception_text}") + except Exception: + exception_text = traceback.format_exc() + raise + if msg is None: + msg = f"{module}.{symbol} could not be imported" + if alt is None: + return UnavailableMeta(symbol, (), {"_msg": msg}), False + else: + return alt, False + + +def gpu_only_import(module, *, alt=None) -> Tuple[object, bool]: + """A function used to import modules required only in GPU installs. + + This function will attempt to import a module with the given name. + This function will attempt to import a symbol with the given name from + the given module, but it will not throw an ImportError if the symbol is not + found. Instead, it will return a placeholder object which will raise an + exception only if used with instructions on installing a GPU build. + + Args: + module (str): The name of the module to import. + alt (object, optional): A module to be used in place of the given module if it + fails to import in a non-GPU-enabled install. Defaults to None. + + Returns: + tuple: A tuple containing two elements. The first element is the imported module, + the given alternate, or a class derived from UnavailableMeta. The second element + is a boolean indicating whether the intended import was successful. + """ + + return safe_import( + module, + msg=f"{module} is not enabled in non GPU-enabled installations or environemnts. {GPU_INSTALL_STRING}", + alt=alt, + ) + + +def gpu_only_import_from(module, symbol, *, alt=None) -> Tuple[object, bool]: + """A function used to import symbols required only in GPU installs. + + This function will attempt to import a module with the given name. + This function will attempt to import a symbol with the given name from + the given module, but it will not throw an ImportError if the symbol is not + found. Instead, it will return a placeholder object which will raise an + exception only if used with instructions on installing a GPU build. + + Args: + module (str): The name of the module to import. + symbol (str): The name of the symbol to import. + alt (object, optional): An object to be used in place of the given symbol if it fails + to import in a non-GPU-enabled install. Defaults to None. + + Returns: + tuple: A tuple containing two elements. The first element is the imported symbol, + the given alternate, or a class derived from UnavailableMeta. The second element + is a boolean indicating whether the intended import was successful. + """ + return safe_import_from( + module, + symbol, + msg=f"{module}.{symbol} is not enabled in non GPU-enabled installations or environments. {GPU_INSTALL_STRING}", + alt=alt, + ) + + +def get_torch_version(): + """Returns the installed PyTorch version as a packaging.version.Version object. + + Handles potential exceptions during version parsing, returning a dummy version + ("0.0.0") if parsing fails (e.g., during documentation builds where torch + might not be fully imported or available). + + Returns: + packaging.version.Version: The parsed PyTorch version, or Version("0.0.0") on error. + """ + try: + _torch_version = PkgVersion(torch.__version__) + except Exception: + # This is a WAR for building docs, where torch is not actually imported + _torch_version = PkgVersion("0.0.0") + return _torch_version + + +def is_torch_min_version(version, check_equality=True): + """Check if minimum version of `torch` is installed.""" + if check_equality: + return get_torch_version() >= PkgVersion(version) + return get_torch_version() > PkgVersion(version) diff --git a/flagscale/train/bridge/utils/instantiate_utils.py b/flagscale/train/bridge/utils/instantiate_utils.py new file mode 100644 index 0000000000..2bbf9a7eb3 --- /dev/null +++ b/flagscale/train/bridge/utils/instantiate_utils.py @@ -0,0 +1,418 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import copy +import functools +import logging + +from enum import Enum +from textwrap import dedent +from typing import Any, Callable, Sequence, Union + +from omegaconf import OmegaConf +from omegaconf._utils import is_structured_config + + +class InstantiationException(Exception): + """Custom exception type for instantiation errors.""" + + ... + + +class InstantiationMode(Enum): + """Enum for instantiation modes.""" + + STRICT = "strict" + LENIENT = "lenient" + + +class _Keys(str, Enum): + """Special keys in configs used by instantiate.""" + + TARGET = "_target_" + PARTIAL = "_partial_" + CALL = "_call_" + ARGS = "_args_" + + +def instantiate( + config: Any, *args: Any, mode: InstantiationMode = InstantiationMode.LENIENT, **kwargs: Any +) -> Any: + """Instantiate an object or callable from a config object. + + This function takes a configuration object (dictionary, list, OmegaConf config, + or Structured Config instance) and instantiates the target specified within it. + + The config object must contain: + _target_ (str): The fully qualified name of the class or callable to instantiate. + + The config object may also contain: + _args_ (list): Positional arguments for the target. + _partial_ (bool): If True, return a functools.partial object instead of calling + the target. Defaults to False. + _call_ (bool): If False, simply resolves and returns the target without calling it. + Defaults to True. + Additional keyword arguments to pass to the target. + + Args: + config: The configuration object describing the target and its parameters. + *args: Optional positional arguments that will override _args_ in the config + if provided. + mode: Instantiation mode (STRICT or LENIENT). In LENIENT mode (default), + errors during instantiation of parameters are logged as warnings, + and None is used instead. In STRICT mode, errors are raised. + **kwargs: Optional keyword arguments that will override parameters in the config. + Note: Dataclass instances in kwargs are treated as nested configs. + + Returns: + The instantiated object or the return value of the callable. + If config._partial_ is True, returns a functools.partial object. + If config._call_ is False, returns the resolved target callable/class itself. + Returns None if the input config is None. + + Raises: + InstantiationException: If the config is invalid, the target cannot be resolved, + or instantiation fails in STRICT mode. + TypeError: If the _partial_ flag is not a boolean. + """ + + # Return None if config is None + if config is None: + return None + + if isinstance(config, (dict, list)): + config = _prepare_input_dict_or_list(config) + + kwargs = _prepare_input_dict_or_list(kwargs) + + # Structured Config always converted first to OmegaConf + if is_structured_config(config) or isinstance(config, (dict, list)): + config = OmegaConf.structured(config, flags={"allow_objects": True}) + + if OmegaConf.is_dict(config): + # Finalize config (convert targets to strings, merge with kwargs) + config_copy = copy.deepcopy(config) + config_copy._set_flag( + flags=["allow_objects", "struct", "readonly"], values=[True, False, False] + ) + config_copy._set_parent(config._get_parent()) + config = config_copy + + if kwargs: + config = OmegaConf.merge(config, kwargs) + + OmegaConf.resolve(config) + + _partial_ = config.pop(_Keys.PARTIAL, False) + + return instantiate_node(config, *args, partial=_partial_, mode=mode) + elif OmegaConf.is_list(config): + # Finalize config (convert targets to strings, merge with kwargs) + config_copy = copy.deepcopy(config) + config_copy._set_flag( + flags=["allow_objects", "struct", "readonly"], values=[True, False, False] + ) + config_copy._set_parent(config._get_parent()) + config = config_copy + + OmegaConf.resolve(config) + + _partial_ = kwargs.pop(_Keys.PARTIAL, False) + + if _partial_: + raise InstantiationException( + "The _partial_ keyword is not compatible with top-level list instantiation" + ) + + return instantiate_node(config, *args, partial=_partial_, mode=mode) + else: + raise InstantiationException( + dedent( + f"""\ + Cannot instantiate config of type {type(config).__name__}. + Top level config must be an OmegaConf DictConfig/ListConfig object, + a plain dict/list, or a Structured Config class or instance.""" + ) + ) + + +def instantiate_node( + node: Any, + *args: Any, + partial: bool = False, + mode: InstantiationMode = InstantiationMode.LENIENT, +) -> Any: + """Recursively instantiates a node within a configuration structure. + + This function handles the instantiation of individual nodes (dictionaries, + lists, or primitive values) within a larger configuration tree, typically + managed by OmegaConf. + + If the node is a dictionary containing a `_target_` key, it resolves and + instantiates the target callable/class using the other items in the + dictionary as keyword arguments. Nested nodes are recursively instantiated. + + If the node is a list, it recursively instantiates each item in the list. + + If the node is not an OmegaConf config node (e.g., a primitive type), it's + returned directly. + + Args: + node: The configuration node to instantiate (can be DictConfig, ListConfig, + or a primitive type). + *args: Positional arguments passed down from the top-level `instantiate` call, + used primarily for the final target call if the node is a dictionary + with `_target_`. + partial: Boolean flag indicating whether to return a `functools.partial` object + instead of calling the target. This can be overridden by a + `_partial_` key within the node itself. + mode: Instantiation mode (STRICT or LENIENT). Determines error handling + behavior for nested instantiations. + + Returns: + The instantiated object, list, or the original node if it wasn't a config. + Returns None if the input node is None or represents a None value in OmegaConf. + + Raises: + InstantiationException: If instantiation fails in STRICT mode, or if there are + issues like incompatible arguments or non-callable targets. + TypeError: If a `_partial_` flag within the config is not a boolean. + """ + # Return None if config is None + if node is None or (OmegaConf.is_config(node) and node._is_none()): + return None + + if not OmegaConf.is_config(node): + return node + + # Override parent modes from config if specified + if OmegaConf.is_dict(node): + # using getitem instead of get(key, default) because OmegaConf will raise an exception + # if the key type is incompatible on get. + partial = node[_Keys.PARTIAL] if _Keys.PARTIAL in node else partial + + full_key = node._get_full_key(None) + + if not isinstance(partial, bool): + msg = f"Instantiation: _partial_ flag must be a bool, got {type(partial)}" + if node and full_key: + msg += f"\nfull_key: {full_key}" + raise TypeError(msg) + + if OmegaConf.is_list(node): + items = [instantiate_node(item, mode=mode) for item in node._iter_ex(resolve=True)] + + return items + elif OmegaConf.is_dict(node): + exclude_keys = set(item.value for item in _Keys if item != _Keys.ARGS) + if _is_target(node): + should_call_target = node.get("_call_", True) + _target_ = _resolve_target( + node.get(_Keys.TARGET), full_key, check_callable=should_call_target + ) + kwargs = {} + is_partial = node.get("_partial_", False) or partial + + if not should_call_target: + if len(set(node.keys()) - {"_target_", "_call_"}) != 0: + extra_keys = set(node.keys()) - {"_target_", "_call_"} + raise InstantiationException( + f"_call_ was set to False for target {_convert_target_to_string(_target_)}," + f" but extra keys were found: {extra_keys}" + ) + else: + return _target_ + + for key in node.keys(): + if key not in exclude_keys: + if OmegaConf.is_missing(node, key) and is_partial: + continue + value = node[key] + try: + value = instantiate_node(value, mode=mode) + except (ImportError, InstantiationException) as e: + if mode == InstantiationMode.STRICT: + raise InstantiationException( + f"Error instantiating {value} for key {full_key}.{key}: {e}" + ) from e + else: + value = None + logging.warning( + f"Error instantiating {value} for key {full_key}.{key}. " + f"Using None instead in lenient mode." + ) + kwargs[key] = _convert_node(value) + + assert callable(_target_) + return _call_target(_target_, partial, args, kwargs, full_key) + else: + dict_items = {} + for key, value in node.items(): + dict_items[key] = instantiate_node(value, mode=mode) + return dict_items + + else: + assert False, f"Unexpected config type : {type(node).__name__}" + + +def _locate(path: str) -> Any: + """ + Locate an object by name or dotted path, importing as necessary. + This function attempts to import modules starting from the most specific path + (back to front), making it possible to import objects where the final component + could be either a module or an attribute of the previous module. + """ + if path == "": + raise ImportError("Empty path") + from importlib import import_module + + parts = [part for part in path.split(".")] + for part in parts: + if not len(part): + raise ValueError( + f"Error loading '{path}': invalid dotstring." + + "\nRelative imports are not supported." + ) + assert len(parts) > 0 + + # Try importing from the most specific path first (back to front) + for i in range(len(parts), 0, -1): + module_path = ".".join(parts[:i]) + try: + obj = import_module(module_path) + + # If this isn't the full path, get the remaining attributes + remaining_parts = parts[i:] + for part in remaining_parts: + try: + obj = getattr(obj, part) + except AttributeError as exc_attr: + raise ImportError( + f"Error loading '{path}':\n{repr(exc_attr)}" + + f"\nAre you sure that '{part}' is an attribute of '{module_path}'?" + ) from exc_attr + + # Successfully found the object + return obj + + except ModuleNotFoundError: + # Module not found, try a less specific path + continue + except Exception as exc_import: + # If we hit a different exception, it's likely an issue with the module itself + raise ImportError(f"Error loading '{path}':\n{repr(exc_import)}") from exc_import + + # If we've tried all paths and nothing worked, report failure with the base module + raise ImportError( + f"Error loading '{path}': Unable to import any module in the path. " + f"Are you sure that module '{parts[0]}' is installed?" + ) + + +def _is_target(x: Any) -> bool: + if isinstance(x, dict): + return "_target_" in x + if OmegaConf.is_dict(x): + return "_target_" in x + return False + + +def _call_target( + _target_: Callable[..., Any], + _partial_: bool, + args: tuple[Any, ...], + kwargs: dict[str, Any], + full_key: str, +) -> Any: + """Call target (type) with args and kwargs.""" + args, kwargs = _extract_pos_args(args, kwargs) + if _partial_: + try: + return functools.partial(_target_, *args, **kwargs) + except Exception as e: + msg = ( + f"Error in creating partial({_convert_target_to_string(_target_)}, ...) object:" + + f"\n{repr(e)}" + ) + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e + else: + try: + return _target_(*args, **kwargs) + except Exception as e: + msg = f"Error in call to target '{_convert_target_to_string(_target_)}':\n{repr(e)}" + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e + + +def _convert_target_to_string(t: Any) -> Any: + if callable(t): + return f"{t.__module__}.{t.__qualname__}" + else: + return t + + +def _prepare_input_dict_or_list(d: Union[dict[Any, Any], list[Any]]) -> Any: + res: Any + if isinstance(d, dict): + res = {} + for k, v in d.items(): + if k == "_target_": + v = _convert_target_to_string(d["_target_"]) + elif isinstance(v, (dict, list)): + v = _prepare_input_dict_or_list(v) + res[k] = v + elif isinstance(d, list): + res = [] + for v in d: + if isinstance(v, (list, dict)): + v = _prepare_input_dict_or_list(v) + res.append(v) + else: + assert False + return res + + +def _resolve_target( + target: Union[str, type, Callable[..., Any]], full_key: str, check_callable: bool = True +) -> Union[type, Callable[..., Any], object]: + """Resolve target string, type or callable into type or callable.""" + if isinstance(target, str): + try: + target = _locate(target) + except Exception as e: + msg = f"Error locating target '{target}'." + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) from e + if check_callable and not callable(target): + msg = f"Expected a callable target, got '{target}' of type '{type(target).__name__}'" + if full_key: + msg += f"\nfull_key: {full_key}" + raise InstantiationException(msg) + return target + + +def _extract_pos_args(input_args: Any, kwargs: Any) -> tuple[Any, Any]: + config_args = kwargs.pop(_Keys.ARGS, ()) + output_args = config_args + + if isinstance(config_args, Sequence): + if len(input_args) > 0: + output_args = input_args + else: + raise InstantiationException( + f"Unsupported _args_ type: '{type(config_args).__name__}'. value: '{config_args}'" + ) + + return output_args, kwargs + + +def _convert_node(node: Any) -> Any: + if OmegaConf.is_config(node): + node = OmegaConf.to_container(node, resolve=True) + + return node diff --git a/flagscale/train/bridge/utils/path_utils.py b/flagscale/train/bridge/utils/path_utils.py new file mode 100644 index 0000000000..0fe9c30ee8 --- /dev/null +++ b/flagscale/train/bridge/utils/path_utils.py @@ -0,0 +1,10 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +from pathlib import Path + + +def resolve_path(path: str) -> Path: + """Resolve a path to an absolute path.""" + return Path(path).expanduser().absolute().resolve() diff --git a/flagscale/train/bridge/utils/vocab_utils.py b/flagscale/train/bridge/utils/vocab_utils.py new file mode 100644 index 0000000000..86012a8874 --- /dev/null +++ b/flagscale/train/bridge/utils/vocab_utils.py @@ -0,0 +1,64 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import math + +from functools import lru_cache + +from flagscale.train.bridge.utils.common_utils import print_rank_0 + + +def calculate_padded_vocab_size( + vocab_size: int, + make_vocab_size_divisible_by: int, + tensor_model_parallel_size: int, + logging_enabled: bool = True, +) -> int: + """Calculate padded vocab size for tensor parallelism. + + This function pads the vocabulary size to ensure it's divisible by the required + multiple for efficient tensor parallel operations. + + Args: + vocab_size: The original (unpadded) vocabulary size + make_vocab_size_divisible_by: Base divisibility requirement (e.g., 128) + tensor_model_parallel_size: Number of tensor parallel ranks + logging_enabled: Whether to log the padding information + + Returns: + int: The padded vocabulary size + """ + padded_size = _calculate_padded_vocab_size_cached( + vocab_size, make_vocab_size_divisible_by, tensor_model_parallel_size + ) + + # Handle logging separately to avoid affecting cache behavior + if logging_enabled: + print_rank_0( + " > padded vocab (size: {}) with {} dummy tokens (new size: {})".format( + vocab_size, padded_size - vocab_size, padded_size + ) + ) + + return padded_size + + +@lru_cache(maxsize=128) +def _calculate_padded_vocab_size_cached( + vocab_size: int, make_vocab_size_divisible_by: int, tensor_model_parallel_size: int +) -> int: + """Cached computation of padded vocab size.""" + if vocab_size <= 0: + raise ValueError(f"vocab_size must be positive, got {vocab_size}") + if make_vocab_size_divisible_by <= 0: + raise ValueError( + f"make_vocab_size_divisible_by must be positive, got {make_vocab_size_divisible_by}" + ) + if tensor_model_parallel_size <= 0: + raise ValueError( + f"tensor_model_parallel_size must be positive, got {tensor_model_parallel_size}" + ) + + multiple = make_vocab_size_divisible_by * tensor_model_parallel_size + return int(math.ceil(vocab_size / multiple) * multiple) diff --git a/flagscale/train/bridge/utils/yaml_utils.py b/flagscale/train/bridge/utils/yaml_utils.py new file mode 100644 index 0000000000..f38553d6a7 --- /dev/null +++ b/flagscale/train/bridge/utils/yaml_utils.py @@ -0,0 +1,203 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +import enum +import functools +import inspect + +from contextlib import contextmanager +from typing import Any, Generator, Optional + +import yaml + + +@contextmanager +def safe_yaml_representers() -> Generator[None, None, None]: + """ + Context manager for safely adding and removing custom YAML representers. + + Temporarily adds custom representers for functions, classes, and other objects + to the YAML SafeDumper, and restores the original representers when exiting + the context. + + Usage: + with safe_yaml_representers(): + yaml_str = yaml.safe_dump(my_complex_object) + """ + # Save original representers + original_representers = yaml.SafeDumper.yaml_representers.copy() + original_multi_representers = yaml.SafeDumper.yaml_multi_representers.copy() + + try: + # Register custom representers + + # Partial representer + yaml.SafeDumper.add_representer(functools.partial, _partial_representer) + + # Enum representer + yaml.SafeDumper.add_multi_representer(enum.Enum, _enum_representer) + + # Function representer + yaml.SafeDumper.add_representer(type(lambda: ...), _function_representer) + yaml.SafeDumper.add_representer(type(object), _function_representer) + + # Try to add torch dtype representer if available + try: + import torch + + yaml.SafeDumper.add_representer(torch.dtype, _torch_dtype_representer) + except ModuleNotFoundError: + pass + + # Try to add GenerationConfig representer if available + try: + from transformers import GenerationConfig + + yaml.SafeDumper.add_representer(GenerationConfig, _generation_config_representer) + except ModuleNotFoundError: + pass + + # Try to add PretrainedConfig representer if available (generic for HF configs) + try: + from transformers import PretrainedConfig + + # Use multi-representer so subclasses of PretrainedConfig are also handled + yaml.SafeDumper.add_multi_representer(PretrainedConfig, _pretrained_config_representer) + except ModuleNotFoundError: + pass + + # General object representer + yaml.SafeDumper.add_multi_representer(object, _safe_object_representer) + + yield + finally: + # Restore original representers + yaml.SafeDumper.yaml_representers = original_representers + yaml.SafeDumper.yaml_multi_representers = original_multi_representers + + +def dump_dataclass_to_yaml(obj: Any, filename: Optional[str] = None) -> Optional[str]: + """Dump a dataclass object or other Python object to a YAML file or string. + + Uses safe representers to handle common types. + + Args: + obj: The object to dump. + filename: If provided, the path to the file where YAML should be written. + If None, returns the YAML string directly. + + Returns: + If filename is None, returns the YAML string representation of the object. + Otherwise, returns None. + """ + with safe_yaml_representers(): + if filename is not None: + with open(filename, "w+") as f: + yaml.safe_dump(obj, f) + else: + return yaml.safe_dump(obj) + + +def _function_representer(dumper, data): + """Represent functions in YAML.""" + value = { + "_target_": f"{inspect.getmodule(data).__name__}.{data.__qualname__}", # type: ignore + "_call_": False, + } + return dumper.represent_data(value) + + +def _torch_dtype_representer(dumper, data): + """Represent torch dtypes in YAML.""" + value = {"_target_": str(data), "_call_": False} + return dumper.represent_data(value) + + +def _safe_object_representer(dumper, data): + """ + General object representer for YAML. + + This function is a fallback for objects that don't have specific representers. + If the object has __qualname__ attr, + the _target_ is set to f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}". + If the object does not have a __qualname__ attr, the _target_ is set from its __class__ attr. + The _call_ key is used to indicate whether the target should be called to create an instance. + + Args: + dumper (yaml.Dumper): The YAML dumper to use for serialization. + data (Any): The data to serialize. + + Returns: + The YAML representation of the data. + """ + try: + obj = data + target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}" + call = False + except AttributeError: + obj = data.__class__ + target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}" + call = True + + value = {"_target_": target, "_call_": call} # type: ignore + return dumper.represent_data(value) + + +def _partial_representer(dumper, data): + """Represent functools.partial objects in YAML.""" + # Get the underlying function + func = data.func + + # Create a dictionary representation + value = { + "_target_": f"{inspect.getmodule(func).__name__}.{func.__qualname__}", + "_partial_": True, + "_args_": list(data.args) if data.args else [], + } + + # Add keyword arguments if any exist + if data.keywords: + for k, v in data.keywords.items(): + value[k] = v + + return dumper.represent_data(value) + + +def _enum_representer(dumper, data): + """Represent enums in YAML.""" + # Create a dictionary representation + enum_class = data.__class__ + value = { + "_target_": f"{inspect.getmodule(enum_class).__name__}.{enum_class.__qualname__}", + "_call_": True, + "_args_": [data.value], + } + + return dumper.represent_data(value) + + +def _generation_config_representer(dumper, data): + """Represent transformers GenerationConfig objects in YAML.""" + cls = data.__class__ + value = { + "_target_": f"{inspect.getmodule(cls).__name__}.{cls.__qualname__}.from_dict", + "_call_": True, + "config_dict": data.to_dict(), + } + + return dumper.represent_data(value) + + +def _pretrained_config_representer(dumper, data): + """Represent transformers PretrainedConfig objects in YAML generically. + + Uses the class's from_dict/to_dict methods to ensure full round-trip of all fields. + """ + cls = data.__class__ + value = { + "_target_": f"{inspect.getmodule(cls).__name__}.{cls.__qualname__}.from_dict", + "_call_": True, + "config_dict": data.to_dict(), + } + return dumper.represent_data(value)