From 4f61dd7710a6477a595879a2ea86e4bf1c6b9a62 Mon Sep 17 00:00:00 2001 From: ynankani Date: Fri, 20 Feb 2026 05:34:43 -0800 Subject: [PATCH 1/2] Add support for export comfyui compatible checkpoint for diffusion model(e.g., LTX-2) Signed-off-by: ynankani --- modelopt/torch/export/unified_export_hf.py | 136 ++++++++++++++++++--- 1 file changed, 117 insertions(+), 19 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ca80cb450..983e18c53 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -28,7 +28,7 @@ import torch import torch.nn as nn -from safetensors.torch import save_file +from safetensors.torch import save_file, load_file, safe_open try: import diffusers @@ -111,20 +111,108 @@ def _is_enabled_quantizer(quantizer): return False +def _merge_diffusion_transformer_with_non_transformer_components( + diffusion_transformer_state_dict: dict[str, torch.Tensor], + merged_base_safetensor_path: str, +) -> tuple[dict[str, torch.Tensor], dict[str, str]]: + """Merge diffusion transformer weights with non-transformer components from a safetensors file. + + Non-transformer components (VAE, vocoder, text encoders) and embeddings connectors are + taken from the base checkpoint. Transformer keys are prefixed with 'model.diffusion_model.' + for ComfyUI compatibility. + + Args: + diffusion_transformer_state_dict: The diffusion transformer state dict (already on CPU). + merged_base_safetensor_path: Path to the full base model safetensors file containing + all components (transformer, VAE, vocoder, etc.). + + Returns: + Tuple of (merged_state_dict, base_metadata) where base_metadata is the original + safetensors metadata from the base checkpoint. + """ + + base_state = load_file(merged_base_safetensor_path) + + non_transformer_prefixes = [ + 'vae.', 'audio_vae.', 'vocoder.', 'text_embedding_projection.', + 'text_encoders.', 'first_stage_model.', 'cond_stage_model.', 'conditioner.', + ] + correct_prefix = 'model.diffusion_model.' + strip_prefixes = ['diffusion_model.', 'transformer.', '_orig_mod.', 'model.', 'velocity_model.'] + + base_non_transformer = {k: v for k, v in base_state.items() + if any(k.startswith(p) for p in non_transformer_prefixes)} + base_connectors = {k: v for k, v in base_state.items() + if 'embeddings_connector' in k and k.startswith(correct_prefix)} + + prefixed = {} + for k, v in diffusion_transformer_state_dict.items(): + clean_k = k + for prefix in strip_prefixes: + if clean_k.startswith(prefix): + clean_k = clean_k[len(prefix):] + break + prefixed[f"{correct_prefix}{clean_k}"] = v + + merged = dict(base_non_transformer) + merged.update(base_connectors) + merged.update(prefixed) + with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f: + base_metadata = f.metadata() or {} + + del base_state + return merged, base_metadata + + def _save_component_state_dict_safetensors( - component: nn.Module, component_export_dir: Path + component: nn.Module, + component_export_dir: Path, + merged_base_safetensor_path: str | None = None, + hf_quant_config: dict | None = None ) -> None: + """Save component state dict as safetensors with optional base checkpoint merge. + + Args: + component: The nn.Module to save. + component_export_dir: Directory to save model.safetensors and config.json. + merged_base_safetensor_path: If provided, merge with non-transformer components + from this base safetensors file. + hf_quant_config: If provided, embed quantization config in safetensors metadata + and per-layer _quantization_metadata for ComfyUI. + """ cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()} - save_file(cpu_state_dict, str(component_export_dir / "model.safetensors")) - with open(component_export_dir / "config.json", "w") as f: - json.dump( - { - "_class_name": type(component).__name__, - "_export_format": "safetensors_state_dict", - }, - f, - indent=4, + metadata: dict[str, str] = {} + metadata_full: dict[str, str] = {} + if merged_base_safetensor_path is not None: + cpu_state_dict, metadata_full = _merge_diffusion_transformer_with_non_transformer_components( + cpu_state_dict, merged_base_safetensor_path ) + metadata["_export_format"] = "safetensors_state_dict" + metadata["_class_name"] = type(component).__name__ + + if hf_quant_config is not None: + metadata_full["quantization_config"] = json.dumps(hf_quant_config) + + # Build per-layer _quantization_metadata for ComfyUI + quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() + layer_metadata = {} + for k in cpu_state_dict: + if k.endswith(".weight_scale") or k.endswith(".weight_scale_2"): + layer_name = k.rsplit(".", 1)[0] + if layer_name.endswith(".weight"): + layer_name = layer_name.rsplit(".", 1)[0] + if layer_name not in layer_metadata: + layer_metadata[layer_name] = {"format": quant_algo} + metadata_full["_quantization_metadata"] = json.dumps({ + "format_version": "1.0", + "layers": layer_metadata, + }) + + metadata_full.update(metadata) + save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"), metadata=metadata_full if merged_base_safetensor_path is not None else None) + + with open(component_export_dir / "config.json", "w") as f: + json.dump(metadata, f, indent=4) def _collect_shared_input_modules( @@ -807,6 +895,7 @@ def _export_diffusers_checkpoint( dtype: torch.dtype | None, export_dir: Path, components: list[str] | None, + merged_base_safetensor_path: str | None = None, max_shard_size: int | str = "10GB", ) -> None: """Internal: Export diffusion(-like) model/pipeline checkpoint. @@ -821,6 +910,8 @@ def _export_diffusers_checkpoint( export_dir: The directory to save the exported checkpoint. components: Optional list of component names to export. Only used for pipelines. If None, all components are exported. + merged_base_safetensor_path: If provided, merge the exported transformer with + non-transformer components from this base safetensors file. max_shard_size: Maximum size of each shard file. If the model exceeds this size, it will be sharded into multiple files and a .safetensors.index.json will be created. Use smaller values like "5GB" or "2GB" to force sharding. @@ -879,7 +970,8 @@ def _export_diffusers_checkpoint( # Step 5: Build quantization config quant_config = get_quant_config(component, is_modelopt_qlora=False) - + hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None + # Step 6: Save the component # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter # - for non-diffusers modules (e.g., LTX-2 transformer), fall back to torch.save @@ -888,12 +980,14 @@ def _export_diffusers_checkpoint( component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) else: with hide_quantizers_from_state_dict(component): - _save_component_state_dict_safetensors(component, component_export_dir) - + _save_component_state_dict_safetensors( + component, + component_export_dir, + merged_base_safetensor_path, + hf_quant_config, + ) # Step 7: Update config.json with quantization info - if quant_config is not None: - hf_quant_config = convert_hf_quant_config_format(quant_config) - + if hf_quant_config is not None: config_path = component_export_dir / "config.json" if config_path.exists(): with open(config_path) as file: @@ -905,7 +999,7 @@ def _export_diffusers_checkpoint( elif hasattr(component, "save_pretrained"): component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) else: - _save_component_state_dict_safetensors(component, component_export_dir) + _save_component_state_dict_safetensors(component, component_export_dir, merged_base_safetensor_path) print(f" Saved to: {component_export_dir}") @@ -985,6 +1079,7 @@ def export_hf_checkpoint( save_modelopt_state: bool = False, components: list[str] | None = None, extra_state_dict: dict[str, torch.Tensor] | None = None, + merged_base_safetensor_path: str | None = None, ): """Export quantized HuggingFace model checkpoint (transformers or diffusers). @@ -1002,6 +1097,9 @@ def export_hf_checkpoint( components: Only used for diffusers pipelines. Optional list of component names to export. If None, all quantized components are exported. extra_state_dict: Extra state dictionary to add to the exported model. + merged_base_safetensor_path: If provided, merge the exported diffusion transformer + with non-transformer components (VAE, vocoder, etc.) from this base safetensors + file. Only used for diffusion model exports (e.g., LTX-2). """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) @@ -1010,7 +1108,7 @@ def export_hf_checkpoint( if HAS_DIFFUSERS: is_diffusers_obj = is_diffusers_object(model) if is_diffusers_obj: - _export_diffusers_checkpoint(model, dtype, export_dir, components) + _export_diffusers_checkpoint(model, dtype, export_dir, components, merged_base_safetensor_path) return # Transformers model export From 69107c0e0a5c0418e1eceee9d9d26ed93513c956 Mon Sep 17 00:00:00 2001 From: ynankani Date: Fri, 20 Feb 2026 06:07:40 -0800 Subject: [PATCH 2/2] Add support for export comfyui compatible checkpoint for diffusion model(e.g., LTX-2) Signed-off-by: ynankani --- modelopt/torch/export/unified_export_hf.py | 83 ++++++++++++++-------- 1 file changed, 53 insertions(+), 30 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 983e18c53..bd6df260c 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -28,7 +28,7 @@ import torch import torch.nn as nn -from safetensors.torch import save_file, load_file, safe_open +from safetensors.torch import load_file, safe_open, save_file try: import diffusers @@ -130,27 +130,38 @@ def _merge_diffusion_transformer_with_non_transformer_components( Tuple of (merged_state_dict, base_metadata) where base_metadata is the original safetensors metadata from the base checkpoint. """ - base_state = load_file(merged_base_safetensor_path) non_transformer_prefixes = [ - 'vae.', 'audio_vae.', 'vocoder.', 'text_embedding_projection.', - 'text_encoders.', 'first_stage_model.', 'cond_stage_model.', 'conditioner.', + "vae.", + "audio_vae.", + "vocoder.", + "text_embedding_projection.", + "text_encoders.", + "first_stage_model.", + "cond_stage_model.", + "conditioner.", ] - correct_prefix = 'model.diffusion_model.' - strip_prefixes = ['diffusion_model.', 'transformer.', '_orig_mod.', 'model.', 'velocity_model.'] + correct_prefix = "model.diffusion_model." + strip_prefixes = ["diffusion_model.", "transformer.", "_orig_mod.", "model.", "velocity_model."] - base_non_transformer = {k: v for k, v in base_state.items() - if any(k.startswith(p) for p in non_transformer_prefixes)} - base_connectors = {k: v for k, v in base_state.items() - if 'embeddings_connector' in k and k.startswith(correct_prefix)} + base_non_transformer = { + k: v + for k, v in base_state.items() + if any(k.startswith(p) for p in non_transformer_prefixes) + } + base_connectors = { + k: v + for k, v in base_state.items() + if "embeddings_connector" in k and k.startswith(correct_prefix) + } prefixed = {} for k, v in diffusion_transformer_state_dict.items(): clean_k = k for prefix in strip_prefixes: if clean_k.startswith(prefix): - clean_k = clean_k[len(prefix):] + clean_k = clean_k[len(prefix) :] break prefixed[f"{correct_prefix}{clean_k}"] = v @@ -165,10 +176,10 @@ def _merge_diffusion_transformer_with_non_transformer_components( def _save_component_state_dict_safetensors( - component: nn.Module, - component_export_dir: Path, - merged_base_safetensor_path: str | None = None, - hf_quant_config: dict | None = None + component: nn.Module, + component_export_dir: Path, + merged_base_safetensor_path: str | None = None, + hf_quant_config: dict | None = None, ) -> None: """Save component state dict as safetensors with optional base checkpoint merge. @@ -184,10 +195,12 @@ def _save_component_state_dict_safetensors( metadata: dict[str, str] = {} metadata_full: dict[str, str] = {} if merged_base_safetensor_path is not None: - cpu_state_dict, metadata_full = _merge_diffusion_transformer_with_non_transformer_components( - cpu_state_dict, merged_base_safetensor_path + cpu_state_dict, metadata_full = ( + _merge_diffusion_transformer_with_non_transformer_components( + cpu_state_dict, merged_base_safetensor_path + ) ) - metadata["_export_format"] = "safetensors_state_dict" + metadata["_export_format"] = "safetensors_state_dict" metadata["_class_name"] = type(component).__name__ if hf_quant_config is not None: @@ -197,20 +210,26 @@ def _save_component_state_dict_safetensors( quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() layer_metadata = {} for k in cpu_state_dict: - if k.endswith(".weight_scale") or k.endswith(".weight_scale_2"): + if k.endswith((".weight_scale", ".weight_scale_2")): layer_name = k.rsplit(".", 1)[0] if layer_name.endswith(".weight"): layer_name = layer_name.rsplit(".", 1)[0] if layer_name not in layer_metadata: layer_metadata[layer_name] = {"format": quant_algo} - metadata_full["_quantization_metadata"] = json.dumps({ - "format_version": "1.0", - "layers": layer_metadata, - }) + metadata_full["_quantization_metadata"] = json.dumps( + { + "format_version": "1.0", + "layers": layer_metadata, + } + ) metadata_full.update(metadata) - save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"), metadata=metadata_full if merged_base_safetensor_path is not None else None) - + save_file( + cpu_state_dict, + str(component_export_dir / "model.safetensors"), + metadata=metadata_full if merged_base_safetensor_path is not None else None, + ) + with open(component_export_dir / "config.json", "w") as f: json.dump(metadata, f, indent=4) @@ -971,7 +990,7 @@ def _export_diffusers_checkpoint( # Step 5: Build quantization config quant_config = get_quant_config(component, is_modelopt_qlora=False) hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None - + # Step 6: Save the component # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter # - for non-diffusers modules (e.g., LTX-2 transformer), fall back to torch.save @@ -981,8 +1000,8 @@ def _export_diffusers_checkpoint( else: with hide_quantizers_from_state_dict(component): _save_component_state_dict_safetensors( - component, - component_export_dir, + component, + component_export_dir, merged_base_safetensor_path, hf_quant_config, ) @@ -999,7 +1018,9 @@ def _export_diffusers_checkpoint( elif hasattr(component, "save_pretrained"): component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) else: - _save_component_state_dict_safetensors(component, component_export_dir, merged_base_safetensor_path) + _save_component_state_dict_safetensors( + component, component_export_dir, merged_base_safetensor_path + ) print(f" Saved to: {component_export_dir}") @@ -1108,7 +1129,9 @@ def export_hf_checkpoint( if HAS_DIFFUSERS: is_diffusers_obj = is_diffusers_object(model) if is_diffusers_obj: - _export_diffusers_checkpoint(model, dtype, export_dir, components, merged_base_safetensor_path) + _export_diffusers_checkpoint( + model, dtype, export_dir, components, merged_base_safetensor_path + ) return # Transformers model export