Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 138 additions & 17 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import torch
import torch.nn as nn
from safetensors.torch import save_file
from safetensors.torch import load_file, safe_open, save_file

try:
import diffusers
Expand Down Expand Up @@ -111,21 +111,128 @@ def _is_enabled_quantizer(quantizer):
return False


def _merge_diffusion_transformer_with_non_transformer_components(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, this seems to work only for LTX2.

Are these mapping relationships hard-coded? If so, we should move this logic into a model-dependent function, for example:

model_type = LTX2
merge_function[LTX2](...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

diffusion_transformer_state_dict: dict[str, torch.Tensor],
merged_base_safetensor_path: str,
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
"""Merge diffusion transformer weights with non-transformer components from a safetensors file.

Non-transformer components (VAE, vocoder, text encoders) and embeddings connectors are
taken from the base checkpoint. Transformer keys are prefixed with 'model.diffusion_model.'
for ComfyUI compatibility.

Args:
diffusion_transformer_state_dict: The diffusion transformer state dict (already on CPU).
merged_base_safetensor_path: Path to the full base model safetensors file containing
all components (transformer, VAE, vocoder, etc.).

Returns:
Tuple of (merged_state_dict, base_metadata) where base_metadata is the original
safetensors metadata from the base checkpoint.
"""
base_state = load_file(merged_base_safetensor_path)

non_transformer_prefixes = [
"vae.",
"audio_vae.",
"vocoder.",
"text_embedding_projection.",
"text_encoders.",
"first_stage_model.",
"cond_stage_model.",
"conditioner.",
]
correct_prefix = "model.diffusion_model."
strip_prefixes = ["diffusion_model.", "transformer.", "_orig_mod.", "model.", "velocity_model."]

base_non_transformer = {
k: v
for k, v in base_state.items()
if any(k.startswith(p) for p in non_transformer_prefixes)
}
base_connectors = {
k: v
for k, v in base_state.items()
if "embeddings_connector" in k and k.startswith(correct_prefix)
}

prefixed = {}
for k, v in diffusion_transformer_state_dict.items():
clean_k = k
for prefix in strip_prefixes:
if clean_k.startswith(prefix):
clean_k = clean_k[len(prefix) :]
break
prefixed[f"{correct_prefix}{clean_k}"] = v

merged = dict(base_non_transformer)
merged.update(base_connectors)
merged.update(prefixed)
with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f:
base_metadata = f.metadata() or {}

del base_state
return merged, base_metadata


def _save_component_state_dict_safetensors(
component: nn.Module, component_export_dir: Path
component: nn.Module,
component_export_dir: Path,
merged_base_safetensor_path: str | None = None,
hf_quant_config: dict | None = None,
) -> None:
"""Save component state dict as safetensors with optional base checkpoint merge.

Args:
component: The nn.Module to save.
component_export_dir: Directory to save model.safetensors and config.json.
merged_base_safetensor_path: If provided, merge with non-transformer components
from this base safetensors file.
hf_quant_config: If provided, embed quantization config in safetensors metadata
and per-layer _quantization_metadata for ComfyUI.
"""
cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()}
save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"))
with open(component_export_dir / "config.json", "w") as f:
json.dump(
metadata: dict[str, str] = {}
metadata_full: dict[str, str] = {}
if merged_base_safetensor_path is not None:
cpu_state_dict, metadata_full = (
_merge_diffusion_transformer_with_non_transformer_components(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted above, this function should be model-dependent

merge_function[model_type](...)

cpu_state_dict, merged_base_safetensor_path
)
)
metadata["_export_format"] = "safetensors_state_dict"
metadata["_class_name"] = type(component).__name__

if hf_quant_config is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add more checks to make it more safer
if hf_quant_config is not None and merged_base_safetensor_path is not None:

metadata_full["quantization_config"] = json.dumps(hf_quant_config)

# Build per-layer _quantization_metadata for ComfyUI
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
layer_metadata = {}
for k in cpu_state_dict:
if k.endswith((".weight_scale", ".weight_scale_2")):
layer_name = k.rsplit(".", 1)[0]
if layer_name.endswith(".weight"):
layer_name = layer_name.rsplit(".", 1)[0]
if layer_name not in layer_metadata:
layer_metadata[layer_name] = {"format": quant_algo}
metadata_full["_quantization_metadata"] = json.dumps(
{
"_class_name": type(component).__name__,
"_export_format": "safetensors_state_dict",
},
f,
indent=4,
"format_version": "1.0",
"layers": layer_metadata,
}
)

metadata_full.update(metadata)
save_file(
cpu_state_dict,
str(component_export_dir / "model.safetensors"),
metadata=metadata_full if merged_base_safetensor_path is not None else None,
)

with open(component_export_dir / "config.json", "w") as f:
json.dump(metadata, f, indent=4)


def _collect_shared_input_modules(
model: nn.Module,
Expand Down Expand Up @@ -807,6 +914,7 @@ def _export_diffusers_checkpoint(
dtype: torch.dtype | None,
export_dir: Path,
components: list[str] | None,
merged_base_safetensor_path: str | None = None,
max_shard_size: int | str = "10GB",
) -> None:
"""Internal: Export diffusion(-like) model/pipeline checkpoint.
Expand All @@ -821,6 +929,8 @@ def _export_diffusers_checkpoint(
export_dir: The directory to save the exported checkpoint.
components: Optional list of component names to export. Only used for pipelines.
If None, all components are exported.
merged_base_safetensor_path: If provided, merge the exported transformer with
non-transformer components from this base safetensors file.
max_shard_size: Maximum size of each shard file. If the model exceeds this size,
it will be sharded into multiple files and a .safetensors.index.json will be
created. Use smaller values like "5GB" or "2GB" to force sharding.
Expand Down Expand Up @@ -879,6 +989,7 @@ def _export_diffusers_checkpoint(

# Step 5: Build quantization config
quant_config = get_quant_config(component, is_modelopt_qlora=False)
hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None

# Step 6: Save the component
# - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter
Expand All @@ -888,12 +999,14 @@ def _export_diffusers_checkpoint(
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
else:
with hide_quantizers_from_state_dict(component):
_save_component_state_dict_safetensors(component, component_export_dir)

_save_component_state_dict_safetensors(
component,
component_export_dir,
merged_base_safetensor_path,
hf_quant_config,
)
# Step 7: Update config.json with quantization info
if quant_config is not None:
hf_quant_config = convert_hf_quant_config_format(quant_config)

if hf_quant_config is not None:
config_path = component_export_dir / "config.json"
if config_path.exists():
with open(config_path) as file:
Expand All @@ -905,7 +1018,9 @@ def _export_diffusers_checkpoint(
elif hasattr(component, "save_pretrained"):
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
else:
_save_component_state_dict_safetensors(component, component_export_dir)
_save_component_state_dict_safetensors(
component, component_export_dir, merged_base_safetensor_path
)

print(f" Saved to: {component_export_dir}")

Expand Down Expand Up @@ -985,6 +1100,7 @@ def export_hf_checkpoint(
save_modelopt_state: bool = False,
components: list[str] | None = None,
extra_state_dict: dict[str, torch.Tensor] | None = None,
merged_base_safetensor_path: str | None = None,
):
"""Export quantized HuggingFace model checkpoint (transformers or diffusers).

Expand All @@ -1002,6 +1118,9 @@ def export_hf_checkpoint(
components: Only used for diffusers pipelines. Optional list of component names
to export. If None, all quantized components are exported.
extra_state_dict: Extra state dictionary to add to the exported model.
merged_base_safetensor_path: If provided, merge the exported diffusion transformer
with non-transformer components (VAE, vocoder, etc.) from this base safetensors
file. Only used for diffusion model exports (e.g., LTX-2).
"""
export_dir = Path(export_dir)
export_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -1010,7 +1129,9 @@ def export_hf_checkpoint(
if HAS_DIFFUSERS:
is_diffusers_obj = is_diffusers_object(model)
if is_diffusers_obj:
_export_diffusers_checkpoint(model, dtype, export_dir, components)
_export_diffusers_checkpoint(
model, dtype, export_dir, components, merged_base_safetensor_path
)
return

# Transformers model export
Expand Down