Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx2_audio_vae_to_diffusers,
convert_ltx2_transformer_to_diffusers,
convert_ltx2_vae_to_diffusers,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
convert_lumina2_to_diffusers,
Expand Down Expand Up @@ -176,6 +179,18 @@
"ZImageControlNetModel": {
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
},
"LTX2VideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx2_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLLTX2Video": {
"checkpoint_mapping_fn": convert_ltx2_vae_to_diffusers,
"default_subfolder": "vae",
},
"AutoencoderKLLTX2Audio": {
"checkpoint_mapping_fn": convert_ltx2_audio_vae_to_diffusers,
"default_subfolder": "audio_vae",
},
}


Expand Down
170 changes: 169 additions & 1 deletion src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
"patchify_proj.weight",
"transformer_blocks.27.scale_shift_table",
"vae.per_channel_statistics.mean-of-means",
"vae.decoder.last_scale_shift_table", # 0.9.1, 0.9.5, 0.9.7, 0.9.8
"vae.decoder.up_blocks.9.res_blocks.0.conv1.conv.weight", # 0.9.0
],
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
Expand Down Expand Up @@ -147,6 +148,11 @@
"net.pos_embedder.dim_spatial_range",
],
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
"ltx2": [
"model.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1.weight",
"vae.per_channel_statistics.mean-of-means",
"audio_vae.per_channel_statistics.mean-of-means",
],
}

DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
Expand Down Expand Up @@ -228,6 +234,7 @@
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
"z-image-turbo-controlnet-2.0": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0"},
"z-image-turbo-controlnet-2.1": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
"ltx2-dev": {"pretrained_model_name_or_path": "Lightricks/LTX-2"},
}

# Use to configure model sample size when original config is provided
Expand Down Expand Up @@ -796,6 +803,9 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
model_type = "z-image-turbo-controlnet"

elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx2"]):
model_type = "ltx2-dev"

else:
model_type = "v1"

Expand Down Expand Up @@ -3920,3 +3930,161 @@ def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwa
return converted_state_dict
else:
raise ValueError("Unknown Z-Image Turbo ControlNet type.")


def convert_ltx2_transformer_to_diffusers(checkpoint, **kwargs):
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
# Transformer prefix
"model.diffusion_model.": "",
# Input Patchify Projections
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
# Modulation Parameters
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
# substrings of the other modulation parameters below
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
# Transformer Blocks
# Per-Block Cross Attention Modulation Parameters
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}

def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)

def remove_keys_inplace(key: str, state_dict) -> None:
state_dict.pop(key)

def convert_ltx2_transformer_adaln_single(key: str, state_dict) -> None:
# Skip if not a weight, bias
if ".weight" not in key and ".bias" not in key:
return

if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param

if key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param

return

LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"video_embeddings_connector": remove_keys_inplace,
"audio_embeddings_connector": remove_keys_inplace,
"adaln_single": convert_ltx2_transformer_adaln_single,
}

converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}

# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)

update_state_dict_inplace(converted_state_dict, key, new_key)

# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)

return converted_state_dict


def convert_ltx2_vae_to_diffusers(checkpoint, **kwargs):
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
# Video VAE prefix
"vae.": "",
# Encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# Decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}

def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)

def remove_keys_inplace(key: str, state_dict) -> None:
state_dict.pop(key)

LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}

converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}

# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in LTX_2_0_VIDEO_VAE_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)

update_state_dict_inplace(converted_state_dict, key, new_key)

# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in LTX_2_0_VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)

return converted_state_dict


def convert_ltx2_audio_vae_to_diffusers(checkpoint, **kwargs):
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
# Audio VAE prefix
"audio_vae.": "",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}

def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)

converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}

# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in LTX_2_0_AUDIO_VAE_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)

update_state_dict_inplace(converted_state_dict, key, new_key)

return converted_state_dict
Loading