Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/ltxv_trainer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,22 @@ class CheckpointsConfig(ConfigBaseModel):
)


class HubConfig(ConfigBaseModel):
"""Configuration for Hugging Face Hub integration"""

push_to_hub: bool = Field(
default=False,
description="Whether to push the model weights to the Hugging Face Hub"
)
hub_model_id: str = Field(
default=None,
description="Hugging Face Hub repository ID (e.g., 'username/repo-name')"
)
hub_token: str = Field(
default=None,
description="Hugging Face token. If None, will use the token from the Hugging Face CLI"
)

class FlowMatchingConfig(ConfigBaseModel):
"""Configuration for flow matching training"""

Expand All @@ -259,6 +275,7 @@ class LtxvTrainerConfig(ConfigBaseModel):
data: DataConfig = Field(default_factory=DataConfig)
validation: ValidationConfig = Field(default_factory=ValidationConfig)
checkpoints: CheckpointsConfig = Field(default_factory=CheckpointsConfig)
hub: HubConfig = Field(default_factory=HubConfig)
flow_matching: FlowMatchingConfig = Field(default_factory=FlowMatchingConfig)

# General configuration
Expand Down
22 changes: 21 additions & 1 deletion src/ltxv_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from accelerate.utils import set_seed
from diffusers import LTXPipeline
from diffusers.utils import export_to_video
from huggingface_hub import create_repo, upload_folder
from loguru import logger
from peft import LoraConfig, get_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
Expand Down Expand Up @@ -50,7 +51,7 @@
from ltxv_trainer.model_loader import load_ltxv_components
from ltxv_trainer.quantization import quantize_model
from ltxv_trainer.timestep_samplers import SAMPLERS
from ltxv_trainer.utils import get_gpu_memory_gb
from ltxv_trainer.utils import get_gpu_memory_gb, save_model_card

# Disable irrelevant warnings from transformers
os.environ["TOKENIZERS_PARALLELISM"] = "true"
Expand Down Expand Up @@ -286,6 +287,25 @@ def train( # noqa: PLR0912, PLR0915
if self._accelerator.is_main_process:
saved_path = self._save_checkpoint()

# Upload artifacts to hub if enabled
if cfg.hub.push_to_hub:
repo_id = cfg.hub.hub_model_id or Path(cfg.output_dir).name
repo_id = create_repo(token=cfg.hub.hub_token, repo_id=repo_id, exist_ok=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the base model parameter here.
You can either point to the general LTXV repo or go finegrain amd allow the user to set a specific version. Either works for me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a model card now :), wdyt?

video_filenames = sampled_videos_paths if sampled_videos_paths else []

save_model_card(
output_dir=cfg.output_dir,
repo_id=repo_id,
pretrained_model_name_or_path=cfg.model.model_source,
videos=video_filenames,
validation_prompts=self._config.validation.prompts
)

upload_folder(
repo_id=repo_id,
folder_path=Path(self._config.output_dir),
)

# Log the training statistics
self._log_training_stats(stats)

Expand Down
149 changes: 149 additions & 0 deletions src/ltxv_trainer/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import io
import subprocess
import os
from typing import List, Union
from pathlib import Path

import torch
from loguru import logger
from PIL import ExifTags, Image, ImageCms, ImageOps
from PIL.Image import Image as PilImage
import numpy as np
from diffusers.utils import export_to_video
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card



from PIL import Image

def get_gpu_memory_gb(device: torch.device) -> float:
"""Get current GPU memory usage in GB using nvidia-smi"""
try:
Expand Down Expand Up @@ -69,3 +77,144 @@ def open_image_as_srgb(image_path: str | Path | io.BytesIO) -> PilImage:
srgb_img.info["icc_profile"] = srgb_profile_data

return srgb_img


def save_model_card(
output_dir: str,
repo_id: str,
pretrained_model_name_or_path: str,
videos: Union[List[str], Union[List[PilImage.Image], List[np.ndarray]]],
validation_prompts: List[str],
fps: int = 30,
) -> None:
widget_dict = []
if videos is not None and len(videos) > 0:
for i, (video, validation_prompt) in enumerate(zip(videos, validation_prompts)):
if not isinstance(video, str):
export_to_video(video, os.path.join(output_dir, f"final_video_{i}.mp4"), fps=fps)
widget_dict.append(
{
"text": validation_prompt if validation_prompt else " ",
"output": {"url": video if isinstance(video, str) else f"final_video_{i}.mp4"},
}
)
if pretrained_model_name_or_path not in ["Lightricks/LTX-Video", "Lightricks/LTX-Video-0.9.5"]:
pretrained_model_name_or_path = "Lightricks/LTX-Video"

model_description = f"""
# LoRA Finetune

<Gallery />

## Model description

This is a lora finetune of model: `{pretrained_model_name_or_path}`.

The model was trained using [`LTX-Video Community Trainer`](https://github.com/Lightricks/LTX-Video-Trainer).

## Download model

[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.

## Usage

### Using Trained LoRAs with `diffusers`:
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.

Text-to-Video generation using the trained LoRA:
```python
import torch
from diffusers import LTXPipeline
from diffusers.utils import export_to_video

pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
pipe.load_lora_weights("{repo_id}", adapter_name="ltxv-lora")
pipe.set_adapters(["ltxv-lora"], [0.75])
pipe.to("cuda")

prompt = "{validation_prompts[0]}"
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=704,
height=480,
num_inference_steps=50,
).frames[0]
export_to_video(video, "output.mp4", fps=24)
```

For Image-to-Video:
```python
import torch
from diffusers import LTXImageToVideoPipeline
from diffusers.utils import export_to_video, load_image

pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
pipe.load_lora_weights("{repo_id}", adapter_name="ltxv-lora")
pipe.set_adapters(["ltxv-lora"], [0.75])
pipe.to("cuda")

image = load_image(
"url_to_your_image",
)
prompt = "{validation_prompts[0]}"
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

video = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
width=704,
height=480,
num_inference_steps=50,
).frames[0]
export_to_video(video, "output.mp4", fps=24)
```

### 🔌 Using Trained LoRAs in ComfyUI

After training your LoRA, you can use it in ComfyUI by following these steps:

1. Copy your trained LoRA weights (`.safetensors` file) to the `models/loras` folder in your ComfyUI installation.

2. Install the ComfyUI-LTXVideoLoRA custom node:

```bash
# In the root folder of your ComfyUI installation
cd custom_nodes
git clone https://github.com/dorpxam/ComfyUI-LTXVideoLoRA
pip install -r ComfyUI-LTXVideoLoRA/requirements.txt
```

3. In your ComfyUI workflow:
- Add the "LTXV LoRA Selector" node to choose your LoRA file
- Connect it to the "LTXV LoRA Loader" node to apply the LoRA to your generation

You can find reference Text-to-Video (T2V) and Image-to-Video (I2V) workflows in the [official LTXV ComfyUI repository](https://github.com/Lightricks/ComfyUI-LTXVideo).

```py
TODO
```

For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
"""

model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
base_model=pretrained_model_name_or_path,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-video",
"image-to-video",
"ltx-video"
"diffusers",
"lora",
"template:sd-lora",
]

model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(output_dir, "README.md"))