From 45a2975769420fcca3b0d4a373d8ee9e9255743d Mon Sep 17 00:00:00 2001 From: LittleNyima Date: Wed, 14 Jan 2026 18:54:39 +0800 Subject: [PATCH] Bugfix: fix memory allocation issues during multi-GPU training --- packages/ltx-trainer/src/ltx_trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/ltx-trainer/src/ltx_trainer/trainer.py b/packages/ltx-trainer/src/ltx_trainer/trainer.py index d7a7a594..62a3b1fc 100644 --- a/packages/ltx-trainer/src/ltx_trainer/trainer.py +++ b/packages/ltx-trainer/src/ltx_trainer/trainer.py @@ -79,9 +79,9 @@ def __init__(self, trainer_config: LtxTrainerConfig) -> None: if IS_MAIN_PROCESS: print_config(trainer_config) self._training_strategy = get_training_strategy(self._config.training_strategy) + self._setup_accelerator() self._cached_validation_embeddings = self._load_text_encoder_and_cache_embeddings() self._load_models() - self._setup_accelerator() self._collect_trainable_params() self._load_checkpoint() self._prepare_models_for_training() @@ -351,7 +351,7 @@ def _load_text_encoder_and_cache_embeddings(self) -> list[CachedPromptEmbeddings self._text_encoder = load_text_encoder( checkpoint_path=self._config.model.model_path, gemma_model_path=self._config.model.text_encoder_path, - device="cuda", + device=self._accelerator.device, dtype=torch.bfloat16, )