diff --git a/packages/ltx-trainer/scripts/process_captions.py b/packages/ltx-trainer/scripts/process_captions.py index 49cab23..0af9a47 100755 --- a/packages/ltx-trainer/scripts/process_captions.py +++ b/packages/ltx-trainer/scripts/process_captions.py @@ -13,6 +13,7 @@ import json import os +import gc from pathlib import Path from typing import Any @@ -332,6 +333,11 @@ def compute_captions_embeddings( # noqa: PLR0913 logger.info(f"Processed {len(dataset):,} captions. Embeddings saved to {output_path}") + if device.startswith("cuda"): + del text_encoder + gc.collect() + torch.cuda.empty_cache() + @app.command() def main( # noqa: PLR0913