Skip to content

Implement Warmup-Stable-Decay (WSD) Learning Rate Scheduler #2882

@bzantium

Description

@bzantium

What problem are you trying to solve?
Currently, MaxText only supports a Cosine decay schedule via create_learning_rate_schedule. While effective, Cosine decay requires the total training steps (learning_rate_schedule_steps) to be fixed in advance. This makes it difficult to extend training if the model hasn't converged or to continue pre-training continuously (infinite training horizons).

Why is this problem important?
The Warmup-Stable-Decay (WSD) scheduler has become a standard in recent LLM pre-training.

Unlike Cosine, WSD maintains a stable (constant) learning rate after warmup. When a specific checkpoint is chosen for convergence, a short decay phase is applied. This allows:

  1. Continuous Training: One "stable" checkpoint can be used to branch out into multiple decay experiments without restarting.
  2. Optimality: Research shows it achieves comparable or better loss than fixed cosine schedules.
  3. Flexibility: It decouples the learning rate schedule from the total number of steps during the bulk of training.

Describe your requested feature or solution
I propose adding a WSD option to create_learning_rate_schedule. This fits naturally into the current optax.join_schedules pattern used in maxtext/utils.py.

We can introduce new config parameters:

  • wsd_learning_rate_final_fraction: To determine the final LR during the decay phase.
  • wsd_decay_steps_fraction: To determine the percentage of steps used for the decay phase (e.g., 10-20%).

Proposed Implementation Logic:
The logic would mirror the existing implementation but with three distinct phases before the zero-padding:

  1. Warmup: Linear warmup to config.learning_rate.
  2. Stable: Constant schedule at config.learning_rate.
  3. Decay: Decay from config.learning_rate to config.learning_rate * wsd_learning_rate_final_fraction.
def create_learning_rate_schedule(config):
  # ... existing setup ...
  
  lr = config.learning_rate
  
  if config.lr_schedule_type == 'wsd':
      # WSD Parameters
      decay_final_lr = lr * config.wsd_learning_rate_final_fraction
      
      warmup_steps = int(config.learning_rate_schedule_steps * config.warmup_steps_fraction)
      decay_steps = int(config.learning_rate_schedule_steps * config.wsd_decay_steps_fraction)
      stable_steps = config.learning_rate_schedule_steps - warmup_steps - decay_steps
      
      # 1. Warmup
      warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps)
      
      # 2. Stable
      stable_schedule = optax.constant_schedule(lr)
      
      # 3. Decay (e.g., Linear or Cosine to final fraction)
      decay_schedule = optax.linear_schedule(init_value=lr, end_value=decay_final_lr, transition_steps=decay_steps)
      
      pieces = [warmup_schedule, stable_schedule, decay_schedule]
      boundaries = [warmup_steps, warmup_steps + stable_steps]
      
      # Handle constant_zero_steps logic (existing logic)...
      constant_zero_steps = config.steps - config.learning_rate_schedule_steps
      if constant_zero_steps > 0:
          pieces.append(optax.constant_schedule(0.0))
          boundaries.append(warmup_steps + stable_steps + decay_steps)

      return optax.join_schedules(pieces, boundaries)

  # ... existing cosine logic ...

Describe alternatives you’ve considered
Sticking with Cosine decay limits our ability to produce flexible intermediate checkpoints for varying downstream compute budgets.

Additional context or examples
Relevant Papers:

Adoption:

  • Widely used in leading open source frameworks like Megatron-LM and verl.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions