diff --git a/README.md b/README.md
index dd61b50..c09fca0 100644
--- a/README.md
+++ b/README.md
@@ -302,6 +302,27 @@ The trainer loads your configuration, initializes models, applies optimizations,
For LoRA training, the weights will be saved as `lora_weights.safetensors` in your output directory.
For full model fine-tuning, the weights will be saved as `model_weights.safetensors`.
+### 🤗 Pushing Models to Hugging Face Hub
+
+You can automatically push your trained models to the Hugging Face Hub by adding the following to your configuration YAML:
+
+```yaml
+hub:
+ push_to_hub: true
+ hub_model_id: "your-username/your-model-name" # Your HF username and desired repo name
+```
+
+Before pushing, make sure you:
+1. Have a Hugging Face account
+2. Are logged in via `huggingface-cli login` or have set the `HUGGING_FACE_HUB_TOKEN` environment variable
+3. Have write access to the specified repository (it will be created if it doesn't exist)
+
+The trainer will:
+- Create a model card with training details and sample outputs
+- Upload the model weights (both original and ComfyUI-compatible versions)
+- Push sample videos as GIFs in the model card
+- Include training configuration and prompts
+
---
## Fast and simple: Running the Complete Pipeline as one command
diff --git a/src/ltxv_trainer/config.py b/src/ltxv_trainer/config.py
index a99eaf5..62b2701 100644
--- a/src/ltxv_trainer/config.py
+++ b/src/ltxv_trainer/config.py
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import Literal
-from pydantic import BaseModel, ConfigDict, Field, field_validator
+from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from ltxv_trainer.model_loader import LtxvModelVersion
from ltxv_trainer.quantization import QuantizationOptions
@@ -246,6 +246,22 @@ class CheckpointsConfig(ConfigBaseModel):
)
+class HubConfig(ConfigBaseModel):
+ """Configuration for Hugging Face Hub integration"""
+
+ push_to_hub: bool = Field(default=False, description="Whether to push the model weights to the Hugging Face Hub")
+ hub_model_id: str | None = Field(
+ default=None, description="Hugging Face Hub repository ID (e.g., 'username/repo-name')"
+ )
+
+ @model_validator(mode="after")
+ def validate_hub_config(self) -> "HubConfig":
+ """Validate that hub_model_id is not None when push_to_hub is True."""
+ if self.push_to_hub and not self.hub_model_id:
+ raise ValueError("hub_model_id must be specified when push_to_hub is True")
+ return self
+
+
class FlowMatchingConfig(ConfigBaseModel):
"""Configuration for flow matching training"""
@@ -271,6 +287,7 @@ class LtxvTrainerConfig(ConfigBaseModel):
data: DataConfig = Field(default_factory=DataConfig)
validation: ValidationConfig = Field(default_factory=ValidationConfig)
checkpoints: CheckpointsConfig = Field(default_factory=CheckpointsConfig)
+ hub: HubConfig = Field(default_factory=HubConfig)
flow_matching: FlowMatchingConfig = Field(default_factory=FlowMatchingConfig)
# General configuration
diff --git a/src/ltxv_trainer/hub_utils.py b/src/ltxv_trainer/hub_utils.py
new file mode 100644
index 0000000..ee7412b
--- /dev/null
+++ b/src/ltxv_trainer/hub_utils.py
@@ -0,0 +1,234 @@
+import tempfile
+from pathlib import Path
+from typing import List, Union
+
+import imageio
+from huggingface_hub import HfApi, create_repo
+from loguru import logger
+
+from ltxv_trainer.config import LtxvTrainerConfig
+from ltxv_trainer.model_loader import try_parse_version
+from scripts.convert_checkpoint import convert_checkpoint
+
+
+def convert_video_to_gif(video_path: Path, output_path: Path) -> None:
+ """Convert a video file to GIF format."""
+ try:
+ # Read the video file
+ reader = imageio.get_reader(str(video_path))
+ fps = reader.get_meta_data()["fps"]
+
+ # Write GIF file with infinite loop
+ writer = imageio.get_writer(
+ str(output_path),
+ fps=min(fps, 15), # Cap FPS at 15 for reasonable file size
+ loop=0, # 0 means infinite loop
+ )
+
+ for frame in reader:
+ writer.append_data(frame)
+
+ writer.close()
+ reader.close()
+ except Exception as e:
+ logger.error(f"Failed to convert video to GIF: {e}")
+ return None
+
+
+def create_model_card(
+ output_dir: Union[str, Path],
+ videos: List[Path],
+ config: LtxvTrainerConfig,
+) -> Path:
+ """Generate and save a model card for the trained model."""
+
+ repo_id = config.hub.hub_model_id
+ pretrained_model_name_or_path = config.model.model_source
+ validation_prompts = config.validation.prompts
+ output_dir = Path(output_dir)
+ template_path = Path(__file__).parent.parent.parent / "templates" / "model_card.md"
+
+ if not template_path.exists():
+ logger.warning("⚠️ Model card template not found, using default template")
+ return
+
+ # Read the template
+ template = template_path.read_text()
+
+ # Get model name from repo_id
+ model_name = repo_id.split("/")[-1]
+
+ # Get base model information
+ version = try_parse_version(pretrained_model_name_or_path)
+ if version:
+ base_model_link = version.safetensors_url
+ base_model_name = str(version)
+ else:
+ base_model_link = f"https://huggingface.co/{pretrained_model_name_or_path}"
+ base_model_name = pretrained_model_name_or_path
+
+ # Format validation prompts and create grid layout
+ prompts_text = ""
+ sample_grid = []
+
+ if validation_prompts and videos:
+ prompts_text = "Example prompts used during validation:\n\n"
+
+ # Create samples directory
+ samples_dir = output_dir / "samples"
+ samples_dir.mkdir(exist_ok=True, parents=True)
+
+ # Process videos and create cells
+ cells = []
+ for i, (prompt, video) in enumerate(zip(validation_prompts, videos, strict=False)):
+ if video.exists():
+ # Add prompt to text section
+ prompts_text += f"- `{prompt}`\n"
+
+ # Convert video to GIF
+ gif_path = samples_dir / f"sample_{i}.gif"
+ try:
+ convert_video_to_gif(video, gif_path)
+
+ # Create grid cell with collapsible description
+ cell = (
+ f""
+ "
"
+ ''
+ f"Prompt
"
+ f"{prompt}"
+ " "
+ )
+ cells.append(cell)
+ except Exception as e:
+ logger.error(f"Failed to process video {video}: {e}")
+
+ # Calculate optimal grid dimensions
+ num_cells = len(cells)
+ if num_cells > 0:
+ # Aim for a roughly square grid, with max 4 columns
+ num_cols = min(4, num_cells)
+ num_rows = (num_cells + num_cols - 1) // num_cols # Ceiling division
+
+ # Create grid rows
+ for row in range(num_rows):
+ start_idx = row * num_cols
+ end_idx = min(start_idx + num_cols, num_cells)
+ row_cells = cells[start_idx:end_idx]
+ # Properly format the row with table markers and exact number of cells
+ formatted_row = "| " + " | ".join(row_cells) + " |"
+ sample_grid.append(formatted_row)
+
+ # Join grid rows with just the content, no headers needed
+ grid_text = "\n".join(sample_grid) if sample_grid else ""
+
+ # Fill in the template
+ model_card_content = template.format(
+ base_model=base_model_name,
+ base_model_link=base_model_link,
+ model_name=model_name,
+ training_type="LoRA fine-tuning" if config.model.training_mode == "lora" else "Full model fine-tuning",
+ training_steps=config.optimization.steps,
+ learning_rate=config.optimization.learning_rate,
+ batch_size=config.optimization.batch_size,
+ validation_prompts=prompts_text,
+ sample_grid=grid_text,
+ )
+
+ # Save the model card directly
+ model_card_path = output_dir / "README.md"
+ model_card_path.write_text(model_card_content)
+
+ return model_card_path
+
+
+def push_to_hub(weights_path: Path, sampled_videos_paths: List[Path], config: LtxvTrainerConfig) -> None:
+ """Push the trained LoRA weights to HuggingFace Hub."""
+ if not config.hub.push_to_hub:
+ return
+
+ if not config.hub.hub_model_id:
+ logger.warning("⚠️ HuggingFace hub_model_id not specified, skipping push to hub")
+ return
+
+ api = HfApi()
+
+ # Try to create repo if it doesn't exist
+ try:
+ create_repo(
+ repo_id=config.hub.hub_model_id,
+ repo_type="model",
+ exist_ok=True, # Don't raise error if repo exists
+ )
+ except Exception as e:
+ logger.error(f"❌ Failed to create repository: {e}")
+ return
+
+ # Upload the original weights file
+ try:
+ api.upload_file(
+ path_or_fileobj=str(weights_path),
+ path_in_repo=weights_path.name,
+ repo_id=config.hub.hub_model_id,
+ repo_type="model",
+ )
+ except Exception as e:
+ logger.error(f"❌ Failed to push {weights_path.name} to HuggingFace Hub: {e}")
+ # Create a temporary directory for the files we want to upload
+ with tempfile.TemporaryDirectory() as temp_dir:
+ temp_path = Path(temp_dir)
+
+ try:
+ # Save model card and copy videos to temp directory
+ create_model_card(
+ output_dir=temp_path,
+ videos=sampled_videos_paths,
+ config=config,
+ )
+
+ # Upload the model card and samples directory
+ api.upload_folder(
+ folder_path=str(temp_path), # Convert to string for compatibility
+ repo_id=config.hub.hub_model_id,
+ repo_type="model",
+ )
+
+ logger.info(f"✅ Successfully uploaded model card and sample videos to {config.hub.hub_model_id}")
+ except Exception as e:
+ logger.error(f"❌ Failed to save/upload model card and videos: {e}")
+
+ logger.info(f"✅ Successfully pushed original LoRA weights to {config.hub.hub_model_id}")
+
+ # Convert and upload ComfyUI version
+ try:
+ # Create a temporary directory for the converted file
+ with tempfile.TemporaryDirectory() as temp_dir:
+ # Convert the weights to ComfyUI format
+ comfy_path = Path(temp_dir) / f"{weights_path.stem}_comfy{weights_path.suffix}"
+
+ convert_checkpoint(
+ input_path=str(weights_path),
+ to_comfy=True,
+ output_path=str(comfy_path),
+ )
+
+ # Find the converted file
+ converted_files = list(Path(temp_dir).glob("*.safetensors"))
+ if not converted_files:
+ logger.warning("⚠️ No converted ComfyUI weights found")
+ return
+
+ converted_file = converted_files[0]
+ comfy_filename = f"comfyui_{weights_path.name}"
+
+ # Upload the converted file
+ api.upload_file(
+ path_or_fileobj=str(converted_file),
+ path_in_repo=comfy_filename,
+ repo_id=config.hub.hub_model_id,
+ repo_type="model",
+ )
+ logger.info(f"✅ Successfully pushed ComfyUI LoRA weights to {config.hub.hub_model_id}")
+
+ except Exception as e:
+ logger.error(f"❌ Failed to convert and push ComfyUI version: {e}")
diff --git a/src/ltxv_trainer/model_loader.py b/src/ltxv_trainer/model_loader.py
index 15aacb0..6c49e9d 100644
--- a/src/ltxv_trainer/model_loader.py
+++ b/src/ltxv_trainer/model_loader.py
@@ -160,7 +160,7 @@ def load_vae(
"""
if isinstance(source, str): # noqa: SIM102
# Try to parse as version first
- if version := _try_parse_version(source):
+ if version := try_parse_version(source):
source = version
if isinstance(source, LtxvModelVersion):
@@ -217,7 +217,7 @@ def load_transformer(
"""
if isinstance(source, str): # noqa: SIM102
# Try to parse as version first
- if version := _try_parse_version(source):
+ if version := try_parse_version(source):
source = version
if isinstance(source, LtxvModelVersion):
@@ -285,7 +285,7 @@ def load_ltxv_components(
)
-def _try_parse_version(source: str | Path) -> LtxvModelVersion | None:
+def try_parse_version(source: str | Path) -> LtxvModelVersion | None:
"""
Try to parse a string as an LtxvModelVersion.
diff --git a/src/ltxv_trainer/trainer.py b/src/ltxv_trainer/trainer.py
index d4f08f6..bb10a16 100644
--- a/src/ltxv_trainer/trainer.py
+++ b/src/ltxv_trainer/trainer.py
@@ -46,6 +46,7 @@
from ltxv_trainer.config import LtxvTrainerConfig
from ltxv_trainer.datasets import PrecomputedDataset
+from ltxv_trainer.hub_utils import push_to_hub
from ltxv_trainer.model_loader import load_ltxv_components
from ltxv_trainer.quantization import quantize_model
from ltxv_trainer.timestep_samplers import SAMPLERS
@@ -155,6 +156,8 @@ def train( # noqa: PLR0912, PLR0915
# Track when actual training starts (after compilation)
actual_training_start = None
+ sampled_videos_paths = None
+
with Live(Panel(Group(train_progress, sample_progress)), refresh_per_second=2):
task = train_progress.add_task(
"Training",
@@ -165,7 +168,7 @@ def train( # noqa: PLR0912, PLR0915
)
if cfg.validation.interval:
- self._sample_videos(sample_progress)
+ sampled_videos_paths = self._sample_videos(sample_progress)
for step in range(cfg.optimization.steps):
# Get next batch, reset the dataloader if needed
@@ -202,7 +205,6 @@ def train( # noqa: PLR0912, PLR0915
if self._lr_scheduler is not None:
self._lr_scheduler.step()
-
# Run validation if needed
if (
cfg.validation.interval
@@ -291,6 +293,10 @@ def train( # noqa: PLR0912, PLR0915
if self._accelerator.is_main_process:
saved_path = self._save_checkpoint()
+ # Upload artifacts to hub if enabled
+ if cfg.hub.push_to_hub:
+ push_to_hub(saved_path, sampled_videos_paths, self._config)
+
# Log the training statistics
self._log_training_stats(stats)
diff --git a/templates/model_card.md b/templates/model_card.md
new file mode 100644
index 0000000..b70682e
--- /dev/null
+++ b/templates/model_card.md
@@ -0,0 +1,42 @@
+# {model_name}
+
+This is a fine-tuned version of [`{base_model}`]({base_model_link}) trained on custom data.
+
+## Model Details
+
+- **Base Model:** [`{base_model}`]({base_model_link})
+- **Training Type:** {training_type}
+- **Training Steps:** {training_steps}
+- **Learning Rate:** {learning_rate}
+- **Batch Size:** {batch_size}
+
+## Sample Outputs
+
+| | | | |
+|:---:|:---:|:---:|:---:|
+{sample_grid}
+
+## Usage
+
+This model is designed to be used with the LTXV (Lightricks Text-to-Video) pipeline.
+
+### 🔌 Using Trained LoRAs in ComfyUI
+In order to use the trained lora in comfy:
+1. Copy your comfyui trained LoRA weights (`comfyui..safetensors` file) to the `models/loras` folder in your ComfyUI installation.
+2. In your ComfyUI workflow:
+ - Add the "LTXV LoRA Selector" node to choose your LoRA file
+ - Connect it to the "LTXV LoRA Loader" node to apply the LoRA to your generation
+
+You can find reference Text-to-Video (T2V) and Image-to-Video (I2V) workflows in the [official LTXV ComfyUI repository](https://github.com/Lightricks/ComfyUI-LTXVideo).
+
+### Example Prompts
+
+{validation_prompts}
+
+
+This model inherits the license of the base model ([`{base_model}`]({base_model_link})).
+
+## Acknowledgments
+
+- Base model by [Lightricks](https://huggingface.co/Lightricks)
+- Training infrastructure: [LTX-Video-Trainer](https://github.com/Lightricks/ltx-video-trainer)