-
Notifications
You must be signed in to change notification settings - Fork 279
Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2) #911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,7 +28,7 @@ | |
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from safetensors.torch import save_file | ||
| from safetensors.torch import load_file, safe_open, save_file | ||
|
|
||
| try: | ||
| import diffusers | ||
|
|
@@ -111,21 +111,128 @@ def _is_enabled_quantizer(quantizer): | |
| return False | ||
|
|
||
|
|
||
| def _merge_diffusion_transformer_with_non_transformer_components( | ||
| diffusion_transformer_state_dict: dict[str, torch.Tensor], | ||
| merged_base_safetensor_path: str, | ||
| ) -> tuple[dict[str, torch.Tensor], dict[str, str]]: | ||
| """Merge diffusion transformer weights with non-transformer components from a safetensors file. | ||
|
|
||
| Non-transformer components (VAE, vocoder, text encoders) and embeddings connectors are | ||
| taken from the base checkpoint. Transformer keys are prefixed with 'model.diffusion_model.' | ||
| for ComfyUI compatibility. | ||
|
|
||
| Args: | ||
| diffusion_transformer_state_dict: The diffusion transformer state dict (already on CPU). | ||
| merged_base_safetensor_path: Path to the full base model safetensors file containing | ||
| all components (transformer, VAE, vocoder, etc.). | ||
|
|
||
| Returns: | ||
| Tuple of (merged_state_dict, base_metadata) where base_metadata is the original | ||
| safetensors metadata from the base checkpoint. | ||
| """ | ||
| base_state = load_file(merged_base_safetensor_path) | ||
|
|
||
| non_transformer_prefixes = [ | ||
| "vae.", | ||
| "audio_vae.", | ||
| "vocoder.", | ||
| "text_embedding_projection.", | ||
| "text_encoders.", | ||
| "first_stage_model.", | ||
| "cond_stage_model.", | ||
| "conditioner.", | ||
| ] | ||
| correct_prefix = "model.diffusion_model." | ||
| strip_prefixes = ["diffusion_model.", "transformer.", "_orig_mod.", "model.", "velocity_model."] | ||
|
|
||
| base_non_transformer = { | ||
| k: v | ||
| for k, v in base_state.items() | ||
| if any(k.startswith(p) for p in non_transformer_prefixes) | ||
| } | ||
| base_connectors = { | ||
| k: v | ||
| for k, v in base_state.items() | ||
| if "embeddings_connector" in k and k.startswith(correct_prefix) | ||
| } | ||
|
|
||
| prefixed = {} | ||
| for k, v in diffusion_transformer_state_dict.items(): | ||
| clean_k = k | ||
| for prefix in strip_prefixes: | ||
| if clean_k.startswith(prefix): | ||
| clean_k = clean_k[len(prefix) :] | ||
| break | ||
| prefixed[f"{correct_prefix}{clean_k}"] = v | ||
|
|
||
| merged = dict(base_non_transformer) | ||
| merged.update(base_connectors) | ||
| merged.update(prefixed) | ||
| with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f: | ||
| base_metadata = f.metadata() or {} | ||
|
|
||
| del base_state | ||
| return merged, base_metadata | ||
|
|
||
|
|
||
| def _save_component_state_dict_safetensors( | ||
| component: nn.Module, component_export_dir: Path | ||
| component: nn.Module, | ||
| component_export_dir: Path, | ||
| merged_base_safetensor_path: str | None = None, | ||
| hf_quant_config: dict | None = None, | ||
| ) -> None: | ||
| """Save component state dict as safetensors with optional base checkpoint merge. | ||
|
|
||
| Args: | ||
| component: The nn.Module to save. | ||
| component_export_dir: Directory to save model.safetensors and config.json. | ||
| merged_base_safetensor_path: If provided, merge with non-transformer components | ||
| from this base safetensors file. | ||
| hf_quant_config: If provided, embed quantization config in safetensors metadata | ||
| and per-layer _quantization_metadata for ComfyUI. | ||
| """ | ||
| cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()} | ||
| save_file(cpu_state_dict, str(component_export_dir / "model.safetensors")) | ||
| with open(component_export_dir / "config.json", "w") as f: | ||
| json.dump( | ||
| metadata: dict[str, str] = {} | ||
| metadata_full: dict[str, str] = {} | ||
| if merged_base_safetensor_path is not None: | ||
| cpu_state_dict, metadata_full = ( | ||
| _merge_diffusion_transformer_with_non_transformer_components( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As noted above, this function should be model-dependent
|
||
| cpu_state_dict, merged_base_safetensor_path | ||
| ) | ||
| ) | ||
| metadata["_export_format"] = "safetensors_state_dict" | ||
| metadata["_class_name"] = type(component).__name__ | ||
|
|
||
| if hf_quant_config is not None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should add more checks to make it more safer |
||
| metadata_full["quantization_config"] = json.dumps(hf_quant_config) | ||
|
|
||
| # Build per-layer _quantization_metadata for ComfyUI | ||
| quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() | ||
| layer_metadata = {} | ||
| for k in cpu_state_dict: | ||
| if k.endswith((".weight_scale", ".weight_scale_2")): | ||
| layer_name = k.rsplit(".", 1)[0] | ||
| if layer_name.endswith(".weight"): | ||
| layer_name = layer_name.rsplit(".", 1)[0] | ||
| if layer_name not in layer_metadata: | ||
| layer_metadata[layer_name] = {"format": quant_algo} | ||
| metadata_full["_quantization_metadata"] = json.dumps( | ||
| { | ||
| "_class_name": type(component).__name__, | ||
| "_export_format": "safetensors_state_dict", | ||
| }, | ||
| f, | ||
| indent=4, | ||
| "format_version": "1.0", | ||
| "layers": layer_metadata, | ||
| } | ||
| ) | ||
|
|
||
| metadata_full.update(metadata) | ||
| save_file( | ||
| cpu_state_dict, | ||
| str(component_export_dir / "model.safetensors"), | ||
| metadata=metadata_full if merged_base_safetensor_path is not None else None, | ||
| ) | ||
|
|
||
| with open(component_export_dir / "config.json", "w") as f: | ||
| json.dump(metadata, f, indent=4) | ||
|
|
||
|
|
||
| def _collect_shared_input_modules( | ||
| model: nn.Module, | ||
|
|
@@ -807,6 +914,7 @@ def _export_diffusers_checkpoint( | |
| dtype: torch.dtype | None, | ||
| export_dir: Path, | ||
| components: list[str] | None, | ||
| merged_base_safetensor_path: str | None = None, | ||
| max_shard_size: int | str = "10GB", | ||
| ) -> None: | ||
| """Internal: Export diffusion(-like) model/pipeline checkpoint. | ||
|
|
@@ -821,6 +929,8 @@ def _export_diffusers_checkpoint( | |
| export_dir: The directory to save the exported checkpoint. | ||
| components: Optional list of component names to export. Only used for pipelines. | ||
| If None, all components are exported. | ||
| merged_base_safetensor_path: If provided, merge the exported transformer with | ||
| non-transformer components from this base safetensors file. | ||
| max_shard_size: Maximum size of each shard file. If the model exceeds this size, | ||
| it will be sharded into multiple files and a .safetensors.index.json will be | ||
| created. Use smaller values like "5GB" or "2GB" to force sharding. | ||
|
|
@@ -879,6 +989,7 @@ def _export_diffusers_checkpoint( | |
|
|
||
| # Step 5: Build quantization config | ||
| quant_config = get_quant_config(component, is_modelopt_qlora=False) | ||
| hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None | ||
|
|
||
| # Step 6: Save the component | ||
| # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter | ||
|
|
@@ -888,12 +999,14 @@ def _export_diffusers_checkpoint( | |
| component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) | ||
| else: | ||
| with hide_quantizers_from_state_dict(component): | ||
| _save_component_state_dict_safetensors(component, component_export_dir) | ||
|
|
||
| _save_component_state_dict_safetensors( | ||
| component, | ||
| component_export_dir, | ||
| merged_base_safetensor_path, | ||
| hf_quant_config, | ||
| ) | ||
| # Step 7: Update config.json with quantization info | ||
| if quant_config is not None: | ||
| hf_quant_config = convert_hf_quant_config_format(quant_config) | ||
|
|
||
| if hf_quant_config is not None: | ||
| config_path = component_export_dir / "config.json" | ||
| if config_path.exists(): | ||
| with open(config_path) as file: | ||
|
|
@@ -905,7 +1018,9 @@ def _export_diffusers_checkpoint( | |
| elif hasattr(component, "save_pretrained"): | ||
| component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) | ||
| else: | ||
| _save_component_state_dict_safetensors(component, component_export_dir) | ||
| _save_component_state_dict_safetensors( | ||
| component, component_export_dir, merged_base_safetensor_path | ||
| ) | ||
|
|
||
| print(f" Saved to: {component_export_dir}") | ||
|
|
||
|
|
@@ -985,6 +1100,7 @@ def export_hf_checkpoint( | |
| save_modelopt_state: bool = False, | ||
| components: list[str] | None = None, | ||
| extra_state_dict: dict[str, torch.Tensor] | None = None, | ||
| merged_base_safetensor_path: str | None = None, | ||
| ): | ||
| """Export quantized HuggingFace model checkpoint (transformers or diffusers). | ||
|
|
||
|
|
@@ -1002,6 +1118,9 @@ def export_hf_checkpoint( | |
| components: Only used for diffusers pipelines. Optional list of component names | ||
| to export. If None, all quantized components are exported. | ||
| extra_state_dict: Extra state dictionary to add to the exported model. | ||
| merged_base_safetensor_path: If provided, merge the exported diffusion transformer | ||
| with non-transformer components (VAE, vocoder, etc.) from this base safetensors | ||
| file. Only used for diffusion model exports (e.g., LTX-2). | ||
| """ | ||
| export_dir = Path(export_dir) | ||
| export_dir.mkdir(parents=True, exist_ok=True) | ||
|
|
@@ -1010,7 +1129,9 @@ def export_hf_checkpoint( | |
| if HAS_DIFFUSERS: | ||
| is_diffusers_obj = is_diffusers_object(model) | ||
| if is_diffusers_obj: | ||
| _export_diffusers_checkpoint(model, dtype, export_dir, components) | ||
| _export_diffusers_checkpoint( | ||
| model, dtype, export_dir, components, merged_base_safetensor_path | ||
| ) | ||
| return | ||
|
|
||
| # Transformers model export | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, this seems to work only for LTX2.
Are these mapping relationships hard-coded? If so, we should move this logic into a model-dependent function, for example:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and this function needs to be moved to https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/export/diffusers_utils.py