-
Notifications
You must be signed in to change notification settings - Fork 56
Description
Summary
When training with multiple GPUs (e.g. 4) and the default split_batches=False, the learning rate schedule finishes too early. For instance, in my case when training 10k steps total, LR reaches its end value at about 2.5k steps instead of the configured 10k steps.
Cause
In accelerate/scheduler.py, AcceleratedScheduler.step() behaves as follows when split_batches=False:
Each call toscheduler.step() triggers num_processes inner scheduler steps (e.g. 4 steps when using 4 GPUs).
The underlying scheduler (e.g. LinearLR) is still created with total_iters=steps (e.g. 10,000).
So after 10,000 / 4 = 2,500 optimization steps, the inner scheduler has already taken 10,000 steps and the LR has decayed to the end value.
So the schedule length is effectively steps / num_processes optimization steps instead of steps.
Expected behavior
LR should decay over the full steps optimization steps (e.g. 10k), regardless of num_processes, when the user configures steps as the total number of optimizer steps.
Workaround
When creating the scheduler, set the scheduler’s total steps to steps * num_processes (e.g. total_iters=40000 for 10k steps on 4 GPUs). Then, with Accelerate stepping the inner scheduler num_processes times per optimization step, the LR correctly decays over 10k optimization steps.
Environment
accelerate (version with AcceleratedScheduler in scheduler.py)
Multi-GPU DDP, split_batches=False (default)
Fixed number of optimization steps (e.g. steps=10000), not derived from dataloader length
Relevant code
accelerate/scheduler.py (lines 72–82): when split_batches is False, the wrapper runs self.scheduler.step() num_processes times per call.
Suggested Fix:
def _create_scheduler(self, optimizer: torch.optim.Optimizer) -> LRScheduler | None:
"""Create learning rate scheduler based on config."""
scheduler_type = self._config.optimization.scheduler_type
steps = self._config.optimization.steps
params = self._config.optimization.scheduler_params or {}
# When split_batches=False (default), Accelerator's AcceleratedScheduler steps the inner
# scheduler num_processes times per our one step(), so we need total_iters = steps * num_processes
# for the LR to decay over the intended `steps` optimization steps (not steps/num_processes).
num_processes = self._accelerator.num_processes
scheduler_steps = steps * num_processes
if scheduler_type is None:
return None
if scheduler_type == "linear":
scheduler = LinearLR(
optimizer,
start_factor=params.pop("start_factor", 1.0),
end_factor=params.pop("end_factor", 0.1),
total_iters=scheduler_steps,
**params,
)
elif scheduler_type == "cosine":
scheduler = CosineAnnealingLR(
optimizer,
T_max=scheduler_steps,
eta_min=params.get("eta_min", 0),
**params,
)
elif scheduler_type == "cosine_with_restarts":
scheduler = CosineAnnealingWarmRestarts(
optimizer,
T_0=params.pop("T_0", scheduler_steps // 4), # First restart cycle length
T_mult=params.pop("T_mult", 1), # Multiplicative factor for cycle lengths
eta_min=params.pop("eta_min", 5e-5),
**params,
)
elif scheduler_type == "polynomial":
scheduler = PolynomialLR(
optimizer,
total_iters=scheduler_steps,
power=params.pop("power", 1.0),
**params,
)
elif scheduler_type == "step":
scheduler = StepLR(
optimizer,
step_size=params.pop("step_size", scheduler_steps // 2),
gamma=params.pop("gamma", 0.1),
**params,
)
elif scheduler_type == "constant":
scheduler = None
else:
raise ValueError(f"Unknown scheduler type: {scheduler_type}")
return scheduler