From 19a444db131b6398fb6d71498c1c76e7fb5052b4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Feb 2026 11:09:00 -0500 Subject: [PATCH 1/5] refactor _inner_training_loop to smaller methods --- src/transformers/trainer.py | 486 +++++++++++++++++------------- src/transformers/trainer_utils.py | 1 + 2 files changed, 281 insertions(+), 206 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 55418a68f718..4cebeaed0c31 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1453,6 +1453,44 @@ def _inner_training_loop( max_steps, ) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size) + model, train_dataloader = self._setup_training(args, max_steps, resume_from_checkpoint, train_dataloader) + + epochs_trained, steps_trained_in_current_epoch, start_time = self._init_loop_state( + args=args, + model=model, + num_update_steps_per_epoch=num_update_steps_per_epoch, + num_train_epochs=num_train_epochs, + max_steps=max_steps, + total_train_batch_size=total_train_batch_size, + num_examples=num_examples, + len_dataloader=len_dataloader, + train_dataloader=train_dataloader, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) + + for epoch in range(epochs_trained, num_train_epochs): + self._run_epoch( + model=model, + epoch=epoch, + train_dataloader=train_dataloader, + len_dataloader=len_dataloader, + args=args, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + start_time=start_time, + resume_from_checkpoint=resume_from_checkpoint, + epochs_trained=epochs_trained, + steps_trained_in_current_epoch=steps_trained_in_current_epoch, + ) + if self.control.should_training_stop: + break + + return self._finalize_training(model, trial, num_train_samples, start_time) + + def _setup_training(self, args, max_steps, resume_from_checkpoint, train_dataloader): + """Create optimizer, wrap model, load checkpoint. Returns (wrapped_model, train_dataloader).""" if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if self.args.n_gpu > 1: # nn.DataParallel(model) replicates the model, creating new variables and module @@ -1486,7 +1524,6 @@ def _inner_training_loop( cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) ] ) - self.state.is_hyper_param_search = trial is not None self.state.train_batch_size = self._train_batch_size # Compute absolute values for logging, eval, and save if given as ratio @@ -1574,6 +1611,26 @@ def _inner_training_loop( # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. + return model, train_dataloader + + def _init_loop_state( + self, + args, + model, + num_update_steps_per_epoch, + num_train_epochs, + max_steps, + total_train_batch_size, + num_examples, + len_dataloader, + train_dataloader, + resume_from_checkpoint, + trial, + ignore_keys_for_eval, + ): + """Initialize training loop state. Returns (epochs_trained, steps_trained_in_current_epoch, start_time).""" + self.state.is_hyper_param_search = trial is not None + # Train! logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") @@ -1623,240 +1680,257 @@ def _inner_training_loop( self.state.init_training_references(self, max_steps, num_train_epochs, trial) # tr_loss is a tensor to avoid synchronization of TPUs through .item() - tr_loss = torch.tensor(0.0, device=args.device) + self._tr_loss = torch.tensor(0.0, device=args.device) # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step model.zero_grad() - grad_norm: float | None = None - learning_rate = None + self._grad_norm: float | None = None + self._learning_rate = None self.control = self.callback_handler.on_train_begin(args, self.state, self.control) if args.eval_on_start: self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) - for epoch in range(epochs_trained, num_train_epochs): - epoch_dataloader = train_dataloader + return epochs_trained, steps_trained_in_current_epoch, start_time - steps_in_epoch = ( - len(epoch_dataloader) - if len_dataloader is not None - else args.max_steps * args.gradient_accumulation_steps - ) - self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) - - step = -1 - rng_to_sync = False - - # Handle resumption from checkpoint - if epoch == epochs_trained and resume_from_checkpoint is not None: - if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip: - epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) - step = steps_trained_in_current_epoch - 1 - rng_to_sync = True - elif steps_trained_in_current_epoch == 0: - self._load_rng_state(resume_from_checkpoint) + def _run_epoch( + self, + model, + epoch, + train_dataloader, + len_dataloader, + args, + trial, + ignore_keys_for_eval, + start_time, + resume_from_checkpoint, + epochs_trained, + steps_trained_in_current_epoch, + ): + """Run one full pass over the dataloader.""" + epoch_dataloader = train_dataloader - if hasattr(epoch_dataloader, "set_epoch"): - epoch_dataloader.set_epoch(epoch) - - epoch_iterator = iter(epoch_dataloader) - # We chunkify the epoch iterator into gradient accumulation steps `n` batches - remainder = steps_in_epoch % args.gradient_accumulation_steps - if remainder == 0: - remainder = args.gradient_accumulation_steps - update_step = -1 - total_updates = steps_in_epoch // args.gradient_accumulation_steps + int( - remainder < args.gradient_accumulation_steps - ) - for _ in range(total_updates): - update_step += 1 - num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder - batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device) - # Store the number of batches for current gradient accumulation - # This is used to correctly scale the loss when the last accumulation step has fewer batches - self.current_gradient_accumulation_steps = len(batch_samples) - for i, inputs in enumerate(batch_samples): - step += 1 - do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch - # Since we perform prefetching, we need to manually set sync_gradients - self.accelerator.gradient_state._set_sync_gradients(do_sync_step) - - if self.args.include_num_input_tokens_seen != "no": - main_input_name = getattr(self.model, "main_input_name", "input_ids") - if main_input_name not in inputs: - logger.warning( - "Tried to track the number of tokens seen, however the current model is " - "not configured properly to know what item is the input. To fix this, add " - "a `main_input_name` attribute to the model class you are using." - ) - else: - if self.args.include_num_input_tokens_seen == "non_padding": - if "attention_mask" in inputs: - input_tokens = inputs["attention_mask"].sum() - elif ( - self.processing_class is not None - and hasattr(self.processing_class, "pad_token_id") - and self.processing_class.pad_token_id is not None - ): - input_tokens = ( - inputs[main_input_name] != self.processing_class.pad_token_id - ).sum() - else: - logger.warning( - "Could not determine method to count non-padding tokens, falling back to counting all tokens." - ) - input_tokens = inputs[main_input_name].numel() + steps_in_epoch = ( + len(epoch_dataloader) + if len_dataloader is not None + else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + step = -1 + rng_to_sync = False + + # Handle resumption from checkpoint + if epoch == epochs_trained and resume_from_checkpoint is not None: + if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip: + epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) + step = steps_trained_in_current_epoch - 1 + rng_to_sync = True + elif steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + + if hasattr(epoch_dataloader, "set_epoch"): + epoch_dataloader.set_epoch(epoch) + + epoch_iterator = iter(epoch_dataloader) + # We chunkify the epoch iterator into gradient accumulation steps `n` batches + remainder = steps_in_epoch % args.gradient_accumulation_steps + if remainder == 0: + remainder = args.gradient_accumulation_steps + update_step = -1 + total_updates = steps_in_epoch // args.gradient_accumulation_steps + int( + remainder < args.gradient_accumulation_steps + ) + for _ in range(total_updates): + update_step += 1 + num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder + batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device) + # Store the number of batches for current gradient accumulation + # This is used to correctly scale the loss when the last accumulation step has fewer batches + self.current_gradient_accumulation_steps = len(batch_samples) + for i, inputs in enumerate(batch_samples): + step += 1 + do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch + # Since we perform prefetching, we need to manually set sync_gradients + self.accelerator.gradient_state._set_sync_gradients(do_sync_step) + + if self.args.include_num_input_tokens_seen != "no": + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." + ) + else: + if self.args.include_num_input_tokens_seen == "non_padding": + if "attention_mask" in inputs: + input_tokens = inputs["attention_mask"].sum() + elif ( + self.processing_class is not None + and hasattr(self.processing_class, "pad_token_id") + and self.processing_class.pad_token_id is not None + ): + input_tokens = ( + inputs[main_input_name] != self.processing_class.pad_token_id + ).sum() else: + logger.warning( + "Could not determine method to count non-padding tokens, falling back to counting all tokens." + ) input_tokens = inputs[main_input_name].numel() + else: + input_tokens = inputs[main_input_name].numel() - input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) - self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() - - if rng_to_sync: - self._load_rng_state(resume_from_checkpoint) - rng_to_sync = False - - if step % args.gradient_accumulation_steps == 0: - self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) + self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() - # We sync the gradients in the following cases: 1. sync_each_batch set to True 2. Using deepspeed 3. when we are at the last batch sample - if ( - self.accelerator.gradient_state.plugin_kwargs.get("sync_each_batch", False) - or self.accelerator.distributed_type == DistributedType.DEEPSPEED - or i == len(batch_samples) - 1 - ): - sync_context = contextlib.nullcontext - else: - sync_context = functools.partial(self.accelerator.no_sync, model=model) - with sync_context(): - tr_loss_step = self.training_step(model, inputs, num_items_in_batch) - - if ( - args.logging_nan_inf_filter - and not is_torch_xla_available() - and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) - ): - # if loss is nan or inf simply add the average of previous logged losses - tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) - else: - if tr_loss.device != tr_loss_step.device: - raise ValueError( - f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" - ) - tr_loss = tr_loss + tr_loss_step + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + # We sync the gradients in the following cases: 1. sync_each_batch set to True 2. Using deepspeed 3. when we are at the last batch sample + if ( + self.accelerator.gradient_state.plugin_kwargs.get("sync_each_batch", False) + or self.accelerator.distributed_type == DistributedType.DEEPSPEED + or i == len(batch_samples) - 1 + ): + sync_context = contextlib.nullcontext + else: + sync_context = functools.partial(self.accelerator.no_sync, model=model) + with sync_context(): + tr_loss_step = self.training_step(model, inputs, num_items_in_batch) + + if ( + args.logging_nan_inf_filter + and not is_torch_xla_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + self._tr_loss += self._tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + if self._tr_loss.device != tr_loss_step.device: + raise ValueError( + f"Calculated loss must be on the original device: {self._tr_loss.device} but device in use is {tr_loss_step.device}" + ) + self._tr_loss += tr_loss_step - self.current_flos += float(self.floating_point_ops(inputs)) + self.current_flos += float(self.floating_point_ops(inputs)) - if do_sync_step: - # Since we perform prefetching, we need to manually set sync_gradients to True - self.accelerator.gradient_state._set_sync_gradients(True) + if do_sync_step: + # Since we perform prefetching, we need to manually set sync_gradients to True + self.accelerator.gradient_state._set_sync_gradients(True) - # Gradient clipping - if args.max_grad_norm is not None and args.max_grad_norm > 0: - if is_sagemaker_mp_enabled() and args.fp16: - _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) - else: - grad_norm_context = contextlib.nullcontext - if self.is_tp_enabled: - from torch.distributed._tensor.experimental import implicit_replication - - grad_norm_context = implicit_replication - with grad_norm_context(): - _grad_norm = self.accelerator.clip_grad_norm_( - model.parameters(), - args.max_grad_norm, - ) + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + if is_sagemaker_mp_enabled() and args.fp16: + _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) + else: + grad_norm_context = contextlib.nullcontext + if self.is_tp_enabled: + from torch.distributed._tensor.experimental import implicit_replication + + grad_norm_context = implicit_replication + with grad_norm_context(): + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) + + if self.accelerator.distributed_type == DistributedType.DEEPSPEED: + self._grad_norm = model.get_global_grad_norm() + # In some cases the grad norm may not return a float + if hasattr(self._grad_norm, "item"): + self._grad_norm = self._grad_norm.item() + else: + self._grad_norm = _grad_norm + + self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) + + context = contextlib.nullcontext + if self.is_tp_enabled: + from torch.distributed._tensor.experimental import implicit_replication + + context = implicit_replication + + with context(): + self.optimizer.step() + + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + + # get learning rate before update + self._learning_rate = self._get_learning_rate() + + if not self.accelerator.optimizer_step_was_skipped: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + self._maybe_log_save_evaluate( + self._tr_loss, + self._grad_norm, + model, + trial, + epoch, + ignore_keys_for_eval, + start_time, + learning_rate=self._learning_rate, + ) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) - if self.accelerator.distributed_type == DistributedType.DEEPSPEED: - grad_norm = model.get_global_grad_norm() - # In some cases the grad norm may not return a float - if hasattr(grad_norm, "item"): - grad_norm = grad_norm.item() - else: - grad_norm = _grad_norm - - self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) - - context = contextlib.nullcontext - if self.is_tp_enabled: - from torch.distributed._tensor.experimental import implicit_replication - - context = implicit_replication - - with context(): - self.optimizer.step() - - self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) - - # get leaning rate before update - learning_rate = self._get_learning_rate() - - if not self.accelerator.optimizer_step_was_skipped: - # Delay optimizer scheduling until metrics are generated - if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - self.lr_scheduler.step() - - model.zero_grad() - self.state.global_step += 1 - self.state.epoch = epoch + (step + 1) / steps_in_epoch - self.control = self.callback_handler.on_step_end(args, self.state, self.control) - self._maybe_log_save_evaluate( - tr_loss, - grad_norm, - model, - trial, - epoch, - ignore_keys_for_eval, - start_time, - learning_rate=learning_rate, - ) - else: - self.control = self.callback_handler.on_substep_end(args, self.state, self.control) - - # PyTorch/XLA relies on the data loader to insert the mark_step for - # each step. Since we are breaking the loop early, we need to manually - # insert the mark_step here. - if self.control.should_epoch_stop or self.control.should_training_stop: - if is_torch_xla_available(): - xm.mark_step() - break - # We also need to break out of the nested loop + # PyTorch/XLA relies on the data loader to insert the mark_step for + # each step. Since we are breaking the loop early, we need to manually + # insert the mark_step here. if self.control.should_epoch_stop or self.control.should_training_stop: if is_torch_xla_available(): xm.mark_step() break - if step < 0: - logger.warning( - "There seems not to be a single sample in your epoch_iterator, stopping training at step" - f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" - f" num_steps ({max_steps}) higher than the number of available samples." - ) - self.control.should_training_stop = True - - self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) - self._maybe_log_save_evaluate( - tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate - ) - - if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + # We also need to break out of the nested loop + if self.control.should_epoch_stop or self.control.should_training_stop: if is_torch_xla_available(): - # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) - xm.master_print(met.metrics_report()) - else: - logger.warning( - "You enabled PyTorch/XLA debug metrics but you don't have a TPU " - "configured. Check your training configuration if this is unexpected." - ) - if self.control.should_training_stop: + xm.mark_step() break + if step < 0: + logger.warning( + "There seems not to be a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({self.state.max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self._maybe_log_save_evaluate( + self._tr_loss, self._grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, + learning_rate=self._learning_rate, + ) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + if is_torch_xla_available(): + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + else: + logger.warning( + "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + "configured. Check your training configuration if this is unexpected." + ) + def _finalize_training(self, model, trial, num_train_samples, start_time): + """Finalize training: metrics, best-model loading, cleanup. Returns TrainOutput.""" logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") - if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: self._load_best_model() # add remaining tr_loss - self._total_loss_scalar += tr_loss.item() + self._total_loss_scalar += self._tr_loss.item() effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError train_loss = self._total_loss_scalar / effective_global_step @@ -1888,7 +1962,7 @@ def _inner_training_loop( logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint, ignore_errors=True) - self.control = self.callback_handler.on_train_end(args, self.state, self.control) + self.control = self.callback_handler.on_train_end(self.args, self.state, self.control) # Wait for the checkpoint to be uploaded. self._finish_current_push() diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 46582e4069c8..aa8717dc8e90 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -612,6 +612,7 @@ class TrainerMemoryTracker: "__init__": "init", "train": "train", "_inner_training_loop": "train", + "_finalize_training": "train", "evaluate": "eval", "predict": "test", } From ff12b75c1ec14ffe35ff2558909454fbf9eae27b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Feb 2026 12:31:39 -0500 Subject: [PATCH 2/5] Refactor dataset/sampler/dataloader to a DataProducer --- src/transformers/data_producer.py | 356 ++++++++++++++++++++++ src/transformers/trainer.py | 481 ++++++++++++++++++++++++------ 2 files changed, 740 insertions(+), 97 deletions(-) create mode 100644 src/transformers/data_producer.py diff --git a/src/transformers/data_producer.py b/src/transformers/data_producer.py new file mode 100644 index 000000000000..835da50802ec --- /dev/null +++ b/src/transformers/data_producer.py @@ -0,0 +1,356 @@ +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +DataProducer protocol for online/async training with the HuggingFace Trainer. + +A ``DataProducer`` generates fresh training data each rollout round, +enabling online RL methods (PPO, GRPO, REINFORCE, online DPO) and +curriculum learning without any changes to the core training loop. + +Quick start:: + + from transformers import Trainer, TrainingArguments + from transformers.data_producer import BaseDataProducer, ProducerConfig, RolloutDataset + + class MyProducer(BaseDataProducer): + def __init__(self, prompts, reward_fn): + super().__init__(ProducerConfig(mini_epochs=2, max_rollouts=100)) + self.prompts = prompts + self.reward_fn = reward_fn + + def produce(self, model, global_step, **kwargs): + completions = model.generate(self.prompts, max_new_tokens=256) + rewards = self.reward_fn(completions) + return RolloutDataset(prompts=self.prompts, completions=completions, rewards=rewards) + + trainer = Trainer( + model=model, + args=TrainingArguments(output_dir="./out", max_steps=5000), + data_producer=MyProducer(prompts, reward_fn), + ) + trainer.train() +""" + +from __future__ import annotations + +import logging +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import torch +from torch.utils.data import Dataset + +from .trainer_callback import TrainerCallback + +if TYPE_CHECKING: + from torch import nn + + +logger = logging.getLogger(__name__) + + +# ====================================================================== +# Configuration +# ====================================================================== + + +@dataclass +class ProducerConfig: + """Configuration for a :class:`DataProducer`. + + Args: + mini_epochs: Number of training passes over each produced dataset. + Use ``>1`` to amortise expensive generation across multiple + gradient updates (common in PPO/GRPO). + max_rollouts: Maximum number of produce-then-train rounds. Set to + ``None`` for unlimited (training stops at ``args.max_steps``). + async_prefetch: If ``True``, the next dataset is produced in a + background thread while the current one is being trained on. + eval_during_produce: If ``True``, switch the model to eval mode + during ``produce()`` (common for generation). + empty_cache_before_produce: Call ``torch.cuda.empty_cache()`` before + ``produce()`` to free memory for generation. + empty_cache_after_produce: Call ``torch.cuda.empty_cache()`` after + ``produce()`` to free memory for training. + """ + + mini_epochs: int = 1 + max_rollouts: int | None = None + async_prefetch: bool = False + eval_during_produce: bool = True + empty_cache_before_produce: bool = False + empty_cache_after_produce: bool = False + + +# ====================================================================== +# Protocol / base classes +# ====================================================================== + + +class DataProducer(ABC): + """Abstract protocol for online data production. + + Subclasses must implement :meth:`produce` and provide a :attr:`config`. + + The Trainer calls ``produce(model, global_step, ...)`` each rollout + round to obtain a fresh ``Dataset`` for training. + """ + + config: ProducerConfig + + @abstractmethod + def produce( + self, + model: nn.Module, + global_step: int, + processing_class: Any = None, + accelerator: Any = None, + args: Any = None, + **kwargs, + ) -> Dataset: + """Generate a fresh training dataset. + + Args: + model: The current model (may be in eval mode if + ``config.eval_during_produce`` is set). + global_step: Current training step. + processing_class: Tokenizer / processor from the Trainer. + accelerator: The Accelerate accelerator. + args: TrainingArguments. + **kwargs: Reserved for future use. + + Returns: + A ``torch.utils.data.Dataset``. Prefer map-style datasets + (with ``__len__``) when ``mini_epochs > 1``. + """ + ... + + +class BaseDataProducer(DataProducer): + """Convenience base class with lifecycle hooks. + + Subclasses only need to implement :meth:`produce`. Optional hooks: + + - ``on_rollout_begin(global_step)`` — called before each ``produce()``. + - ``on_rollout_end(dataset, global_step)`` — called after each ``produce()``. + """ + + def __init__(self, config: ProducerConfig | None = None): + self.config = config or ProducerConfig() + + def on_rollout_begin(self, global_step: int) -> None: + """Hook called before ``produce()``. Override for logging/setup.""" + pass + + def on_rollout_end(self, dataset: Dataset, global_step: int) -> None: + """Hook called after ``produce()``. Override for logging/cleanup.""" + pass + + +# ====================================================================== +# Async wrapper +# ====================================================================== + + +class AsyncDataProducer(DataProducer): + """Wraps a :class:`DataProducer` for background-thread data generation. + + While the Trainer trains on the current dataset, the next dataset is + produced in a background thread. ``produce()`` blocks until the + prefetched result is ready, then kicks off the *next* prefetch. + + Usage:: + + producer = AsyncDataProducer(MyProducer(...)) + # or: set config.async_prefetch = True and the Trainer wraps it automatically + """ + + def __init__(self, inner: DataProducer): + self._inner = inner + self.config = inner.config + self._prefetch_thread: threading.Thread | None = None + self._prefetch_result: Dataset | None = None + self._prefetch_error: BaseException | None = None + + def produce( + self, + model: nn.Module, + global_step: int, + processing_class: Any = None, + accelerator: Any = None, + args: Any = None, + **kwargs, + ) -> Dataset: + # If there's a prefetched result, use it + if self._prefetch_thread is not None: + self._prefetch_thread.join() + self._prefetch_thread = None + if self._prefetch_error is not None: + raise self._prefetch_error + result = self._prefetch_result + self._prefetch_result = None + else: + # First call — produce synchronously + result = self._inner.produce( + model=model, + global_step=global_step, + processing_class=processing_class, + accelerator=accelerator, + args=args, + **kwargs, + ) + + # Start prefetching the next dataset + self._start_prefetch(model, global_step, processing_class, accelerator, args, **kwargs) + return result + + def _start_prefetch(self, model, global_step, processing_class, accelerator, args, **kwargs): + def _worker(): + try: + self._prefetch_result = self._inner.produce( + model=model, + global_step=global_step, + processing_class=processing_class, + accelerator=accelerator, + args=args, + **kwargs, + ) + except BaseException as e: + self._prefetch_error = e + + self._prefetch_error = None + self._prefetch_thread = threading.Thread(target=_worker, daemon=True) + self._prefetch_thread.start() + + # Forward lifecycle hooks + def on_rollout_begin(self, global_step: int) -> None: + if hasattr(self._inner, "on_rollout_begin"): + self._inner.on_rollout_begin(global_step=global_step) + + def on_rollout_end(self, dataset: Dataset, global_step: int) -> None: + if hasattr(self._inner, "on_rollout_end"): + self._inner.on_rollout_end(dataset=dataset, global_step=global_step) + + +# ====================================================================== +# Callback +# ====================================================================== + + +class DataProducerCallback(TrainerCallback): + """Trainer callback that forwards lifecycle events to a DataProducer. + + Automatically added by the Trainer when a ``data_producer`` is set. + """ + + def __init__(self, data_producer: DataProducer): + self.data_producer = data_producer + + def on_train_begin(self, args, state, control, **kwargs): + """Log that online training is starting.""" + logger.info("DataProducerCallback: online training started.") + return control + + def on_train_end(self, args, state, control, **kwargs): + """Clean up async resources if any.""" + if isinstance(self.data_producer, AsyncDataProducer): + if self.data_producer._prefetch_thread is not None: + self.data_producer._prefetch_thread.join(timeout=5.0) + return control + + +# ====================================================================== +# Reference datasets +# ====================================================================== + + +class RolloutDataset(Dataset): + """Simple dataset for RL rollout data (prompts, completions, rewards). + + Each item is a dict with keys ``"prompt"``, ``"completion"``, + ``"reward"``, plus any extra fields. + + Usage:: + + ds = RolloutDataset( + prompts=prompt_ids, # list[Tensor] or Tensor + completions=comp_ids, # list[Tensor] or Tensor + rewards=reward_values, # list[float] or Tensor + extras={"advantages": adv_tensor}, + ) + """ + + def __init__( + self, + prompts: list | torch.Tensor, + completions: list | torch.Tensor, + rewards: list | torch.Tensor, + extras: dict[str, Any] | None = None, + ): + self.prompts = prompts + self.completions = completions + self.rewards = rewards + self.extras = extras or {} + assert len(prompts) == len(completions) == len(rewards) + + def __len__(self) -> int: + return len(self.prompts) + + def __getitem__(self, idx: int) -> dict[str, Any]: + item = { + "prompt": self.prompts[idx], + "completion": self.completions[idx], + "reward": self.rewards[idx], + } + for k, v in self.extras.items(): + item[k] = v[idx] + return item + + +class PreferencePairDataset(Dataset): + """Simple dataset for preference pair data (prompt, chosen, rejected). + + Each item is a dict with keys ``"prompt"``, ``"chosen"``, ``"rejected"``. + + Usage:: + + ds = PreferencePairDataset( + prompts=prompt_ids, + chosen=chosen_ids, + rejected=rejected_ids, + ) + """ + + def __init__( + self, + prompts: list | torch.Tensor, + chosen: list | torch.Tensor, + rejected: list | torch.Tensor, + ): + self.prompts = prompts + self.chosen = chosen + self.rejected = rejected + assert len(prompts) == len(chosen) == len(rejected) + + def __len__(self) -> int: + return len(self.prompts) + + def __getitem__(self, idx: int) -> dict[str, Any]: + return { + "prompt": self.prompts[idx], + "chosen": self.chosen[idx], + "rejected": self.rejected[idx], + } diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4cebeaed0c31..a343694c319d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -28,7 +28,9 @@ import tempfile import time import warnings +from abc import ABC, abstractmethod from collections.abc import Callable, Iterator, Mapping +from dataclasses import dataclass, field from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any @@ -55,6 +57,7 @@ from . import __version__ from .configuration_utils import PreTrainedConfig from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator +from .data_producer import AsyncDataProducer, DataProducer, DataProducerCallback, ProducerConfig from .debug_utils import DebugOption, DebugUnderflowOverflow from .feature_extraction_sequence_utils import SequenceFeatureExtractor from .feature_extraction_utils import FeatureExtractionMixin @@ -247,6 +250,265 @@ FSDP_MODEL_NAME = "pytorch_model_fsdp" +# ====================================================================== +# Internal data structures for the unified training loop +# ====================================================================== + + +@dataclass +class _TrainingPlan: + """Everything computed *before* optimizer/model setup. + + Both epoch sources produce one of these so the rest of the pipeline + can be completely source-agnostic. + """ + + max_steps: int + num_train_epochs: int + num_train_samples: int + num_update_steps_per_epoch: int | None # None for online (no fixed epoch size) + total_train_batch_size: int + num_examples: int + len_dataloader: int | None + # The "reference" dataloader used for callback setup and SP adapter. + initial_dataloader: DataLoader + + +@dataclass +class _EpochSpec: + """What the unified loop needs for one pass over a dataloader.""" + + epoch: float + dataloader: DataLoader + len_dataloader: int | None = None + # Resume from checkpoint attributes - set after a checkpoint + resume_from_checkpoint: str | None = None + steps_to_skip: int = 0 + is_resume_epoch: bool = False + + +class _EpochSource(ABC): + """Internal abstraction that yields epoch specs for the training loop. + + Two implementations: + - ``_StaticEpochSource``: wraps a fixed ``train_dataset`` / DataLoader + - ``_OnlineEpochSource``: wraps a ``DataProducer`` + + The Trainer only ever sees the base class, so the training loop is + completely source-agnostic. + """ + + @abstractmethod + def compute_plan(self, trainer: "Trainer", args: "TrainingArguments") -> _TrainingPlan: + """Create dataloaders, compute max_steps, etc. Called once.""" + ... + + def post_model_setup(self, model: nn.Module, trainer: "Trainer") -> None: + """Hook called after ``_setup_training`` wraps the model. + + Use for operations that need the prepared model (e.g. SP adapter). + Default: no-op. + """ + pass + + @abstractmethod + def log_banner(self, plan: _TrainingPlan, args: "TrainingArguments", model: nn.Module, trainer: "Trainer") -> None: + """Print the various logging info before training starts.""" + ... + + @abstractmethod + def iter_epochs( + self, + model: nn.Module, + trainer: "Trainer", + epochs_trained: int, + steps_trained_in_current_epoch: int, + resume_from_checkpoint: str | None, + ) -> Iterator[_EpochSpec]: + """Yield one ``_EpochSpec`` per training pass over a dataloader.""" + ... + + +class _StaticEpochSource(_EpochSource): + """Wraps the existing ``train_dataset`` -> DataLoader -> Sampler pipeline. + + This is the default path when no ``data_producer`` is set. + """ + + def __init__(self): + self.train_dataloader: DataLoader | None = None + self.num_train_epochs: int = 1 + self.num_examples: int = 0 + self.len_dataloader: int | None = None + + def compute_plan(self, trainer, args): + dl = trainer.get_train_dataloader() + if trainer.is_fsdp_xla_v2_enabled: + dl = tpu_spmd_dataloader(dl) + self.train_dataloader = dl + + total_train_batch_size = trainer.get_total_train_batch_size(args) + + ( + num_train_epochs, + num_update_steps_per_epoch, + num_examples, + num_train_samples, + epoch_based, + len_dataloader, + max_steps, + ) = trainer.set_initial_training_values(args, dl, total_train_batch_size) + + self.num_train_epochs = num_train_epochs + self.num_examples = num_examples + self.len_dataloader = len_dataloader + + return _TrainingPlan( + max_steps=max_steps, + num_train_epochs=num_train_epochs, + num_train_samples=num_train_samples, + num_update_steps_per_epoch=num_update_steps_per_epoch, + total_train_batch_size=total_train_batch_size, + num_examples=num_examples, + len_dataloader=len_dataloader, + initial_dataloader=dl, + ) + + def post_model_setup(self, model, trainer): + # Apply Ulysses/SP dataloader adapter + pc = getattr(trainer.accelerator, "parallelism_config", None) + if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled: + self.train_dataloader = trainer.accelerator.deepspeed_ulysses_dl_adapter(self.train_dataloader, model) + + def log_banner(self, plan, args, model, trainer): + logger.info("***** Running training *****") + logger.info(f" Num examples = {plan.num_examples:,}") + logger.info(f" Num Epochs = {plan.num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size:,}") + if args.per_device_train_batch_size != trainer._train_batch_size: + logger.info(f" Training with DataParallel so batch size has been adjusted to: {trainer._train_batch_size:,}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {plan.total_train_batch_size:,}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {plan.max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") + + def iter_epochs(self, model, trainer, epochs_trained, steps_trained_in_current_epoch, resume_from_checkpoint): + for epoch in range(epochs_trained, self.num_train_epochs): + is_first = epoch == epochs_trained + has_resume = is_first and resume_from_checkpoint is not None + yield _EpochSpec( + epoch=epoch, + dataloader=self.train_dataloader, + len_dataloader=self.len_dataloader, + resume_from_checkpoint=resume_from_checkpoint if has_resume else None, + steps_to_skip=steps_trained_in_current_epoch if is_first else 0, + is_resume_epoch=has_resume, + ) + + +class _OnlineEpochSource(_EpochSource): + """Wraps a ``DataProducer`` to yield produce->train->produce->train epochs. + + Each rollout round calls ``DataProducer.produce(model, step)`` to + generate a fresh dataset, wraps it in a DataLoader using the Trainer's + existing collator/sampler infrastructure, then yields ``mini_epochs`` + passes over that data. + """ + + def __init__(self, data_producer: DataProducer): + self.data_producer = data_producer + self.config = data_producer.config + self.initial_dataset: Dataset | None = None + self.initial_dataloader: DataLoader | None = None + + def compute_plan(self, trainer, args): + total_train_batch_size = trainer.get_total_train_batch_size(args) + + # Produce initial dataset to size the training plan + self.initial_dataset = trainer._produce_data(trainer.model) + + # IterableDataset + mini_epochs > 1 is almost certainly a bug + if isinstance(self.initial_dataset, IterableDataset) and self.config.mini_epochs > 1: + logger.warning( + "DataProducer returned an IterableDataset with " + f"mini_epochs={self.config.mini_epochs}. Each mini-epoch " + "will see different data because IterableDataset creates " + "a fresh iterator each pass. Use a map-style Dataset " + "(e.g. RolloutDataset) if you want multiple passes over " + "the same rollout data." + ) + + self.initial_dataloader = trainer._get_online_dataloader(self.initial_dataset) + len_dl = len(self.initial_dataloader) if has_length(self.initial_dataloader) else None + + # Compute max_steps + if args.max_steps > 0: + max_steps = args.max_steps + elif self.config.max_rollouts is not None and len_dl is not None: + steps_per_mini_epoch = max(len_dl // args.gradient_accumulation_steps, 1) + max_steps = steps_per_mini_epoch * self.config.mini_epochs * self.config.max_rollouts + else: + raise ValueError( + "When using a DataProducer, you must set either " + "`args.max_steps` or `producer_config.max_rollouts` " + "(and use a map-style Dataset so its length is known)." + ) + + return _TrainingPlan( + max_steps=max_steps, + num_train_epochs=max_steps, # not epoch-based + num_train_samples=max_steps * total_train_batch_size, + num_update_steps_per_epoch=None, + total_train_batch_size=total_train_batch_size, + num_examples=len(self.initial_dataset) if has_length(self.initial_dataset) else 0, + len_dataloader=len_dl, + initial_dataloader=self.initial_dataloader, + ) + + def log_banner(self, plan, args, model, trainer): + logger.info("***** Running online training *****") + logger.info(f" Mini-epochs per rollout = {self.config.mini_epochs}") + logger.info(f" Max rollouts = {self.config.max_rollouts or 'unlimited (use max_steps)'}") + logger.info(f" Async prefetch = {self.config.async_prefetch}") + logger.info(f" Total train batch size = {plan.total_train_batch_size:,}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {plan.max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") + + def iter_epochs(self, model, trainer, epochs_trained, steps_trained_in_current_epoch, resume_from_checkpoint): + rollout = 0 + current_dataset = self.initial_dataset + + while True: + # Check stopping before each rollout + if trainer.state.global_step >= trainer.state.max_steps: + return + if self.config.max_rollouts is not None and rollout >= self.config.max_rollouts: + return + if trainer.control.should_training_stop: + return + + # Produce (first round already produced in compute_plan) + if rollout > 0: + current_dataset = trainer._produce_data(model) + dl = trainer._get_online_dataloader(current_dataset) + len_dl = len(dl) if has_length(dl) else None + + # Yield mini_epochs passes over this rollout's data + for mini in range(self.config.mini_epochs): + if trainer.state.global_step >= trainer.state.max_steps: + return + if trainer.control.should_training_stop: + return + yield _EpochSpec( + epoch=rollout + mini / self.config.mini_epochs, + dataloader=dl, + len_dataloader=len_dl, + ) + + rollout += 1 + + @requires( backends=( "torch", @@ -379,6 +641,7 @@ def __init__( optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None, preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + data_producer: DataProducer | None = None, ): # Init flow: # 1. Args & seed – defaults, determinism @@ -539,6 +802,16 @@ def __init__( self.processing_class = processing_class self.neftune_noise_alpha = args.neftune_noise_alpha + # Online / async RL support + self.data_producer = data_producer + if data_producer is not None: + if not isinstance(data_producer, DataProducer): + raise TypeError( + f"`data_producer` must be a DataProducer instance, got {type(data_producer)}" + ) + if data_producer.config.async_prefetch: + self.data_producer = AsyncDataProducer(data_producer) + # Callables self.compute_loss_func = compute_loss_func self.compute_metrics = compute_metrics @@ -565,6 +838,8 @@ def __init__( callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler ) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + if self.data_producer is not None: + self.add_callback(DataProducerCallback(self.data_producer)) # ---- 9. Hub & output --------------------------------------------------------- self.hub_model_id = None # Set by init_hf_repo() when push_to_hub is enabled @@ -1407,6 +1682,68 @@ def train( ignore_keys_for_eval=ignore_keys_for_eval, ) + def _create_epoch_source(self) -> _EpochSource: + """Return the appropriate epoch source for this training run.""" + if getattr(self, "data_producer", None) is not None: + return _OnlineEpochSource(self.data_producer) + return _StaticEpochSource() + + @torch.no_grad() + def _produce_data(self, model: nn.Module) -> Dataset: + """Call the DataProducer to generate a fresh training dataset. + + Handles eval/train mode switching and CUDA cache clearing per + the producer's config. + """ + producer = self.data_producer + config = producer.config + + if hasattr(producer, "on_rollout_begin"): + producer.on_rollout_begin(global_step=self.state.global_step) + + if config.empty_cache_before_produce and torch.cuda.is_available(): + torch.cuda.empty_cache() + + was_training = model.training + if config.eval_during_produce: + model.eval() + + dataset = producer.produce( + model=model, + global_step=self.state.global_step, + processing_class=self.processing_class, + accelerator=self.accelerator, + args=self.args, + ) + + if config.eval_during_produce and was_training: + model.train() + + if config.empty_cache_after_produce and torch.cuda.is_available(): + torch.cuda.empty_cache() + + if hasattr(producer, "on_rollout_end"): + producer.on_rollout_end( + dataset=dataset, + global_step=self.state.global_step, + ) + + return dataset + + def _get_online_dataloader(self, dataset: Dataset) -> DataLoader: + """Create a DataLoader for a produced dataset. + + Reuses the Trainer's existing collator, sampler, and accelerator + preparation. + """ + return self._get_dataloader( + dataset=dataset, + description="OnlineTraining", + batch_size=self._train_batch_size, + sampler_fn=self._get_train_sampler, + is_training=True, + ) + def _inner_training_loop( self, batch_size: int | None = None, @@ -1432,65 +1769,43 @@ def _inner_training_loop( self.args.per_device_train_batch_size = original_bs self.state.train_batch_size = self._train_batch_size logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") - # Data loader and number of training steps - train_dataloader = self.get_train_dataloader() - if self.is_fsdp_xla_v2_enabled: - train_dataloader = tpu_spmd_dataloader(train_dataloader) - # Setting up training control variables: - # number of training epochs: num_train_epochs - # number of training steps per epoch: num_update_steps_per_epoch - # total number of training steps to execute: max_steps - total_train_batch_size = self.get_total_train_batch_size(args) + # Compute training plan (source-specific: static dataset or DataProducer) + source = self._create_epoch_source() + plan = source.compute_plan(self, args) - ( - num_train_epochs, - num_update_steps_per_epoch, - num_examples, - num_train_samples, - epoch_based, - len_dataloader, - max_steps, - ) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size) + # Setup optimizer, model, checkpoint (shared) + model = self._setup_training(args, plan.max_steps, resume_from_checkpoint) + + # Post-model-setup hook (e.g. SP adapter for static source) + source.post_model_setup(model, self) - model, train_dataloader = self._setup_training(args, max_steps, resume_from_checkpoint, train_dataloader) + # Logging banner (source-specific) + source.log_banner(plan, args, model, self) + # Initialize loop state (shared) epochs_trained, steps_trained_in_current_epoch, start_time = self._init_loop_state( args=args, model=model, - num_update_steps_per_epoch=num_update_steps_per_epoch, - num_train_epochs=num_train_epochs, - max_steps=max_steps, - total_train_batch_size=total_train_batch_size, - num_examples=num_examples, - len_dataloader=len_dataloader, - train_dataloader=train_dataloader, + plan=plan, + train_dataloader=plan.initial_dataloader, resume_from_checkpoint=resume_from_checkpoint, trial=trial, ignore_keys_for_eval=ignore_keys_for_eval, ) - for epoch in range(epochs_trained, num_train_epochs): - self._run_epoch( - model=model, - epoch=epoch, - train_dataloader=train_dataloader, - len_dataloader=len_dataloader, - args=args, - trial=trial, - ignore_keys_for_eval=ignore_keys_for_eval, - start_time=start_time, - resume_from_checkpoint=resume_from_checkpoint, - epochs_trained=epochs_trained, - steps_trained_in_current_epoch=steps_trained_in_current_epoch, - ) + # === THE UNIFIED LOOP === + for spec in source.iter_epochs( + model, self, epochs_trained, steps_trained_in_current_epoch, resume_from_checkpoint, + ): + self._run_epoch(model, spec, args, trial, ignore_keys_for_eval, start_time) if self.control.should_training_stop: break - return self._finalize_training(model, trial, num_train_samples, start_time) + return self._finalize_training(model, trial, plan.num_train_samples, start_time) - def _setup_training(self, args, max_steps, resume_from_checkpoint, train_dataloader): - """Create optimizer, wrap model, load checkpoint. Returns (wrapped_model, train_dataloader).""" + def _setup_training(self, args, max_steps, resume_from_checkpoint): + """Create optimizer, wrap model, load checkpoint. Returns the wrapped model.""" if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if self.args.n_gpu > 1: # nn.DataParallel(model) replicates the model, creating new variables and module @@ -1574,11 +1889,6 @@ def _setup_training(self, args, max_steps, resume_from_checkpoint, train_dataloa # Create scheduler now that the optimizer won't change anymore self.create_scheduler(num_training_steps=max_steps) - # since DataLoader was Accelerate prepared w/o a model arg in the same call, we now have to complete the DL wrapping for ALST/UlyssesSP, after model has been prepared - pc = getattr(self.accelerator, "parallelism_config", None) - if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled: - train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model) - if self.is_fsdp_enabled: self.model = self.model_wrapped = model # Fix `got mixed torch.Tensor and DTensor` error in model.generate() for FSDP2 with LoRA @@ -1611,18 +1921,13 @@ def _setup_training(self, args, max_steps, resume_from_checkpoint, train_dataloa # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. - return model, train_dataloader + return model def _init_loop_state( self, args, model, - num_update_steps_per_epoch, - num_train_epochs, - max_steps, - total_train_batch_size, - num_examples, - len_dataloader, + plan: _TrainingPlan, train_dataloader, resume_from_checkpoint, trial, @@ -1631,18 +1936,6 @@ def _init_loop_state( """Initialize training loop state. Returns (epochs_trained, steps_trained_in_current_epoch, start_time).""" self.state.is_hyper_param_search = trial is not None - # Train! - logger.info("***** Running training *****") - logger.info(f" Num examples = {num_examples:,}") - logger.info(f" Num Epochs = {num_train_epochs:,}") - logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") - if self.args.per_device_train_batch_size != self._train_batch_size: - logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {max_steps:,}") - logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") - self.state.epoch = 0 start_time = time.time() self.initial_num_input_tokens_seen_for_session = self.state.num_input_tokens_seen @@ -1656,17 +1949,22 @@ def _init_loop_state( self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) compare_trainer_and_checkpoint_args(self.args, self.state) self._load_callback_state() - epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) - if not args.ignore_data_skip: - steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) - steps_trained_in_current_epoch *= args.gradient_accumulation_steps + if plan.num_update_steps_per_epoch is not None: + epochs_trained = int(self.state.global_step // plan.num_update_steps_per_epoch) + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % plan.num_update_steps_per_epoch + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 else: + # Online path: no fixed epoch size, so no epoch/step skipping + epochs_trained = 0 steps_trained_in_current_epoch = 0 logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from global step {self.state.global_step}") - if not args.ignore_data_skip: + if not args.ignore_data_skip and plan.num_update_steps_per_epoch is not None: logger.info( f" Will skip the first {epochs_trained} epochs then the first" f" {steps_trained_in_current_epoch} batches in the first epoch." @@ -1677,7 +1975,7 @@ def _init_loop_state( setattr(self.callback_handler, attr, getattr(self, attr)) self.callback_handler.train_dataloader = train_dataloader - self.state.init_training_references(self, max_steps, num_train_epochs, trial) + self.state.init_training_references(self, plan.max_steps, plan.num_train_epochs, trial) # tr_loss is a tensor to avoid synchronization of TPUs through .item() self._tr_loss = torch.tensor(0.0, device=args.device) @@ -1694,22 +1992,11 @@ def _init_loop_state( return epochs_trained, steps_trained_in_current_epoch, start_time - def _run_epoch( - self, - model, - epoch, - train_dataloader, - len_dataloader, - args, - trial, - ignore_keys_for_eval, - start_time, - resume_from_checkpoint, - epochs_trained, - steps_trained_in_current_epoch, - ): - """Run one full pass over the dataloader.""" - epoch_dataloader = train_dataloader + def _run_epoch(self, model, spec: _EpochSpec, args, trial, ignore_keys_for_eval, start_time): + """Run one full pass over a dataloader described by *spec*.""" + epoch = spec.epoch + epoch_dataloader = spec.dataloader + len_dataloader = spec.len_dataloader steps_in_epoch = ( len(epoch_dataloader) @@ -1722,16 +2009,16 @@ def _run_epoch( rng_to_sync = False # Handle resumption from checkpoint - if epoch == epochs_trained and resume_from_checkpoint is not None: - if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip: - epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) - step = steps_trained_in_current_epoch - 1 + if spec.is_resume_epoch and spec.resume_from_checkpoint is not None: + if spec.steps_to_skip > 0 and not args.ignore_data_skip: + epoch_dataloader = skip_first_batches(epoch_dataloader, spec.steps_to_skip) + step = spec.steps_to_skip - 1 rng_to_sync = True - elif steps_trained_in_current_epoch == 0: - self._load_rng_state(resume_from_checkpoint) + elif spec.steps_to_skip == 0: + self._load_rng_state(spec.resume_from_checkpoint) if hasattr(epoch_dataloader, "set_epoch"): - epoch_dataloader.set_epoch(epoch) + epoch_dataloader.set_epoch(int(epoch)) epoch_iterator = iter(epoch_dataloader) # We chunkify the epoch iterator into gradient accumulation steps `n` batches @@ -1787,7 +2074,7 @@ def _run_epoch( self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() if rng_to_sync: - self._load_rng_state(resume_from_checkpoint) + self._load_rng_state(spec.resume_from_checkpoint) rng_to_sync = False if step % args.gradient_accumulation_steps == 0: From af3c4db15f08e7cb1e419c4b0092495a3b7a1b87 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Feb 2026 12:36:08 -0500 Subject: [PATCH 3/5] commit the tests too --- tests/trainer/test_data_producer.py | 580 ++++++++++++++++++++++++++++ 1 file changed, 580 insertions(+) create mode 100644 tests/trainer/test_data_producer.py diff --git a/tests/trainer/test_data_producer.py b/tests/trainer/test_data_producer.py new file mode 100644 index 000000000000..9dad01d76494 --- /dev/null +++ b/tests/trainer/test_data_producer.py @@ -0,0 +1,580 @@ +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the DataProducer protocol and online training support.""" + +import tempfile +import unittest + +import numpy as np +import torch +from torch import nn +from torch.utils.data import Dataset, IterableDataset + +from transformers import Trainer, TrainingArguments +from transformers.data_producer import ( + AsyncDataProducer, + BaseDataProducer, + DataProducer, + DataProducerCallback, + PreferencePairDataset, + ProducerConfig, + RolloutDataset, +) + + +# --------------------------------------------------------------------------- +# Test fixtures +# --------------------------------------------------------------------------- + + +class SimpleDataset(Dataset): + """A minimal dataset that returns (input_x, labels) pairs.""" + + def __init__(self, length=16, seed=42): + rng = np.random.RandomState(seed) + self.x = rng.normal(size=(length,)).astype(np.float32) + self.y = (2.0 * self.x + 3.0 + rng.normal(scale=0.1, size=(length,))).astype(np.float32) + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return {"input_x": self.x[idx], "labels": self.y[idx]} + + +class SimpleIterableDataset(IterableDataset): + """An IterableDataset that yields a fixed number of items.""" + + def __init__(self, length=16, seed=42): + self.length = length + self.seed = seed + + def __iter__(self): + rng = np.random.RandomState(self.seed) + for _ in range(self.length): + x = np.float32(rng.normal()) + y = np.float32(2.0 * x + 3.0) + yield {"input_x": x, "labels": y} + + +class RegressionModel(nn.Module): + """A trivial y = ax + b model for testing.""" + + def __init__(self): + super().__init__() + self.a = nn.Parameter(torch.tensor(0.0)) + self.b = nn.Parameter(torch.tensor(0.0)) + self.config = None + + def forward(self, input_x, labels=None, **kwargs): + y = input_x * self.a + self.b + if labels is None: + return (y,) + loss = nn.functional.mse_loss(y, labels) + return (loss, y) + + +class CountingProducer(BaseDataProducer): + """A DataProducer that counts how many times produce() is called.""" + + def __init__(self, config=None, dataset_size=16): + super().__init__(config or ProducerConfig()) + self.call_count = 0 + self.global_steps_seen = [] + self.dataset_size = dataset_size + + def produce(self, model, global_step, **kwargs): + self.call_count += 1 + self.global_steps_seen.append(global_step) + return SimpleDataset(length=self.dataset_size) + + +class IterableProducer(BaseDataProducer): + """A DataProducer that returns an IterableDataset.""" + + def __init__(self, config=None, dataset_size=16): + super().__init__(config or ProducerConfig()) + self.dataset_size = dataset_size + + def produce(self, model, global_step, **kwargs): + return SimpleIterableDataset(length=self.dataset_size) + + +class LifecycleTrackingProducer(BaseDataProducer): + """Tracks lifecycle hook calls.""" + + def __init__(self, config=None): + super().__init__(config or ProducerConfig(max_rollouts=2)) + self.rollout_begins = [] + self.rollout_ends = [] + self.produce_calls = [] + + def on_rollout_begin(self, global_step): + self.rollout_begins.append(global_step) + + def on_rollout_end(self, dataset, global_step): + self.rollout_ends.append(global_step) + + def produce(self, model, global_step, **kwargs): + self.produce_calls.append(global_step) + return SimpleDataset(length=8) + + +def _make_trainer(model=None, data_producer=None, max_steps=10, **kwargs): + """Helper to create a Trainer with a DataProducer.""" + if model is None: + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=max_steps, + per_device_train_batch_size=4, + logging_steps=1, + save_strategy="no", + report_to="none", + use_cpu=True, + **kwargs, + ) + trainer = Trainer( + model=model, + args=args, + data_producer=data_producer, + ) + yield trainer + + +# --------------------------------------------------------------------------- +# Unit tests: data_producer.py classes +# --------------------------------------------------------------------------- + + +class TestProducerConfig(unittest.TestCase): + def test_defaults(self): + config = ProducerConfig() + self.assertEqual(config.mini_epochs, 1) + self.assertIsNone(config.max_rollouts) + self.assertFalse(config.async_prefetch) + self.assertTrue(config.eval_during_produce) + self.assertFalse(config.empty_cache_before_produce) + self.assertFalse(config.empty_cache_after_produce) + + def test_custom_values(self): + config = ProducerConfig(mini_epochs=3, max_rollouts=50, async_prefetch=True) + self.assertEqual(config.mini_epochs, 3) + self.assertEqual(config.max_rollouts, 50) + self.assertTrue(config.async_prefetch) + + +class TestRolloutDataset(unittest.TestCase): + def test_basic(self): + prompts = [torch.tensor([1, 2, 3])] * 4 + completions = [torch.tensor([4, 5, 6])] * 4 + rewards = [1.0, 0.5, 0.8, 0.3] + ds = RolloutDataset(prompts=prompts, completions=completions, rewards=rewards) + self.assertEqual(len(ds), 4) + item = ds[0] + self.assertIn("prompt", item) + self.assertIn("completion", item) + self.assertIn("reward", item) + + def test_with_extras(self): + prompts = [torch.tensor([1])] * 3 + completions = [torch.tensor([2])] * 3 + rewards = [1.0, 0.5, 0.8] + extras = {"advantages": [0.1, 0.2, 0.3]} + ds = RolloutDataset(prompts=prompts, completions=completions, rewards=rewards, extras=extras) + item = ds[1] + self.assertAlmostEqual(item["advantages"], 0.2) + + def test_length_mismatch_raises(self): + with self.assertRaises(AssertionError): + RolloutDataset(prompts=[1, 2], completions=[1], rewards=[1, 2]) + + +class TestPreferencePairDataset(unittest.TestCase): + def test_basic(self): + prompts = [torch.tensor([1])] * 3 + chosen = [torch.tensor([2])] * 3 + rejected = [torch.tensor([3])] * 3 + ds = PreferencePairDataset(prompts=prompts, chosen=chosen, rejected=rejected) + self.assertEqual(len(ds), 3) + item = ds[0] + self.assertIn("prompt", item) + self.assertIn("chosen", item) + self.assertIn("rejected", item) + + def test_length_mismatch_raises(self): + with self.assertRaises(AssertionError): + PreferencePairDataset(prompts=[1, 2], chosen=[1], rejected=[1, 2]) + + +class TestBaseDataProducer(unittest.TestCase): + def test_default_config(self): + producer = CountingProducer() + self.assertIsNotNone(producer.config) + self.assertEqual(producer.config.mini_epochs, 1) + + def test_custom_config(self): + config = ProducerConfig(mini_epochs=3) + producer = CountingProducer(config=config) + self.assertEqual(producer.config.mini_epochs, 3) + + def test_lifecycle_hooks_are_noop(self): + producer = CountingProducer() + # Should not raise + producer.on_rollout_begin(global_step=0) + producer.on_rollout_end(dataset=SimpleDataset(), global_step=0) + + +class TestAsyncDataProducer(unittest.TestCase): + def test_wraps_inner(self): + inner = CountingProducer(config=ProducerConfig(max_rollouts=5)) + async_producer = AsyncDataProducer(inner) + self.assertEqual(async_producer.config.max_rollouts, 5) + + def test_first_call_synchronous(self): + inner = CountingProducer(config=ProducerConfig(max_rollouts=5)) + async_producer = AsyncDataProducer(inner) + model = RegressionModel() + dataset = async_producer.produce(model=model, global_step=0) + self.assertIsNotNone(dataset) + self.assertEqual(inner.call_count, 2) # 1 sync + 1 prefetch started + + def test_forwards_lifecycle_hooks(self): + inner = LifecycleTrackingProducer() + async_producer = AsyncDataProducer(inner) + async_producer.on_rollout_begin(global_step=5) + self.assertEqual(inner.rollout_begins, [5]) + + +class TestDataProducerCallback(unittest.TestCase): + def test_is_trainer_callback(self): + from transformers.trainer_callback import TrainerCallback + + producer = CountingProducer() + callback = DataProducerCallback(producer) + self.assertIsInstance(callback, TrainerCallback) + + +# --------------------------------------------------------------------------- +# Integration tests: Trainer + DataProducer +# --------------------------------------------------------------------------- + + +class TestTrainerWithDataProducer(unittest.TestCase): + def test_invalid_data_producer_type(self): + """Passing a non-DataProducer should raise TypeError.""" + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, max_steps=1, report_to="none", use_cpu=True + ) + with self.assertRaises(TypeError): + Trainer(model=model, args=args, data_producer="not a producer") + + def test_basic_online_training(self): + """DataProducer with max_rollouts=3 should train successfully.""" + producer = CountingProducer( + config=ProducerConfig(max_rollouts=3), + dataset_size=8, + ) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=6, + per_device_train_batch_size=4, + logging_steps=1, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + trainer = Trainer(model=model, args=args, data_producer=producer) + result = trainer.train() + + # 3 rollouts: 1 in compute_plan + 2 in iter_epochs + self.assertEqual(producer.call_count, 3) + self.assertEqual(result.global_step, 6) + + def test_mini_epochs(self): + """mini_epochs=2 should yield 2 training passes per rollout.""" + producer = CountingProducer( + config=ProducerConfig(mini_epochs=2, max_rollouts=2), + dataset_size=8, + ) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=8, # 2 steps/epoch × 2 mini_epochs × 2 rollouts = 8 + per_device_train_batch_size=4, + logging_steps=1, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + trainer = Trainer(model=model, args=args, data_producer=producer) + result = trainer.train() + + # 2 rollouts: 1 in compute_plan + 1 in iter_epochs + self.assertEqual(producer.call_count, 2) + self.assertEqual(result.global_step, 8) + + def test_max_steps_stops_training(self): + """Training should stop at max_steps even if max_rollouts allows more.""" + producer = CountingProducer( + config=ProducerConfig(max_rollouts=100), + dataset_size=8, + ) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=4, + per_device_train_batch_size=4, + logging_steps=1, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + trainer = Trainer(model=model, args=args, data_producer=producer) + result = trainer.train() + self.assertEqual(result.global_step, 4) + + def test_lifecycle_hooks_called(self): + """on_rollout_begin and on_rollout_end should be called for each produce().""" + producer = LifecycleTrackingProducer( + config=ProducerConfig(max_rollouts=2), + ) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=4, + per_device_train_batch_size=4, + logging_steps=1, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + trainer = Trainer(model=model, args=args, data_producer=producer) + trainer.train() + + # 2 produce calls: 1 in compute_plan + 1 in iter_epochs + self.assertEqual(len(producer.rollout_begins), 2) + self.assertEqual(len(producer.rollout_ends), 2) + self.assertEqual(len(producer.produce_calls), 2) + + def test_eval_during_produce(self): + """Model should be in eval mode during produce() if config says so.""" + model_modes = [] + + class ModeTrackingProducer(BaseDataProducer): + def produce(self, model, global_step, **kwargs): + model_modes.append(model.training) + return SimpleDataset(length=8) + + producer = ModeTrackingProducer( + config=ProducerConfig(max_rollouts=2, eval_during_produce=True) + ) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=4, + per_device_train_batch_size=4, + logging_steps=1, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + trainer = Trainer(model=model, args=args, data_producer=producer) + trainer.train() + + # Model should have been in eval mode during produce + for mode in model_modes: + self.assertFalse(mode, "Model should be in eval mode during produce()") + + def test_no_eval_during_produce(self): + """Model should stay in training mode if eval_during_produce=False.""" + model_modes = [] + + class ModeTrackingProducer(BaseDataProducer): + def produce(self, model, global_step, **kwargs): + model_modes.append(model.training) + return SimpleDataset(length=8) + + producer = ModeTrackingProducer( + config=ProducerConfig(max_rollouts=2, eval_during_produce=False) + ) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=4, + per_device_train_batch_size=4, + logging_steps=1, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + trainer = Trainer(model=model, args=args, data_producer=producer) + trainer.train() + + # All calls after the first (compute_plan, model not yet in train mode) + # should have model in training mode + # The first produce is in compute_plan before model.train(), so skip it + for mode in model_modes[1:]: + self.assertTrue(mode, "Model should be in training mode during produce()") + + def test_async_prefetch_wrapping(self): + """Setting async_prefetch=True should wrap the producer.""" + producer = CountingProducer( + config=ProducerConfig(max_rollouts=2, async_prefetch=True), + dataset_size=8, + ) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=4, + per_device_train_batch_size=4, + logging_steps=1, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + trainer = Trainer(model=model, args=args, data_producer=producer) + self.assertIsInstance(trainer.data_producer, AsyncDataProducer) + + def test_no_data_producer_uses_static_path(self): + """Without data_producer, the static training path should work.""" + model = RegressionModel() + ds = SimpleDataset(length=16) + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + num_train_epochs=2, + per_device_train_batch_size=4, + logging_steps=1, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + trainer = Trainer(model=model, args=args, train_dataset=ds) + result = trainer.train() + self.assertGreater(result.global_step, 0) + + def test_requires_max_steps_or_max_rollouts(self): + """Without max_steps or max_rollouts, should raise ValueError.""" + producer = CountingProducer( + config=ProducerConfig(max_rollouts=None), + dataset_size=8, + ) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=-1, + per_device_train_batch_size=4, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + trainer = Trainer(model=model, args=args, data_producer=producer) + with self.assertRaises(ValueError): + trainer.train() + + def test_iterable_dataset_warning(self): + """IterableDataset with mini_epochs > 1 should log a warning.""" + import logging + + producer = IterableProducer( + config=ProducerConfig(mini_epochs=2, max_rollouts=1), + dataset_size=8, + ) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=4, + per_device_train_batch_size=4, + logging_steps=1, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + trainer = Trainer(model=model, args=args, data_producer=producer) + with self.assertLogs("transformers.trainer", level=logging.WARNING) as cm: + trainer.train() + warning_found = any("IterableDataset" in msg and "mini_epochs" in msg for msg in cm.output) + self.assertTrue(warning_found, "Expected warning about IterableDataset + mini_epochs") + + def test_produce_receives_kwargs(self): + """produce() should receive processing_class and accelerator.""" + received_kwargs = {} + + class KwargsTrackingProducer(BaseDataProducer): + def produce(self, model, global_step, processing_class=None, accelerator=None, args=None, **kwargs): + received_kwargs["processing_class"] = processing_class + received_kwargs["accelerator"] = accelerator + received_kwargs["args"] = args + return SimpleDataset(length=8) + + producer = KwargsTrackingProducer(config=ProducerConfig(max_rollouts=1)) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=2, + per_device_train_batch_size=4, + logging_steps=1, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + trainer = Trainer(model=model, args=args, data_producer=producer, processing_class="test_tokenizer") + trainer.train() + + self.assertEqual(received_kwargs["processing_class"], "test_tokenizer") + self.assertIsNotNone(received_kwargs["accelerator"]) + self.assertIsNotNone(received_kwargs["args"]) + + def test_loss_decreases_with_online_training(self): + """Online training should produce decreasing loss over steps.""" + producer = CountingProducer( + config=ProducerConfig(max_rollouts=5), + dataset_size=16, + ) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=20, + per_device_train_batch_size=4, + learning_rate=0.1, + logging_steps=5, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + trainer = Trainer(model=model, args=args, data_producer=producer) + result = trainer.train() + self.assertEqual(result.global_step, 20) + # Loss should be finite + self.assertTrue(np.isfinite(result.training_loss)) + + +if __name__ == "__main__": + unittest.main() From a2bfec724b6e4fba331d8d8c694c6af9f9aef891 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Feb 2026 12:59:01 -0500 Subject: [PATCH 4/5] also handle eval and test datasets with dataproducers --- src/transformers/trainer.py | 60 +++++++- tests/trainer/test_data_producer.py | 221 ++++++++++++++++++++++++++++ 2 files changed, 273 insertions(+), 8 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a343694c319d..65710e55f0da 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -642,6 +642,8 @@ def __init__( optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None, preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, data_producer: DataProducer | None = None, + eval_data_producer: DataProducer | None = None, + test_data_producer: DataProducer | None = None, ): # Init flow: # 1. Args & seed – defaults, determinism @@ -812,6 +814,18 @@ def __init__( if data_producer.config.async_prefetch: self.data_producer = AsyncDataProducer(data_producer) + self.eval_data_producer = eval_data_producer + if eval_data_producer is not None and not isinstance(eval_data_producer, DataProducer): + raise TypeError( + f"`eval_data_producer` must be a DataProducer instance, got {type(eval_data_producer)}" + ) + + self.test_data_producer = test_data_producer + if test_data_producer is not None and not isinstance(test_data_producer, DataProducer): + raise TypeError( + f"`test_data_producer` must be a DataProducer instance, got {type(test_data_producer)}" + ) + # Callables self.compute_loss_func = compute_loss_func self.compute_metrics = compute_metrics @@ -911,9 +925,16 @@ def _validate_args(self) -> None: " boolean argument which will be triggered after the last batch of the eval set to signal that the" " summary statistics should be returned by the function." ) - if args.eval_strategy is not None and args.eval_strategy != "no" and self.eval_dataset is None: + if ( + args.eval_strategy is not None + and args.eval_strategy != "no" + and self.eval_dataset is None + and getattr(self, "eval_data_producer", None) is None + ): raise ValueError( - f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. " + f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` " + f"or `eval_data_producer` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an " + f"`eval_dataset` or `eval_data_producer`." ) if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end: if args.metric_for_best_model is None: @@ -1689,13 +1710,17 @@ def _create_epoch_source(self) -> _EpochSource: return _StaticEpochSource() @torch.no_grad() - def _produce_data(self, model: nn.Module) -> Dataset: - """Call the DataProducer to generate a fresh training dataset. + def _produce_data(self, model: nn.Module, producer: DataProducer | None = None) -> Dataset: + """Call a DataProducer to generate a fresh dataset. + + Args: + model: The current model (may be switched to eval mode). + producer: The producer to call. Defaults to ``self.data_producer``. Handles eval/train mode switching and CUDA cache clearing per the producer's config. """ - producer = self.data_producer + producer = producer or self.data_producer config = producer.config if hasattr(producer, "on_rollout_begin"): @@ -2897,6 +2922,11 @@ def evaluate( # handle multiple eval datasets override = eval_dataset is not None eval_dataset = eval_dataset if override else self.eval_dataset + + # Fall back to eval_data_producer if no static dataset + if eval_dataset is None and getattr(self, "eval_data_producer", None) is not None: + eval_dataset = self._produce_data(self.model, producer=self.eval_data_producer) + if isinstance(eval_dataset, dict): metrics = {} for eval_dataset_name, _eval_dataset in eval_dataset.items(): @@ -3160,7 +3190,10 @@ def evaluation_loop( return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) def predict( - self, test_dataset: Dataset, ignore_keys: list[str] | None = None, metric_key_prefix: str = "test" + self, + test_dataset: Dataset | None = None, + ignore_keys: list[str] | None = None, + metric_key_prefix: str = "test", ) -> PredictionOutput: """ Run prediction and returns predictions and potential metrics. @@ -3169,9 +3202,11 @@ def predict( will also return metrics, like in `evaluate()`. Args: - test_dataset (`Dataset`): + test_dataset (`Dataset`, *optional*): Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the - `model.forward()` method are automatically removed. Has to implement the method `__len__` + `model.forward()` method are automatically removed. Has to implement the method `__len__`. + If not provided and a ``test_data_producer`` was set on the Trainer, a fresh dataset will be + produced automatically. ignore_keys (`list[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. @@ -3194,6 +3229,15 @@ def predict( - metrics (`dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained labels). """ + # Fall back to test_data_producer if no explicit dataset + if test_dataset is None and getattr(self, "test_data_producer", None) is not None: + test_dataset = self._produce_data(self.model, producer=self.test_data_producer) + if test_dataset is None: + raise ValueError( + "`predict()` requires either a `test_dataset` argument or " + "a `test_data_producer` set on the Trainer." + ) + # memory metrics - must set up as early as possible self._memory_tracker.start() diff --git a/tests/trainer/test_data_producer.py b/tests/trainer/test_data_producer.py index 9dad01d76494..29d98baa7ce4 100644 --- a/tests/trainer/test_data_producer.py +++ b/tests/trainer/test_data_producer.py @@ -576,5 +576,226 @@ def test_loss_decreases_with_online_training(self): self.assertTrue(np.isfinite(result.training_loss)) +# --------------------------------------------------------------------------- +# Integration tests: eval_data_producer & test_data_producer +# --------------------------------------------------------------------------- + + +class TestTrainerWithEvalDataProducer(unittest.TestCase): + def test_eval_data_producer_basic(self): + """eval_data_producer.produce() should be called during evaluate().""" + eval_producer = CountingProducer(config=ProducerConfig(), dataset_size=8) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=1, + per_device_train_batch_size=4, + eval_strategy="no", + save_strategy="no", + report_to="none", + use_cpu=True, + ) + train_ds = SimpleDataset(length=8) + trainer = Trainer( + model=model, + args=args, + train_dataset=train_ds, + eval_data_producer=eval_producer, + ) + metrics = trainer.evaluate() + self.assertEqual(eval_producer.call_count, 1) + self.assertIn("eval_loss", metrics) + + def test_eval_data_producer_during_training(self): + """eval_data_producer should be called at eval steps during training.""" + eval_producer = CountingProducer(config=ProducerConfig(), dataset_size=8) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=4, + per_device_train_batch_size=4, + eval_strategy="steps", + eval_steps=2, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + train_ds = SimpleDataset(length=16) + trainer = Trainer( + model=model, + args=args, + train_dataset=train_ds, + eval_data_producer=eval_producer, + ) + trainer.train() + # eval at step 2 and step 4 = 2 calls + self.assertEqual(eval_producer.call_count, 2) + + def test_explicit_eval_dataset_overrides_producer(self): + """Passing eval_dataset to evaluate() should override eval_data_producer.""" + eval_producer = CountingProducer(config=ProducerConfig(), dataset_size=8) + model = RegressionModel() + explicit_ds = SimpleDataset(length=8) + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=1, + per_device_train_batch_size=4, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + train_ds = SimpleDataset(length=8) + trainer = Trainer( + model=model, + args=args, + train_dataset=train_ds, + eval_data_producer=eval_producer, + ) + metrics = trainer.evaluate(eval_dataset=explicit_ds) + # Producer should NOT have been called since explicit dataset was provided + self.assertEqual(eval_producer.call_count, 0) + self.assertIn("eval_loss", metrics) + + def test_static_eval_dataset_takes_priority_over_producer(self): + """self.eval_dataset should take priority over eval_data_producer.""" + eval_producer = CountingProducer(config=ProducerConfig(), dataset_size=8) + model = RegressionModel() + static_eval_ds = SimpleDataset(length=8) + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=1, + per_device_train_batch_size=4, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + train_ds = SimpleDataset(length=8) + trainer = Trainer( + model=model, + args=args, + train_dataset=train_ds, + eval_dataset=static_eval_ds, + eval_data_producer=eval_producer, + ) + metrics = trainer.evaluate() + # Producer should NOT have been called since self.eval_dataset exists + self.assertEqual(eval_producer.call_count, 0) + self.assertIn("eval_loss", metrics) + + def test_invalid_eval_data_producer_type(self): + """Passing a non-DataProducer as eval_data_producer should raise TypeError.""" + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, max_steps=1, report_to="none", use_cpu=True + ) + with self.assertRaises(TypeError): + Trainer(model=model, args=args, eval_data_producer="not a producer") + + def test_eval_strategy_accepts_eval_data_producer(self): + """eval_strategy should not raise when eval_data_producer is set but eval_dataset is None.""" + eval_producer = CountingProducer(config=ProducerConfig(), dataset_size=8) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=2, + per_device_train_batch_size=4, + eval_strategy="steps", + eval_steps=1, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + train_ds = SimpleDataset(length=8) + # Should NOT raise ValueError about missing eval_dataset + trainer = Trainer( + model=model, + args=args, + train_dataset=train_ds, + eval_data_producer=eval_producer, + ) + trainer.train() + self.assertGreater(eval_producer.call_count, 0) + + +class TestTrainerWithTestDataProducer(unittest.TestCase): + def test_test_data_producer_basic(self): + """test_data_producer.produce() should be called during predict().""" + test_producer = CountingProducer(config=ProducerConfig(), dataset_size=8) + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=1, + per_device_train_batch_size=4, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + train_ds = SimpleDataset(length=8) + trainer = Trainer( + model=model, + args=args, + train_dataset=train_ds, + test_data_producer=test_producer, + ) + output = trainer.predict() + self.assertEqual(test_producer.call_count, 1) + self.assertIsNotNone(output.predictions) + + def test_explicit_test_dataset_overrides_producer(self): + """Passing test_dataset to predict() should override test_data_producer.""" + test_producer = CountingProducer(config=ProducerConfig(), dataset_size=8) + model = RegressionModel() + explicit_ds = SimpleDataset(length=8) + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, + max_steps=1, + per_device_train_batch_size=4, + save_strategy="no", + report_to="none", + use_cpu=True, + ) + train_ds = SimpleDataset(length=8) + trainer = Trainer( + model=model, + args=args, + train_dataset=train_ds, + test_data_producer=test_producer, + ) + output = trainer.predict(test_dataset=explicit_ds) + # Producer should NOT have been called + self.assertEqual(test_producer.call_count, 0) + self.assertIsNotNone(output.predictions) + + def test_predict_raises_without_dataset_or_producer(self): + """predict() with no test_dataset and no test_data_producer should raise.""" + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, max_steps=1, report_to="none", use_cpu=True + ) + train_ds = SimpleDataset(length=8) + trainer = Trainer(model=model, args=args, train_dataset=train_ds) + with self.assertRaises(ValueError): + trainer.predict() + + def test_invalid_test_data_producer_type(self): + """Passing a non-DataProducer as test_data_producer should raise TypeError.""" + model = RegressionModel() + with tempfile.TemporaryDirectory() as tmp: + args = TrainingArguments( + output_dir=tmp, max_steps=1, report_to="none", use_cpu=True + ) + with self.assertRaises(TypeError): + Trainer(model=model, args=args, test_data_producer=42) + + if __name__ == "__main__": unittest.main() From 5f0c75e5fd0c97cfffdc99f2d6801c8d7d368019 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Feb 2026 16:31:24 -0500 Subject: [PATCH 5/5] chore: lint --- src/transformers/data_producer.py | 3 ++- .../models/metaclip_2/convert_metaclip_2_to_hf.py | 1 - src/transformers/trainer.py | 4 ++-- tests/trainer/test_data_producer.py | 1 - 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/transformers/data_producer.py b/src/transformers/data_producer.py index 835da50802ec..b3ccf6cb1304 100644 --- a/src/transformers/data_producer.py +++ b/src/transformers/data_producer.py @@ -47,7 +47,7 @@ def produce(self, model, global_step, **kwargs): import logging import threading from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Any import torch @@ -55,6 +55,7 @@ def produce(self, model, global_step, **kwargs): from .trainer_callback import TrainerCallback + if TYPE_CHECKING: from torch import nn diff --git a/src/transformers/models/metaclip_2/convert_metaclip_2_to_hf.py b/src/transformers/models/metaclip_2/convert_metaclip_2_to_hf.py index ae3a682fdb58..6db36199dca3 100644 --- a/src/transformers/models/metaclip_2/convert_metaclip_2_to_hf.py +++ b/src/transformers/models/metaclip_2/convert_metaclip_2_to_hf.py @@ -25,7 +25,6 @@ # Import MetaCLIP modules from src.mini_clip.factory import create_model_and_transforms - from transformers import ( AutoTokenizer, CLIPImageProcessor, diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 65710e55f0da..d8808eb97b19 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -30,7 +30,7 @@ import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterator, Mapping -from dataclasses import dataclass, field +from dataclasses import dataclass from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any @@ -57,7 +57,7 @@ from . import __version__ from .configuration_utils import PreTrainedConfig from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator -from .data_producer import AsyncDataProducer, DataProducer, DataProducerCallback, ProducerConfig +from .data_producer import AsyncDataProducer, DataProducer, DataProducerCallback from .debug_utils import DebugOption, DebugUnderflowOverflow from .feature_extraction_sequence_utils import SequenceFeatureExtractor from .feature_extraction_utils import FeatureExtractionMixin diff --git a/tests/trainer/test_data_producer.py b/tests/trainer/test_data_producer.py index 29d98baa7ce4..7d9d46fe8cdb 100644 --- a/tests/trainer/test_data_producer.py +++ b/tests/trainer/test_data_producer.py @@ -25,7 +25,6 @@ from transformers.data_producer import ( AsyncDataProducer, BaseDataProducer, - DataProducer, DataProducerCallback, PreferencePairDataset, ProducerConfig,