Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,27 @@ The trainer loads your configuration, initializes models, applies optimizations,
For LoRA training, the weights will be saved as `lora_weights.safetensors` in your output directory.
For full model fine-tuning, the weights will be saved as `model_weights.safetensors`.

### 🤗 Pushing Models to Hugging Face Hub

You can automatically push your trained models to the Hugging Face Hub by adding the following to your configuration YAML:

```yaml
hub:
push_to_hub: true
hub_model_id: "your-username/your-model-name" # Your HF username and desired repo name
```

Before pushing, make sure you:
1. Have a Hugging Face account
2. Are logged in via `huggingface-cli login` or have set the `HUGGING_FACE_HUB_TOKEN` environment variable
3. Have write access to the specified repository (it will be created if it doesn't exist)

The trainer will:
- Create a model card with training details and sample outputs
- Upload the model weights (both original and ComfyUI-compatible versions)
- Push sample videos as GIFs in the model card
- Include training configuration and prompts

---

## Fast and simple: Running the Complete Pipeline as one command
Expand Down
19 changes: 18 additions & 1 deletion src/ltxv_trainer/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import Literal

from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from ltxv_trainer.model_loader import LtxvModelVersion
from ltxv_trainer.quantization import QuantizationOptions
Expand Down Expand Up @@ -246,6 +246,22 @@ class CheckpointsConfig(ConfigBaseModel):
)


class HubConfig(ConfigBaseModel):
"""Configuration for Hugging Face Hub integration"""

push_to_hub: bool = Field(default=False, description="Whether to push the model weights to the Hugging Face Hub")
hub_model_id: str | None = Field(
default=None, description="Hugging Face Hub repository ID (e.g., 'username/repo-name')"
)

@model_validator(mode="after")
def validate_hub_config(self) -> "HubConfig":
"""Validate that hub_model_id is not None when push_to_hub is True."""
if self.push_to_hub and not self.hub_model_id:
raise ValueError("hub_model_id must be specified when push_to_hub is True")
return self


class FlowMatchingConfig(ConfigBaseModel):
"""Configuration for flow matching training"""

Expand All @@ -271,6 +287,7 @@ class LtxvTrainerConfig(ConfigBaseModel):
data: DataConfig = Field(default_factory=DataConfig)
validation: ValidationConfig = Field(default_factory=ValidationConfig)
checkpoints: CheckpointsConfig = Field(default_factory=CheckpointsConfig)
hub: HubConfig = Field(default_factory=HubConfig)
flow_matching: FlowMatchingConfig = Field(default_factory=FlowMatchingConfig)

# General configuration
Expand Down
234 changes: 234 additions & 0 deletions src/ltxv_trainer/hub_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import tempfile
from pathlib import Path
from typing import List, Union

import imageio
from huggingface_hub import HfApi, create_repo
from loguru import logger

from ltxv_trainer.config import LtxvTrainerConfig
from ltxv_trainer.model_loader import try_parse_version
from scripts.convert_checkpoint import convert_checkpoint


def convert_video_to_gif(video_path: Path, output_path: Path) -> None:
"""Convert a video file to GIF format."""
try:
# Read the video file
reader = imageio.get_reader(str(video_path))
fps = reader.get_meta_data()["fps"]

# Write GIF file with infinite loop
writer = imageio.get_writer(
str(output_path),
fps=min(fps, 15), # Cap FPS at 15 for reasonable file size
loop=0, # 0 means infinite loop
)

for frame in reader:
writer.append_data(frame)

writer.close()
reader.close()
except Exception as e:
logger.error(f"Failed to convert video to GIF: {e}")
return None


def create_model_card(
output_dir: Union[str, Path],
videos: List[Path],
config: LtxvTrainerConfig,
) -> Path:
"""Generate and save a model card for the trained model."""

repo_id = config.hub.hub_model_id
pretrained_model_name_or_path = config.model.model_source
validation_prompts = config.validation.prompts
output_dir = Path(output_dir)
template_path = Path(__file__).parent.parent.parent / "templates" / "model_card.md"

if not template_path.exists():
logger.warning("⚠️ Model card template not found, using default template")
return

# Read the template
template = template_path.read_text()

# Get model name from repo_id
model_name = repo_id.split("/")[-1]

# Get base model information
version = try_parse_version(pretrained_model_name_or_path)
if version:
base_model_link = version.safetensors_url
base_model_name = str(version)
else:
base_model_link = f"https://huggingface.co/{pretrained_model_name_or_path}"
base_model_name = pretrained_model_name_or_path

# Format validation prompts and create grid layout
prompts_text = ""
sample_grid = []

if validation_prompts and videos:
prompts_text = "Example prompts used during validation:\n\n"

# Create samples directory
samples_dir = output_dir / "samples"
samples_dir.mkdir(exist_ok=True, parents=True)

# Process videos and create cells
cells = []
for i, (prompt, video) in enumerate(zip(validation_prompts, videos, strict=False)):
if video.exists():
# Add prompt to text section
prompts_text += f"- `{prompt}`\n"

# Convert video to GIF
gif_path = samples_dir / f"sample_{i}.gif"
try:
convert_video_to_gif(video, gif_path)

# Create grid cell with collapsible description
cell = (
f"![example{i + 1}](./samples/sample_{i}.gif)"
"<br>"
'<details style="max-width: 300px; margin: auto;">'
f"<summary>Prompt</summary>"
f"{prompt}"
"</details>"
)
cells.append(cell)
except Exception as e:
logger.error(f"Failed to process video {video}: {e}")

# Calculate optimal grid dimensions
num_cells = len(cells)
if num_cells > 0:
# Aim for a roughly square grid, with max 4 columns
num_cols = min(4, num_cells)
num_rows = (num_cells + num_cols - 1) // num_cols # Ceiling division

# Create grid rows
for row in range(num_rows):
start_idx = row * num_cols
end_idx = min(start_idx + num_cols, num_cells)
row_cells = cells[start_idx:end_idx]
# Properly format the row with table markers and exact number of cells
formatted_row = "| " + " | ".join(row_cells) + " |"
sample_grid.append(formatted_row)

# Join grid rows with just the content, no headers needed
grid_text = "\n".join(sample_grid) if sample_grid else ""

# Fill in the template
model_card_content = template.format(
base_model=base_model_name,
base_model_link=base_model_link,
model_name=model_name,
training_type="LoRA fine-tuning" if config.model.training_mode == "lora" else "Full model fine-tuning",
training_steps=config.optimization.steps,
learning_rate=config.optimization.learning_rate,
batch_size=config.optimization.batch_size,
validation_prompts=prompts_text,
sample_grid=grid_text,
)

# Save the model card directly
model_card_path = output_dir / "README.md"
model_card_path.write_text(model_card_content)

return model_card_path


def push_to_hub(weights_path: Path, sampled_videos_paths: List[Path], config: LtxvTrainerConfig) -> None:
"""Push the trained LoRA weights to HuggingFace Hub."""
if not config.hub.push_to_hub:
return

if not config.hub.hub_model_id:
logger.warning("⚠️ HuggingFace hub_model_id not specified, skipping push to hub")
return

api = HfApi()

# Try to create repo if it doesn't exist
try:
create_repo(
repo_id=config.hub.hub_model_id,
repo_type="model",
exist_ok=True, # Don't raise error if repo exists
)
except Exception as e:
logger.error(f"❌ Failed to create repository: {e}")
return

# Upload the original weights file
try:
api.upload_file(
path_or_fileobj=str(weights_path),
path_in_repo=weights_path.name,
repo_id=config.hub.hub_model_id,
repo_type="model",
)
except Exception as e:
logger.error(f"❌ Failed to push {weights_path.name} to HuggingFace Hub: {e}")
# Create a temporary directory for the files we want to upload
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)

try:
# Save model card and copy videos to temp directory
create_model_card(
output_dir=temp_path,
videos=sampled_videos_paths,
config=config,
)

# Upload the model card and samples directory
api.upload_folder(
folder_path=str(temp_path), # Convert to string for compatibility
repo_id=config.hub.hub_model_id,
repo_type="model",
)

logger.info(f"✅ Successfully uploaded model card and sample videos to {config.hub.hub_model_id}")
except Exception as e:
logger.error(f"❌ Failed to save/upload model card and videos: {e}")

logger.info(f"✅ Successfully pushed original LoRA weights to {config.hub.hub_model_id}")

# Convert and upload ComfyUI version
try:
# Create a temporary directory for the converted file
with tempfile.TemporaryDirectory() as temp_dir:
# Convert the weights to ComfyUI format
comfy_path = Path(temp_dir) / f"{weights_path.stem}_comfy{weights_path.suffix}"

convert_checkpoint(
input_path=str(weights_path),
to_comfy=True,
output_path=str(comfy_path),
)

# Find the converted file
converted_files = list(Path(temp_dir).glob("*.safetensors"))
if not converted_files:
logger.warning("⚠️ No converted ComfyUI weights found")
return

converted_file = converted_files[0]
comfy_filename = f"comfyui_{weights_path.name}"

# Upload the converted file
api.upload_file(
path_or_fileobj=str(converted_file),
path_in_repo=comfy_filename,
repo_id=config.hub.hub_model_id,
repo_type="model",
)
logger.info(f"✅ Successfully pushed ComfyUI LoRA weights to {config.hub.hub_model_id}")

except Exception as e:
logger.error(f"❌ Failed to convert and push ComfyUI version: {e}")
6 changes: 3 additions & 3 deletions src/ltxv_trainer/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def load_vae(
"""
if isinstance(source, str): # noqa: SIM102
# Try to parse as version first
if version := _try_parse_version(source):
if version := try_parse_version(source):
source = version

if isinstance(source, LtxvModelVersion):
Expand Down Expand Up @@ -217,7 +217,7 @@ def load_transformer(
"""
if isinstance(source, str): # noqa: SIM102
# Try to parse as version first
if version := _try_parse_version(source):
if version := try_parse_version(source):
source = version

if isinstance(source, LtxvModelVersion):
Expand Down Expand Up @@ -285,7 +285,7 @@ def load_ltxv_components(
)


def _try_parse_version(source: str | Path) -> LtxvModelVersion | None:
def try_parse_version(source: str | Path) -> LtxvModelVersion | None:
"""
Try to parse a string as an LtxvModelVersion.

Expand Down
10 changes: 8 additions & 2 deletions src/ltxv_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

from ltxv_trainer.config import LtxvTrainerConfig
from ltxv_trainer.datasets import PrecomputedDataset
from ltxv_trainer.hub_utils import push_to_hub
from ltxv_trainer.model_loader import load_ltxv_components
from ltxv_trainer.quantization import quantize_model
from ltxv_trainer.timestep_samplers import SAMPLERS
Expand Down Expand Up @@ -155,6 +156,8 @@ def train( # noqa: PLR0912, PLR0915
# Track when actual training starts (after compilation)
actual_training_start = None

sampled_videos_paths = None

with Live(Panel(Group(train_progress, sample_progress)), refresh_per_second=2):
task = train_progress.add_task(
"Training",
Expand All @@ -165,7 +168,7 @@ def train( # noqa: PLR0912, PLR0915
)

if cfg.validation.interval:
self._sample_videos(sample_progress)
sampled_videos_paths = self._sample_videos(sample_progress)

for step in range(cfg.optimization.steps):
# Get next batch, reset the dataloader if needed
Expand Down Expand Up @@ -202,7 +205,6 @@ def train( # noqa: PLR0912, PLR0915

if self._lr_scheduler is not None:
self._lr_scheduler.step()

# Run validation if needed
if (
cfg.validation.interval
Expand Down Expand Up @@ -291,6 +293,10 @@ def train( # noqa: PLR0912, PLR0915
if self._accelerator.is_main_process:
saved_path = self._save_checkpoint()

# Upload artifacts to hub if enabled
if cfg.hub.push_to_hub:
push_to_hub(saved_path, sampled_videos_paths, self._config)

# Log the training statistics
self._log_training_stats(stats)

Expand Down
Loading