From 20f5340b0e7fbc0d292e3248862c79e43e33430f Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Mon, 9 Feb 2026 15:37:48 -0800 Subject: [PATCH 1/2] detect MTP, copy the original mtp.safetensors, update the index file Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/example_utils.py | 109 ++++++++++----------- modelopt/torch/export/unified_export_hf.py | 62 ++++++++++++ 2 files changed, 116 insertions(+), 55 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 93687a8d0..f5f6ca092 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -28,7 +28,6 @@ import transformers from accelerate import infer_auto_device_map, init_empty_weights from accelerate.utils import get_max_memory -from safetensors.torch import load_file from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -316,32 +315,36 @@ def get_processor( return None -def load_mtp_weights( - model: torch.nn.Module, model_path: str -) -> tuple[list[str], dict[str, torch.Tensor]]: - """Load MTP weights from the model checkpoint. +def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[str]: + """Detect MTP weights in separate safetensors files (e.g., GLM-4.7). - Some models store additional layers in separate safetensors files with non-standard - names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these - files even though they're referenced in model.safetensors.index.json. + Some models store MTP (Multi-Token Prediction) layers in separate safetensors files + (e.g., mtp.safetensors) that are referenced in model.safetensors.index.json but + not loaded by HuggingFace transformers (because the model architecture doesn't + include these layers). - This function detects such cases and explicitly loads the missing weights. + This function: + 1. Detects non-standard safetensors files with weights not in the model + 2. Stores info about these files on the model for later export (model._mtp_files_info) + 3. Returns the layer prefixes (e.g., ["model.layers.92"]) for quantization exclusion + + Note: The weights are NOT loaded into the model (since the model architecture doesn't + support them), but we track them so they can be copied during export. Args: - model: The loaded model that may be missing weights + model: The loaded model model_path: Path to the model directory Returns: - List of layer prefixes that were loaded from non-standard safetensors files. + List of layer prefixes that contain MTP weights (e.g., ["model.layers.92"]). These layers should typically be excluded from quantization. - Empty list if no additional weights were loaded. - Dictionary of MTP weights that were not loaded into the model state dict. + Empty list if no MTP weights were found. """ model_path = Path(model_path) index_file = model_path / "model.safetensors.index.json" if not index_file.exists(): - return [], {} + return [] # Load the index to find all referenced safetensors files index = json.load(open(index_file)) @@ -353,58 +356,54 @@ def load_mtp_weights( mtp_weight_map.setdefault(v, []).append(k) if not mtp_weight_map: - return [], {} + return [] - def _extract_layer_prefixes(keys): - mtp_layer_prefixes = set() - for key in keys: - parts = key.split(".") - for i, part in enumerate(parts): - if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): - prefix = ".".join(parts[: i + 2]) - mtp_layer_prefixes.add(prefix) - break - - return mtp_layer_prefixes - - # Flatten mtp_weight_map.values() (list of list of str) to a single list of str - mtp_keys = [k for keys in mtp_weight_map.values() for k in keys] - mtp_layer_prefixes = _extract_layer_prefixes(mtp_keys) - - # Check which non-standard files exist and have missing weights + # Check which non-standard files exist and have weights not in the model model_state = model.state_dict() - total_loaded = 0 - - not_in_state_dict = {} + mtp_files_info = [] # Store info for export: [{source_path, filename, weight_map}] + mtp_layer_prefixes = [] - for filename, mtp_keys in mtp_weight_map.items(): + for filename in mtp_weight_map: filepath = model_path / filename if not filepath.exists(): continue - print(f"Loading {len(mtp_keys)} mtp weights from {filename}...") - weights = load_file(str(filepath), device="cpu") - weights = {k: v for k, v in weights.items() if k in mtp_keys} - # Load the MTP weights to the model state dict - in_state_dict = {k: weights[k] for k in weights if k in model_state} - not_in_state_dict = not_in_state_dict | { - k: weights[k] for k in weights if k not in model_state - } - - if in_state_dict: - model.load_state_dict(in_state_dict, strict=False) - total_loaded += len(in_state_dict) - - if total_loaded > 0: - print( - f"✓ Successfully loaded {total_loaded} MTP weights, " - f"{len(not_in_state_dict)} MTP weights not in model.state_dict" - ) + # Find keys that should be in this file + expected_keys = [k for k, v in index["weight_map"].items() if v == filename] + + # Check which are missing from the model (i.e., model doesn't have these modules) + missing_keys = [k for k in expected_keys if k not in model_state] + + # Extract layer prefixes from all expected keys + for key in expected_keys: + parts = key.split(".") + for i, part in enumerate(parts): + if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): + prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92" + if prefix not in mtp_layer_prefixes: + mtp_layer_prefixes.append(prefix) + break + + # If there are missing keys, the model architecture doesn't support these weights + # Store info for copying during export + if missing_keys: + file_weight_map = dict.fromkeys(expected_keys, filename) + mtp_files_info.append({ + "source_path": str(filepath), + "filename": filename, + "weight_map": file_weight_map, + }) + print(f"Found {len(expected_keys)} MTP weights in {filename} (will copy during export)") + + # Store MTP file info on the model for use during export + if mtp_files_info: + model._mtp_files_info = mtp_files_info + print(f"✓ Stored {len(mtp_files_info)} MTP file(s) info for export") if mtp_layer_prefixes: print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}") - return list(mtp_layer_prefixes), not_in_state_dict + return mtp_layer_prefixes def get_dtype(dtype): diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 5703f4515..a7961adc0 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -18,6 +18,7 @@ import collections.abc import json import re +import shutil import tempfile import warnings from builtins import ValueError @@ -954,6 +955,64 @@ def _export_diffusers_checkpoint( print(f"Export complete. Saved to: {export_dir}") +def _copy_mtp_files_if_needed(model: nn.Module, export_dir: Path) -> None: + """Copy MTP (Multi-Token Prediction) safetensors files if they exist. + + Some models like GLM-4.7 have MTP layers stored in separate safetensors files + (e.g., mtp.safetensors) that aren't part of the model's state_dict because + HuggingFace Transformers doesn't create the corresponding modules. + + This function copies those files to the export directory and updates the + model.safetensors.index.json to include the MTP weights. + + Args: + model: The model being exported (may have _mtp_files_info attribute) + export_dir: The export directory path + """ + mtp_files_info = getattr(model, "_mtp_files_info", None) + if not mtp_files_info: + return + + export_dir = Path(export_dir) + index_file = export_dir / "model.safetensors.index.json" + + # Load existing index if present + if index_file.exists(): + with open(index_file) as f: + index_data = json.load(f) + else: + # Create a basic index structure if it doesn't exist + index_data = {"metadata": {}, "weight_map": {}} + + # Copy each MTP file and update the index + for mtp_info in mtp_files_info: + source_path = Path(mtp_info["source_path"]) + filename = mtp_info["filename"] + weight_map = mtp_info["weight_map"] + + if not source_path.exists(): + print(f"Warning: MTP source file not found: {source_path}") + continue + + dest_path = export_dir / filename + + # Copy the file + print(f"Copying MTP file: {filename}") + shutil.copy2(source_path, dest_path) + + # Update the weight map in the index + for weight_name, file_name in weight_map.items(): + index_data["weight_map"][weight_name] = file_name + + print(f"✓ Copied {filename} with {len(weight_map)} weights") + + # Write updated index + with open(index_file, "w") as f: + json.dump(index_data, f, indent=2) + + print("✓ Updated model.safetensors.index.json with MTP weights") + + def export_hf_checkpoint( model: Any, dtype: torch.dtype | None = None, @@ -1019,6 +1078,9 @@ def export_hf_checkpoint( save_modelopt_state=save_modelopt_state, ) + # Copy MTP files if present (e.g., GLM-4.7 mtp.safetensors) + _copy_mtp_files_if_needed(model, export_dir) + original_config = f"{export_dir}/config.json" config_data = {} From ba5c7002a51c3750d035a9ecaaede9a211df9bf8 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Mon, 9 Feb 2026 16:30:27 -0800 Subject: [PATCH 2/2] update Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/example_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index f5f6ca092..78e6221d4 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -388,11 +388,13 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[ # Store info for copying during export if missing_keys: file_weight_map = dict.fromkeys(expected_keys, filename) - mtp_files_info.append({ - "source_path": str(filepath), - "filename": filename, - "weight_map": file_weight_map, - }) + mtp_files_info.append( + { + "source_path": str(filepath), + "filename": filename, + "weight_map": file_weight_map, + } + ) print(f"Found {len(expected_keys)} MTP weights in {filename} (will copy during export)") # Store MTP file info on the model for use during export