diff --git a/src/ltxv_trainer/config.py b/src/ltxv_trainer/config.py index c1606a6..27096c2 100644 --- a/src/ltxv_trainer/config.py +++ b/src/ltxv_trainer/config.py @@ -234,6 +234,22 @@ class CheckpointsConfig(ConfigBaseModel): ) +class HubConfig(ConfigBaseModel): + """Configuration for Hugging Face Hub integration""" + + push_to_hub: bool = Field( + default=False, + description="Whether to push the model weights to the Hugging Face Hub" + ) + hub_model_id: str = Field( + default=None, + description="Hugging Face Hub repository ID (e.g., 'username/repo-name')" + ) + hub_token: str = Field( + default=None, + description="Hugging Face token. If None, will use the token from the Hugging Face CLI" + ) + class FlowMatchingConfig(ConfigBaseModel): """Configuration for flow matching training""" @@ -259,6 +275,7 @@ class LtxvTrainerConfig(ConfigBaseModel): data: DataConfig = Field(default_factory=DataConfig) validation: ValidationConfig = Field(default_factory=ValidationConfig) checkpoints: CheckpointsConfig = Field(default_factory=CheckpointsConfig) + hub: HubConfig = Field(default_factory=HubConfig) flow_matching: FlowMatchingConfig = Field(default_factory=FlowMatchingConfig) # General configuration diff --git a/src/ltxv_trainer/trainer.py b/src/ltxv_trainer/trainer.py index 9128f32..e618762 100644 --- a/src/ltxv_trainer/trainer.py +++ b/src/ltxv_trainer/trainer.py @@ -14,6 +14,7 @@ from accelerate.utils import set_seed from diffusers import LTXPipeline from diffusers.utils import export_to_video +from huggingface_hub import create_repo, upload_folder from loguru import logger from peft import LoraConfig, get_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer @@ -50,7 +51,7 @@ from ltxv_trainer.model_loader import load_ltxv_components from ltxv_trainer.quantization import quantize_model from ltxv_trainer.timestep_samplers import SAMPLERS -from ltxv_trainer.utils import get_gpu_memory_gb +from ltxv_trainer.utils import get_gpu_memory_gb, save_model_card # Disable irrelevant warnings from transformers os.environ["TOKENIZERS_PARALLELISM"] = "true" @@ -286,6 +287,25 @@ def train( # noqa: PLR0912, PLR0915 if self._accelerator.is_main_process: saved_path = self._save_checkpoint() + # Upload artifacts to hub if enabled + if cfg.hub.push_to_hub: + repo_id = cfg.hub.hub_model_id or Path(cfg.output_dir).name + repo_id = create_repo(token=cfg.hub.hub_token, repo_id=repo_id, exist_ok=True) + video_filenames = sampled_videos_paths if sampled_videos_paths else [] + + save_model_card( + output_dir=cfg.output_dir, + repo_id=repo_id, + pretrained_model_name_or_path=cfg.model.model_source, + videos=video_filenames, + validation_prompts=self._config.validation.prompts + ) + + upload_folder( + repo_id=repo_id, + folder_path=Path(self._config.output_dir), + ) + # Log the training statistics self._log_training_stats(stats) diff --git a/src/ltxv_trainer/utils.py b/src/ltxv_trainer/utils.py index dbbaf71..b69d89a 100644 --- a/src/ltxv_trainer/utils.py +++ b/src/ltxv_trainer/utils.py @@ -1,13 +1,21 @@ import io import subprocess +import os +from typing import List, Union from pathlib import Path import torch from loguru import logger from PIL import ExifTags, Image, ImageCms, ImageOps from PIL.Image import Image as PilImage +import numpy as np +from diffusers.utils import export_to_video +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card + +from PIL import Image + def get_gpu_memory_gb(device: torch.device) -> float: """Get current GPU memory usage in GB using nvidia-smi""" try: @@ -69,3 +77,144 @@ def open_image_as_srgb(image_path: str | Path | io.BytesIO) -> PilImage: srgb_img.info["icc_profile"] = srgb_profile_data return srgb_img + + +def save_model_card( + output_dir: str, + repo_id: str, + pretrained_model_name_or_path: str, + videos: Union[List[str], Union[List[PilImage.Image], List[np.ndarray]]], + validation_prompts: List[str], + fps: int = 30, +) -> None: + widget_dict = [] + if videos is not None and len(videos) > 0: + for i, (video, validation_prompt) in enumerate(zip(videos, validation_prompts)): + if not isinstance(video, str): + export_to_video(video, os.path.join(output_dir, f"final_video_{i}.mp4"), fps=fps) + widget_dict.append( + { + "text": validation_prompt if validation_prompt else " ", + "output": {"url": video if isinstance(video, str) else f"final_video_{i}.mp4"}, + } + ) + if pretrained_model_name_or_path not in ["Lightricks/LTX-Video", "Lightricks/LTX-Video-0.9.5"]: + pretrained_model_name_or_path = "Lightricks/LTX-Video" + + model_description = f""" +# LoRA Finetune + + + +## Model description + +This is a lora finetune of model: `{pretrained_model_name_or_path}`. + +The model was trained using [`LTX-Video Community Trainer`](https://github.com/Lightricks/LTX-Video-Trainer). + +## Download model + +[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. + +## Usage + +### Using Trained LoRAs with `diffusers`: +Requires the [๐Ÿงจ Diffusers library](https://github.com/huggingface/diffusers) installed. + +Text-to-Video generation using the trained LoRA: +```python +import torch +from diffusers import LTXPipeline +from diffusers.utils import export_to_video + +pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) +pipe.load_lora_weights("{repo_id}", adapter_name="ltxv-lora") +pipe.set_adapters(["ltxv-lora"], [0.75]) +pipe.to("cuda") + +prompt = "{validation_prompts[0]}" +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" +video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=704, + height=480, + num_inference_steps=50, +).frames[0] +export_to_video(video, "output.mp4", fps=24) +``` + +For Image-to-Video: +```python +import torch +from diffusers import LTXImageToVideoPipeline +from diffusers.utils import export_to_video, load_image + +pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) +pipe.load_lora_weights("{repo_id}", adapter_name="ltxv-lora") +pipe.set_adapters(["ltxv-lora"], [0.75]) +pipe.to("cuda") + +image = load_image( + "url_to_your_image", +) +prompt = "{validation_prompts[0]}" +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + +video = pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + width=704, + height=480, + num_inference_steps=50, +).frames[0] +export_to_video(video, "output.mp4", fps=24) +``` + +### ๐Ÿ”Œ Using Trained LoRAs in ComfyUI + +After training your LoRA, you can use it in ComfyUI by following these steps: + +1. Copy your trained LoRA weights (`.safetensors` file) to the `models/loras` folder in your ComfyUI installation. + +2. Install the ComfyUI-LTXVideoLoRA custom node: + + ```bash + # In the root folder of your ComfyUI installation + cd custom_nodes + git clone https://github.com/dorpxam/ComfyUI-LTXVideoLoRA + pip install -r ComfyUI-LTXVideoLoRA/requirements.txt + ``` + +3. In your ComfyUI workflow: + - Add the "LTXV LoRA Selector" node to choose your LoRA file + - Connect it to the "LTXV LoRA Loader" node to apply the LoRA to your generation + +You can find reference Text-to-Video (T2V) and Image-to-Video (I2V) workflows in the [official LTXV ComfyUI repository](https://github.com/Lightricks/ComfyUI-LTXVideo). + +```py +TODO +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. +""" + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + base_model=pretrained_model_name_or_path, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "image-to-video", + "ltx-video" + "diffusers", + "lora", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(output_dir, "README.md")) \ No newline at end of file