Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 56 additions & 55 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import get_max_memory
from safetensors.torch import load_file
from transformers import (
AutoConfig,
AutoModelForCausalLM,
Expand Down Expand Up @@ -316,32 +315,36 @@ def get_processor(
return None


def load_mtp_weights(
model: torch.nn.Module, model_path: str
) -> tuple[list[str], dict[str, torch.Tensor]]:
"""Load MTP weights from the model checkpoint.
def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[str]:
"""Detect MTP weights in separate safetensors files (e.g., GLM-4.7).

Some models store additional layers in separate safetensors files with non-standard
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
files even though they're referenced in model.safetensors.index.json.
Some models store MTP (Multi-Token Prediction) layers in separate safetensors files
(e.g., mtp.safetensors) that are referenced in model.safetensors.index.json but
not loaded by HuggingFace transformers (because the model architecture doesn't
include these layers).

This function detects such cases and explicitly loads the missing weights.
This function:
1. Detects non-standard safetensors files with weights not in the model
2. Stores info about these files on the model for later export (model._mtp_files_info)
3. Returns the layer prefixes (e.g., ["model.layers.92"]) for quantization exclusion

Note: The weights are NOT loaded into the model (since the model architecture doesn't
support them), but we track them so they can be copied during export.

Args:
model: The loaded model that may be missing weights
model: The loaded model
model_path: Path to the model directory

Returns:
List of layer prefixes that were loaded from non-standard safetensors files.
List of layer prefixes that contain MTP weights (e.g., ["model.layers.92"]).
These layers should typically be excluded from quantization.
Empty list if no additional weights were loaded.
Dictionary of MTP weights that were not loaded into the model state dict.
Empty list if no MTP weights were found.
"""
model_path = Path(model_path)
index_file = model_path / "model.safetensors.index.json"

if not index_file.exists():
return [], {}
return []

# Load the index to find all referenced safetensors files
index = json.load(open(index_file))
Expand All @@ -353,58 +356,56 @@ def load_mtp_weights(
mtp_weight_map.setdefault(v, []).append(k)

if not mtp_weight_map:
return [], {}

def _extract_layer_prefixes(keys):
mtp_layer_prefixes = set()
for key in keys:
parts = key.split(".")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2])
mtp_layer_prefixes.add(prefix)
break

return mtp_layer_prefixes

# Flatten mtp_weight_map.values() (list of list of str) to a single list of str
mtp_keys = [k for keys in mtp_weight_map.values() for k in keys]
mtp_layer_prefixes = _extract_layer_prefixes(mtp_keys)
return []

# Check which non-standard files exist and have missing weights
# Check which non-standard files exist and have weights not in the model
model_state = model.state_dict()
total_loaded = 0
mtp_files_info = [] # Store info for export: [{source_path, filename, weight_map}]
mtp_layer_prefixes = []

not_in_state_dict = {}

for filename, mtp_keys in mtp_weight_map.items():
for filename in mtp_weight_map:
filepath = model_path / filename
if not filepath.exists():
continue

print(f"Loading {len(mtp_keys)} mtp weights from {filename}...")
weights = load_file(str(filepath), device="cpu")
weights = {k: v for k, v in weights.items() if k in mtp_keys}
# Load the MTP weights to the model state dict
in_state_dict = {k: weights[k] for k in weights if k in model_state}
not_in_state_dict = not_in_state_dict | {
k: weights[k] for k in weights if k not in model_state
}

if in_state_dict:
model.load_state_dict(in_state_dict, strict=False)
total_loaded += len(in_state_dict)

if total_loaded > 0:
print(
f"✓ Successfully loaded {total_loaded} MTP weights, "
f"{len(not_in_state_dict)} MTP weights not in model.state_dict"
)
# Find keys that should be in this file
expected_keys = [k for k, v in index["weight_map"].items() if v == filename]

# Check which are missing from the model (i.e., model doesn't have these modules)
missing_keys = [k for k in expected_keys if k not in model_state]

# Extract layer prefixes from all expected keys
for key in expected_keys:
parts = key.split(".")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
if prefix not in mtp_layer_prefixes:
mtp_layer_prefixes.append(prefix)
break

# If there are missing keys, the model architecture doesn't support these weights
# Store info for copying during export
if missing_keys:
file_weight_map = dict.fromkeys(expected_keys, filename)
mtp_files_info.append(
{
"source_path": str(filepath),
"filename": filename,
"weight_map": file_weight_map,
}
)
print(f"Found {len(expected_keys)} MTP weights in {filename} (will copy during export)")

# Store MTP file info on the model for use during export
if mtp_files_info:
model._mtp_files_info = mtp_files_info
print(f"✓ Stored {len(mtp_files_info)} MTP file(s) info for export")

if mtp_layer_prefixes:
print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}")

return list(mtp_layer_prefixes), not_in_state_dict
return mtp_layer_prefixes


def get_dtype(dtype):
Expand Down
62 changes: 62 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import collections.abc
import json
import re
import shutil
import tempfile
import warnings
from builtins import ValueError
Expand Down Expand Up @@ -954,6 +955,64 @@ def _export_diffusers_checkpoint(
print(f"Export complete. Saved to: {export_dir}")


def _copy_mtp_files_if_needed(model: nn.Module, export_dir: Path) -> None:
"""Copy MTP (Multi-Token Prediction) safetensors files if they exist.

Some models like GLM-4.7 have MTP layers stored in separate safetensors files
(e.g., mtp.safetensors) that aren't part of the model's state_dict because
HuggingFace Transformers doesn't create the corresponding modules.

This function copies those files to the export directory and updates the
model.safetensors.index.json to include the MTP weights.

Args:
model: The model being exported (may have _mtp_files_info attribute)
export_dir: The export directory path
"""
mtp_files_info = getattr(model, "_mtp_files_info", None)
if not mtp_files_info:
return

export_dir = Path(export_dir)
index_file = export_dir / "model.safetensors.index.json"

# Load existing index if present
if index_file.exists():
with open(index_file) as f:
index_data = json.load(f)
else:
# Create a basic index structure if it doesn't exist
index_data = {"metadata": {}, "weight_map": {}}

# Copy each MTP file and update the index
for mtp_info in mtp_files_info:
source_path = Path(mtp_info["source_path"])
filename = mtp_info["filename"]
weight_map = mtp_info["weight_map"]

if not source_path.exists():
print(f"Warning: MTP source file not found: {source_path}")
continue

dest_path = export_dir / filename

# Copy the file
print(f"Copying MTP file: {filename}")
shutil.copy2(source_path, dest_path)

# Update the weight map in the index
for weight_name, file_name in weight_map.items():
index_data["weight_map"][weight_name] = file_name

print(f"✓ Copied {filename} with {len(weight_map)} weights")

# Write updated index
with open(index_file, "w") as f:
json.dump(index_data, f, indent=2)

print("✓ Updated model.safetensors.index.json with MTP weights")


def export_hf_checkpoint(
model: Any,
dtype: torch.dtype | None = None,
Expand Down Expand Up @@ -1019,6 +1078,9 @@ def export_hf_checkpoint(
save_modelopt_state=save_modelopt_state,
)

# Copy MTP files if present (e.g., GLM-4.7 mtp.safetensors)
_copy_mtp_files_if_needed(model, export_dir)

original_config = f"{export_dir}/config.json"
config_data = {}

Expand Down