-
Notifications
You must be signed in to change notification settings - Fork 443
Description
What problem are you trying to solve?
Currently, MaxText only supports a Cosine decay schedule via create_learning_rate_schedule. While effective, Cosine decay requires the total training steps (learning_rate_schedule_steps) to be fixed in advance. This makes it difficult to extend training if the model hasn't converged or to continue pre-training continuously (infinite training horizons).
Why is this problem important?
The Warmup-Stable-Decay (WSD) scheduler has become a standard in recent LLM pre-training.
Unlike Cosine, WSD maintains a stable (constant) learning rate after warmup. When a specific checkpoint is chosen for convergence, a short decay phase is applied. This allows:
- Continuous Training: One "stable" checkpoint can be used to branch out into multiple decay experiments without restarting.
- Optimality: Research shows it achieves comparable or better loss than fixed cosine schedules.
- Flexibility: It decouples the learning rate schedule from the total number of steps during the bulk of training.
Describe your requested feature or solution
I propose adding a WSD option to create_learning_rate_schedule. This fits naturally into the current optax.join_schedules pattern used in maxtext/utils.py.
We can introduce new config parameters:
wsd_learning_rate_final_fraction: To determine the final LR during the decay phase.wsd_decay_steps_fraction: To determine the percentage of steps used for the decay phase (e.g., 10-20%).
Proposed Implementation Logic:
The logic would mirror the existing implementation but with three distinct phases before the zero-padding:
- Warmup: Linear warmup to
config.learning_rate. - Stable: Constant schedule at
config.learning_rate. - Decay: Decay from
config.learning_ratetoconfig.learning_rate * wsd_learning_rate_final_fraction.
def create_learning_rate_schedule(config):
# ... existing setup ...
lr = config.learning_rate
if config.lr_schedule_type == 'wsd':
# WSD Parameters
decay_final_lr = lr * config.wsd_learning_rate_final_fraction
warmup_steps = int(config.learning_rate_schedule_steps * config.warmup_steps_fraction)
decay_steps = int(config.learning_rate_schedule_steps * config.wsd_decay_steps_fraction)
stable_steps = config.learning_rate_schedule_steps - warmup_steps - decay_steps
# 1. Warmup
warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps)
# 2. Stable
stable_schedule = optax.constant_schedule(lr)
# 3. Decay (e.g., Linear or Cosine to final fraction)
decay_schedule = optax.linear_schedule(init_value=lr, end_value=decay_final_lr, transition_steps=decay_steps)
pieces = [warmup_schedule, stable_schedule, decay_schedule]
boundaries = [warmup_steps, warmup_steps + stable_steps]
# Handle constant_zero_steps logic (existing logic)...
constant_zero_steps = config.steps - config.learning_rate_schedule_steps
if constant_zero_steps > 0:
pieces.append(optax.constant_schedule(0.0))
boundaries.append(warmup_steps + stable_steps + decay_steps)
return optax.join_schedules(pieces, boundaries)
# ... existing cosine logic ...Describe alternatives you’ve considered
Sticking with Cosine decay limits our ability to produce flexible intermediate checkpoints for varying downstream compute budgets.
Additional context or examples
Relevant Papers:
- MiniCPM: MiniCPM: Unveiling the Potential of Small Language Models with Scalable Training Strategies
- WSD Analysis: Understanding Warmup-Stable-Decay Learning Rates: A River Valley Loss Landscape Perspective
Adoption:
- Widely used in leading open source frameworks like Megatron-LM and verl.