diff --git a/README.md b/README.md index dd61b50..c09fca0 100644 --- a/README.md +++ b/README.md @@ -302,6 +302,27 @@ The trainer loads your configuration, initializes models, applies optimizations, For LoRA training, the weights will be saved as `lora_weights.safetensors` in your output directory. For full model fine-tuning, the weights will be saved as `model_weights.safetensors`. +### 🤗 Pushing Models to Hugging Face Hub + +You can automatically push your trained models to the Hugging Face Hub by adding the following to your configuration YAML: + +```yaml +hub: + push_to_hub: true + hub_model_id: "your-username/your-model-name" # Your HF username and desired repo name +``` + +Before pushing, make sure you: +1. Have a Hugging Face account +2. Are logged in via `huggingface-cli login` or have set the `HUGGING_FACE_HUB_TOKEN` environment variable +3. Have write access to the specified repository (it will be created if it doesn't exist) + +The trainer will: +- Create a model card with training details and sample outputs +- Upload the model weights (both original and ComfyUI-compatible versions) +- Push sample videos as GIFs in the model card +- Include training configuration and prompts + --- ## Fast and simple: Running the Complete Pipeline as one command diff --git a/src/ltxv_trainer/config.py b/src/ltxv_trainer/config.py index a99eaf5..62b2701 100644 --- a/src/ltxv_trainer/config.py +++ b/src/ltxv_trainer/config.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Literal -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from ltxv_trainer.model_loader import LtxvModelVersion from ltxv_trainer.quantization import QuantizationOptions @@ -246,6 +246,22 @@ class CheckpointsConfig(ConfigBaseModel): ) +class HubConfig(ConfigBaseModel): + """Configuration for Hugging Face Hub integration""" + + push_to_hub: bool = Field(default=False, description="Whether to push the model weights to the Hugging Face Hub") + hub_model_id: str | None = Field( + default=None, description="Hugging Face Hub repository ID (e.g., 'username/repo-name')" + ) + + @model_validator(mode="after") + def validate_hub_config(self) -> "HubConfig": + """Validate that hub_model_id is not None when push_to_hub is True.""" + if self.push_to_hub and not self.hub_model_id: + raise ValueError("hub_model_id must be specified when push_to_hub is True") + return self + + class FlowMatchingConfig(ConfigBaseModel): """Configuration for flow matching training""" @@ -271,6 +287,7 @@ class LtxvTrainerConfig(ConfigBaseModel): data: DataConfig = Field(default_factory=DataConfig) validation: ValidationConfig = Field(default_factory=ValidationConfig) checkpoints: CheckpointsConfig = Field(default_factory=CheckpointsConfig) + hub: HubConfig = Field(default_factory=HubConfig) flow_matching: FlowMatchingConfig = Field(default_factory=FlowMatchingConfig) # General configuration diff --git a/src/ltxv_trainer/hub_utils.py b/src/ltxv_trainer/hub_utils.py new file mode 100644 index 0000000..ee7412b --- /dev/null +++ b/src/ltxv_trainer/hub_utils.py @@ -0,0 +1,234 @@ +import tempfile +from pathlib import Path +from typing import List, Union + +import imageio +from huggingface_hub import HfApi, create_repo +from loguru import logger + +from ltxv_trainer.config import LtxvTrainerConfig +from ltxv_trainer.model_loader import try_parse_version +from scripts.convert_checkpoint import convert_checkpoint + + +def convert_video_to_gif(video_path: Path, output_path: Path) -> None: + """Convert a video file to GIF format.""" + try: + # Read the video file + reader = imageio.get_reader(str(video_path)) + fps = reader.get_meta_data()["fps"] + + # Write GIF file with infinite loop + writer = imageio.get_writer( + str(output_path), + fps=min(fps, 15), # Cap FPS at 15 for reasonable file size + loop=0, # 0 means infinite loop + ) + + for frame in reader: + writer.append_data(frame) + + writer.close() + reader.close() + except Exception as e: + logger.error(f"Failed to convert video to GIF: {e}") + return None + + +def create_model_card( + output_dir: Union[str, Path], + videos: List[Path], + config: LtxvTrainerConfig, +) -> Path: + """Generate and save a model card for the trained model.""" + + repo_id = config.hub.hub_model_id + pretrained_model_name_or_path = config.model.model_source + validation_prompts = config.validation.prompts + output_dir = Path(output_dir) + template_path = Path(__file__).parent.parent.parent / "templates" / "model_card.md" + + if not template_path.exists(): + logger.warning("⚠️ Model card template not found, using default template") + return + + # Read the template + template = template_path.read_text() + + # Get model name from repo_id + model_name = repo_id.split("/")[-1] + + # Get base model information + version = try_parse_version(pretrained_model_name_or_path) + if version: + base_model_link = version.safetensors_url + base_model_name = str(version) + else: + base_model_link = f"https://huggingface.co/{pretrained_model_name_or_path}" + base_model_name = pretrained_model_name_or_path + + # Format validation prompts and create grid layout + prompts_text = "" + sample_grid = [] + + if validation_prompts and videos: + prompts_text = "Example prompts used during validation:\n\n" + + # Create samples directory + samples_dir = output_dir / "samples" + samples_dir.mkdir(exist_ok=True, parents=True) + + # Process videos and create cells + cells = [] + for i, (prompt, video) in enumerate(zip(validation_prompts, videos, strict=False)): + if video.exists(): + # Add prompt to text section + prompts_text += f"- `{prompt}`\n" + + # Convert video to GIF + gif_path = samples_dir / f"sample_{i}.gif" + try: + convert_video_to_gif(video, gif_path) + + # Create grid cell with collapsible description + cell = ( + f"![example{i + 1}](./samples/sample_{i}.gif)" + "
" + '
' + f"Prompt" + f"{prompt}" + "
" + ) + cells.append(cell) + except Exception as e: + logger.error(f"Failed to process video {video}: {e}") + + # Calculate optimal grid dimensions + num_cells = len(cells) + if num_cells > 0: + # Aim for a roughly square grid, with max 4 columns + num_cols = min(4, num_cells) + num_rows = (num_cells + num_cols - 1) // num_cols # Ceiling division + + # Create grid rows + for row in range(num_rows): + start_idx = row * num_cols + end_idx = min(start_idx + num_cols, num_cells) + row_cells = cells[start_idx:end_idx] + # Properly format the row with table markers and exact number of cells + formatted_row = "| " + " | ".join(row_cells) + " |" + sample_grid.append(formatted_row) + + # Join grid rows with just the content, no headers needed + grid_text = "\n".join(sample_grid) if sample_grid else "" + + # Fill in the template + model_card_content = template.format( + base_model=base_model_name, + base_model_link=base_model_link, + model_name=model_name, + training_type="LoRA fine-tuning" if config.model.training_mode == "lora" else "Full model fine-tuning", + training_steps=config.optimization.steps, + learning_rate=config.optimization.learning_rate, + batch_size=config.optimization.batch_size, + validation_prompts=prompts_text, + sample_grid=grid_text, + ) + + # Save the model card directly + model_card_path = output_dir / "README.md" + model_card_path.write_text(model_card_content) + + return model_card_path + + +def push_to_hub(weights_path: Path, sampled_videos_paths: List[Path], config: LtxvTrainerConfig) -> None: + """Push the trained LoRA weights to HuggingFace Hub.""" + if not config.hub.push_to_hub: + return + + if not config.hub.hub_model_id: + logger.warning("⚠️ HuggingFace hub_model_id not specified, skipping push to hub") + return + + api = HfApi() + + # Try to create repo if it doesn't exist + try: + create_repo( + repo_id=config.hub.hub_model_id, + repo_type="model", + exist_ok=True, # Don't raise error if repo exists + ) + except Exception as e: + logger.error(f"❌ Failed to create repository: {e}") + return + + # Upload the original weights file + try: + api.upload_file( + path_or_fileobj=str(weights_path), + path_in_repo=weights_path.name, + repo_id=config.hub.hub_model_id, + repo_type="model", + ) + except Exception as e: + logger.error(f"❌ Failed to push {weights_path.name} to HuggingFace Hub: {e}") + # Create a temporary directory for the files we want to upload + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + try: + # Save model card and copy videos to temp directory + create_model_card( + output_dir=temp_path, + videos=sampled_videos_paths, + config=config, + ) + + # Upload the model card and samples directory + api.upload_folder( + folder_path=str(temp_path), # Convert to string for compatibility + repo_id=config.hub.hub_model_id, + repo_type="model", + ) + + logger.info(f"✅ Successfully uploaded model card and sample videos to {config.hub.hub_model_id}") + except Exception as e: + logger.error(f"❌ Failed to save/upload model card and videos: {e}") + + logger.info(f"✅ Successfully pushed original LoRA weights to {config.hub.hub_model_id}") + + # Convert and upload ComfyUI version + try: + # Create a temporary directory for the converted file + with tempfile.TemporaryDirectory() as temp_dir: + # Convert the weights to ComfyUI format + comfy_path = Path(temp_dir) / f"{weights_path.stem}_comfy{weights_path.suffix}" + + convert_checkpoint( + input_path=str(weights_path), + to_comfy=True, + output_path=str(comfy_path), + ) + + # Find the converted file + converted_files = list(Path(temp_dir).glob("*.safetensors")) + if not converted_files: + logger.warning("⚠️ No converted ComfyUI weights found") + return + + converted_file = converted_files[0] + comfy_filename = f"comfyui_{weights_path.name}" + + # Upload the converted file + api.upload_file( + path_or_fileobj=str(converted_file), + path_in_repo=comfy_filename, + repo_id=config.hub.hub_model_id, + repo_type="model", + ) + logger.info(f"✅ Successfully pushed ComfyUI LoRA weights to {config.hub.hub_model_id}") + + except Exception as e: + logger.error(f"❌ Failed to convert and push ComfyUI version: {e}") diff --git a/src/ltxv_trainer/model_loader.py b/src/ltxv_trainer/model_loader.py index 15aacb0..6c49e9d 100644 --- a/src/ltxv_trainer/model_loader.py +++ b/src/ltxv_trainer/model_loader.py @@ -160,7 +160,7 @@ def load_vae( """ if isinstance(source, str): # noqa: SIM102 # Try to parse as version first - if version := _try_parse_version(source): + if version := try_parse_version(source): source = version if isinstance(source, LtxvModelVersion): @@ -217,7 +217,7 @@ def load_transformer( """ if isinstance(source, str): # noqa: SIM102 # Try to parse as version first - if version := _try_parse_version(source): + if version := try_parse_version(source): source = version if isinstance(source, LtxvModelVersion): @@ -285,7 +285,7 @@ def load_ltxv_components( ) -def _try_parse_version(source: str | Path) -> LtxvModelVersion | None: +def try_parse_version(source: str | Path) -> LtxvModelVersion | None: """ Try to parse a string as an LtxvModelVersion. diff --git a/src/ltxv_trainer/trainer.py b/src/ltxv_trainer/trainer.py index d4f08f6..bb10a16 100644 --- a/src/ltxv_trainer/trainer.py +++ b/src/ltxv_trainer/trainer.py @@ -46,6 +46,7 @@ from ltxv_trainer.config import LtxvTrainerConfig from ltxv_trainer.datasets import PrecomputedDataset +from ltxv_trainer.hub_utils import push_to_hub from ltxv_trainer.model_loader import load_ltxv_components from ltxv_trainer.quantization import quantize_model from ltxv_trainer.timestep_samplers import SAMPLERS @@ -155,6 +156,8 @@ def train( # noqa: PLR0912, PLR0915 # Track when actual training starts (after compilation) actual_training_start = None + sampled_videos_paths = None + with Live(Panel(Group(train_progress, sample_progress)), refresh_per_second=2): task = train_progress.add_task( "Training", @@ -165,7 +168,7 @@ def train( # noqa: PLR0912, PLR0915 ) if cfg.validation.interval: - self._sample_videos(sample_progress) + sampled_videos_paths = self._sample_videos(sample_progress) for step in range(cfg.optimization.steps): # Get next batch, reset the dataloader if needed @@ -202,7 +205,6 @@ def train( # noqa: PLR0912, PLR0915 if self._lr_scheduler is not None: self._lr_scheduler.step() - # Run validation if needed if ( cfg.validation.interval @@ -291,6 +293,10 @@ def train( # noqa: PLR0912, PLR0915 if self._accelerator.is_main_process: saved_path = self._save_checkpoint() + # Upload artifacts to hub if enabled + if cfg.hub.push_to_hub: + push_to_hub(saved_path, sampled_videos_paths, self._config) + # Log the training statistics self._log_training_stats(stats) diff --git a/templates/model_card.md b/templates/model_card.md new file mode 100644 index 0000000..b70682e --- /dev/null +++ b/templates/model_card.md @@ -0,0 +1,42 @@ +# {model_name} + +This is a fine-tuned version of [`{base_model}`]({base_model_link}) trained on custom data. + +## Model Details + +- **Base Model:** [`{base_model}`]({base_model_link}) +- **Training Type:** {training_type} +- **Training Steps:** {training_steps} +- **Learning Rate:** {learning_rate} +- **Batch Size:** {batch_size} + +## Sample Outputs + +| | | | | +|:---:|:---:|:---:|:---:| +{sample_grid} + +## Usage + +This model is designed to be used with the LTXV (Lightricks Text-to-Video) pipeline. + +### 🔌 Using Trained LoRAs in ComfyUI +In order to use the trained lora in comfy: +1. Copy your comfyui trained LoRA weights (`comfyui..safetensors` file) to the `models/loras` folder in your ComfyUI installation. +2. In your ComfyUI workflow: + - Add the "LTXV LoRA Selector" node to choose your LoRA file + - Connect it to the "LTXV LoRA Loader" node to apply the LoRA to your generation + +You can find reference Text-to-Video (T2V) and Image-to-Video (I2V) workflows in the [official LTXV ComfyUI repository](https://github.com/Lightricks/ComfyUI-LTXVideo). + +### Example Prompts + +{validation_prompts} + + +This model inherits the license of the base model ([`{base_model}`]({base_model_link})). + +## Acknowledgments + +- Base model by [Lightricks](https://huggingface.co/Lightricks) +- Training infrastructure: [LTX-Video-Trainer](https://github.com/Lightricks/ltx-video-trainer)