Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions docs/batch_size_scheduler/QUICKSTART.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Batch Size Scheduler - Quick Start

## TL;DR

Add `[batch_size_scheduler]` to your config TOML:

```toml
# Linear rampup (recommended for large models)
[batch_size_scheduler]
mode = "linear"
start_batch_size = 1024 # Start small
rampup_samples = 1000000000 # Ramp over 1B samples

[training]
global_batch_size = 4096 # Target batch size
```

That's it! Old configs without `[batch_size_scheduler]` continue to work unchanged.

## Quick Reference

| Mode | Config | Behavior |
|------|--------|----------|
| **Constant** | `mode = "constant"` or omit section | Fixed batch size |
| **Linear** | `mode = "linear"` | Smooth ramp: start → end |
| **Increment** | `mode = "increment"` | Step-wise: start, start+inc, start+2*inc, ... |

## Formula

```
Linear: batch_size = start + (consumed_samples / rampup_samples) * (end - start)
Increment: batch_size = start + floor(consumed_samples / samples_per_step) * increment
```

## Common Configurations

### DeepSeek-V3 Style
```toml
[batch_size_scheduler]
mode = "linear"
start_batch_size = 3072
rampup_samples = 114746093750 # 469B tokens / 4096 seq_len
```

### Megatron Style
```toml
[batch_size_scheduler]
mode = "increment"
start_batch_size = 1024
increment = 1024
rampup_samples = 244140625 # 1B tokens / 4096 seq_len
```

## Testing

```bash
# Unit tests
python docs/batch_size_scheduler/test_unit.py

# Integration tests
bash docs/batch_size_scheduler/test_batch_size_scheduler.sh
```
190 changes: 190 additions & 0 deletions docs/batch_size_scheduler/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Batch Size Scheduler

Dynamic batch size scheduling for torchtitan, enabling batch size warmup/rampup during training.

## Overview

Batch size warmup is a technique used by large-scale training runs (DeepSeek-V3, Megatron, etc.) to improve training stability and efficiency. Instead of starting with the full batch size, training begins with a smaller batch size and gradually increases it.

**Key Design Principle**: The batch size scheduler is **completely orthogonal** to data stages and LR scheduler:

```
BatchSizeScheduler: f(consumed_samples) → batch_size
DataStageManager: f(current_step) → dataloader
LRScheduler: f(current_step) → learning_rate
```

They don't know about each other. The training loop coordinates them independently.

## Schedule Modes

### 1. Constant (Default)

Fixed batch size throughout training. This is the default behavior and maintains backward compatibility with existing configs.

```toml
[batch_size_scheduler]
mode = "constant"
# Or simply omit the [batch_size_scheduler] section entirely
```

### 2. Linear Rampup

Smooth interpolation from `start_batch_size` to `global_batch_size` over `rampup_samples`.

**Used by**: DeepSeek-V3 (3072 → 15360 over 469B tokens)

```toml
[batch_size_scheduler]
mode = "linear"
start_batch_size = 1024
rampup_samples = 1000000000 # 1B samples

[training]
global_batch_size = 4096 # Target batch size
```

**Behavior**:
```
samples: 0 -------- 500M -------- 1B -------- 2B
batch_size: 1024 ----- 2560 ------- 4096 ------ 4096
(linear interpolation) (constant)
```

### 3. Increment Rampup

Step-wise increments at regular intervals (Megatron style).

**Used by**: Megatron-LM

```toml
[batch_size_scheduler]
mode = "increment"
start_batch_size = 1024
increment = 1024
rampup_samples = 1000000000

[training]
global_batch_size = 4096
```

**Behavior**:
```
samples: 0 ---- 333M ---- 666M ---- 1B ---- 2B
batch_size: 1024 2048 3072 4096 4096
(step increases) (constant)
```

## Configuration Reference

```toml
[batch_size_scheduler]
mode = "constant" # "constant", "linear", or "increment"
start_batch_size = 0 # Starting batch size (0 = use global_batch_size)
rampup_samples = 0 # Samples over which to ramp up (0 = no rampup)
increment = 0 # For "increment" mode (0 = auto, uses start_batch_size)
```

### Constraints

- `start_batch_size` must be divisible by `local_batch_size × data_parallel_degree`
- `global_batch_size` must be divisible by `local_batch_size × data_parallel_degree`
- `increment` must be divisible by `local_batch_size × data_parallel_degree`

## How It Works

### Architecture

```
┌─────────────────────────────────────────────────────────────┐
│ Training Loop │
│ │
│ consumed_samples ──► BatchSizeManager ──► gradient_accum │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Scheduler │ │
│ │ (stateless) │ │
│ └─────────────────┘ │
│ │ │
│ ┌─────────────────┼─────────────────┐ │
│ ▼ ▼ ▼ │
│ ConstantBatchSize LinearRampup IncrementRampup │
└─────────────────────────────────────────────────────────────┘
```

## Checkpointing

The scheduler is **stateless** - only `consumed_samples` needs to be checkpointed:

```python
# Saved in checkpoint
{
"step": 1000,
"ntokens_seen": 4194304000,
"consumed_samples": 1024000, # ← This is all the scheduler needs
}
```

On resume, the scheduler automatically computes the correct batch size from `consumed_samples`.

**Backward Compatibility**: Old checkpoints without `consumed_samples` default to 0.

## Example Configurations

### DeepSeek-V3 Style

```toml
[batch_size_scheduler]
mode = "linear"
start_batch_size = 3072
rampup_samples = 114746093750 # 469B tokens / 4096 seq_len

[training]
global_batch_size = 15360
local_batch_size = 4
seq_len = 4096
```

### Megatron Style

```toml
[batch_size_scheduler]
mode = "increment"
start_batch_size = 1024
increment = 1024
rampup_samples = 244140625 # 1B tokens / 4096 seq_len

[training]
global_batch_size = 4096
local_batch_size = 4
seq_len = 4096
```

### Quick Test (Debug Model)

```toml
[batch_size_scheduler]
mode = "linear"
start_batch_size = 8
rampup_samples = 1000

[training]
global_batch_size = 32
local_batch_size = 4
steps = 100
```

## Logging

When batch size changes, you'll see:

```
[INFO] Batch size changed: 1024 -> 2048, grad_accum_steps=2
```

At initialization:

```
[INFO] Batch size scheduler: linear rampup 1024 -> 4096 over 1000000 samples
```
48 changes: 48 additions & 0 deletions docs/batch_size_scheduler/debug_checkpoint_resume.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Batch Size Scheduler Test Config - Checkpoint Resume Test
#
# Tests that batch size scheduler state is correctly restored from checkpoint.
# Run twice: first to create checkpoint, second to resume and verify batch size.
#
# Usage (create checkpoint):
# torchrun --nproc_per_node=1 -m torchtitan.train --job.config_file docs/batch_size_scheduler/debug_checkpoint_resume.toml --training.steps 30
#
# Usage (resume from checkpoint):
# torchrun --nproc_per_node=1 -m torchtitan.train --job.config_file docs/batch_size_scheduler/debug_checkpoint_resume.toml --checkpoint.initial_load_path ./outputs/batch_size_scheduler_ckpt/step-30

[job]
description = "Batch size scheduler test - checkpoint resume"
dump_folder = "./outputs/batch_size_scheduler_ckpt"

[batch_size_scheduler]
mode = "linear"
start_batch_size = 8
rampup_samples = 1000

[training]
dataset = "c4_test"
local_batch_size = 4
global_batch_size = 32
seq_len = 256
steps = 60
max_norm = 1.0

[model]
name = "llama3"
flavor = "debugmodel"
hf_assets_path = "./tests/assets/tokenizer"

[optimizer]
name = "AdamW"
lr = 1e-4

[lr_scheduler]
warmup_steps = 10

[metrics]
log_freq = 5
enable_tensorboard = false

[checkpoint]
enable = true
interval = 30
folder = "checkpoints"
42 changes: 42 additions & 0 deletions docs/batch_size_scheduler/debug_constant.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Batch Size Scheduler Test Config - Constant (Debug Model)
#
# Tests constant batch size (default behavior, backward compatible).
# This config should behave identically to configs without [batch_size_scheduler].
#
# Usage:
# torchrun --nproc_per_node=1 -m torchtitan.train --job.config_file docs/batch_size_scheduler/debug_constant.toml

[job]
description = "Batch size scheduler test - constant (baseline)"
dump_folder = "./outputs/batch_size_scheduler_constant"

[batch_size_scheduler]
mode = "constant"
# start_batch_size and rampup_samples default to 0, meaning no rampup

[training]
dataset = "c4_test"
local_batch_size = 4
global_batch_size = 32
seq_len = 256
steps = 100
max_norm = 1.0

[model]
name = "llama3"
flavor = "debugmodel"
hf_assets_path = "./tests/assets/tokenizer"

[optimizer]
name = "AdamW"
lr = 1e-4

[lr_scheduler]
warmup_steps = 10

[metrics]
log_freq = 10
enable_tensorboard = false

[checkpoint]
enable = false
44 changes: 44 additions & 0 deletions docs/batch_size_scheduler/debug_increment.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Batch Size Scheduler Test Config - Increment Rampup (Debug Model)
#
# Tests Megatron-style increment batch size rampup with a small debug model.
# Batch size ramps from 8 to 32 in increments of 8 over 400 samples.
#
# Usage:
# torchrun --nproc_per_node=1 -m torchtitan.train --job.config_file docs/batch_size_scheduler/debug_increment.toml

[job]
description = "Batch size scheduler test - increment rampup"
dump_folder = "./outputs/batch_size_scheduler_increment"

[batch_size_scheduler]
mode = "increment"
start_batch_size = 8
increment = 8
rampup_samples = 400

[training]
dataset = "c4_test"
local_batch_size = 4
global_batch_size = 32
seq_len = 256
steps = 100
max_norm = 1.0

[model]
name = "llama3"
flavor = "debugmodel"
hf_assets_path = "./tests/assets/tokenizer"

[optimizer]
name = "AdamW"
lr = 1e-4

[lr_scheduler]
warmup_steps = 10

[metrics]
log_freq = 10
enable_tensorboard = false

[checkpoint]
enable = false
Loading