From b7098916a81d878f6638e9b36a6cf17f5bc41bdc Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Thu, 15 Jan 2026 20:16:21 +0000 Subject: [PATCH] add initial implementation for batch size scheduler --- docs/batch_size_scheduler/QUICKSTART.md | 62 +++++ docs/batch_size_scheduler/README.md | 190 +++++++++++++++ .../debug_checkpoint_resume.toml | 48 ++++ docs/batch_size_scheduler/debug_constant.toml | 42 ++++ .../batch_size_scheduler/debug_increment.toml | 44 ++++ docs/batch_size_scheduler/debug_linear.toml | 43 ++++ .../test_batch_size_scheduler.sh | 67 ++++++ .../test_checkpoint_resume.sh | 65 ++++++ docs/batch_size_scheduler/test_unit.py | 217 ++++++++++++++++++ torchtitan/components/batch_size_scheduler.py | 195 ++++++++++++++++ torchtitan/config/job_config.py | 39 ++++ torchtitan/train.py | 78 +++++-- 12 files changed, 1076 insertions(+), 14 deletions(-) create mode 100644 docs/batch_size_scheduler/QUICKSTART.md create mode 100644 docs/batch_size_scheduler/README.md create mode 100644 docs/batch_size_scheduler/debug_checkpoint_resume.toml create mode 100644 docs/batch_size_scheduler/debug_constant.toml create mode 100644 docs/batch_size_scheduler/debug_increment.toml create mode 100644 docs/batch_size_scheduler/debug_linear.toml create mode 100755 docs/batch_size_scheduler/test_batch_size_scheduler.sh create mode 100755 docs/batch_size_scheduler/test_checkpoint_resume.sh create mode 100644 docs/batch_size_scheduler/test_unit.py create mode 100644 torchtitan/components/batch_size_scheduler.py diff --git a/docs/batch_size_scheduler/QUICKSTART.md b/docs/batch_size_scheduler/QUICKSTART.md new file mode 100644 index 0000000000..5f0841a8f1 --- /dev/null +++ b/docs/batch_size_scheduler/QUICKSTART.md @@ -0,0 +1,62 @@ +# Batch Size Scheduler - Quick Start + +## TL;DR + +Add `[batch_size_scheduler]` to your config TOML: + +```toml +# Linear rampup (recommended for large models) +[batch_size_scheduler] +mode = "linear" +start_batch_size = 1024 # Start small +rampup_samples = 1000000000 # Ramp over 1B samples + +[training] +global_batch_size = 4096 # Target batch size +``` + +That's it! Old configs without `[batch_size_scheduler]` continue to work unchanged. + +## Quick Reference + +| Mode | Config | Behavior | +|------|--------|----------| +| **Constant** | `mode = "constant"` or omit section | Fixed batch size | +| **Linear** | `mode = "linear"` | Smooth ramp: start → end | +| **Increment** | `mode = "increment"` | Step-wise: start, start+inc, start+2*inc, ... | + +## Formula + +``` +Linear: batch_size = start + (consumed_samples / rampup_samples) * (end - start) +Increment: batch_size = start + floor(consumed_samples / samples_per_step) * increment +``` + +## Common Configurations + +### DeepSeek-V3 Style +```toml +[batch_size_scheduler] +mode = "linear" +start_batch_size = 3072 +rampup_samples = 114746093750 # 469B tokens / 4096 seq_len +``` + +### Megatron Style +```toml +[batch_size_scheduler] +mode = "increment" +start_batch_size = 1024 +increment = 1024 +rampup_samples = 244140625 # 1B tokens / 4096 seq_len +``` + +## Testing + +```bash +# Unit tests +python docs/batch_size_scheduler/test_unit.py + +# Integration tests +bash docs/batch_size_scheduler/test_batch_size_scheduler.sh +``` diff --git a/docs/batch_size_scheduler/README.md b/docs/batch_size_scheduler/README.md new file mode 100644 index 0000000000..61f6e722c5 --- /dev/null +++ b/docs/batch_size_scheduler/README.md @@ -0,0 +1,190 @@ +# Batch Size Scheduler + +Dynamic batch size scheduling for torchtitan, enabling batch size warmup/rampup during training. + +## Overview + +Batch size warmup is a technique used by large-scale training runs (DeepSeek-V3, Megatron, etc.) to improve training stability and efficiency. Instead of starting with the full batch size, training begins with a smaller batch size and gradually increases it. + +**Key Design Principle**: The batch size scheduler is **completely orthogonal** to data stages and LR scheduler: + +``` +BatchSizeScheduler: f(consumed_samples) → batch_size +DataStageManager: f(current_step) → dataloader +LRScheduler: f(current_step) → learning_rate +``` + +They don't know about each other. The training loop coordinates them independently. + +## Schedule Modes + +### 1. Constant (Default) + +Fixed batch size throughout training. This is the default behavior and maintains backward compatibility with existing configs. + +```toml +[batch_size_scheduler] +mode = "constant" +# Or simply omit the [batch_size_scheduler] section entirely +``` + +### 2. Linear Rampup + +Smooth interpolation from `start_batch_size` to `global_batch_size` over `rampup_samples`. + +**Used by**: DeepSeek-V3 (3072 → 15360 over 469B tokens) + +```toml +[batch_size_scheduler] +mode = "linear" +start_batch_size = 1024 +rampup_samples = 1000000000 # 1B samples + +[training] +global_batch_size = 4096 # Target batch size +``` + +**Behavior**: +``` +samples: 0 -------- 500M -------- 1B -------- 2B +batch_size: 1024 ----- 2560 ------- 4096 ------ 4096 + (linear interpolation) (constant) +``` + +### 3. Increment Rampup + +Step-wise increments at regular intervals (Megatron style). + +**Used by**: Megatron-LM + +```toml +[batch_size_scheduler] +mode = "increment" +start_batch_size = 1024 +increment = 1024 +rampup_samples = 1000000000 + +[training] +global_batch_size = 4096 +``` + +**Behavior**: +``` +samples: 0 ---- 333M ---- 666M ---- 1B ---- 2B +batch_size: 1024 2048 3072 4096 4096 + (step increases) (constant) +``` + +## Configuration Reference + +```toml +[batch_size_scheduler] +mode = "constant" # "constant", "linear", or "increment" +start_batch_size = 0 # Starting batch size (0 = use global_batch_size) +rampup_samples = 0 # Samples over which to ramp up (0 = no rampup) +increment = 0 # For "increment" mode (0 = auto, uses start_batch_size) +``` + +### Constraints + +- `start_batch_size` must be divisible by `local_batch_size × data_parallel_degree` +- `global_batch_size` must be divisible by `local_batch_size × data_parallel_degree` +- `increment` must be divisible by `local_batch_size × data_parallel_degree` + +## How It Works + +### Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Training Loop │ +│ │ +│ consumed_samples ──► BatchSizeManager ──► gradient_accum │ +│ │ │ +│ ▼ │ +│ ┌─────────────────┐ │ +│ │ Scheduler │ │ +│ │ (stateless) │ │ +│ └─────────────────┘ │ +│ │ │ +│ ┌─────────────────┼─────────────────┐ │ +│ ▼ ▼ ▼ │ +│ ConstantBatchSize LinearRampup IncrementRampup │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Checkpointing + +The scheduler is **stateless** - only `consumed_samples` needs to be checkpointed: + +```python +# Saved in checkpoint +{ + "step": 1000, + "ntokens_seen": 4194304000, + "consumed_samples": 1024000, # ← This is all the scheduler needs +} +``` + +On resume, the scheduler automatically computes the correct batch size from `consumed_samples`. + +**Backward Compatibility**: Old checkpoints without `consumed_samples` default to 0. + +## Example Configurations + +### DeepSeek-V3 Style + +```toml +[batch_size_scheduler] +mode = "linear" +start_batch_size = 3072 +rampup_samples = 114746093750 # 469B tokens / 4096 seq_len + +[training] +global_batch_size = 15360 +local_batch_size = 4 +seq_len = 4096 +``` + +### Megatron Style + +```toml +[batch_size_scheduler] +mode = "increment" +start_batch_size = 1024 +increment = 1024 +rampup_samples = 244140625 # 1B tokens / 4096 seq_len + +[training] +global_batch_size = 4096 +local_batch_size = 4 +seq_len = 4096 +``` + +### Quick Test (Debug Model) + +```toml +[batch_size_scheduler] +mode = "linear" +start_batch_size = 8 +rampup_samples = 1000 + +[training] +global_batch_size = 32 +local_batch_size = 4 +steps = 100 +``` + +## Logging + +When batch size changes, you'll see: + +``` +[INFO] Batch size changed: 1024 -> 2048, grad_accum_steps=2 +``` + +At initialization: + +``` +[INFO] Batch size scheduler: linear rampup 1024 -> 4096 over 1000000 samples +``` diff --git a/docs/batch_size_scheduler/debug_checkpoint_resume.toml b/docs/batch_size_scheduler/debug_checkpoint_resume.toml new file mode 100644 index 0000000000..2d949ab6e4 --- /dev/null +++ b/docs/batch_size_scheduler/debug_checkpoint_resume.toml @@ -0,0 +1,48 @@ +# Batch Size Scheduler Test Config - Checkpoint Resume Test +# +# Tests that batch size scheduler state is correctly restored from checkpoint. +# Run twice: first to create checkpoint, second to resume and verify batch size. +# +# Usage (create checkpoint): +# torchrun --nproc_per_node=1 -m torchtitan.train --job.config_file docs/batch_size_scheduler/debug_checkpoint_resume.toml --training.steps 30 +# +# Usage (resume from checkpoint): +# torchrun --nproc_per_node=1 -m torchtitan.train --job.config_file docs/batch_size_scheduler/debug_checkpoint_resume.toml --checkpoint.initial_load_path ./outputs/batch_size_scheduler_ckpt/step-30 + +[job] +description = "Batch size scheduler test - checkpoint resume" +dump_folder = "./outputs/batch_size_scheduler_ckpt" + +[batch_size_scheduler] +mode = "linear" +start_batch_size = 8 +rampup_samples = 1000 + +[training] +dataset = "c4_test" +local_batch_size = 4 +global_batch_size = 32 +seq_len = 256 +steps = 60 +max_norm = 1.0 + +[model] +name = "llama3" +flavor = "debugmodel" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 1e-4 + +[lr_scheduler] +warmup_steps = 10 + +[metrics] +log_freq = 5 +enable_tensorboard = false + +[checkpoint] +enable = true +interval = 30 +folder = "checkpoints" diff --git a/docs/batch_size_scheduler/debug_constant.toml b/docs/batch_size_scheduler/debug_constant.toml new file mode 100644 index 0000000000..13d08d8317 --- /dev/null +++ b/docs/batch_size_scheduler/debug_constant.toml @@ -0,0 +1,42 @@ +# Batch Size Scheduler Test Config - Constant (Debug Model) +# +# Tests constant batch size (default behavior, backward compatible). +# This config should behave identically to configs without [batch_size_scheduler]. +# +# Usage: +# torchrun --nproc_per_node=1 -m torchtitan.train --job.config_file docs/batch_size_scheduler/debug_constant.toml + +[job] +description = "Batch size scheduler test - constant (baseline)" +dump_folder = "./outputs/batch_size_scheduler_constant" + +[batch_size_scheduler] +mode = "constant" +# start_batch_size and rampup_samples default to 0, meaning no rampup + +[training] +dataset = "c4_test" +local_batch_size = 4 +global_batch_size = 32 +seq_len = 256 +steps = 100 +max_norm = 1.0 + +[model] +name = "llama3" +flavor = "debugmodel" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 1e-4 + +[lr_scheduler] +warmup_steps = 10 + +[metrics] +log_freq = 10 +enable_tensorboard = false + +[checkpoint] +enable = false diff --git a/docs/batch_size_scheduler/debug_increment.toml b/docs/batch_size_scheduler/debug_increment.toml new file mode 100644 index 0000000000..bed394d0c1 --- /dev/null +++ b/docs/batch_size_scheduler/debug_increment.toml @@ -0,0 +1,44 @@ +# Batch Size Scheduler Test Config - Increment Rampup (Debug Model) +# +# Tests Megatron-style increment batch size rampup with a small debug model. +# Batch size ramps from 8 to 32 in increments of 8 over 400 samples. +# +# Usage: +# torchrun --nproc_per_node=1 -m torchtitan.train --job.config_file docs/batch_size_scheduler/debug_increment.toml + +[job] +description = "Batch size scheduler test - increment rampup" +dump_folder = "./outputs/batch_size_scheduler_increment" + +[batch_size_scheduler] +mode = "increment" +start_batch_size = 8 +increment = 8 +rampup_samples = 400 + +[training] +dataset = "c4_test" +local_batch_size = 4 +global_batch_size = 32 +seq_len = 256 +steps = 100 +max_norm = 1.0 + +[model] +name = "llama3" +flavor = "debugmodel" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 1e-4 + +[lr_scheduler] +warmup_steps = 10 + +[metrics] +log_freq = 10 +enable_tensorboard = false + +[checkpoint] +enable = false diff --git a/docs/batch_size_scheduler/debug_linear.toml b/docs/batch_size_scheduler/debug_linear.toml new file mode 100644 index 0000000000..967b31f66f --- /dev/null +++ b/docs/batch_size_scheduler/debug_linear.toml @@ -0,0 +1,43 @@ +# Batch Size Scheduler Test Config - Linear Rampup (Debug Model) +# +# Tests linear batch size rampup with a small debug model. +# Batch size ramps from 8 to 32 over 500 samples. +# +# Usage: +# torchrun --nproc_per_node=1 -m torchtitan.train --job.config_file docs/batch_size_scheduler/debug_linear.toml + +[job] +description = "Batch size scheduler test - linear rampup" +dump_folder = "./outputs/batch_size_scheduler_linear" + +[batch_size_scheduler] +mode = "linear" +start_batch_size = 8 +rampup_samples = 500 + +[training] +dataset = "c4_test" +local_batch_size = 4 +global_batch_size = 32 +seq_len = 256 +steps = 100 +max_norm = 1.0 + +[model] +name = "llama3" +flavor = "debugmodel" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 1e-4 + +[lr_scheduler] +warmup_steps = 10 + +[metrics] +log_freq = 10 +enable_tensorboard = false + +[checkpoint] +enable = false diff --git a/docs/batch_size_scheduler/test_batch_size_scheduler.sh b/docs/batch_size_scheduler/test_batch_size_scheduler.sh new file mode 100755 index 0000000000..105900caa7 --- /dev/null +++ b/docs/batch_size_scheduler/test_batch_size_scheduler.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Test script for batch size scheduler feature +# +# This script runs a series of tests to verify the batch size scheduler works correctly. +# +# Usage: +# cd /path/to/torchtitan +# bash docs/batch_size_scheduler/test_batch_size_scheduler.sh +# +# Requirements: +# - Single GPU available +# - torchtitan installed in current environment + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +CONFIG_DIR="$SCRIPT_DIR" + +echo "========================================" +echo "Batch Size Scheduler Test Suite" +echo "========================================" +echo "Repository: $REPO_ROOT" +echo "" + +cd "$REPO_ROOT" + +# Clean up previous test outputs +rm -rf ./outputs/batch_size_scheduler_* + +echo "" +echo "========================================" +echo "Test 1: Constant Batch Size (Baseline)" +echo "========================================" +echo "Expected: Batch size stays at 32 throughout training" +echo "" + +torchrun --nproc_per_node=1 -m torchtitan.train \ + --job.config_file "$CONFIG_DIR/debug_constant.toml" \ + 2>&1 | grep -E "(Batch size|grad_accum|Training starts|Training completed)" + +echo "" +echo "========================================" +echo "Test 2: Linear Rampup" +echo "========================================" +echo "Expected: Batch size ramps from 8 to 32 over 500 samples" +echo "" + +torchrun --nproc_per_node=1 -m torchtitan.train \ + --job.config_file "$CONFIG_DIR/debug_linear.toml" \ + 2>&1 | grep -E "(Batch size|grad_accum|Training starts|Training completed)" + +echo "" +echo "========================================" +echo "Test 3: Increment Rampup (Megatron-style)" +echo "========================================" +echo "Expected: Batch size steps 8 -> 16 -> 24 -> 32 over 400 samples" +echo "" + +torchrun --nproc_per_node=1 -m torchtitan.train \ + --job.config_file "$CONFIG_DIR/debug_increment.toml" \ + 2>&1 | grep -E "(Batch size|grad_accum|Training starts|Training completed)" + +echo "" +echo "========================================" +echo "All tests completed!" +echo "========================================" diff --git a/docs/batch_size_scheduler/test_checkpoint_resume.sh b/docs/batch_size_scheduler/test_checkpoint_resume.sh new file mode 100755 index 0000000000..b18f778d6d --- /dev/null +++ b/docs/batch_size_scheduler/test_checkpoint_resume.sh @@ -0,0 +1,65 @@ +#!/bin/bash +# Test checkpoint resume for batch size scheduler +# +# This script verifies that batch size state is correctly restored from checkpoint. +# +# Usage: +# cd /path/to/torchtitan +# bash docs/batch_size_scheduler/test_checkpoint_resume.sh +# +# Requirements: +# - Single GPU available +# - torchtitan installed in current environment + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +CONFIG_DIR="$SCRIPT_DIR" +OUTPUT_DIR="$REPO_ROOT/outputs/batch_size_scheduler_ckpt" + +echo "========================================" +echo "Batch Size Scheduler Checkpoint Resume Test" +echo "========================================" +echo "" + +cd "$REPO_ROOT" + +# Clean up previous test outputs +rm -rf "$OUTPUT_DIR" + +echo "Step 1: Run training for 30 steps and create checkpoint" +echo "========================================" +echo "" + +torchrun --nproc_per_node=1 -m torchtitan.train \ + --job.config_file "$CONFIG_DIR/debug_checkpoint_resume.toml" \ + --training.steps 30 \ + 2>&1 | grep -E "(Batch size|grad_accum|Training starts|Training completed|consumed_samples|checkpoint)" + +echo "" +echo "Step 2: Resume from checkpoint and continue to step 60" +echo "========================================" +echo "Expected: Batch size should continue from where it left off" +echo "" + +# Find the checkpoint directory +CKPT_DIR=$(ls -d "$OUTPUT_DIR/checkpoints/step-"* 2>/dev/null | head -1) + +if [ -z "$CKPT_DIR" ]; then + echo "ERROR: Checkpoint not found in $OUTPUT_DIR/checkpoints/" + exit 1 +fi + +echo "Found checkpoint: $CKPT_DIR" +echo "" + +torchrun --nproc_per_node=1 -m torchtitan.train \ + --job.config_file "$CONFIG_DIR/debug_checkpoint_resume.toml" \ + --checkpoint.initial_load_path "$CKPT_DIR" \ + 2>&1 | grep -E "(Batch size|grad_accum|Training starts|Training completed|consumed_samples)" + +echo "" +echo "========================================" +echo "Checkpoint resume test completed!" +echo "========================================" diff --git a/docs/batch_size_scheduler/test_unit.py b/docs/batch_size_scheduler/test_unit.py new file mode 100644 index 0000000000..7ac630423a --- /dev/null +++ b/docs/batch_size_scheduler/test_unit.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +""" +Unit tests for batch size scheduler. + +Usage: + python docs/batch_size_scheduler/test_unit.py +""" + +import sys + + +def test_constant_batch_size(): + """Test constant batch size scheduler.""" + from torchtitan.components.batch_size_scheduler import ( + BatchSizeState, + ConstantBatchSize, + ) + + scheduler = ConstantBatchSize(batch_size=4096) + + # Should always return the same batch size + for samples in [0, 1000, 1000000, 1000000000]: + state = BatchSizeState(consumed_samples=samples) + assert scheduler.get_batch_size(state) == 4096, f"Failed at {samples}" + + print(" ConstantBatchSize: PASSED") + + +def test_linear_rampup(): + """Test linear rampup scheduler.""" + from torchtitan.components.batch_size_scheduler import BatchSizeState, LinearRampup + + scheduler = LinearRampup( + start_batch_size=1024, end_batch_size=4096, rampup_samples=1000000 + ) + + # At 0%: should be start_batch_size + state = BatchSizeState(consumed_samples=0) + assert scheduler.get_batch_size(state) == 1024, "Failed at 0%" + + # At 50%: should be halfway + state = BatchSizeState(consumed_samples=500000) + expected = 1024 + 0.5 * (4096 - 1024) # 2560 + assert scheduler.get_batch_size(state) == int( + expected + ), f"Failed at 50%: got {scheduler.get_batch_size(state)}, expected {int(expected)}" + + # At 100%: should be end_batch_size + state = BatchSizeState(consumed_samples=1000000) + assert scheduler.get_batch_size(state) == 4096, "Failed at 100%" + + # After 100%: should stay at end_batch_size + state = BatchSizeState(consumed_samples=2000000) + assert scheduler.get_batch_size(state) == 4096, "Failed after 100%" + + print(" LinearRampup: PASSED") + + +def test_increment_rampup(): + """Test increment rampup scheduler.""" + from torchtitan.components.batch_size_scheduler import ( + BatchSizeState, + IncrementRampup, + ) + + scheduler = IncrementRampup( + start_batch_size=1024, + end_batch_size=4096, + increment=1024, + rampup_samples=1000000, + ) + + # 3 increments needed: 1024 -> 2048 -> 3072 -> 4096 + # samples_per_increment = 1000000 / 3 = 333333.33 + + # At 0%: should be start_batch_size + state = BatchSizeState(consumed_samples=0) + assert scheduler.get_batch_size(state) == 1024, "Failed at 0%" + + # Just before first increment + state = BatchSizeState(consumed_samples=333332) + assert scheduler.get_batch_size(state) == 1024, "Failed just before first increment" + + # After first increment + state = BatchSizeState(consumed_samples=333334) + assert scheduler.get_batch_size(state) == 2048, "Failed after first increment" + + # At 100%: should be end_batch_size + state = BatchSizeState(consumed_samples=1000000) + assert scheduler.get_batch_size(state) == 4096, "Failed at 100%" + + # After 100%: should stay at end_batch_size + state = BatchSizeState(consumed_samples=2000000) + assert scheduler.get_batch_size(state) == 4096, "Failed after 100%" + + print(" IncrementRampup: PASSED") + + +def test_batch_size_manager(): + """Test batch size manager alignment and grad accum.""" + from torchtitan.components.batch_size_scheduler import ( + BatchSizeManager, + BatchSizeState, + LinearRampup, + ) + + scheduler = LinearRampup( + start_batch_size=1024, end_batch_size=4096, rampup_samples=1000000 + ) + manager = BatchSizeManager( + scheduler=scheduler, micro_batch_size=4, data_parallel_size=2 + ) + # unit = 4 * 2 = 8 + + # Test alignment + state = BatchSizeState(consumed_samples=0) + assert manager.get_batch_size(state) == 1024, "Alignment failed" + assert manager.get_grad_accum_steps(state) == 1024 // 8, "Grad accum failed" + + # Test did_change + changed, old, new = manager.did_change(state) + assert not changed, "Should not detect change on first call" + assert old == new == 1024, "Old and new should be equal" + + # Move to 50% + state = BatchSizeState(consumed_samples=500000) + changed, old, new = manager.did_change(state) + assert changed, "Should detect change" + assert old == 1024, "Old should be 1024" + assert new == 2560, "New should be 2560" + + print(" BatchSizeManager: PASSED") + + +def test_build_batch_size_manager(): + """Test factory function with config.""" + from dataclasses import dataclass + + from torchtitan.components.batch_size_scheduler import ( + BatchSizeState, + build_batch_size_manager, + ) + + # Mock config + @dataclass + class MockBSConfig: + mode: str = "linear" + start_batch_size: int = 1024 + rampup_samples: int = 1000000 + increment: int = 0 + + @dataclass + class MockTrainingConfig: + global_batch_size: int = 4096 + local_batch_size: int = 4 + + @dataclass + class MockJobConfig: + batch_size_scheduler: MockBSConfig = None + training: MockTrainingConfig = None + + def __post_init__(self): + if self.batch_size_scheduler is None: + self.batch_size_scheduler = MockBSConfig() + if self.training is None: + self.training = MockTrainingConfig() + + config = MockJobConfig() + manager = build_batch_size_manager(config, dp_degree=2) + + state = BatchSizeState(consumed_samples=0) + assert manager.get_batch_size(state) == 1024, "Factory failed" + + # Test constant mode (default) + config = MockJobConfig( + batch_size_scheduler=MockBSConfig(mode="constant", start_batch_size=0) + ) + manager = build_batch_size_manager(config, dp_degree=2) + + state = BatchSizeState(consumed_samples=500000) + assert manager.get_batch_size(state) == 4096, "Constant mode failed" + + print(" build_batch_size_manager: PASSED") + + +def main(): + print("=" * 50) + print("Batch Size Scheduler Unit Tests") + print("=" * 50) + print() + + try: + test_constant_batch_size() + test_linear_rampup() + test_increment_rampup() + test_batch_size_manager() + test_build_batch_size_manager() + + print() + print("=" * 50) + print("All tests PASSED!") + print("=" * 50) + return 0 + + except AssertionError as e: + print(f"\nFAILED: {e}") + return 1 + except Exception as e: + print(f"\nERROR: {e}") + import traceback + + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/torchtitan/components/batch_size_scheduler.py b/torchtitan/components/batch_size_scheduler.py new file mode 100644 index 0000000000..6e5b654049 --- /dev/null +++ b/torchtitan/components/batch_size_scheduler.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Batch Size Scheduler - Orthogonal to Data Stages and LR Scheduler. + +Handles dynamic batch size scheduling based on consumed_samples only. +No knowledge of data stages, datasets, or learning rate. + +Supports three modes: +- constant: Fixed batch size (default, backward compatible) +- linear: Smooth interpolation (DeepSeek-V3 style) +- increment: Step-wise increments (Megatron style) +""" + +from dataclasses import dataclass, field + +from torchtitan.tools.logging import logger + + +@dataclass +class BatchSizeState: + """State for batch size computation. Owned by training loop.""" + + consumed_samples: int = 0 + + +class BatchSizeScheduler: + """Base class for batch size schedulers. Stateless - all state in BatchSizeState.""" + + def get_batch_size(self, state: BatchSizeState) -> int: + raise NotImplementedError + + def get_name(self) -> str: + raise NotImplementedError + + +@dataclass +class ConstantBatchSize(BatchSizeScheduler): + """Fixed batch size throughout training.""" + + batch_size: int + + def get_batch_size(self, state: BatchSizeState) -> int: + return self.batch_size + + def get_name(self) -> str: + return "constant" + + +@dataclass +class LinearRampup(BatchSizeScheduler): + """ + Smooth linear interpolation from start to end. + Used by: DeepSeek-V3 (3072 -> 15360 over 469B tokens) + """ + + start_batch_size: int + end_batch_size: int + rampup_samples: int + + def get_batch_size(self, state: BatchSizeState) -> int: + if state.consumed_samples >= self.rampup_samples: + return self.end_batch_size + progress = state.consumed_samples / self.rampup_samples + return int( + self.start_batch_size + + progress * (self.end_batch_size - self.start_batch_size) + ) + + def get_name(self) -> str: + return "linear" + + +@dataclass +class IncrementRampup(BatchSizeScheduler): + """ + Megatron-style step-wise increments at regular intervals. + """ + + start_batch_size: int + end_batch_size: int + increment: int + rampup_samples: int + _samples_per_increment: float = field(init=False, default=0.0) + + def __post_init__(self): + diff = self.end_batch_size - self.start_batch_size + num_increments = diff // self.increment if self.increment > 0 else 0 + self._samples_per_increment = ( + self.rampup_samples / num_increments if num_increments > 0 else float("inf") + ) + + def get_batch_size(self, state: BatchSizeState) -> int: + if state.consumed_samples >= self.rampup_samples: + return self.end_batch_size + steps = int(state.consumed_samples / self._samples_per_increment) + return min(self.start_batch_size + steps * self.increment, self.end_batch_size) + + def get_name(self) -> str: + return "increment" + + +class BatchSizeManager: + """ + Wraps scheduler. Handles alignment and grad accum computation. + + NOT responsible for: tracking consumed_samples, data loading, LR. + """ + + def __init__( + self, + scheduler: BatchSizeScheduler, + micro_batch_size: int, + data_parallel_size: int, + ): + self.scheduler = scheduler + self.micro_batch_size = micro_batch_size + self.data_parallel_size = data_parallel_size + self._unit = micro_batch_size * data_parallel_size + self._last_batch_size: int | None = None + + def get_batch_size(self, state: BatchSizeState) -> int: + """Get aligned global batch size.""" + raw = self.scheduler.get_batch_size(state) + return (raw // self._unit) * self._unit + + def get_grad_accum_steps(self, state: BatchSizeState) -> int: + """Get gradient accumulation steps.""" + return self.get_batch_size(state) // self._unit + + def did_change(self, state: BatchSizeState) -> tuple[bool, int, int]: + """Check if batch size changed. Returns (changed, old, new).""" + current = self.get_batch_size(state) + old = self._last_batch_size + changed = old is not None and current != old + self._last_batch_size = current + return changed, old if old is not None else current, current + + +def build_batch_size_manager( + job_config, + dp_degree: int, +) -> BatchSizeManager: + """Factory to build manager from config.""" + bs_cfg = job_config.batch_size_scheduler + training = job_config.training + + # Compute target global batch size (same logic as train.py) + target = training.global_batch_size + if target < 0: + target = training.local_batch_size * dp_degree + + micro_batch_size = training.local_batch_size + + # Build appropriate scheduler based on mode + # Default to constant if no rampup configured (backward compatible) + if ( + bs_cfg.mode == "constant" + or bs_cfg.start_batch_size <= 0 + or bs_cfg.rampup_samples <= 0 + ): + scheduler = ConstantBatchSize(batch_size=target) + logger.info(f"Batch size scheduler: constant at {target}") + elif bs_cfg.mode == "linear": + scheduler = LinearRampup( + start_batch_size=bs_cfg.start_batch_size, + end_batch_size=target, + rampup_samples=bs_cfg.rampup_samples, + ) + logger.info( + f"Batch size scheduler: linear rampup {bs_cfg.start_batch_size} -> {target} " + f"over {bs_cfg.rampup_samples} samples" + ) + elif bs_cfg.mode == "increment": + increment = ( + bs_cfg.increment if bs_cfg.increment > 0 else bs_cfg.start_batch_size + ) + scheduler = IncrementRampup( + start_batch_size=bs_cfg.start_batch_size, + end_batch_size=target, + increment=increment, + rampup_samples=bs_cfg.rampup_samples, + ) + logger.info( + f"Batch size scheduler: increment rampup {bs_cfg.start_batch_size} -> {target} " + f"(increment={increment}) over {bs_cfg.rampup_samples} samples" + ) + else: + raise ValueError(f"Unknown batch_size_scheduler.mode: {bs_cfg.mode}") + + return BatchSizeManager(scheduler, micro_batch_size, dp_degree) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 8a19466d63..ec30ce1d78 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -290,6 +290,44 @@ class LRScheduler: """ +@dataclass +class BatchSizeScheduler: + """ + Batch size scheduling configuration. + + Orthogonal to data stages and LR scheduler - only controls batch size + based on consumed_samples. When mode is 'constant' or start_batch_size is 0, + uses fixed global_batch_size (backward compatible with existing configs). + """ + + mode: Literal["constant", "linear", "increment"] = "constant" + """ + Schedule mode: + - 'constant': Fixed batch size (default, backward compatible) + - 'linear': Smooth interpolation from start to end (DeepSeek-V3 style) + - 'increment': Step-wise increments at regular intervals (Megatron style) + """ + + start_batch_size: int = 0 + """ + Starting batch size for rampup. 0 means use global_batch_size (no rampup). + Must be divisible by local_batch_size * data_parallel_degree. + """ + + rampup_samples: int = 0 + """ + Number of samples over which to ramp up to global_batch_size. + 0 means no rampup (constant batch size). + """ + + increment: int = 0 + """ + Batch size increment for 'increment' mode. + 0 means auto-compute (use start_batch_size as increment). + Must be divisible by local_batch_size * data_parallel_degree. + """ + + @dataclass class Training: dataset: str = "c4_test" @@ -1204,6 +1242,7 @@ class JobConfig: model: Model = field(default_factory=Model) optimizer: Optimizer = field(default_factory=Optimizer) lr_scheduler: LRScheduler = field(default_factory=LRScheduler) + batch_size_scheduler: BatchSizeScheduler = field(default_factory=BatchSizeScheduler) training: Training = field(default_factory=Training) parallelism: Parallelism = field(default_factory=Parallelism) deepep: DeepEP = field(default_factory=DeepEP) diff --git a/torchtitan/train.py b/torchtitan/train.py index cde676431c..a5000be29c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -15,6 +15,11 @@ from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module +from torchtitan.components.batch_size_scheduler import ( + BatchSizeManager, + BatchSizeState, + build_batch_size_manager, +) from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training @@ -54,6 +59,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): # non-swappable training components checkpointer: CheckpointManager ft_manager: FTManager + batch_size_manager: BatchSizeManager # runtime utilities device: torch.device @@ -66,6 +72,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): # additional training states step: int ntokens_seen: int + consumed_samples: int # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @record @@ -184,24 +191,29 @@ def __init__(self, job_config: JobConfig): job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager ) - # verify batch sizes - global_batch_size = job_config.training.global_batch_size - if global_batch_size < 0: - # This global batch size results in 1 gradient accumulation - # step. - global_batch_size = job_config.training.local_batch_size * dp_degree - assert global_batch_size > 0 + # Build batch size manager for dynamic batch size scheduling + self.batch_size_manager = build_batch_size_manager(job_config, dp_degree) + + # Get initial batch size and validate + initial_state = BatchSizeState(consumed_samples=0) + global_batch_size = self.batch_size_manager.get_batch_size(initial_state) + + # Validate final (target) batch size from config + target_batch_size = job_config.training.global_batch_size + if target_batch_size < 0: + target_batch_size = job_config.training.local_batch_size * dp_degree + assert target_batch_size > 0 assert ( - global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 + target_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 ), ( f"global batch size must be multiple of local batch size times " - f"data-parallel degree ({global_batch_size} " + f"data-parallel degree ({target_batch_size} " f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" ) - # calculate gradient accumulation steps - self.gradient_accumulation_steps = global_batch_size // ( - job_config.training.local_batch_size * dp_degree + # Calculate initial gradient accumulation steps + self.gradient_accumulation_steps = self.batch_size_manager.get_grad_accum_steps( + initial_state ) assert self.gradient_accumulation_steps > 0 self.loss_fn = rescale_accumulated_loss( @@ -309,6 +321,7 @@ def __init__(self, job_config: JobConfig): # These attributes must be initialized before checkpoint loading. self.step = 0 self.ntokens_seen = 0 + self.consumed_samples = 0 self.checkpointer = CheckpointManager( dataloader=self.dataloader, @@ -675,6 +688,20 @@ def forward_backward_step( def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): + # Check for batch size changes (dynamic batch size scheduling) + bs_state = BatchSizeState(consumed_samples=self.consumed_samples) + changed, old_bs, new_bs = self.batch_size_manager.did_change(bs_state) + if changed: + self.gradient_accumulation_steps = ( + self.batch_size_manager.get_grad_accum_steps(bs_state) + ) + # Update loss rescaling for new accumulation steps + self.loss_fn.accumulation_steps = self.gradient_accumulation_steps + logger.info( + f"Batch size changed: {old_bs} -> {new_bs}, " + f"grad_accum_steps={self.gradient_accumulation_steps}" + ) + self.optimizers.zero_grad() # Save the current step learning rate for logging lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] @@ -704,6 +731,9 @@ def train_step( self.optimizers.step() self.lr_schedulers.step() + # Update consumed samples for batch size scheduler + self.consumed_samples += self.batch_size_manager.get_batch_size(bs_state) + # Reduce the data collected over gradient accumulation steps. loss = torch.sum(torch.stack(accumulated_losses)) @@ -784,7 +814,9 @@ def train(self): ), ), ): - first_step_save = self.step == 0 and job_config.checkpoint.enable_first_step_checkpoint + first_step_save = ( + self.step == 0 and job_config.checkpoint.enable_first_step_checkpoint + ) if first_step_save: self.checkpointer.save(1, False) @@ -837,11 +869,29 @@ def should_continue_training(self) -> bool: return self.step < self.job_config.training.steps def state_dict(self) -> dict[str, Any]: - return {"step": self.step, "ntokens_seen": self.ntokens_seen} + return { + "step": self.step, + "ntokens_seen": self.ntokens_seen, + "consumed_samples": self.consumed_samples, + } def load_state_dict(self, state_dict: dict[str, Any]): self.step = state_dict["step"] self.ntokens_seen = state_dict["ntokens_seen"] + # Backward compatible: old checkpoints may not have consumed_samples + self.consumed_samples = state_dict.get("consumed_samples", 0) + + # Sync batch size state with restored consumed_samples + bs_state = BatchSizeState(consumed_samples=self.consumed_samples) + self.gradient_accumulation_steps = self.batch_size_manager.get_grad_accum_steps( + bs_state + ) + # Update loss rescaling to match current batch size + self.loss_fn.accumulation_steps = self.gradient_accumulation_steps + # Initialize batch size manager's last_batch_size to prevent spurious change detection + self.batch_size_manager._last_batch_size = ( + self.batch_size_manager.get_batch_size(bs_state) + ) def close(self) -> None: if self.checkpointer: