Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py
index 5ff62f74c..3c0571b2d 100644
index 5ff62f74c..46bd78b24 100644
--- a/megatron/core/transformer/transformer_config.py
+++ b/megatron/core/transformer/transformer_config.py
@@ -317,6 +317,15 @@ class TransformerConfig(ModelParallelConfig):
@@ -212,6 +212,9 @@ class TransformerConfig(ModelParallelConfig):
moe_deepep_num_sms: int = 20
"""Number of SMs to use for DeepEP."""

+ untie_embeddings_and_output_weights: bool = False
+ """The model's input word embedding matrix and the output layer's weight matrix are tied"""
+
####################
# initialization
####################
@@ -317,6 +320,15 @@ class TransformerConfig(ModelParallelConfig):
the number of transformer layers to recompute within each pipeline stage. Must be None for
'selective' activation checkpointing."""

Expand All @@ -18,7 +28,7 @@ index 5ff62f74c..3c0571b2d 100644
distribute_saved_activations: Optional[bool] = None
"""If True, distribute recomputed activations across the model parallel group."""

@@ -417,6 +426,12 @@ class TransformerConfig(ModelParallelConfig):
@@ -417,6 +429,12 @@ class TransformerConfig(ModelParallelConfig):
together with fp4 mode (i.e., TransformerConfig.fp4 is not None). Note that not all parameters
will be converted to fp4; for example, biases will remain unchanged."""

Expand All @@ -31,7 +41,7 @@ index 5ff62f74c..3c0571b2d 100644
####################
# MoE related
####################
@@ -644,6 +659,9 @@ class TransformerConfig(ModelParallelConfig):
@@ -644,6 +662,9 @@ class TransformerConfig(ModelParallelConfig):
config_logger_dir: str = ""
"""When non-empty, dumps entry-point configs to config_logger_dir"""

Expand All @@ -41,7 +51,7 @@ index 5ff62f74c..3c0571b2d 100644
flash_decode: bool = False
""" Use the optimized flash decoding kernel during inference. """

@@ -705,6 +723,31 @@ class TransformerConfig(ModelParallelConfig):
@@ -705,6 +726,31 @@ class TransformerConfig(ModelParallelConfig):
"""Transformer implementation to use.
Options are 'transformer_engine' for Transformer Engine and 'local' for MCore."""

Expand Down Expand Up @@ -73,7 +83,7 @@ index 5ff62f74c..3c0571b2d 100644
def __post_init__(self):
"""Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
@@ -1481,6 +1524,9 @@ class TransformerConfig(ModelParallelConfig):
@@ -1481,6 +1527,9 @@ class TransformerConfig(ModelParallelConfig):
f"the number of layers ({self.num_layers})"
)

Expand All @@ -83,7 +93,7 @@ index 5ff62f74c..3c0571b2d 100644

@dataclass
class MLATransformerConfig(TransformerConfig):
@@ -1569,3 +1615,4 @@ class MLATransformerConfig(TransformerConfig):
@@ -1569,3 +1618,4 @@ class MLATransformerConfig(TransformerConfig):
assert (
self.apply_rope_fusion is False
), "Rope Fusion is not compatible with caching latents"
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py
index 104fa6882..722859bf6 100644
index 104fa6882..7326bea6f 100644
--- a/megatron/training/checkpointing.py
+++ b/megatron/training/checkpointing.py
@@ -286,12 +286,15 @@ def read_metadata(tracker_filename):
Expand All @@ -20,7 +20,31 @@ index 104fa6882..722859bf6 100644
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
max_iter = iters_cuda[0].item()

@@ -692,6 +695,28 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
@@ -473,6 +476,23 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
# Save dataloader state if the dataloader supports it (currently only Megatron Energon).
maybe_save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None))

+ # save hf format model weight
+ hf_checkpoint_name = get_checkpoint_name(save_dir, iteration, release=False, pipeline_parallel=pipeline_parallel,
+ tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=True)
+ if args.save_hf and hasattr(args,'hf_config_path'):
+ assert args.hf_config_path is not None, "hf_config_path should not be None"
+ #use megatron bridge
+ from flagscale.train.bridge.models import AutoBridge
+ from flagscale.train.bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry
+ from transformers import AutoConfig
+ #Load the HF model from config
+ config_load=args.hf_config_path
+ config = safe_load_config_with_retry(config_load, trust_remote_code=False)
+ bridge=AutoBridge.from_hf_config(config)
+ #Save the HF model weights in the corresponding iteration's safetensor folder.
+ safe_save=os.path.join(hf_checkpoint_name, 'safetensor')
+ bridge.save_hf_pretrained(model=model,path=safe_save)
+
# Save distributed optimizer's custom parameter state.
if (
args.use_distributed_optimizer
@@ -692,6 +712,28 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
if not torch.distributed.is_initialized() \
or is_last_rank():
def wandb_finalize_fn():
Expand Down Expand Up @@ -49,7 +73,7 @@ index 104fa6882..722859bf6 100644
wandb_utils.on_save_checkpoint_success(checkpoint_name, get_checkpoint_tracker_filename(save_dir), save_dir, iteration)
if args.async_save:
assert async_save_request is not None
@@ -774,9 +799,7 @@ def maybe_save_dataloader_state(train_iterator, iteration, dataloader_save_path)
@@ -774,9 +816,7 @@ def maybe_save_dataloader_state(train_iterator, iteration, dataloader_save_path)

torch.distributed.barrier(group=mpu.get_data_parallel_group())

Expand All @@ -60,7 +84,7 @@ index 104fa6882..722859bf6 100644
torch.distributed.barrier(group=mpu.get_data_parallel_group())

dataloader_save_dict = {}
@@ -1239,6 +1262,10 @@ def load_args_from_checkpoint(
@@ -1239,6 +1279,10 @@ def load_args_from_checkpoint(
checkpoint_args, 'add_bias_linear', not getattr(checkpoint_args, 'disable_bias_linear')
)

Expand All @@ -71,7 +95,7 @@ index 104fa6882..722859bf6 100644
def _set_arg(arg_name, old_arg_name=None, force=False):
if not force and getattr(args, arg_name, None) is not None:
return
@@ -1274,6 +1301,8 @@ def load_args_from_checkpoint(
@@ -1274,6 +1318,8 @@ def load_args_from_checkpoint(
_set_arg('add_qkv_bias', force=True)
_set_arg('squared_relu', force=True)
_set_arg('swiglu', force=True)
Expand All @@ -80,7 +104,25 @@ index 104fa6882..722859bf6 100644
_set_arg('untie_embeddings_and_output_weights', force=True)
_set_arg('apply_layernorm_1p', force=True)
_set_arg('normalization', force=True)
@@ -1432,6 +1461,14 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
@@ -1347,6 +1393,17 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
args = get_args()
load_dir = getattr(args, load_arg)

+ # load hf format
+ if args.load_hf:
+ # use megatron bridge
+ from flagscale.train.bridge.models import AutoBridge
+ bridge=AutoBridge.from_hf_pretrained(load_dir)
+ bridge.load_hf_weights(ddp_model)
+ # no optimizer weight
+ iteration=0
+ num_floating_point_operations_so_far=0
+ return iteration, num_floating_point_operations_so_far
+
# Check for model-opt format loading
if hasattr(args, 'load_model_opt_format') and args.load_model_opt_format:
print_rank_0(f'Loading checkpoint using ModelOpt format from {load_dir}')
@@ -1432,6 +1489,14 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
mismatch_msg = "(TP, PP) mismatch after resume ({} vs {} from checkpoint)".format(
run_tp_pp, ckpt_tp_pp
)
Expand All @@ -95,15 +137,15 @@ index 104fa6882..722859bf6 100644

# Determine if RNG state will be loaded
if (ckpt_tp_pp == run_tp_pp and not release and not args.finetune and not args.no_load_rng
@@ -1468,6 +1505,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
@@ -1468,6 +1533,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
ckpt_tp_pp != run_tp_pp
and sharded_sd_metadata['distrib_optim_sharding_type']
not in DistributedOptimizer.checkpoint_fully_reshardable_formats
+ and convert_to_ep ########## FlagScale Added ##########
):
raise RuntimeError(f"{mismatch_msg}: not supported for DistributedOptimizer with sharding type"
f" {sharded_sd_metadata['distrib_optim_sharding_type']}."
@@ -1481,7 +1519,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
@@ -1481,7 +1547,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
gen_sd_optim = None
gen_sd_opt_param_scheduler = None

Expand All @@ -112,7 +154,7 @@ index 104fa6882..722859bf6 100644
model_sd_kwargs = dict(metadata=sharded_sd_metadata)

# Determine if rerun state will be loaded
@@ -1829,3 +1867,4 @@ def load_biencoder_checkpoint(model, only_query_model=False,
@@ -1829,3 +1895,4 @@ def load_biencoder_checkpoint(model, only_query_model=False,
print(' successfully loaded {}'.format(checkpoint_name))

return model
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
diff --git a/megatron/training/yaml_arguments.py b/megatron/training/yaml_arguments.py
index 390d503ee..ae1f3069d 100644
--- a/megatron/training/yaml_arguments.py
+++ b/megatron/training/yaml_arguments.py
@@ -409,7 +409,8 @@ def core_transformer_config_from_yaml(args, transfomer_key = "language_model"):
# Hardcoded
kw_args['deallocate_pipeline_outputs'] = True
kw_args['pipeline_dtype'] = kw_args['params_dtype']
- kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
+ kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
+ kw_args['untie_embeddings_and_output_weights'] = args.untie_embeddings_and_output_weights

assert args.activation_func in ["swiglu","squaredrelu","gelu"], f"{args.activation_func} is not a supported activation function"
if args.activation_func == "swiglu":
8 changes: 8 additions & 0 deletions flagscale/train/bridge/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) 2025, BAAI. All rights reserved.
#
# Mainly adapted from: https://github.com/NVIDIA-NeMo/Megatron-Bridge
"""Megatron Bridge - A component of the Megatron ecosystem."""

from flagscale.train.bridge.models.conversion.auto_bridge import AutoBridge

__all__ = ["AutoBridge"]
99 changes: 99 additions & 0 deletions flagscale/train/bridge/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2025, BAAI. All rights reserved.
#
# Copied from: https://github.com/NVIDIA-NeMo/Megatron-Bridge

from flagscale.train.bridge.models.conversion.auto_bridge import AutoBridge
from flagscale.train.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
from flagscale.train.bridge.models.conversion.model_bridge import MegatronModelBridge
from flagscale.train.bridge.models.conversion.param_mapping import (
AutoMapping,
ColumnParallelMapping,
GatedMLPMapping,
MegatronParamMapping,
QKVMapping,
ReplicatedMapping,
RowParallelMapping,
)
from flagscale.train.bridge.models.deepseek import (
DeepSeekModelProvider,
DeepSeekProvider,
DeepSeekV2LiteModelProvider,
DeepSeekV2LiteProvider,
DeepSeekV2ModelProvider,
DeepSeekV2Provider,
DeepSeekV3ModelProvider,
DeepSeekV3Provider,
MoonlightModelProvider16B,
MoonlightProvider,
)
from flagscale.train.bridge.models.gpt_provider import GPTModelProvider
from flagscale.train.bridge.models.qwen import (
Qwen2ModelProvider,
Qwen2ModelProvider1P5B,
Qwen2ModelProvider7B,
Qwen2ModelProvider72B,
Qwen2ModelProvider500M,
Qwen3ModelProvider,
Qwen3ModelProvider1P7B,
Qwen3ModelProvider4B,
Qwen3ModelProvider8B,
Qwen3ModelProvider14B,
Qwen3ModelProvider32B,
Qwen3ModelProvider600M,
Qwen3MoEModelProvider,
Qwen3MoEModelProvider30B_A3B,
Qwen3MoEModelProvider235B_A22B,
Qwen25ModelProvider1P5B,
Qwen25ModelProvider3B,
Qwen25ModelProvider7B,
Qwen25ModelProvider14B,
Qwen25ModelProvider32B,
Qwen25ModelProvider72B,
Qwen25ModelProvider500M,
)

__all__ = [
"AutoBridge",
"MegatronMappingRegistry",
"MegatronModelBridge",
"ColumnParallelMapping",
"GatedMLPMapping",
"MegatronParamMapping",
"QKVMapping",
"ReplicatedMapping",
"RowParallelMapping",
"AutoMapping",
"GPTModelProvider",
"Qwen2ModelProvider",
"Qwen2ModelProvider500M",
"Qwen2ModelProvider1P5B",
"Qwen2ModelProvider7B",
"Qwen2ModelProvider72B",
"Qwen25ModelProvider500M",
"Qwen25ModelProvider1P5B",
"Qwen25ModelProvider3B",
"Qwen25ModelProvider7B",
"Qwen25ModelProvider14B",
"Qwen25ModelProvider32B",
"Qwen25ModelProvider72B",
"Qwen3ModelProvider",
"Qwen3ModelProvider600M",
"Qwen3ModelProvider1P7B",
"Qwen3ModelProvider4B",
"Qwen3ModelProvider8B",
"Qwen3ModelProvider14B",
"Qwen3ModelProvider32B",
"Qwen3MoEModelProvider",
"Qwen3MoEModelProvider30B_A3B",
"Qwen3MoEModelProvider235B_A22B",
"DeepSeekModelProvider",
"DeepSeekProvider",
"DeepSeekV2LiteModelProvider",
"DeepSeekV2LiteProvider",
"DeepSeekV2ModelProvider",
"DeepSeekV2Provider",
"DeepSeekV3ModelProvider",
"DeepSeekV3Provider",
"MoonlightModelProvider16B",
"MoonlightProvider",
]
Loading