diff --git a/.github/workflows/sonarcloud.yml b/.github/workflows/sonarcloud.yml
index 94c494207..b1cc0c841 100644
--- a/.github/workflows/sonarcloud.yml
+++ b/.github/workflows/sonarcloud.yml
@@ -125,7 +125,7 @@ jobs:
run: dotnet restore
- name: Build (Release)
- run: dotnet build -c Release --no-restore
+ run: dotnet build -c Release --no-restore -p:UseSharedCompilation=false
- name: Upload build artifacts
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v4
@@ -464,7 +464,7 @@ jobs:
& "${{ runner.temp }}\scanner\dotnet-sonarscanner" begin @params
- name: Build (Release)
- run: dotnet build -c Release --no-restore
+ run: dotnet build -c Release --no-restore -p:UseSharedCompilation=false
- name: End SonarCloud analysis
if: github.event_name != 'pull_request' || github.event.pull_request.changed_files <= 250
diff --git a/src/AiDotNet.csproj b/src/AiDotNet.csproj
index 627277839..25d238f43 100644
--- a/src/AiDotNet.csproj
+++ b/src/AiDotNet.csproj
@@ -99,31 +99,6 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- true
- Generated
-
-
-
-
-
-
-
@@ -151,6 +126,32 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ true
+ Generated
+
+
+
+
+
+
+
diff --git a/src/AiModelBuilder.cs b/src/AiModelBuilder.cs
index b04193aaf..19183d459 100644
--- a/src/AiModelBuilder.cs
+++ b/src/AiModelBuilder.cs
@@ -164,6 +164,10 @@ public partial class AiModelBuilder : IAiModelBuilder? _distributedBackend;
private DistributedStrategy _distributedStrategy = DistributedStrategy.DDP;
private IShardingConfiguration? _distributedConfiguration;
+ private IPipelinePartitionStrategy? _pipelinePartitionStrategy;
+ private IPipelineSchedule? _pipelineSchedule;
+ private ActivationCheckpointConfig? _pipelineCheckpointConfig;
+ private int _pipelineMicroBatchSize = 1;
private ICrossValidator? _crossValidator;
private AgentConfiguration? _agentConfig;
private AgentAssistanceOptions _agentOptions = AgentAssistanceOptions.Default;
@@ -1762,7 +1766,12 @@ private async Task> BuildSupervisedInternalAsy
new DistributedTraining.ZeRO3Model(_model, shardingConfig),
new DistributedTraining.ZeRO3Optimizer(optimizer, shardingConfig)),
DistributedStrategy.PipelineParallel => CreateDistributedPair(
- new DistributedTraining.PipelineParallelModel(_model, shardingConfig),
+ new DistributedTraining.PipelineParallelModel(
+ _model, shardingConfig,
+ microBatchSize: _pipelineMicroBatchSize,
+ partitionStrategy: _pipelinePartitionStrategy,
+ schedule: _pipelineSchedule,
+ checkpointConfig: _pipelineCheckpointConfig),
new DistributedTraining.PipelineParallelOptimizer(optimizer, shardingConfig)),
DistributedStrategy.TensorParallel => CreateDistributedPair(
new DistributedTraining.TensorParallelModel(_model, shardingConfig),
@@ -3790,6 +3799,10 @@ public IAiModelBuilder ConfigureMetaLearning(IMetaLearner
+ ///
+ /// For pipeline parallelism, call after this method
+ /// to customize scheduling, partitioning, and activation checkpointing.
+ ///
///
public IAiModelBuilder ConfigureDistributedTraining(
ICommunicationBackend? backend = null,
@@ -3802,6 +3815,72 @@ public IAiModelBuilder ConfigureDistributedTraining(
return this;
}
+ ///
+ /// Configures pipeline-specific options for pipeline parallel training.
+ ///
+ ///
+ /// Pipeline execution schedule. If null, uses GPipeSchedule.
+ /// Use for reduced pipeline bubble (~12-15% vs ~50%).
+ ///
+ ///
+ /// Strategy for partitioning layers across pipeline stages.
+ /// If null, uses uniform partitioning. Use
+ /// to balance computational cost across stages.
+ ///
+ ///
+ /// Activation checkpointing configuration.
+ /// If null, checkpointing is disabled. Enable to reduce memory from O(L) to O(sqrt(L)).
+ ///
+ ///
+ /// Number of micro-batches to split the full batch into for pipeline execution.
+ /// Higher values reduce pipeline bubble but increase memory. Default: 1.
+ ///
+ /// This builder instance for method chaining.
+ ///
+ ///
+ /// Call this after with
+ /// DistributedStrategy.PipelineParallel to customize pipeline scheduling,
+ /// partitioning, activation checkpointing, and micro-batch count.
+ ///
+ ///
+ /// For Beginners: This method fine-tunes how pipeline parallelism works.
+ /// You only need to call it if you want to change the defaults (GPipe schedule,
+ /// uniform partitioning, no checkpointing, 1 micro-batch).
+ ///
+ ///
+ /// Example:
+ ///
+ /// var result = builder
+ /// .ConfigureModel(myModel)
+ /// .ConfigureDistributedTraining(strategy: DistributedStrategy.PipelineParallel)
+ /// .ConfigurePipelineParallelism(
+ /// schedule: new OneForwardOneBackwardSchedule(),
+ /// partitionStrategy: new LoadBalancedPartitionStrategy<double>(estimatedLayerSize: 1024),
+ /// checkpointConfig: new ActivationCheckpointConfig { Enabled = true },
+ /// microBatchCount: 8)
+ /// .Build(xTrain, yTrain);
+ ///
+ ///
+ ///
+ public IAiModelBuilder ConfigurePipelineParallelism(
+ IPipelineSchedule? schedule = null,
+ IPipelinePartitionStrategy? partitionStrategy = null,
+ ActivationCheckpointConfig? checkpointConfig = null,
+ int microBatchCount = 1)
+ {
+ if (microBatchCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(microBatchCount),
+ $"Micro-batch count must be at least 1, but was {microBatchCount}.");
+ }
+
+ _pipelineSchedule = schedule;
+ _pipelinePartitionStrategy = partitionStrategy;
+ _pipelineCheckpointConfig = checkpointConfig;
+ _pipelineMicroBatchSize = microBatchCount;
+ return this;
+ }
+
///
/// Enables AI agent assistance during the model building process.
///
diff --git a/src/DistributedTraining/ActivationCheckpointConfig.cs b/src/DistributedTraining/ActivationCheckpointConfig.cs
new file mode 100644
index 000000000..ecced5a15
--- /dev/null
+++ b/src/DistributedTraining/ActivationCheckpointConfig.cs
@@ -0,0 +1,140 @@
+namespace AiDotNet.DistributedTraining;
+
+///
+/// Configuration for activation checkpointing in pipeline parallel training.
+///
+///
+///
+/// Activation checkpointing (also called gradient checkpointing) trades compute for memory
+/// by only storing activations at checkpoint layers during the forward pass. Intermediate
+/// activations are recomputed from the nearest checkpoint during the backward pass.
+///
+/// For Beginners: During training, the forward pass must save intermediate results
+/// (activations) so the backward pass can compute gradients. For very deep models, storing all
+/// these activations uses enormous amounts of memory.
+///
+/// Activation checkpointing is like taking notes at chapter boundaries instead of every page:
+/// - Without checkpointing: Save every activation (lots of memory, no recomputation)
+/// - With checkpointing: Save every Nth activation, recompute the rest (less memory, more compute)
+///
+/// Memory savings: O(L) → O(sqrt(L)) where L = number of layers.
+/// For 100 layers, this reduces memory from 100 activations to ~10 activations.
+///
+/// The trade-off is ~33% more compute time, but this enables training models that otherwise
+/// wouldn't fit in memory.
+///
+/// Reference: Chen et al., "Training Deep Nets with Sublinear Memory Cost", 2016.
+/// https://arxiv.org/abs/1604.06174
+///
+public class ActivationCheckpointConfig
+{
+ private int _checkpointEveryNLayers = 10;
+ private int _maxActivationsInMemory;
+
+ ///
+ /// Gets or sets whether activation checkpointing is enabled.
+ ///
+ ///
+ /// For Beginners: Set this to true to enable memory savings. Default is false
+ /// (no checkpointing, standard behavior).
+ ///
+ public bool Enabled { get; set; }
+
+ ///
+ /// Gets or sets how often to save a checkpoint (every N layers).
+ ///
+ ///
+ /// For Beginners: Lower values save more activations (more memory, less recomputation).
+ /// Higher values save fewer (less memory, more recomputation).
+ ///
+ /// Optimal value is approximately sqrt(total_layers) for minimum total cost.
+ /// For a 100-layer model, checkpointing every 10 layers is a good default.
+ ///
+ /// Default: 10 layers between checkpoints.
+ ///
+ /// Thrown when value is less than 1.
+ public int CheckpointEveryNLayers
+ {
+ get => _checkpointEveryNLayers;
+ set
+ {
+ if (value < 1)
+ {
+ throw new ArgumentOutOfRangeException(nameof(CheckpointEveryNLayers),
+ $"CheckpointEveryNLayers must be at least 1, but was {value}. " +
+ "A value of 0 would cause division-by-zero in interval-based checkpointing.");
+ }
+ _checkpointEveryNLayers = value;
+ }
+ }
+
+ ///
+ /// Gets or sets the recomputation strategy to use during the backward pass.
+ ///
+ ///
+ /// For Beginners:
+ /// - Selective: Only recompute activations that are needed and not checkpointed (recommended)
+ /// - Full: Recompute all non-checkpointed activations from the previous checkpoint
+ /// - None: Don't recompute, equivalent to no checkpointing (for testing/debugging)
+ ///
+ ///
+ public RecomputeStrategy RecomputeStrategy { get; set; } = RecomputeStrategy.Selective;
+
+ ///
+ /// Gets or sets the maximum number of activations to keep in memory simultaneously.
+ ///
+ ///
+ /// For Beginners: This caps how many activations are stored at once.
+ /// Set to 0 for no limit (uses CheckpointEveryNLayers to determine storage).
+ /// A non-zero value overrides CheckpointEveryNLayers by dynamically adjusting
+ /// the checkpoint frequency to stay within the memory budget.
+ ///
+ /// Thrown when value is negative.
+ public int MaxActivationsInMemory
+ {
+ get => _maxActivationsInMemory;
+ set
+ {
+ if (value < 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(MaxActivationsInMemory),
+ $"MaxActivationsInMemory must be non-negative, but was {value}. " +
+ "Use 0 for no limit.");
+ }
+ _maxActivationsInMemory = value;
+ }
+ }
+
+ ///
+ /// Gets or sets whether to checkpoint the very first layer's input.
+ ///
+ ///
+ /// For Beginners: The first layer's input is always needed for the backward pass.
+ /// If true, it's saved as a checkpoint. If false, the caller must ensure the input is
+ /// available during the backward pass (which is usually the case).
+ ///
+ public bool CheckpointFirstLayer { get; set; } = true;
+}
+
+///
+/// Strategy for recomputing activations during the backward pass.
+///
+public enum RecomputeStrategy
+{
+ ///
+ /// Only recompute activations that are needed for the current backward step.
+ /// This is the most memory-efficient but requires careful bookkeeping.
+ ///
+ Selective,
+
+ ///
+ /// Recompute all activations between the two nearest checkpoints during backward.
+ /// Simpler implementation but may do slightly more work than necessary.
+ ///
+ Full,
+
+ ///
+ /// No recomputation. Equivalent to disabled checkpointing. Useful for debugging.
+ ///
+ None
+}
diff --git a/src/DistributedTraining/GPipeSchedule.cs b/src/DistributedTraining/GPipeSchedule.cs
new file mode 100644
index 000000000..fb8810a61
--- /dev/null
+++ b/src/DistributedTraining/GPipeSchedule.cs
@@ -0,0 +1,102 @@
+using AiDotNet.Interfaces;
+
+namespace AiDotNet.DistributedTraining;
+
+///
+/// Implements the GPipe scheduling strategy: all forward passes first, then all backward passes.
+///
+///
+///
+/// GPipe is the simplest pipeline schedule. It executes all forward micro-batches sequentially
+/// through the pipeline, storing all activations, then executes all backward micro-batches
+/// in reverse order.
+///
+/// For Beginners: GPipe is the straightforward approach:
+///
+/// 1. Push ALL micro-batches through the forward pass (left to right through stages)
+/// 2. Then push ALL micro-batches through the backward pass (right to left)
+///
+/// This creates a "bubble" where stages are idle during pipeline fill and drain.
+/// With P stages and M micro-batches, the bubble fraction is approximately (P-1)/(P-1+M).
+///
+/// For 4 stages and 4 micro-batches:
+///
+/// Stage 0: F0 F1 F2 F3 __ __ __ B3 B2 B1 B0
+/// Stage 1: __ F0 F1 F2 F3 __ B3 B2 B1 B0 __
+/// Stage 2: __ __ F0 F1 F2 F3 B3 B2 B1 __ __
+/// Stage 3: __ __ __ F0 F1 F2 B3 B2 __ __ __
+///
+///
+/// The underscores represent idle time (bubble).
+///
+/// Reference: Huang et al., "GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism", 2019.
+/// https://arxiv.org/abs/1811.06965
+///
+public class GPipeSchedule : IPipelineSchedule
+{
+ ///
+ public string Name => "GPipe";
+
+ ///
+ public int VirtualStagesPerRank => 1;
+
+ ///
+ public IReadOnlyList GetSchedule(int stageId, int numStages, int numMicroBatches)
+ {
+ if (numStages <= 0)
+ {
+ throw new ArgumentException("Number of stages must be positive.", nameof(numStages));
+ }
+
+ if (numMicroBatches <= 0)
+ {
+ throw new ArgumentException("Number of micro-batches must be positive.", nameof(numMicroBatches));
+ }
+
+ if (stageId < 0 || stageId >= numStages)
+ {
+ throw new ArgumentOutOfRangeException(nameof(stageId),
+ $"Stage ID must be between 0 and {numStages - 1}.");
+ }
+
+ var ops = new List();
+
+ // All forward passes
+ for (int m = 0; m < numMicroBatches; m++)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = m,
+ IsWarmup = m < stageId,
+ IsCooldown = false
+ });
+ }
+
+ // All backward passes (in reverse micro-batch order)
+ for (int m = numMicroBatches - 1; m >= 0; m--)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Backward,
+ MicroBatchIndex = m,
+ IsWarmup = false,
+ IsCooldown = m >= numMicroBatches - stageId
+ });
+ }
+
+ return ops;
+ }
+
+ ///
+ public double EstimateBubbleFraction(int numStages, int numMicroBatches)
+ {
+ if (numStages <= 1 || numMicroBatches <= 0)
+ {
+ return 0.0;
+ }
+
+ // GPipe bubble fraction: (P-1) / (P-1+M) where P = stages, M = micro-batches
+ return (double)(numStages - 1) / (numStages - 1 + numMicroBatches);
+ }
+}
diff --git a/src/DistributedTraining/Interleaved1F1BSchedule.cs b/src/DistributedTraining/Interleaved1F1BSchedule.cs
new file mode 100644
index 000000000..5d77b64c9
--- /dev/null
+++ b/src/DistributedTraining/Interleaved1F1BSchedule.cs
@@ -0,0 +1,189 @@
+using AiDotNet.Interfaces;
+
+namespace AiDotNet.DistributedTraining;
+
+///
+/// Implements the Interleaved 1F1B pipeline schedule with multiple virtual stages per rank.
+///
+///
+///
+/// Interleaved 1F1B assigns V non-contiguous model chunks ("virtual stages") to each rank.
+/// Rank i holds chunks {i, i+P, i+2P, ...} where P is the number of physical ranks.
+/// This reduces the pipeline bubble by a factor of V compared to standard 1F1B.
+///
+///
+/// When a microbatch is ready for multiple local virtual stages, Interleaved 1F1B
+/// prioritizes the earlier microbatch (depth-first ordering). This is in contrast
+/// to Looped BFS which prioritizes the earlier stage.
+///
+/// For Beginners: Standard 1F1B gives each GPU one big chunk of the model.
+/// Interleaved 1F1B gives each GPU V smaller, evenly-spaced chunks instead.
+///
+/// Example with 4 GPUs, V=2 (8 total chunks):
+/// - GPU 0: chunks 0 and 4
+/// - GPU 1: chunks 1 and 5
+/// - GPU 2: chunks 2 and 6
+/// - GPU 3: chunks 3 and 7
+///
+/// This means each microbatch visits each GPU twice (once for each chunk), creating more
+/// opportunities to interleave work and reduce idle time. The bubble shrinks from
+/// ~(P-1)/(2M+P-1) to ~(P-1)/(2MV+P-1).
+///
+/// Used in production by Megatron-LM v2 and NVIDIA NeMo.
+///
+/// Reference: Narayanan et al., "Efficient Large-Scale Language Model Training
+/// on GPU Clusters Using Megatron-LM", SC 2021. https://arxiv.org/abs/2104.04473
+///
+public class Interleaved1F1BSchedule : IPipelineSchedule
+{
+ private readonly int _virtualStagesPerRank;
+
+ ///
+ /// Creates a new Interleaved 1F1B schedule.
+ ///
+ ///
+ /// Number of model chunks per rank. Default is 2.
+ /// Higher values reduce bubble but increase communication.
+ /// Must be at least 2 (otherwise use standard 1F1B).
+ ///
+ public Interleaved1F1BSchedule(int virtualStagesPerRank = 2)
+ {
+ if (virtualStagesPerRank < 2)
+ {
+ throw new ArgumentOutOfRangeException(nameof(virtualStagesPerRank),
+ "Interleaved schedule requires at least 2 virtual stages per rank. " +
+ "Use OneForwardOneBackwardSchedule for single-stage scheduling.");
+ }
+
+ _virtualStagesPerRank = virtualStagesPerRank;
+ }
+
+ ///
+ public string Name => "Interleaved-1F1B";
+
+ ///
+ public int VirtualStagesPerRank => _virtualStagesPerRank;
+
+ ///
+ public IReadOnlyList GetSchedule(int stageId, int numStages, int numMicroBatches)
+ {
+ if (numStages <= 0)
+ {
+ throw new ArgumentException("Number of stages must be positive.", nameof(numStages));
+ }
+
+ if (numMicroBatches <= 0)
+ {
+ throw new ArgumentException("Number of micro-batches must be positive.", nameof(numMicroBatches));
+ }
+
+ if (stageId < 0 || stageId >= numStages)
+ {
+ throw new ArgumentOutOfRangeException(nameof(stageId),
+ $"Stage ID must be between 0 and {numStages - 1}.");
+ }
+
+ var ops = new List();
+ int totalVirtualStages = numStages * _virtualStagesPerRank;
+
+ // Each rank handles V virtual stages. Virtual stage IDs for rank stageId:
+ // stageId, stageId + numStages, stageId + 2*numStages, ...
+ // In the interleaved schedule, microbatches flow through all virtual stages.
+
+ // Warmup: number of forward passes before steady state begins
+ // For interleaved, warmup is proportional to (totalVirtualStages - rank's first virtual stage - 1)
+ int numWarmupForwards = Math.Min(
+ totalVirtualStages - 1 - stageId,
+ numMicroBatches * _virtualStagesPerRank);
+
+ int totalForwards = numMicroBatches * _virtualStagesPerRank;
+ int totalBackwards = totalForwards;
+ int forwardsDone = 0;
+ int backwardsDone = 0;
+
+ // Phase 1: Warmup - forwards across virtual stages in depth-first order
+ // (prioritize earlier microbatch over earlier virtual stage)
+ for (int i = 0; i < numWarmupForwards && forwardsDone < totalForwards; i++)
+ {
+ // Depth-first: cycle through virtual stages for each microbatch
+ int vStage = forwardsDone % _virtualStagesPerRank;
+ int microBatch = forwardsDone / _virtualStagesPerRank;
+
+ if (microBatch < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = microBatch,
+ VirtualStageIndex = vStage,
+ IsWarmup = true,
+ IsCooldown = false
+ });
+ forwardsDone++;
+ }
+ }
+
+ // Phase 2: Steady state - alternating forward and backward
+ while (forwardsDone < totalForwards || backwardsDone < totalBackwards)
+ {
+ // One forward (if available)
+ if (forwardsDone < totalForwards)
+ {
+ int vStage = forwardsDone % _virtualStagesPerRank;
+ int microBatch = forwardsDone / _virtualStagesPerRank;
+
+ if (microBatch < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = microBatch,
+ VirtualStageIndex = vStage,
+ IsWarmup = false,
+ IsCooldown = false
+ });
+ forwardsDone++;
+ }
+ }
+
+ // One backward (if available)
+ if (backwardsDone < totalBackwards)
+ {
+ int vStage = backwardsDone % _virtualStagesPerRank;
+ int microBatch = backwardsDone / _virtualStagesPerRank;
+
+ if (microBatch < numMicroBatches)
+ {
+ bool isCooldown = forwardsDone >= totalForwards;
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Backward,
+ MicroBatchIndex = microBatch,
+ VirtualStageIndex = _virtualStagesPerRank - 1 - vStage, // Backward visits in reverse
+ IsWarmup = false,
+ IsCooldown = isCooldown
+ });
+ backwardsDone++;
+ }
+ }
+ }
+
+ return ops;
+ }
+
+ ///
+ public double EstimateBubbleFraction(int numStages, int numMicroBatches)
+ {
+ if (numStages <= 1 || numMicroBatches <= 0)
+ {
+ return 0.0;
+ }
+
+ // Interleaved 1F1B bubble: (P-1) / (2*M*V + P - 1)
+ // V times smaller than standard 1F1B
+ long p = numStages;
+ long m = numMicroBatches;
+ long v = _virtualStagesPerRank;
+ return (double)(p - 1) / (2 * m * v + p - 1);
+ }
+}
diff --git a/src/DistributedTraining/LoadBalancedPartitionStrategy.cs b/src/DistributedTraining/LoadBalancedPartitionStrategy.cs
new file mode 100644
index 000000000..346fe6255
--- /dev/null
+++ b/src/DistributedTraining/LoadBalancedPartitionStrategy.cs
@@ -0,0 +1,313 @@
+using AiDotNet.Interfaces;
+
+namespace AiDotNet.DistributedTraining;
+
+///
+/// Partitions model parameters across pipeline stages using estimated computational cost per layer.
+///
+///
+///
+/// Instead of dividing parameters uniformly, this strategy uses a cost function to estimate
+/// the computational load for each parameter group (layer). It then assigns parameters to stages
+/// so that each stage has roughly equal total cost, reducing pipeline bubble overhead.
+///
+/// For Beginners: Imagine an assembly line where some tasks take much longer than others.
+/// If you assign tasks purely by count, some workers finish early and wait while others are still busy.
+/// This strategy assigns tasks by estimated time, so all workers finish at roughly the same time.
+///
+/// For neural networks, attention layers are much more expensive than simple normalization layers,
+/// so this strategy gives fewer attention layers to each stage to balance the workload.
+///
+/// The cost function estimates FLOPs (floating point operations) for a block of parameters:
+/// - Dense/linear layers: ~2 * inputSize * outputSize FLOPs
+/// - Attention: ~4 * seqLen * d_model FLOPs
+/// - LayerNorm: ~5 * d_model FLOPs
+///
+/// Since we don't have layer-level metadata in the parameter vector, costs are estimated from
+/// parameter counts using the heuristic that computation scales quadratically with matrix dimensions.
+///
+/// Reference: Megatron-LM layer assignment algorithm, NVIDIA 2020.
+///
+/// The numeric type for operations.
+public class LoadBalancedPartitionStrategy : IPipelinePartitionStrategy
+{
+ private readonly Func? _costEstimator;
+ private readonly int[] _layerBoundaries;
+ private readonly bool _isAutoDetect;
+
+ ///
+ /// Creates a load-balanced partition strategy with explicit layer boundaries and optional cost estimator.
+ ///
+ ///
+ /// Array of parameter indices where each layer starts, in strictly increasing order.
+ /// All values must be non-negative. For example, if a model has 3 layers
+ /// with 100, 200, and 150 parameters respectively, pass [0, 100, 300].
+ /// The total parameter count is inferred as layerBoundaries[last] + size of last layer.
+ /// For Beginners: This tells the partitioner where each layer's parameters begin
+ /// in the flat parameter vector. You can get these from your model's layer structure.
+ ///
+ ///
+ /// Optional function that estimates the computational cost of a layer given its parameter count.
+ /// If null, cost is estimated as parameterCount^(3/2) which approximates the relationship
+ /// between matrix sizes and FLOP counts for dense layers.
+ /// For Beginners: This function converts "number of parameters" into "how long
+ /// this layer takes to compute." The default assumes dense matrix multiplication.
+ ///
+ /// Thrown when layerBoundaries is null, empty,
+ /// contains negative values, or is not strictly increasing.
+ public LoadBalancedPartitionStrategy(int[] layerBoundaries, Func? costEstimator = null)
+ {
+ if (layerBoundaries is null || layerBoundaries.Length == 0)
+ {
+ throw new ArgumentException("Layer boundaries must be provided and non-empty.", nameof(layerBoundaries));
+ }
+
+ // Validate all boundaries are non-negative and strictly increasing
+ if (layerBoundaries[0] < 0)
+ {
+ throw new ArgumentException(
+ $"Layer boundary at index 0 is negative ({layerBoundaries[0]}). All boundaries must be non-negative.",
+ nameof(layerBoundaries));
+ }
+
+ for (int i = 1; i < layerBoundaries.Length; i++)
+ {
+ if (layerBoundaries[i] < 0)
+ {
+ throw new ArgumentException(
+ $"Layer boundary at index {i} is negative ({layerBoundaries[i]}). All boundaries must be non-negative.",
+ nameof(layerBoundaries));
+ }
+
+ if (layerBoundaries[i] <= layerBoundaries[i - 1])
+ {
+ throw new ArgumentException(
+ $"Layer boundaries must be strictly increasing, but boundary[{i}]={layerBoundaries[i]} " +
+ $"<= boundary[{i - 1}]={layerBoundaries[i - 1]}.",
+ nameof(layerBoundaries));
+ }
+ }
+
+ _layerBoundaries = layerBoundaries;
+ _costEstimator = costEstimator;
+ _isAutoDetect = false;
+ }
+
+ ///
+ /// Creates a load-balanced partition strategy that auto-detects layer boundaries
+ /// using a fixed layer size estimate.
+ ///
+ ///
+ /// Estimated average number of parameters per layer.
+ /// For Beginners: If you know your model has ~1000 parameters per layer,
+ /// pass 1000 here and the partitioner will create synthetic layer boundaries.
+ ///
+ /// Optional cost estimator function.
+ /// Thrown when estimatedLayerSize is not positive.
+ public LoadBalancedPartitionStrategy(int estimatedLayerSize, Func? costEstimator = null)
+ {
+ if (estimatedLayerSize <= 0)
+ {
+ throw new ArgumentException("Estimated layer size must be positive.", nameof(estimatedLayerSize));
+ }
+
+ _layerBoundaries = new[] { estimatedLayerSize };
+ _costEstimator = costEstimator;
+ _isAutoDetect = true;
+ }
+
+ ///
+ public (int StartIndex, int Size)[] ComputePartition(int totalParameters, int numStages)
+ {
+ if (totalParameters <= 0)
+ {
+ throw new ArgumentException("Total parameters must be positive.", nameof(totalParameters));
+ }
+
+ if (numStages <= 0)
+ {
+ throw new ArgumentException("Number of stages must be positive.", nameof(numStages));
+ }
+
+ // Build layer sizes from boundaries
+ var layerSizes = BuildLayerSizes(totalParameters);
+ var layerCosts = ComputeLayerCosts(layerSizes);
+
+ // Use dynamic programming to find the optimal partition that minimizes
+ // the maximum cost across all stages (minimize pipeline bubble)
+ var assignment = OptimalPartition(layerSizes, layerCosts, numStages);
+
+ return assignment;
+ }
+
+ private int[] BuildLayerSizes(int totalParameters)
+ {
+ if (_isAutoDetect)
+ {
+ // Auto-detect mode: use estimated layer size to create synthetic boundaries
+ int estimatedLayerSize = _layerBoundaries[0];
+ int numLayers = Math.Max(1, totalParameters / estimatedLayerSize);
+ var sizes = new int[numLayers];
+ int baseSize = totalParameters / numLayers;
+ int remainder = totalParameters % numLayers;
+
+ for (int i = 0; i < numLayers; i++)
+ {
+ sizes[i] = baseSize + (i < remainder ? 1 : 0);
+ }
+
+ return sizes;
+ }
+
+ // Explicit boundaries mode: compute sizes from consecutive boundary differences
+ if (_layerBoundaries[_layerBoundaries.Length - 1] > totalParameters)
+ {
+ throw new ArgumentException(
+ $"Last layer boundary ({_layerBoundaries[_layerBoundaries.Length - 1]}) exceeds " +
+ $"total parameters ({totalParameters}).",
+ nameof(totalParameters));
+ }
+
+ var layerSizes = new int[_layerBoundaries.Length];
+ for (int i = 0; i < _layerBoundaries.Length; i++)
+ {
+ int start = _layerBoundaries[i];
+ int end = (i + 1 < _layerBoundaries.Length) ? _layerBoundaries[i + 1] : totalParameters;
+ layerSizes[i] = end - start;
+ }
+
+ return layerSizes;
+ }
+
+ private double[] ComputeLayerCosts(int[] layerSizes)
+ {
+ var costs = new double[layerSizes.Length];
+
+ for (int i = 0; i < layerSizes.Length; i++)
+ {
+ // Default heuristic: cost scales as paramCount^1.5
+ // For a square weight matrix of dimension n: params = n^2, FLOPs = 2*n^3 = 2*(params)^1.5.
+ // This is a reasonable approximation for dense/linear layers.
+ costs[i] = _costEstimator is not null
+ ? _costEstimator(layerSizes[i])
+ : Math.Pow(layerSizes[i], 1.5);
+ }
+
+ return costs;
+ }
+
+ ///
+ /// Uses dynamic programming to find the partition of layers into stages
+ /// that minimizes the maximum stage cost (min-max partitioning).
+ ///
+ private (int StartIndex, int Size)[] OptimalPartition(int[] layerSizes, double[] layerCosts, int numStages)
+ {
+ int numLayers = layerSizes.Length;
+
+ if (numStages >= numLayers)
+ {
+ // More stages than layers: assign one layer per stage, remaining stages get empty shards
+ return AssignOneLayerPerStage(layerSizes, numStages);
+ }
+
+ // Prefix sums for parameter sizes and costs
+ var paramPrefix = new long[numLayers + 1];
+ var costPrefix = new double[numLayers + 1];
+
+ for (int i = 0; i < numLayers; i++)
+ {
+ paramPrefix[i + 1] = paramPrefix[i] + layerSizes[i];
+ costPrefix[i + 1] = costPrefix[i] + layerCosts[i];
+ }
+
+ // dp[s][l] = minimum of maximum stage cost when assigning layers 0..l-1 to stages 0..s-1
+ var dp = new double[numStages + 1][];
+ var splitPoint = new int[numStages + 1][];
+
+ for (int s = 0; s <= numStages; s++)
+ {
+ dp[s] = new double[numLayers + 1];
+ splitPoint[s] = new int[numLayers + 1];
+ for (int i = 0; i < dp[s].Length; i++)
+ {
+ dp[s][i] = double.MaxValue;
+ }
+ }
+
+ dp[0][0] = 0.0;
+
+ // Base case: one stage gets all layers up to l
+ for (int l = 1; l <= numLayers; l++)
+ {
+ dp[1][l] = costPrefix[l];
+ splitPoint[1][l] = 0;
+ }
+
+ // Fill DP table
+ for (int s = 2; s <= numStages; s++)
+ {
+ for (int l = s; l <= numLayers; l++)
+ {
+ // Try all possible split points for the last stage
+ for (int k = s - 1; k < l; k++)
+ {
+ double lastStageCost = costPrefix[l] - costPrefix[k];
+ double candidate = Math.Max(dp[s - 1][k], lastStageCost);
+
+ if (candidate < dp[s][l])
+ {
+ dp[s][l] = candidate;
+ splitPoint[s][l] = k;
+ }
+ }
+ }
+ }
+
+ // Backtrack to find optimal partition
+ var stageEndLayers = new int[numStages];
+ int currentLayer = numLayers;
+
+ for (int s = numStages; s >= 1; s--)
+ {
+ stageEndLayers[s - 1] = currentLayer;
+ currentLayer = splitPoint[s][currentLayer];
+ }
+
+ // Convert layer assignments to parameter partitions
+ var partitions = new (int StartIndex, int Size)[numStages];
+ int layerStart = 0;
+
+ for (int s = 0; s < numStages; s++)
+ {
+ int layerEnd = stageEndLayers[s];
+ int paramStart = (int)paramPrefix[layerStart];
+ int paramSize = (int)(paramPrefix[layerEnd] - paramPrefix[layerStart]);
+ partitions[s] = (paramStart, paramSize);
+ layerStart = layerEnd;
+ }
+
+ return partitions;
+ }
+
+ private static (int StartIndex, int Size)[] AssignOneLayerPerStage(int[] layerSizes, int numStages)
+ {
+ var partitions = new (int StartIndex, int Size)[numStages];
+ int currentStart = 0;
+
+ for (int i = 0; i < numStages; i++)
+ {
+ if (i < layerSizes.Length)
+ {
+ partitions[i] = (currentStart, layerSizes[i]);
+ currentStart += layerSizes[i];
+ }
+ else
+ {
+ // Empty stage (more stages than layers)
+ partitions[i] = (currentStart, 0);
+ }
+ }
+
+ return partitions;
+ }
+}
diff --git a/src/DistributedTraining/LoopedBFSSchedule.cs b/src/DistributedTraining/LoopedBFSSchedule.cs
new file mode 100644
index 000000000..f87a4daf7
--- /dev/null
+++ b/src/DistributedTraining/LoopedBFSSchedule.cs
@@ -0,0 +1,186 @@
+using AiDotNet.Interfaces;
+
+namespace AiDotNet.DistributedTraining;
+
+///
+/// Implements the Looped BFS (Breadth-First Schedule) pipeline schedule with multiple virtual stages per rank.
+///
+///
+///
+/// Looped BFS, like Interleaved 1F1B, assigns V non-contiguous model chunks ("virtual stages")
+/// to each rank. Rank i holds chunks {i, i+P, i+2P, ...} where P is the number of physical ranks.
+///
+///
+/// The key difference from Interleaved 1F1B is the scheduling priority:
+/// - Interleaved 1F1B (Depth-First): Prioritizes the earlier microbatch. If microbatch 0
+/// is ready for virtual stages 0 and 1, it runs stage 0 for microbatch 0 first.
+/// - Looped BFS (Breadth-First): Prioritizes the earlier virtual stage. If microbatches 0
+/// and 1 are ready for virtual stage 0, it processes them both before moving to stage 1.
+///
+/// For Beginners: Imagine a factory with two assembly stations per worker (V=2).
+/// Depth-first (Interleaved 1F1B) means: finish one product at both stations before starting the next.
+/// Breadth-first (Looped BFS) means: run all products through station 1, then all through station 2.
+///
+/// Looped BFS tends to have slightly higher pipeline utilization in some configurations because
+/// it minimizes the number of times data needs to cross between physical ranks. However, it
+/// may have higher peak memory usage since more microbatches are in flight at each virtual stage.
+///
+/// Example with 4 GPUs, V=2 (8 total chunks):
+/// - GPU 0: chunks 0 and 4
+/// - GPU 1: chunks 1 and 5
+/// - GPU 2: chunks 2 and 6
+/// - GPU 3: chunks 3 and 7
+///
+/// Looped BFS processes ALL microbatches through chunks 0-3 first (loop 1),
+/// then ALL microbatches through chunks 4-7 (loop 2).
+///
+/// Reference: Lamy-Poirier, "Breadth-First Pipeline Parallelism", 2022.
+/// https://arxiv.org/abs/2211.05953
+///
+public class LoopedBFSSchedule : IPipelineSchedule
+{
+ private readonly int _virtualStagesPerRank;
+
+ ///
+ /// Creates a new Looped BFS schedule.
+ ///
+ ///
+ /// Number of model chunks per rank. Default is 2.
+ /// Higher values reduce bubble but increase communication.
+ /// Must be at least 2 (otherwise use standard 1F1B).
+ ///
+ public LoopedBFSSchedule(int virtualStagesPerRank = 2)
+ {
+ if (virtualStagesPerRank < 2)
+ {
+ throw new ArgumentOutOfRangeException(nameof(virtualStagesPerRank),
+ "Looped BFS requires at least 2 virtual stages per rank. " +
+ "Use OneForwardOneBackwardSchedule for single-stage scheduling.");
+ }
+
+ _virtualStagesPerRank = virtualStagesPerRank;
+ }
+
+ ///
+ public string Name => "Looped-BFS";
+
+ ///
+ public int VirtualStagesPerRank => _virtualStagesPerRank;
+
+ ///
+ public IReadOnlyList GetSchedule(int stageId, int numStages, int numMicroBatches)
+ {
+ if (numStages <= 0)
+ {
+ throw new ArgumentException("Number of stages must be positive.", nameof(numStages));
+ }
+
+ if (numMicroBatches <= 0)
+ {
+ throw new ArgumentException("Number of micro-batches must be positive.", nameof(numMicroBatches));
+ }
+
+ if (stageId < 0 || stageId >= numStages)
+ {
+ throw new ArgumentOutOfRangeException(nameof(stageId),
+ $"Stage ID must be between 0 and {numStages - 1}.");
+ }
+
+ var ops = new List();
+
+ // Looped BFS: process all microbatches through each virtual stage loop before moving
+ // to the next virtual stage. Within each loop, use 1F1B-style scheduling.
+ //
+ // Loop structure:
+ // for vStage in 0..V-1:
+ // warmup forwards for this vStage
+ // steady-state 1F1B for this vStage
+ // cooldown backwards for this vStage
+
+ for (int vStage = 0; vStage < _virtualStagesPerRank; vStage++)
+ {
+ // Within each loop, apply 1F1B scheduling for this virtual stage
+ int numWarmupForwards = Math.Min(numStages - 1 - stageId, numMicroBatches);
+ int numSteadyState = Math.Max(0, numMicroBatches - numWarmupForwards);
+
+ // Phase 1: Warmup - forward passes only
+ int forwardIdx = 0;
+ for (int i = 0; i < numWarmupForwards; i++)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = forwardIdx,
+ VirtualStageIndex = vStage,
+ IsWarmup = true,
+ IsCooldown = false
+ });
+ forwardIdx++;
+ }
+
+ // Phase 2: Steady state - alternating 1F1B
+ int backwardIdx = 0;
+ for (int i = 0; i < numSteadyState; i++)
+ {
+ // Forward
+ if (forwardIdx < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = forwardIdx,
+ VirtualStageIndex = vStage,
+ IsWarmup = false,
+ IsCooldown = false
+ });
+ forwardIdx++;
+ }
+
+ // Backward
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Backward,
+ MicroBatchIndex = backwardIdx,
+ VirtualStageIndex = vStage,
+ IsWarmup = false,
+ IsCooldown = false
+ });
+ backwardIdx++;
+ }
+
+ // Phase 3: Cooldown - remaining backward passes
+ while (backwardIdx < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Backward,
+ MicroBatchIndex = backwardIdx,
+ VirtualStageIndex = vStage,
+ IsWarmup = false,
+ IsCooldown = true
+ });
+ backwardIdx++;
+ }
+ }
+
+ return ops;
+ }
+
+ ///
+ public double EstimateBubbleFraction(int numStages, int numMicroBatches)
+ {
+ if (numStages <= 1 || numMicroBatches <= 0)
+ {
+ return 0.0;
+ }
+
+ // Looped BFS has approximately the same bubble as Interleaved 1F1B
+ // but the communication pattern differs. The bubble is roughly:
+ // (P-1) / (2*M*V + P - 1)
+ // Same asymptotic behavior as Interleaved 1F1B.
+ long p = numStages;
+ long m = numMicroBatches;
+ long v = _virtualStagesPerRank;
+ return (double)(p - 1) / (2 * m * v + p - 1);
+ }
+}
diff --git a/src/DistributedTraining/OneForwardOneBackwardSchedule.cs b/src/DistributedTraining/OneForwardOneBackwardSchedule.cs
new file mode 100644
index 000000000..e95b1b555
--- /dev/null
+++ b/src/DistributedTraining/OneForwardOneBackwardSchedule.cs
@@ -0,0 +1,149 @@
+using AiDotNet.Interfaces;
+
+namespace AiDotNet.DistributedTraining;
+
+///
+/// Implements the 1F1B (One-Forward-One-Backward) pipeline schedule.
+///
+///
+///
+/// The 1F1B schedule interleaves forward and backward passes to minimize pipeline bubble
+/// and memory usage. It has three phases:
+///
+/// 1. Warmup: Each stage executes forward passes to fill the pipeline.
+/// Stage i performs (numStages - 1 - i) forward passes before steady state.
+///
+/// 2. Steady State: Each stage alternates between one forward and one backward pass.
+/// This keeps all stages busy and limits memory usage to at most (numStages) activations.
+///
+/// 3. Cooldown: Remaining backward passes drain the pipeline.
+///
+/// For Beginners: Instead of doing ALL forward passes then ALL backward passes (GPipe),
+/// 1F1B interleaves them. This is like a factory where each worker handles their current item
+/// and immediately starts the return processing, rather than waiting for all items to pass through.
+///
+/// Benefits:
+/// - Reduces pipeline bubble from ~50% to ~12-15%
+/// - Limits peak memory to (numStages) stored activations instead of (numMicroBatches)
+/// - More efficient for large numbers of micro-batches
+///
+/// Example with 4 stages and 8 micro-batches:
+///
+/// Stage 0: F0 F1 F2 F3 B0 F4 B1 F5 B2 F6 B3 F7 B4 B5 B6 B7
+/// Stage 1: F0 F1 F2 B0 F3 B1 F4 B2 F5 B3 F6 B4 F7 B5 B6 B7
+/// Stage 2: F0 F1 B0 F2 B1 F3 B2 F4 B3 F5 B4 F6 B5 F7 B6 B7
+/// Stage 3: F0 B0 F1 B1 F2 B2 F3 B3 F4 B4 F5 B5 F6 B6 F7 B7
+///
+///
+/// Reference: Narayanan et al., "PipeDream: Generalized Pipeline Parallelism for DNN Training", SOSP 2019.
+/// https://arxiv.org/abs/1806.03377
+///
+public class OneForwardOneBackwardSchedule : IPipelineSchedule
+{
+ ///
+ public string Name => "1F1B";
+
+ ///
+ public int VirtualStagesPerRank => 1;
+
+ ///
+ public IReadOnlyList GetSchedule(int stageId, int numStages, int numMicroBatches)
+ {
+ if (numStages <= 0)
+ {
+ throw new ArgumentException("Number of stages must be positive.", nameof(numStages));
+ }
+
+ if (numMicroBatches <= 0)
+ {
+ throw new ArgumentException("Number of micro-batches must be positive.", nameof(numMicroBatches));
+ }
+
+ if (stageId < 0 || stageId >= numStages)
+ {
+ throw new ArgumentOutOfRangeException(nameof(stageId),
+ $"Stage ID must be between 0 and {numStages - 1}.");
+ }
+
+ var ops = new List();
+
+ // Number of warmup forward passes for this stage
+ // Earlier stages need more warmup to fill the pipeline
+ int numWarmupForwards = Math.Min(numStages - 1 - stageId, numMicroBatches);
+
+ // Number of steady-state 1F1B pairs
+ int numSteadyState = Math.Max(0, numMicroBatches - numWarmupForwards);
+
+ // Phase 1: Warmup - only forward passes
+ int forwardIdx = 0;
+ for (int i = 0; i < numWarmupForwards; i++)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = forwardIdx,
+ IsWarmup = true,
+ IsCooldown = false
+ });
+ forwardIdx++;
+ }
+
+ // Phase 2: Steady state - alternating 1F1B
+ int backwardIdx = 0;
+ for (int i = 0; i < numSteadyState; i++)
+ {
+ // One forward
+ if (forwardIdx < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = forwardIdx,
+ IsWarmup = false,
+ IsCooldown = false
+ });
+ forwardIdx++;
+ }
+
+ // One backward
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Backward,
+ MicroBatchIndex = backwardIdx,
+ IsWarmup = false,
+ IsCooldown = false
+ });
+ backwardIdx++;
+ }
+
+ // Phase 3: Cooldown - only backward passes
+ while (backwardIdx < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Backward,
+ MicroBatchIndex = backwardIdx,
+ IsWarmup = false,
+ IsCooldown = true
+ });
+ backwardIdx++;
+ }
+
+ return ops;
+ }
+
+ ///
+ public double EstimateBubbleFraction(int numStages, int numMicroBatches)
+ {
+ if (numStages <= 1 || numMicroBatches <= 0)
+ {
+ return 0.0;
+ }
+
+ // 1F1B bubble fraction: (P-1) / (2*M + P - 1) where P = stages, M = micro-batches
+ // This is approximately half of GPipe's bubble for large M
+ long p = numStages;
+ long m = numMicroBatches;
+ return (double)(p - 1) / (2 * m + p - 1);
+ }
+}
diff --git a/src/DistributedTraining/PipelineParallelModel.cs b/src/DistributedTraining/PipelineParallelModel.cs
index f745f0f78..3d8327bc5 100644
--- a/src/DistributedTraining/PipelineParallelModel.cs
+++ b/src/DistributedTraining/PipelineParallelModel.cs
@@ -10,7 +10,7 @@ namespace AiDotNet.DistributedTraining;
///
///
/// Strategy Overview:
-/// Pipeline Parallelism (GPipe-style) divides the model vertically into stages, with each process
+/// Pipeline Parallelism divides the model vertically into stages, with each process
/// owning specific layers. Input mini-batches are divided into micro-batches that flow through
/// the pipeline stages sequentially. This enables training models too large to fit on a single device
/// while maintaining good hardware utilization through micro-batch pipelining.
@@ -24,47 +24,32 @@ namespace AiDotNet.DistributedTraining;
/// flow through the pipeline like cars on an assembly line. While Process 1 is working on micro-batch 1,
/// Process 0 can start on micro-batch 2.
///
-/// Use Cases:
-/// - Very deep models that don't fit on a single GPU
-/// - When model depth (layers) >> width (parameters per layer)
-/// - Transformer models with many layers
-/// - Complementary to data parallelism (can combine them)
-///
-/// Trade-offs:
-/// - Memory: Excellent for deep models - each rank stores only its layers
-/// - Communication: Low - only activations passed between adjacent stages
-/// - Complexity: High - requires micro-batching, careful scheduling, pipeline bubble overhead
-/// - Best for: Very deep models, limited per-device memory
-/// - Limitation: Pipeline "bubble" (idle time) reduces efficiency, typically ~12-25% for GPipe
-///
-/// Implementation Note:
-/// This implementation provides GPipe-style pipeline parallelism with gradient-based backward pass.
-/// The forward pass sends activations between adjacent stages, and the backward pass communicates
-/// gradients in the reverse direction. Gradients are accumulated across stages and applied to
-/// parameters after the backward pass completes.
-///
-/// Gradient Approximation: Since IFullModel.Train() combines gradient computation and parameter
-/// updates into a single operation, gradients are approximated as parameter differences
-/// (params_before - params_after). This captures the complete parameter update including learning
-/// rate and optimizer state. For access to raw gradients before optimizer application, extend
-/// this class or use an optimizer that exposes gradients via IGradientBasedOptimizer.
-///
-/// For production use with specific models, consider:
-/// 1. Model-specific layer partitioning strategies (e.g., balance compute load across stages)
-/// 2. Micro-batch scheduling to reduce pipeline bubbles
-/// 3. Activation checkpointing to reduce memory usage
-///
-///
-/// Example:
-///
-/// var model = new DeepNeuralNetwork<double>(...); // 100 layers
-/// var backend = new InMemoryCommunicationBackend<double>(rank: 0, worldSize: 4);
-/// var config = new ShardingConfiguration<double>(backend);
-///
-/// // Rank 0: layers 0-24, Rank 1: layers 25-49, Rank 2: layers 50-74, Rank 3: layers 75-99
-/// var pipelineModel = new PipelineParallelModel<double, Tensor<double>, Tensor<double>>(
-/// model, config, microBatchSize: 4);
-///
+/// Supported Features (Issue #463):
+///
+/// -
+/// 7 Pipeline Schedules: GPipe, 1F1B, ZB-H1, ZB-H2, ZB-V, Interleaved 1F1B, Looped BFS.
+/// Zero Bubble schedules decompose backward into BackwardInput + BackwardWeight for optimal throughput.
+///
+/// -
+/// Virtual Stages: Multi-stage schedules (Interleaved 1F1B, Looped BFS, ZB-V) assign
+/// multiple non-contiguous model chunks per rank, reducing pipeline bubble by factor V.
+///
+/// -
+/// Micro-Batch Slicing: Input is automatically sliced into micro-batches that flow
+/// through the pipeline independently.
+///
+/// -
+/// Backward Decomposition: If the wrapped model implements ,
+/// BackwardInput and BackwardWeight are truly decomposed. Otherwise, a compatible emulation is used.
+///
+/// -
+/// Activation Checkpointing: Trade compute for memory by recomputing activations from
+/// checkpoints during the backward pass.
+///
+/// -
+/// Load-Balanced Partitioning: Balance compute across stages via dynamic programming.
+///
+///
///
///
/// The numeric type
@@ -73,19 +58,109 @@ namespace AiDotNet.DistributedTraining;
public class PipelineParallelModel : ShardedModelBase
{
private readonly int _microBatchSize;
+ private readonly IPipelinePartitionStrategy? _partitionStrategy;
+ private readonly IPipelineSchedule _schedule;
+ private readonly ActivationCheckpointConfig _checkpointConfig;
private int _stageId;
private int _numStages;
+ private int _virtualStagesPerRank;
+
+ // Total virtual stages across all ranks
+ private int _totalVirtualStages;
+
+ // Parameter ranges for each virtual stage this rank owns.
+ // For single-stage schedules (V=1): one entry mapping to the full shard.
+ // For multi-stage schedules (V>1): V entries for non-contiguous model chunks.
+ // Key = local virtual stage index (0..V-1), Value = (StartIndex, Size) in full param vector.
+ private readonly Dictionary _virtualStagePartitions = new();
+
+ // Activation storage for checkpointing.
+ // Key format: (microBatchIndex * _virtualStagesPerRank + virtualStageIndex) for uniqueness.
+ private readonly Dictionary> _checkpointedActivations = new();
+
+ // Cached state from BackwardInput for later use by BackwardWeight (Zero Bubble B/W decomposition).
+ // Key format: (microBatchIndex * _virtualStagesPerRank + virtualStageIndex).
+ private readonly Dictionary _cachedBackwardState = new();
+
+ // Cached weight gradients from BackwardInput for fallback accumulation when model
+ // does not support IPipelineDecomposableModel (emulated B/W split).
+ private readonly Dictionary> _cachedWeightGradients = new();
+
+ // Whether the wrapped model supports true B/W decomposition
+ private bool _supportsDecomposedBackward;
+
+ // Communication tag ranges to prevent collisions between forward activations,
+ // backward gradients, and predict-time messages.
+ private const int ActivationTagBase = 0;
+ private const int GradientTagBase = 1_000_000;
+ private const int PredictTagBase = 2_000_000;
+
+ ///
+ /// Gets the pipeline schedule used by this model.
+ ///
+ ///
+ /// This property is internal. Configure the schedule via AiModelBuilder methods
+ /// (e.g., ConfigurePipelineParallelism) rather than accessing this directly.
+ ///
+ internal IPipelineSchedule Schedule => _schedule;
+
+ ///
+ /// Gets the activation checkpoint configuration.
+ ///
+ ///
+ /// This property is internal. Configure checkpointing via AiModelBuilder methods
+ /// rather than accessing this directly.
+ ///
+ internal ActivationCheckpointConfig CheckpointConfig => _checkpointConfig;
+
+ ///
+ /// Gets the partition strategy, or null if using uniform partitioning.
+ ///
+ ///
+ /// This property is internal. Configure the partition strategy via AiModelBuilder methods
+ /// rather than accessing this directly.
+ ///
+ internal IPipelinePartitionStrategy? PartitionStrategy => _partitionStrategy;
+
+ ///
+ /// Gets the estimated pipeline bubble fraction for the current configuration.
+ ///
+ public double EstimatedBubbleFraction
+ {
+ get
+ {
+ if (_numStages <= 0)
+ {
+ throw new InvalidOperationException(
+ "EstimatedBubbleFraction cannot be computed before sharding is initialized.");
+ }
+
+ return _schedule.EstimateBubbleFraction(_numStages, _microBatchSize);
+ }
+ }
///
/// Creates a new Pipeline Parallel model.
///
- /// The model to split into pipeline stages
- /// Configuration for sharding and communication
- /// Size of micro-batches for pipeline execution (default: 1)
+ /// The model to split into pipeline stages.
+ /// Configuration for sharding and communication.
+ /// Number of micro-batches to split the input into (default: 1).
+ ///
+ /// Strategy for partitioning parameters across stages. If null, uses uniform partitioning.
+ ///
+ ///
+ /// Pipeline execution schedule. If null, uses .
+ ///
+ ///
+ /// Activation checkpointing configuration. If null, checkpointing is disabled.
+ ///
public PipelineParallelModel(
IFullModel wrappedModel,
IShardingConfiguration config,
- int microBatchSize = 1)
+ int microBatchSize = 1,
+ IPipelinePartitionStrategy? partitionStrategy = null,
+ IPipelineSchedule? schedule = null,
+ ActivationCheckpointConfig? checkpointConfig = null)
: base(wrappedModel, config)
{
if (microBatchSize < 1)
@@ -95,7 +170,21 @@ public PipelineParallelModel(
}
_microBatchSize = microBatchSize;
- // Note: _stageId and _numStages are set in OnBeforeInitializeSharding which is called by lazy initialization
+ _partitionStrategy = partitionStrategy;
+ _schedule = schedule ?? new GPipeSchedule();
+ _checkpointConfig = checkpointConfig ?? new ActivationCheckpointConfig();
+
+ // Activation checkpointing recomputation strategies (Selective, Full) require
+ // layer-level forward pass decomposition that is not yet implemented.
+ // Only interval-based checkpoint storage is currently functional.
+ if (_checkpointConfig.Enabled &&
+ _checkpointConfig.RecomputeStrategy != RecomputeStrategy.None)
+ {
+ throw new NotImplementedException(
+ $"Activation checkpointing with RecomputeStrategy.{_checkpointConfig.RecomputeStrategy} " +
+ "is not yet implemented. Use RecomputeStrategy.None to enable checkpoint storage " +
+ "without recomputation, or disable checkpointing entirely.");
+ }
}
///
@@ -105,28 +194,162 @@ protected override void OnBeforeInitializeSharding()
{
_stageId = Config.CommunicationBackend.Rank;
_numStages = Config.CommunicationBackend.WorldSize;
+ _virtualStagesPerRank = _schedule.VirtualStagesPerRank;
+ _totalVirtualStages = _numStages * _virtualStagesPerRank;
+ _supportsDecomposedBackward = WrappedModel is IPipelineDecomposableModel;
}
///
- /// Initializes pipeline parallelism by partitioning parameters into stages.
+ /// Initializes pipeline parallelism by partitioning parameters into stages,
+ /// including virtual stage partitions for multi-stage schedules.
///
protected override void InitializeSharding()
{
var fullParameters = WrappedModel.GetParameters();
int totalParams = fullParameters.Length;
- // Divide parameters into pipeline stages
- // Each stage owns a contiguous chunk of parameters (representing layers)
- int baseShardSize = totalParams / _numStages;
- int remainder = totalParams % _numStages;
+ _virtualStagePartitions.Clear();
+
+ if (_virtualStagesPerRank > 1)
+ {
+ // Multi-stage schedule: partition into totalVirtualStages chunks,
+ // then assign V non-contiguous chunks to this rank.
+ // Rank i gets virtual stages: i, i+P, i+2P, ...
+ (int StartIndex, int Size)[] vsPartitions;
+
+ if (_partitionStrategy is not null)
+ {
+ // Use the configured partition strategy for load-balanced partitioning
+ // across all virtual stages (not just physical stages)
+ vsPartitions = _partitionStrategy.ComputePartition(totalParams, _totalVirtualStages);
+
+ if (vsPartitions is null || vsPartitions.Length != _totalVirtualStages)
+ {
+ throw new InvalidOperationException(
+ $"Partition strategy returned {(vsPartitions is null ? "null" : $"{vsPartitions.Length} partitions")} " +
+ $"but expected exactly {_totalVirtualStages} partitions for {_virtualStagesPerRank} virtual stages per rank.");
+ }
+
+ // Validate bounds for all virtual stage partitions
+ for (int vs = 0; vs < _totalVirtualStages; vs++)
+ {
+ var (start, size) = vsPartitions[vs];
+ if (start < 0 || size < 0 || start + size > totalParams)
+ {
+ throw new InvalidOperationException(
+ $"Partition strategy returned invalid partition for virtual stage {vs}: " +
+ $"StartIndex={start}, Size={size}, but total parameters is {totalParams}.");
+ }
+ }
+ }
+ else
+ {
+ // Uniform partitioning
+ vsPartitions = new (int StartIndex, int Size)[_totalVirtualStages];
+ int baseChunkSize = totalParams / _totalVirtualStages;
+ int remainder = totalParams % _totalVirtualStages;
+ int offset = 0;
+ for (int vs = 0; vs < _totalVirtualStages; vs++)
+ {
+ int size = baseChunkSize + (vs < remainder ? 1 : 0);
+ vsPartitions[vs] = (offset, size);
+ offset += size;
+ }
+ }
+
+ // Assign this rank's virtual stages
+ int totalShardSize = 0;
+ for (int v = 0; v < _virtualStagesPerRank; v++)
+ {
+ int globalVirtualStageId = _stageId + v * _numStages;
+ if (globalVirtualStageId < _totalVirtualStages)
+ {
+ var partition = vsPartitions[globalVirtualStageId];
+ _virtualStagePartitions[v] = partition;
+ totalShardSize += partition.Size;
+ }
+ }
+
+ // The shard for base class is the union of all virtual stage parameters.
+ // Use the first virtual stage's start as the shard start.
+ if (_virtualStagePartitions.Count > 0)
+ {
+ ShardStartIndex = _virtualStagePartitions[0].StartIndex;
+ ShardSize = totalShardSize;
+ }
+ else
+ {
+ ShardStartIndex = 0;
+ ShardSize = 0;
+ }
+ }
+ else
+ {
+ // Single-stage schedule: standard partitioning
+ if (_partitionStrategy is not null)
+ {
+ var partitions = _partitionStrategy.ComputePartition(totalParams, _numStages);
+
+ if (partitions is null || partitions.Length != _numStages)
+ {
+ throw new InvalidOperationException(
+ $"Partition strategy returned {(partitions is null ? "null" : $"{partitions.Length} partitions")} " +
+ $"but expected exactly {_numStages} partitions.");
+ }
+
+ var stagePartition = partitions[_stageId];
+ if (stagePartition.StartIndex < 0 || stagePartition.Size < 0 ||
+ stagePartition.StartIndex + stagePartition.Size > totalParams)
+ {
+ throw new InvalidOperationException(
+ $"Partition strategy returned invalid partition for stage {_stageId}: " +
+ $"StartIndex={stagePartition.StartIndex}, Size={stagePartition.Size}, " +
+ $"but total parameters is {totalParams}.");
+ }
- ShardSize = baseShardSize + (_stageId < remainder ? 1 : 0);
- ShardStartIndex = _stageId * baseShardSize + Math.Min(_stageId, remainder);
+ ShardStartIndex = stagePartition.StartIndex;
+ ShardSize = stagePartition.Size;
+ }
+ else
+ {
+ int baseShardSize = totalParams / _numStages;
+ int leftover = totalParams % _numStages;
- // Extract this stage's parameters
- var shardData = new T[ShardSize];
- Array.Copy(fullParameters.ToArray(), ShardStartIndex, shardData, 0, ShardSize);
- LocalShard = new Vector(shardData);
+ ShardSize = baseShardSize + (_stageId < leftover ? 1 : 0);
+ ShardStartIndex = _stageId * baseShardSize + Math.Min(_stageId, leftover);
+ }
+
+ _virtualStagePartitions[0] = (ShardStartIndex, ShardSize);
+ }
+
+ // Extract this stage's parameters (union of all virtual stage params)
+ if (ShardSize > 0)
+ {
+ var shardData = new T[ShardSize];
+ if (_virtualStagesPerRank > 1)
+ {
+ // For multi-stage: gather non-contiguous chunks
+ int destOffset = 0;
+ var paramArray = fullParameters.ToArray();
+ for (int v = 0; v < _virtualStagesPerRank; v++)
+ {
+ if (_virtualStagePartitions.TryGetValue(v, out var partition))
+ {
+ Array.Copy(paramArray, partition.StartIndex, shardData, destOffset, partition.Size);
+ destOffset += partition.Size;
+ }
+ }
+ }
+ else
+ {
+ Array.Copy(fullParameters.ToArray(), ShardStartIndex, shardData, 0, ShardSize);
+ }
+ LocalShard = new Vector(shardData);
+ }
+ else
+ {
+ LocalShard = new Vector(0);
+ }
CachedFullParameters = null;
}
@@ -134,8 +357,26 @@ protected override void InitializeSharding()
///
public override void Train(TInput input, TOutput expectedOutput)
{
- // GPipe-style pipeline parallel training with gradient-based backward pass
- // Strategy: Forward pass sends activations, backward pass sends gradients
+ // Pipeline parallel training using the configured schedule
+ var scheduleOps = _schedule.GetSchedule(_stageId, _numStages, _microBatchSize);
+
+ // Validate schedule output: externally injectable schedules may emit invalid indices
+ foreach (var op in scheduleOps)
+ {
+ if (op.MicroBatchIndex < 0 || op.MicroBatchIndex >= _microBatchSize)
+ {
+ throw new InvalidOperationException(
+ $"Schedule '{_schedule.Name}' emitted MicroBatchIndex={op.MicroBatchIndex} " +
+ $"but valid range is [0, {_microBatchSize - 1}].");
+ }
+
+ if (op.VirtualStageIndex < 0 || op.VirtualStageIndex >= _virtualStagesPerRank)
+ {
+ throw new InvalidOperationException(
+ $"Schedule '{_schedule.Name}' emitted VirtualStageIndex={op.VirtualStageIndex} " +
+ $"but valid range is [0, {_virtualStagesPerRank - 1}].");
+ }
+ }
// Gather full parameters before training
var fullParams = GatherFullParameters();
@@ -144,131 +385,607 @@ public override void Train(TInput input, TOutput expectedOutput)
// Save parameters BEFORE training to compute gradients
var parametersBefore = new Vector(fullParams.ToArray());
- // Determine actual input for this stage
- TInput stageInput = input;
+ // Accumulated weight gradients across all micro-batches
+ Vector? accumulatedGradients = null;
- // FORWARD PASS: Receive activations from previous stage
- if (_stageId > 0)
+ // Slice input and targets into micro-batches
+ var microBatches = SliceInputIntoMicroBatches(input);
+ var microBatchTargets = SliceTargetIntoMicroBatches(expectedOutput);
+
+ // Track activations per (microBatch, virtualStage) for backward pass
+ var forwardInputs = new Dictionary();
+ var forwardOutputs = new Dictionary();
+
+ // Clear state from previous iteration
+ _checkpointedActivations.Clear();
+ _cachedBackwardState.Clear();
+ _cachedWeightGradients.Clear();
+
+ foreach (var op in scheduleOps)
{
- // Protocol: First receive 1-element size header, then receive activations
- // This prevents size mismatches when stage output size differs from input size
- Vector sizeHeader = Config.CommunicationBackend.Receive(_stageId - 1, count: 1, tag: 0);
- int activationSize = NumOps.ToInt32(sizeHeader[0]);
+ int opKey = GetOperationKey(op.MicroBatchIndex, op.VirtualStageIndex);
- Vector receivedActivations = Config.CommunicationBackend.Receive(_stageId - 1, activationSize, tag: 0);
+ if (op.Type == PipelineOperationType.Forward)
+ {
+ ExecuteForward(op, microBatches, forwardInputs, forwardOutputs, opKey);
+ }
+ else if (op.Type == PipelineOperationType.Backward)
+ {
+ // Combined backward: compute all gradients and communicate in one step.
+ // Used by traditional schedules (GPipe, 1F1B).
+ var microBatchInput = RetrieveMicroBatchInput(opKey, forwardInputs, microBatches, op);
+ var microBatchTarget = GetMicroBatchTarget(op.MicroBatchIndex, microBatchTargets, expectedOutput);
- // For intermediate stages, convert received activations to TInput type WITHOUT using
- // the original input as reference (which would have the wrong shape for non-first stages).
- // Use ConversionsHelper to centralize conversion logic and avoid code duplication.
- stageInput = ConversionsHelper.ConvertVectorToInputWithoutReference(receivedActivations);
+ var gradientVector = WrappedModel.ComputeGradients(microBatchInput, microBatchTarget);
+
+ ReceiveAndAccumulateDownstreamGradients(gradientVector, op.MicroBatchIndex, op.VirtualStageIndex);
+ SendGradientsUpstream(gradientVector, op.MicroBatchIndex, op.VirtualStageIndex);
+ accumulatedGradients = AccumulateGradients(accumulatedGradients, gradientVector);
+
+ FreeNonCheckpointedActivations(opKey, forwardInputs, forwardOutputs);
+ }
+ else if (op.Type == PipelineOperationType.BackwardInput)
+ {
+ // Zero Bubble B step: compute activation gradients (critical path).
+ // Upstream stage is waiting for these gradients.
+ var microBatchInput = RetrieveMicroBatchInput(opKey, forwardInputs, microBatches, op);
+ var microBatchTarget = GetMicroBatchTarget(op.MicroBatchIndex, microBatchTargets, expectedOutput);
+
+ if (_supportsDecomposedBackward)
+ {
+ // True decomposition: compute only activation gradients
+ var decomposable = (IPipelineDecomposableModel)WrappedModel;
+ var (activationGrads, cachedState) = decomposable.ComputeActivationGradients(
+ microBatchInput, microBatchTarget);
+
+ ReceiveAndAccumulateDownstreamGradients(activationGrads, op.MicroBatchIndex, op.VirtualStageIndex);
+ SendGradientsUpstream(activationGrads, op.MicroBatchIndex, op.VirtualStageIndex);
+
+ // Cache state for BackwardWeight to avoid redundant computation
+ _cachedBackwardState[opKey] = cachedState;
+ }
+ else
+ {
+ // Emulated decomposition: compute full gradients now, send activation grads upstream,
+ // cache weight gradients for BackwardWeight step to accumulate later.
+ var fullGradients = WrappedModel.ComputeGradients(microBatchInput, microBatchTarget);
+
+ ReceiveAndAccumulateDownstreamGradients(fullGradients, op.MicroBatchIndex, op.VirtualStageIndex);
+ SendGradientsUpstream(fullGradients, op.MicroBatchIndex, op.VirtualStageIndex);
+
+ // Cache the weight gradients for the W step
+ _cachedWeightGradients[opKey] = fullGradients;
+ }
+ }
+ else if (op.Type == PipelineOperationType.BackwardWeight)
+ {
+ // Zero Bubble W step: compute weight gradients (fills bubbles).
+ // No other stage depends on this - can be deferred.
+ Vector weightGradients;
+
+ if (_supportsDecomposedBackward)
+ {
+ // True decomposition: compute only weight gradients
+ var decomposable = (IPipelineDecomposableModel)WrappedModel;
+ var microBatchInput = RetrieveMicroBatchInput(opKey, forwardInputs, microBatches, op);
+ var microBatchTarget = GetMicroBatchTarget(op.MicroBatchIndex, microBatchTargets, expectedOutput);
+
+ _cachedBackwardState.TryGetValue(opKey, out var cachedState);
+ weightGradients = decomposable.ComputeWeightGradients(
+ microBatchInput, microBatchTarget, cachedState);
+ _cachedBackwardState.Remove(opKey);
+ }
+ else
+ {
+ // Emulated: use cached gradients from BackwardInput step
+ if (_cachedWeightGradients.TryGetValue(opKey, out var cached))
+ {
+ weightGradients = cached;
+ _cachedWeightGradients.Remove(opKey);
+ }
+ else
+ {
+ // Fallback: recompute full gradients
+ var microBatchInput = RetrieveMicroBatchInput(opKey, forwardInputs, microBatches, op);
+ var microBatchTarget = GetMicroBatchTarget(op.MicroBatchIndex, microBatchTargets, expectedOutput);
+ weightGradients = WrappedModel.ComputeGradients(microBatchInput, microBatchTarget);
+ }
+ }
+
+ accumulatedGradients = AccumulateGradients(accumulatedGradients, weightGradients);
+ FreeNonCheckpointedActivations(opKey, forwardInputs, forwardOutputs);
+ }
+ }
+
+ // Apply accumulated gradients averaged across micro-batches
+ if (accumulatedGradients is not null)
+ {
+ T microBatchCount = NumOps.FromDouble(_microBatchSize);
+ for (int i = 0; i < accumulatedGradients.Length; i++)
+ {
+ accumulatedGradients[i] = NumOps.Divide(accumulatedGradients[i], microBatchCount);
+ }
+
+ WrappedModel.SetParameters(parametersBefore);
+ WrappedModel.ApplyGradients(accumulatedGradients, Config.LearningRate);
+ }
+
+ // Extract this stage's parameter shard
+ var updatedParams = WrappedModel.GetParameters();
+ UpdateLocalShardFromFull(updatedParams);
+ InvalidateCache();
+
+ // Clean up all activation/gradient storage
+ _checkpointedActivations.Clear();
+ _cachedBackwardState.Clear();
+ _cachedWeightGradients.Clear();
+
+ // Synchronize parameters across stages for consistency
+ if (Config.AutoSyncGradients)
+ {
+ SynchronizeGradients();
}
+ }
- // Compute true gradients using the model's gradient computation
- // This provides accurate gradients before optimizer updates are applied
- var gradientVector = WrappedModel.ComputeGradients(stageInput, expectedOutput);
+ ///
+ /// Executes a forward operation, handling virtual stage routing and activation checkpointing.
+ ///
+ private void ExecuteForward(
+ PipelineOperation op,
+ Dictionary microBatches,
+ Dictionary forwardInputs,
+ Dictionary forwardOutputs,
+ int opKey)
+ {
+ var stageInput = GetStageInput(microBatches, op.MicroBatchIndex, op.VirtualStageIndex, forwardOutputs);
+
+ // Checkpoint activation if configured
+ if (ShouldCheckpointActivation(opKey))
+ {
+ var inputVector = ConversionsHelper.ConvertToVector(stageInput);
+ _checkpointedActivations[opKey] = inputVector;
+ }
- // Predict stage output for forward pass communication
+ forwardInputs[opKey] = stageInput;
+
+ // Forward pass through the model
var stageOutput = WrappedModel.Predict(stageInput);
+ forwardOutputs[opKey] = stageOutput;
- // FORWARD PASS: Send activations to next stage
- if (_stageId < _numStages - 1)
+ // Send activations to the next stage in the pipeline
+ SendActivationsForward(stageOutput, op.MicroBatchIndex, op.VirtualStageIndex);
+ }
+
+ ///
+ /// Slices input into micro-batches by converting to a vector and dividing evenly.
+ /// If the input cannot be sliced (e.g., single sample), all micro-batches use the same input.
+ ///
+ private Dictionary SliceInputIntoMicroBatches(TInput fullData)
+ {
+ var slices = new Dictionary();
+
+ if (_microBatchSize <= 1)
+ {
+ slices[0] = fullData;
+ return slices;
+ }
+
+ // Convert to vector for slicing
+ Vector fullVector;
+ try
+ {
+ fullVector = ConversionsHelper.ConvertToVector(fullData);
+ }
+ catch (InvalidOperationException)
+ {
+ throw new InvalidOperationException(
+ $"Cannot slice input of type {typeof(TInput).Name} into micro-batches. " +
+ "The input must be convertible to a vector for pipeline parallel training with micro-batches > 1.");
+ }
+
+ int totalElements = fullVector.Length;
+ int microBatchElements = totalElements / _microBatchSize;
+
+ if (microBatchElements <= 0)
+ {
+ throw new InvalidOperationException(
+ $"Cannot slice {totalElements} elements into {_microBatchSize} micro-batches. " +
+ $"Reduce pipelineMicroBatchSize to at most {totalElements}.");
+ }
+
+ var fullArray = fullVector.ToArray();
+ for (int i = 0; i < _microBatchSize; i++)
+ {
+ int startIdx = i * microBatchElements;
+ int size = (i == _microBatchSize - 1)
+ ? totalElements - startIdx // Last slice gets remainder
+ : microBatchElements;
+
+ var sliceData = new T[size];
+ Array.Copy(fullArray, startIdx, sliceData, 0, size);
+ var sliceVector = new Vector(sliceData);
+
+ slices[i] = ConversionsHelper.ConvertVectorToInputWithoutReference(sliceVector);
+ }
+
+ return slices;
+ }
+
+ ///
+ /// Slices target output into micro-batches by converting to a vector and dividing evenly.
+ /// If the target cannot be sliced, all micro-batches use the same target.
+ ///
+ private Dictionary SliceTargetIntoMicroBatches(TOutput fullTarget)
+ {
+ var slices = new Dictionary();
+
+ if (_microBatchSize <= 1)
+ {
+ slices[0] = fullTarget;
+ return slices;
+ }
+
+ Vector fullVector;
+ try
+ {
+ fullVector = ConversionsHelper.ConvertToVector(fullTarget);
+ }
+ catch (InvalidOperationException)
+ {
+ throw new InvalidOperationException(
+ $"Cannot slice target of type {typeof(TOutput).Name} into micro-batches. " +
+ "The target must be convertible to a vector for pipeline parallel training with micro-batches > 1.");
+ }
+
+ int totalElements = fullVector.Length;
+ int microBatchElements = totalElements / _microBatchSize;
+
+ if (microBatchElements <= 0)
+ {
+ throw new InvalidOperationException(
+ $"Cannot slice {totalElements} target elements into {_microBatchSize} micro-batches. " +
+ $"Reduce pipelineMicroBatchSize to at most {totalElements}.");
+ }
+
+ var fullArray = fullVector.ToArray();
+ for (int i = 0; i < _microBatchSize; i++)
+ {
+ int startIdx = i * microBatchElements;
+ int size = (i == _microBatchSize - 1)
+ ? totalElements - startIdx
+ : microBatchElements;
+
+ var sliceData = new T[size];
+ Array.Copy(fullArray, startIdx, sliceData, 0, size);
+ var sliceVector = new Vector(sliceData);
+
+ // Convert back via input conversion (TOutput and TInput use the same underlying mechanism)
+ slices[i] = ConversionsHelper.ConvertVectorToInputWithoutReference(sliceVector);
+ }
+
+ return slices;
+ }
+
+ ///
+ /// Gets a unique key for a (microBatchIndex, virtualStageIndex) combination.
+ ///
+ private int GetOperationKey(int microBatchIndex, int virtualStageIndex)
+ {
+ return microBatchIndex * _virtualStagesPerRank + virtualStageIndex;
+ }
+
+ ///
+ /// Gets the input for this stage, receiving from previous stage if needed.
+ /// For multi-stage schedules, routes based on virtual stage index.
+ ///
+ private TInput GetStageInput(
+ Dictionary microBatches, int microBatchIndex, int virtualStageIndex,
+ Dictionary? forwardOutputs = null)
+ {
+ // For virtual stage 0 of this rank, receive from the previous rank's last virtual stage
+ // For subsequent virtual stages, use the forward output from this rank's previous virtual stage
+ bool isFirstVirtualStageOnRank = virtualStageIndex == 0;
+
+ if (isFirstVirtualStageOnRank && _stageId > 0)
+ {
+ // Receive from previous rank (its last virtual stage's output)
+ int tag = ComputeForwardTag(microBatchIndex, virtualStageIndex);
+ Vector sizeHeader = Config.CommunicationBackend.Receive(
+ _stageId - 1, count: 1, tag: tag);
+ int activationSize = NumOps.ToInt32(sizeHeader[0]);
+
+ Vector receivedActivations = Config.CommunicationBackend.Receive(
+ _stageId - 1, activationSize, tag: tag);
+
+ return ConversionsHelper.ConvertVectorToInputWithoutReference(receivedActivations);
+ }
+
+ if (isFirstVirtualStageOnRank && microBatches.TryGetValue(microBatchIndex, out var microBatch))
+ {
+ // First stage, first virtual stage: use the micro-batch input directly
+ return microBatch;
+ }
+
+ // For non-first virtual stages on this rank: use the forward output from the
+ // previous virtual stage on the same micro-batch.
+ if (!isFirstVirtualStageOnRank && forwardOutputs is not null)
+ {
+ int prevVStageKey = GetOperationKey(microBatchIndex, virtualStageIndex - 1);
+ if (forwardOutputs.TryGetValue(prevVStageKey, out var prevOutput))
+ {
+ // Convert the previous virtual stage's output to an input for the next stage
+ var outputVector = ConversionsHelper.ConvertToVector(prevOutput);
+ return ConversionsHelper.ConvertVectorToInputWithoutReference(outputVector);
+ }
+ }
+
+ // Should not reach here in normal operation
+ throw new InvalidOperationException(
+ $"No input available for micro-batch {microBatchIndex}, virtual stage {virtualStageIndex}. " +
+ (isFirstVirtualStageOnRank
+ ? "Expected micro-batch input was not found."
+ : $"Forward output from virtual stage {virtualStageIndex - 1} was not found. " +
+ "Ensure the schedule processes virtual stages in order."));
+ }
+
+ ///
+ /// Gets the target for a specific micro-batch.
+ ///
+ private TOutput GetMicroBatchTarget(int microBatchIndex, Dictionary microBatchTargets, TOutput fullTarget)
+ {
+ if (microBatchTargets.TryGetValue(microBatchIndex, out var target))
+ {
+ return target;
+ }
+ return fullTarget;
+ }
+
+ ///
+ /// Sends activations to the next stage in the pipeline.
+ /// For multi-stage schedules, only sends when transitioning between ranks.
+ ///
+ private void SendActivationsForward(TOutput stageOutput, int microBatchIndex, int virtualStageIndex)
+ {
+ // Only send to next rank when this is the last virtual stage on this rank
+ bool isLastVirtualStageOnRank = virtualStageIndex == _virtualStagesPerRank - 1;
+
+ if (isLastVirtualStageOnRank && _stageId < _numStages - 1)
{
Vector activationsToSend = ConversionsHelper.ConvertToVector(stageOutput);
+ int tag = ComputeForwardTag(microBatchIndex, 0); // Next rank receives at vStage 0
- // Protocol: First send 1-element size header, then send activations
- // This allows receiver to know the exact size of incoming activations
var sizeHeader = new Vector(new[] { NumOps.FromDouble(activationsToSend.Length) });
- Config.CommunicationBackend.Send(sizeHeader, _stageId + 1, tag: 0);
- Config.CommunicationBackend.Send(activationsToSend, _stageId + 1, tag: 0);
+ Config.CommunicationBackend.Send(sizeHeader, _stageId + 1, tag: tag);
+ Config.CommunicationBackend.Send(activationsToSend, _stageId + 1, tag: tag);
}
+ }
- // BACKWARD PASS: Gradient communication
- // Gradients flow backward through the pipeline (opposite direction of activations)
- if (_stageId < _numStages - 1)
+ ///
+ /// Computes a unique communication tag for forward pass activations.
+ /// Tags are in the range [ActivationTagBase, GradientTagBase).
+ ///
+ private int ComputeForwardTag(int microBatchIndex, int virtualStageIndex)
+ {
+ return ActivationTagBase + microBatchIndex * (_virtualStagesPerRank + 1) + virtualStageIndex;
+ }
+
+ ///
+ /// Computes a unique communication tag for backward pass gradients.
+ /// Tags are in the range [GradientTagBase, PredictTagBase).
+ ///
+ private int ComputeBackwardTag(int microBatchIndex, int virtualStageIndex)
+ {
+ return GradientTagBase + microBatchIndex * (_virtualStagesPerRank + 1) + virtualStageIndex;
+ }
+
+ ///
+ /// Determines whether an activation should be checkpointed based on configuration.
+ ///
+ private bool ShouldCheckpointActivation(int opKey)
+ {
+ if (!_checkpointConfig.Enabled)
+ {
+ return false;
+ }
+
+ // MaxActivationsInMemory > 0 overrides interval-based checkpointing
+ if (_checkpointConfig.MaxActivationsInMemory > 0)
+ {
+ return _checkpointedActivations.Count < _checkpointConfig.MaxActivationsInMemory;
+ }
+
+ // CheckpointFirstLayer: always checkpoint opKey 0 if enabled
+ if (_checkpointConfig.CheckpointFirstLayer && opKey == 0)
+ {
+ return true;
+ }
+
+ // Interval-based checkpointing (CheckpointEveryNLayers validated >= 1 in setter)
+ return _checkpointConfig.CheckpointEveryNLayers > 0
+ && opKey % _checkpointConfig.CheckpointEveryNLayers == 0;
+ }
+
+ ///
+ /// Retrieves the input for a micro-batch from cache, checkpoint, or recomputes it.
+ /// Implements activation checkpointing recomputation when enabled.
+ ///
+ private TInput RetrieveMicroBatchInput(
+ int opKey,
+ Dictionary forwardInputs,
+ Dictionary microBatches,
+ PipelineOperation op)
+ {
+ // Check if input is still cached from forward pass
+ if (forwardInputs.TryGetValue(opKey, out var cachedInput))
+ {
+ return cachedInput;
+ }
+
+ // Check activation checkpoints
+ if (_checkpointConfig.Enabled && _checkpointedActivations.TryGetValue(opKey, out var checkpointedVector))
+ {
+ // Found a checkpoint - recompute from it if needed
+ var recomputedInput = ConversionsHelper.ConvertVectorToInputWithoutReference(checkpointedVector);
+
+ // If the checkpoint is for this exact operation, return directly
+ return recomputedInput;
+ }
+
+ // Check if there's a nearby checkpoint to recompute from
+ // NOTE: Currently unreachable because the constructor rejects RecomputeStrategy != None.
+ // This is infrastructure for future recompute support (Selective/Full strategies).
+ if (_checkpointConfig.Enabled && _checkpointConfig.RecomputeStrategy != RecomputeStrategy.None)
+ {
+ // Find the nearest earlier checkpoint within the SAME micro-batch.
+ // opKey = microBatchIndex * _virtualStagesPerRank + virtualStageIndex,
+ // so the current micro-batch's first key is microBatchIndex * _virtualStagesPerRank.
+ int microBatchStartKey = op.MicroBatchIndex * _virtualStagesPerRank;
+ int nearestCheckpointKey = -1;
+ for (int searchKey = opKey - 1; searchKey >= microBatchStartKey; searchKey--)
+ {
+ if (_checkpointedActivations.ContainsKey(searchKey))
+ {
+ nearestCheckpointKey = searchKey;
+ break;
+ }
+ }
+
+ if (nearestCheckpointKey >= 0)
+ {
+ // Recompute forward from the nearest checkpoint to reconstruct the needed activation
+ var checkpointVector = _checkpointedActivations[nearestCheckpointKey];
+ var recomputeInput = ConversionsHelper.ConvertVectorToInputWithoutReference(checkpointVector);
+
+ // Run forward passes from checkpoint to target, recomputing activations
+ TInput currentInput = recomputeInput;
+ for (int step = nearestCheckpointKey; step < opKey; step++)
+ {
+ var stepOutput = WrappedModel.Predict(currentInput);
+ currentInput = ConversionsHelper.ConvertVectorToInputWithoutReference(
+ ConversionsHelper.ConvertToVector(stepOutput));
+ }
+
+ return currentInput;
+ }
+ }
+
+ // Fallback: use the original micro-batch input
+ return GetStageInput(microBatches, op.MicroBatchIndex, op.VirtualStageIndex);
+ }
+
+ ///
+ /// Receives gradients from the downstream (next) stage and accumulates them.
+ /// For multi-stage schedules, handles virtual stage routing.
+ ///
+ private void ReceiveAndAccumulateDownstreamGradients(
+ Vector gradientVector, int microBatchIndex, int virtualStageIndex)
+ {
+ // Only receive from next rank when this is the last virtual stage on this rank
+ bool isLastVirtualStageOnRank = virtualStageIndex == _virtualStagesPerRank - 1;
+
+ if (isLastVirtualStageOnRank && _stageId < _numStages - 1)
{
- // Non-last stages receive gradient contributions from next stage
- Vector nextStageGradients = Config.CommunicationBackend.Receive(_stageId + 1, gradientVector.Length, tag: 1);
+ int tag = ComputeBackwardTag(microBatchIndex, virtualStageIndex);
+ Vector nextStageGradients = Config.CommunicationBackend.Receive(
+ _stageId + 1, gradientVector.Length, tag: tag);
- // Accumulate gradients: local gradients + gradients from downstream stages
for (int i = 0; i < gradientVector.Length; i++)
{
gradientVector[i] = NumOps.Add(gradientVector[i], nextStageGradients[i]);
}
}
+ }
- if (_stageId > 0)
+ ///
+ /// Sends gradients to the upstream (previous) stage.
+ /// For multi-stage schedules, handles virtual stage routing.
+ ///
+ private void SendGradientsUpstream(Vector gradientVector, int microBatchIndex, int virtualStageIndex)
+ {
+ // Only send to previous rank when this is the first virtual stage on this rank
+ bool isFirstVirtualStageOnRank = virtualStageIndex == 0;
+
+ if (isFirstVirtualStageOnRank && _stageId > 0)
{
- // Non-first stages send accumulated gradients to previous stage
- Config.CommunicationBackend.Send(gradientVector, _stageId - 1, tag: 1);
+ int tag = ComputeBackwardTag(microBatchIndex, _virtualStagesPerRank - 1);
+ Config.CommunicationBackend.Send(gradientVector, _stageId - 1, tag: tag);
}
+ }
- // Apply accumulated gradients to parameters using the configured learning rate
- // In pipeline parallelism, we use a simple SGD-style update: θ = θ - lr * gradients
- // For more sophisticated optimization, wrap this model with a gradient-based optimizer
- WrappedModel.SetParameters(parametersBefore);
- WrappedModel.ApplyGradients(gradientVector, Config.LearningRate);
+ ///
+ /// Accumulates gradients across micro-batches.
+ ///
+ private Vector AccumulateGradients(Vector? accumulated, Vector newGradients)
+ {
+ if (accumulated is null)
+ {
+ // Clone to avoid mutating the original
+ var copy = new T[newGradients.Length];
+ for (int i = 0; i < newGradients.Length; i++)
+ {
+ copy[i] = newGradients[i];
+ }
+ return new Vector(copy);
+ }
- // Extract this stage's parameter shard
- var updatedParams = WrappedModel.GetParameters();
- UpdateLocalShardFromFull(updatedParams);
- InvalidateCache();
+ if (accumulated.Length != newGradients.Length)
+ {
+ throw new InvalidOperationException(
+ $"Gradient length mismatch: accumulated has {accumulated.Length} elements " +
+ $"but new gradients have {newGradients.Length} elements.");
+ }
- // Synchronize parameters across stages for consistency
- if (Config.AutoSyncGradients)
+ for (int i = 0; i < accumulated.Length; i++)
{
- SynchronizeGradients();
+ accumulated[i] = NumOps.Add(accumulated[i], newGradients[i]);
+ }
+
+ return accumulated;
+ }
+
+ ///
+ /// Frees non-checkpointed activations to save memory.
+ ///
+ private void FreeNonCheckpointedActivations(
+ int opKey, Dictionary forwardInputs, Dictionary forwardOutputs)
+ {
+ if (!_checkpointedActivations.ContainsKey(opKey))
+ {
+ forwardInputs.Remove(opKey);
+ forwardOutputs.Remove(opKey);
}
}
///
public override TOutput Predict(TInput input)
{
- // Pipeline forward pass for inference
- // Activations flow through stages sequentially
-
var fullParams = GatherFullParameters();
WrappedModel.SetParameters(fullParams);
- // Determine actual input for this stage
TInput stageInput = input;
- // FORWARD PASS: Receive activations from previous stage
if (_stageId > 0)
{
- // Protocol: First receive 1-element size header, then receive activations
- // This prevents size mismatches when stage output size differs from input size
- Vector sizeHeader = Config.CommunicationBackend.Receive(_stageId - 1, count: 1, tag: 10);
+ int tag = PredictTagBase;
+ Vector sizeHeader = Config.CommunicationBackend.Receive(_stageId - 1, count: 1, tag: tag);
int activationSize = NumOps.ToInt32(sizeHeader[0]);
- Vector receivedActivations = Config.CommunicationBackend.Receive(_stageId - 1, activationSize, tag: 10);
-
- // For intermediate stages, convert received activations to TInput type WITHOUT using
- // the original input as reference (which would have the wrong shape for non-first stages).
- // Use ConversionsHelper to centralize conversion logic and avoid code duplication.
+ Vector receivedActivations = Config.CommunicationBackend.Receive(_stageId - 1, activationSize, tag: tag);
stageInput = ConversionsHelper.ConvertVectorToInputWithoutReference(receivedActivations);
}
- // Process through this stage's layers
TOutput stageOutput = WrappedModel.Predict(stageInput);
- // FORWARD PASS: Send activations to next stage
if (_stageId < _numStages - 1)
{
- // Non-last stages send their output to next stage
+ int tag = PredictTagBase;
Vector activationsToSend = ConversionsHelper.ConvertToVector(stageOutput);
- // Protocol: First send 1-element size header, then send activations
- // This allows receiver to know the exact size of incoming activations
var sizeHeader = new Vector(new[] { NumOps.FromDouble(activationsToSend.Length) });
- Config.CommunicationBackend.Send(sizeHeader, _stageId + 1, tag: 10);
- Config.CommunicationBackend.Send(activationsToSend, _stageId + 1, tag: 10);
-
- // Intermediate stages must still return a value
- // Return the stage output (caller should only use output from last stage)
- return stageOutput;
+ Config.CommunicationBackend.Send(sizeHeader, _stageId + 1, tag: tag);
+ Config.CommunicationBackend.Send(activationsToSend, _stageId + 1, tag: tag);
}
- // Last stage returns the final prediction
return stageOutput;
}
@@ -283,6 +1000,12 @@ public override ModelMetadata GetModelMetadata()
metadata.SetProperty("StageId", _stageId);
metadata.SetProperty("NumStages", _numStages);
metadata.SetProperty("MicroBatchSize", _microBatchSize);
+ metadata.SetProperty("Schedule", _schedule.Name);
+ metadata.SetProperty("VirtualStagesPerRank", _virtualStagesPerRank);
+ metadata.SetProperty("EstimatedBubbleFraction", EstimatedBubbleFraction);
+ metadata.SetProperty("ActivationCheckpointing", _checkpointConfig.Enabled);
+ metadata.SetProperty("PartitionStrategy", _partitionStrategy?.GetType().Name ?? "Uniform");
+ metadata.SetProperty("SupportsDecomposedBackward", _supportsDecomposedBackward);
return metadata;
}
@@ -290,7 +1013,8 @@ public override ModelMetadata GetModelMetadata()
public override IFullModel WithParameters(Vector parameters)
{
return new PipelineParallelModel(
- WrappedModel.WithParameters(parameters), Config, _microBatchSize);
+ WrappedModel.WithParameters(parameters), Config, _microBatchSize,
+ _partitionStrategy, _schedule, _checkpointConfig);
}
///
@@ -304,6 +1028,10 @@ public override byte[] Serialize()
writer.Write(Config.AutoSyncGradients);
writer.Write(Config.MinimumParameterGroupSize);
writer.Write(Config.EnableGradientCompression);
+ writer.Write(_schedule.Name);
+ writer.Write(_checkpointConfig.Enabled);
+ writer.Write(_checkpointConfig.CheckpointEveryNLayers);
+ writer.Write(_virtualStagesPerRank);
var modelData = WrappedModel.Serialize();
writer.Write(modelData.Length);
writer.Write(modelData);
@@ -318,9 +1046,13 @@ public override void Deserialize(byte[] data)
int savedWorldSize = reader.ReadInt32();
int savedRank = reader.ReadInt32();
int savedMicroBatchSize = reader.ReadInt32();
- reader.ReadBoolean();
- reader.ReadInt32();
- reader.ReadBoolean();
+ reader.ReadBoolean(); // AutoSyncGradients
+ reader.ReadInt32(); // MinimumParameterGroupSize
+ reader.ReadBoolean(); // EnableGradientCompression
+ reader.ReadString(); // Schedule name (informational)
+ reader.ReadBoolean(); // Checkpointing enabled
+ reader.ReadInt32(); // CheckpointEveryNLayers
+ reader.ReadInt32(); // VirtualStagesPerRank (informational)
if (savedWorldSize != WorldSize)
throw new InvalidOperationException($"World size mismatch: {savedWorldSize} vs {WorldSize}");
@@ -368,6 +1100,8 @@ public override void LoadModel(string filePath)
///
public override IFullModel Clone()
{
- return new PipelineParallelModel(WrappedModel.Clone(), Config, _microBatchSize);
+ return new PipelineParallelModel(
+ WrappedModel.Clone(), Config, _microBatchSize,
+ _partitionStrategy, _schedule, _checkpointConfig);
}
}
diff --git a/src/DistributedTraining/UniformPartitionStrategy.cs b/src/DistributedTraining/UniformPartitionStrategy.cs
new file mode 100644
index 000000000..aa0c86672
--- /dev/null
+++ b/src/DistributedTraining/UniformPartitionStrategy.cs
@@ -0,0 +1,49 @@
+using AiDotNet.Interfaces;
+
+namespace AiDotNet.DistributedTraining;
+
+///
+/// Divides model parameters evenly across pipeline stages.
+///
+///
+///
+/// This is the simplest partitioning strategy: each stage gets approximately the same
+/// number of parameters. When the total isn't evenly divisible, earlier stages get one
+/// extra parameter each.
+///
+/// For Beginners: This is the default strategy. It splits the model like cutting
+/// a cake into equal slices. It works well when all layers have similar computational cost,
+/// but can cause imbalance when some layers (like attention) are much heavier than others.
+///
+///
+/// The numeric type for operations.
+public class UniformPartitionStrategy : IPipelinePartitionStrategy
+{
+ ///
+ public (int StartIndex, int Size)[] ComputePartition(int totalParameters, int numStages)
+ {
+ if (totalParameters <= 0)
+ {
+ throw new ArgumentException("Total parameters must be positive.", nameof(totalParameters));
+ }
+
+ if (numStages <= 0)
+ {
+ throw new ArgumentException("Number of stages must be positive.", nameof(numStages));
+ }
+
+ var partitions = new (int StartIndex, int Size)[numStages];
+ int baseSize = totalParameters / numStages;
+ int remainder = totalParameters % numStages;
+ int currentStart = 0;
+
+ for (int i = 0; i < numStages; i++)
+ {
+ int size = baseSize + (i < remainder ? 1 : 0);
+ partitions[i] = (currentStart, size);
+ currentStart += size;
+ }
+
+ return partitions;
+ }
+}
diff --git a/src/DistributedTraining/ZeroBubbleH1Schedule.cs b/src/DistributedTraining/ZeroBubbleH1Schedule.cs
new file mode 100644
index 000000000..27e7eb6b1
--- /dev/null
+++ b/src/DistributedTraining/ZeroBubbleH1Schedule.cs
@@ -0,0 +1,169 @@
+using AiDotNet.Interfaces;
+
+namespace AiDotNet.DistributedTraining;
+
+///
+/// Implements the Zero Bubble H1 (ZB-H1) pipeline schedule.
+///
+///
+///
+/// ZB-H1 splits the backward pass into two independent computations:
+/// - B (BackwardInput): Computes activation gradients (dL/dInput) - on the critical path.
+/// - W (BackwardWeight): Computes weight gradients (dL/dWeights) - can be deferred.
+///
+/// By deferring W to fill pipeline bubbles, ZB-H1 reduces the bubble to approximately
+/// one-third of 1F1B's bubble while maintaining the same peak memory footprint.
+///
+/// For Beginners: In standard 1F1B, the backward pass computes both activation and
+/// weight gradients together. ZB-H1 splits this into two steps. The activation gradient (B)
+/// must be done quickly (the previous stage is waiting), but the weight gradient (W) can wait.
+/// By scheduling W during idle time, we reduce wasted time by ~67% compared to 1F1B.
+///
+/// Think of it like a car wash: the "rinse" (B) must happen right after soap, but "waxing" (W)
+/// can be done whenever there's a free slot.
+///
+/// Reference: Qi et al., "Zero Bubble Pipeline Parallelism", ICLR 2024 Spotlight.
+/// https://arxiv.org/abs/2401.10241
+///
+public class ZeroBubbleH1Schedule : IPipelineSchedule
+{
+ ///
+ public string Name => "ZB-H1";
+
+ ///
+ public int VirtualStagesPerRank => 1;
+
+ ///
+ public IReadOnlyList GetSchedule(int stageId, int numStages, int numMicroBatches)
+ {
+ if (numStages <= 0)
+ {
+ throw new ArgumentException("Number of stages must be positive.", nameof(numStages));
+ }
+
+ if (numMicroBatches <= 0)
+ {
+ throw new ArgumentException("Number of micro-batches must be positive.", nameof(numMicroBatches));
+ }
+
+ if (stageId < 0 || stageId >= numStages)
+ {
+ throw new ArgumentOutOfRangeException(nameof(stageId),
+ $"Stage ID must be between 0 and {numStages - 1}.");
+ }
+
+ var ops = new List();
+
+ // ZB-H1 follows 1F1B structure but splits backward into B + W
+ // Key constraint: maintain same number of in-flight micro-batches as 1F1B
+ // (i.e., at most numStages micro-batches stored at once)
+
+ int numWarmupForwards = Math.Min(numStages - 1 - stageId, numMicroBatches);
+ int numSteadyState = Math.Max(0, numMicroBatches - numWarmupForwards);
+
+ // Phase 1: Warmup - forward passes only (same as 1F1B)
+ int forwardIdx = 0;
+ for (int i = 0; i < numWarmupForwards; i++)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = forwardIdx,
+ IsWarmup = true,
+ IsCooldown = false
+ });
+ forwardIdx++;
+ }
+
+ // Phase 2: Steady state - 1F-1B-1W pattern
+ // For each steady-state step: one Forward, one BackwardInput, and
+ // schedule BackwardWeight for the micro-batch that completed B earliest.
+ int backwardInputIdx = 0;
+ int backwardWeightIdx = 0;
+
+ for (int i = 0; i < numSteadyState; i++)
+ {
+ // Forward
+ if (forwardIdx < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = forwardIdx,
+ IsWarmup = false,
+ IsCooldown = false
+ });
+ forwardIdx++;
+ }
+
+ // BackwardInput (B) - on the critical path
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardInput,
+ MicroBatchIndex = backwardInputIdx,
+ IsWarmup = false,
+ IsCooldown = false
+ });
+ backwardInputIdx++;
+
+ // BackwardWeight (W) - fills bubbles, scheduled for earlier micro-batch
+ // ZB-H1 constraint: W starts only after enough B steps to maintain
+ // the same in-flight count as 1F1B
+ if (backwardWeightIdx < backwardInputIdx && backwardWeightIdx < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardWeight,
+ MicroBatchIndex = backwardWeightIdx,
+ IsWarmup = false,
+ IsCooldown = false
+ });
+ backwardWeightIdx++;
+ }
+ }
+
+ // Phase 3: Cooldown - remaining B and W passes
+ while (backwardInputIdx < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardInput,
+ MicroBatchIndex = backwardInputIdx,
+ IsWarmup = false,
+ IsCooldown = true
+ });
+ backwardInputIdx++;
+ }
+
+ // Drain remaining W passes
+ while (backwardWeightIdx < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardWeight,
+ MicroBatchIndex = backwardWeightIdx,
+ IsWarmup = false,
+ IsCooldown = true
+ });
+ backwardWeightIdx++;
+ }
+
+ return ops;
+ }
+
+ ///
+ public double EstimateBubbleFraction(int numStages, int numMicroBatches)
+ {
+ if (numStages <= 1 || numMicroBatches <= 0)
+ {
+ return 0.0;
+ }
+
+ // ZB-H1 bubble is approximately 1/3 of 1F1B's bubble
+ // 1F1B bubble: (P-1) / (2*M + P - 1)
+ // ZB-H1 bubble: ~(P-1) / (3*M + P - 1)
+ long p = numStages;
+ long m = numMicroBatches;
+ return (double)(p - 1) / (3 * m + p - 1);
+ }
+}
diff --git a/src/DistributedTraining/ZeroBubbleH2Schedule.cs b/src/DistributedTraining/ZeroBubbleH2Schedule.cs
new file mode 100644
index 000000000..c86f18c83
--- /dev/null
+++ b/src/DistributedTraining/ZeroBubbleH2Schedule.cs
@@ -0,0 +1,182 @@
+using AiDotNet.Interfaces;
+
+namespace AiDotNet.DistributedTraining;
+
+///
+/// Implements the Zero Bubble H2 (ZB-H2) pipeline schedule.
+///
+///
+///
+/// ZB-H2 achieves true zero pipeline bubble by allowing more in-flight micro-batches
+/// than 1F1B, trading peak memory for throughput. Like ZB-H1, it splits backward into
+/// BackwardInput (B) and BackwardWeight (W), but schedules more aggressively.
+///
+/// For Beginners: ZB-H2 is the "maximum throughput" variant. It allows more
+/// micro-batches to be in progress simultaneously (using more memory) to completely
+/// eliminate idle time. If you have enough GPU memory, ZB-H2 gives the best possible
+/// pipeline utilization.
+///
+/// The tradeoff:
+/// - ZB-H1: Same memory as 1F1B, ~1/3 bubble
+/// - ZB-H2: More memory than 1F1B, ~0% bubble (zero idle time)
+///
+/// Reference: Qi et al., "Zero Bubble Pipeline Parallelism", ICLR 2024 Spotlight.
+/// https://arxiv.org/abs/2401.10241
+///
+public class ZeroBubbleH2Schedule : IPipelineSchedule
+{
+ ///
+ public string Name => "ZB-H2";
+
+ ///
+ public int VirtualStagesPerRank => 1;
+
+ ///
+ public IReadOnlyList GetSchedule(int stageId, int numStages, int numMicroBatches)
+ {
+ if (numStages <= 0)
+ {
+ throw new ArgumentException("Number of stages must be positive.", nameof(numStages));
+ }
+
+ if (numMicroBatches <= 0)
+ {
+ throw new ArgumentException("Number of micro-batches must be positive.", nameof(numMicroBatches));
+ }
+
+ if (stageId < 0 || stageId >= numStages)
+ {
+ throw new ArgumentOutOfRangeException(nameof(stageId),
+ $"Stage ID must be between 0 and {numStages - 1}.");
+ }
+
+ var ops = new List();
+
+ // ZB-H2 allows more warmup forwards than 1F1B to fill the pipeline more aggressively.
+ // The key difference from ZB-H1: we allow up to (numStages - 1) additional in-flight
+ // micro-batches, which uses more memory but fills all bubbles.
+
+ // Extended warmup: later stages (higher stageId) get fewer warmup forwards
+ // because their inputs arrive later in the pipeline.
+ // Stage 0 gets up to numStages warmup forwards, stage (numStages-1) gets 1.
+ int numWarmupForwards = Math.Min(numStages - stageId, numMicroBatches);
+
+ // Phase 1: Extended warmup - more forward passes to fill pipeline completely
+ int forwardIdx = 0;
+ for (int i = 0; i < numWarmupForwards; i++)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = forwardIdx,
+ IsWarmup = true,
+ IsCooldown = false
+ });
+ forwardIdx++;
+ }
+
+ // Phase 2: Steady state - interleave F, B, W to maintain zero bubble
+ int backwardInputIdx = 0;
+ int backwardWeightIdx = 0;
+ int steadyStateCount = Math.Max(0, numMicroBatches - numWarmupForwards);
+
+ for (int i = 0; i < steadyStateCount; i++)
+ {
+ // BackwardInput (B) first - critical path
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardInput,
+ MicroBatchIndex = backwardInputIdx,
+ IsWarmup = false,
+ IsCooldown = false
+ });
+ backwardInputIdx++;
+
+ // Forward for next micro-batch
+ if (forwardIdx < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = forwardIdx,
+ IsWarmup = false,
+ IsCooldown = false
+ });
+ forwardIdx++;
+ }
+
+ // BackwardWeight (W) - fills any remaining time
+ if (backwardWeightIdx < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardWeight,
+ MicroBatchIndex = backwardWeightIdx,
+ IsWarmup = false,
+ IsCooldown = false
+ });
+ backwardWeightIdx++;
+ }
+ }
+
+ // Phase 3: Cooldown - drain remaining B and W
+ while (backwardInputIdx < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardInput,
+ MicroBatchIndex = backwardInputIdx,
+ IsWarmup = false,
+ IsCooldown = true
+ });
+ backwardInputIdx++;
+
+ // Interleave W during cooldown
+ if (backwardWeightIdx < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardWeight,
+ MicroBatchIndex = backwardWeightIdx,
+ IsWarmup = false,
+ IsCooldown = true
+ });
+ backwardWeightIdx++;
+ }
+ }
+
+ // Final W drain
+ while (backwardWeightIdx < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardWeight,
+ MicroBatchIndex = backwardWeightIdx,
+ IsWarmup = false,
+ IsCooldown = true
+ });
+ backwardWeightIdx++;
+ }
+
+ return ops;
+ }
+
+ ///
+ public double EstimateBubbleFraction(int numStages, int numMicroBatches)
+ {
+ if (numStages <= 1 || numMicroBatches <= 0)
+ {
+ return 0.0;
+ }
+
+ // ZB-H2 achieves near-zero bubble when numMicroBatches >= numStages
+ // For insufficient micro-batches, there's still some residual bubble
+ if (numMicroBatches >= numStages)
+ {
+ return 0.0;
+ }
+
+ // Fallback estimate for small M
+ return (double)((long)numStages - numMicroBatches) / (3L * numMicroBatches + numStages);
+ }
+}
diff --git a/src/DistributedTraining/ZeroBubbleVSchedule.cs b/src/DistributedTraining/ZeroBubbleVSchedule.cs
new file mode 100644
index 000000000..44aabdaa7
--- /dev/null
+++ b/src/DistributedTraining/ZeroBubbleVSchedule.cs
@@ -0,0 +1,262 @@
+using AiDotNet.Interfaces;
+
+namespace AiDotNet.DistributedTraining;
+
+///
+/// Implements the Zero Bubble V (ZB-V) pipeline schedule with 2 virtual stages per rank.
+///
+///
+///
+/// ZB-V combines the backward decomposition of ZB-H1/H2 with the virtual stage concept of
+/// Interleaved 1F1B, using exactly V=2 virtual stages per rank. Each rank processes two
+/// non-contiguous model chunks, creating a V-shaped execution pattern that achieves zero
+/// pipeline bubble with the same peak memory as standard 1F1B.
+///
+///
+/// The V-shape comes from the execution pattern on each rank:
+/// - First half: Forward passes fill from top to bottom (forward through virtual stage 0)
+/// - Middle: V-shaped transition from forward to backward
+/// - Second half: Backward passes drain from bottom to top (backward through virtual stage 1)
+///
+/// For Beginners: ZB-V is the best of both worlds:
+/// - Like Interleaved 1F1B: uses 2 model chunks per GPU to reduce bubble
+/// - Like ZB-H1: splits backward into B (activation gradients) and W (weight gradients)
+/// - Unlike ZB-H2: does NOT use extra memory (same as 1F1B)
+///
+/// The result is zero pipeline bubble with no extra memory cost. The tradeoff is slightly
+/// more communication (each microbatch crosses each GPU twice) and implementation complexity.
+///
+/// Example with 4 GPUs (8 total virtual stages):
+/// - GPU 0: virtual stages 0 and 4
+/// - GPU 1: virtual stages 1 and 5
+/// - GPU 2: virtual stages 2 and 6
+/// - GPU 3: virtual stages 3 and 7
+///
+/// Each microbatch flows: 0->1->2->3->4->5->6->7 (visiting each GPU twice).
+///
+/// Reference: Qi et al., "Zero Bubble Pipeline Parallelism", ICLR 2024 Spotlight.
+/// https://arxiv.org/abs/2401.10241
+///
+public class ZeroBubbleVSchedule : IPipelineSchedule
+{
+ ///
+ public string Name => "ZB-V";
+
+ ///
+ public int VirtualStagesPerRank => 2;
+
+ ///
+ public IReadOnlyList GetSchedule(int stageId, int numStages, int numMicroBatches)
+ {
+ if (numStages <= 0)
+ {
+ throw new ArgumentException("Number of stages must be positive.", nameof(numStages));
+ }
+
+ if (numMicroBatches <= 0)
+ {
+ throw new ArgumentException("Number of micro-batches must be positive.", nameof(numMicroBatches));
+ }
+
+ if (stageId < 0 || stageId >= numStages)
+ {
+ throw new ArgumentOutOfRangeException(nameof(stageId),
+ $"Stage ID must be between 0 and {numStages - 1}.");
+ }
+
+ var ops = new List();
+
+ // ZB-V uses exactly 2 virtual stages per rank (V=2).
+ // Virtual stage IDs for rank stageId: stageId (chunk 0) and stageId + numStages (chunk 1).
+ //
+ // The schedule interleaves F/B/W operations across both virtual stages:
+ // - Forward on virtual stage 0 (chunk 0)
+ // - Forward on virtual stage 1 (chunk 1)
+ // - BackwardInput on virtual stage 1 (chunk 1, reverse order)
+ // - BackwardInput on virtual stage 0 (chunk 0, reverse order)
+ // - BackwardWeight fills any remaining gaps
+
+ // Warmup: forwards across both virtual stages
+ // Number of warmup forwards scales with position in pipeline
+ int warmupForwardsPerChunk = Math.Min(numStages - 1 - stageId, numMicroBatches);
+
+ int forwardCount0 = 0; // Forward count for virtual stage 0
+ int forwardCount1 = 0; // Forward count for virtual stage 1
+ int backwardInputCount0 = 0;
+ int backwardInputCount1 = 0;
+ int backwardWeightCount0 = 0;
+ int backwardWeightCount1 = 0;
+
+ // Phase 1: Warmup - interleaved forwards across both virtual stages
+ // Depth-first: complete a microbatch through both chunks before starting next
+ for (int i = 0; i < warmupForwardsPerChunk && forwardCount0 < numMicroBatches; i++)
+ {
+ // Forward on chunk 0
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = forwardCount0,
+ VirtualStageIndex = 0,
+ IsWarmup = true,
+ IsCooldown = false
+ });
+ forwardCount0++;
+
+ // Forward on chunk 1 for the same microbatch (if chunk 0 output is ready)
+ if (forwardCount1 < forwardCount0 && forwardCount1 < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = forwardCount1,
+ VirtualStageIndex = 1,
+ IsWarmup = true,
+ IsCooldown = false
+ });
+ forwardCount1++;
+ }
+ }
+
+ // Phase 2: Steady state - F0, F1, B1, B0, W interleaving
+ // Continue until all forwards and backwards are complete
+ while (forwardCount0 < numMicroBatches ||
+ forwardCount1 < numMicroBatches ||
+ backwardInputCount0 < numMicroBatches ||
+ backwardInputCount1 < numMicroBatches)
+ {
+ bool isCooldown = forwardCount0 >= numMicroBatches && forwardCount1 >= numMicroBatches;
+
+ // Forward on chunk 0 (if available)
+ if (forwardCount0 < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = forwardCount0,
+ VirtualStageIndex = 0,
+ IsWarmup = false,
+ IsCooldown = false
+ });
+ forwardCount0++;
+ }
+
+ // Forward on chunk 1 (if chunk 0 has produced output for this microbatch)
+ if (forwardCount1 < forwardCount0 && forwardCount1 < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.Forward,
+ MicroBatchIndex = forwardCount1,
+ VirtualStageIndex = 1,
+ IsWarmup = false,
+ IsCooldown = false
+ });
+ forwardCount1++;
+ }
+
+ // BackwardInput on chunk 1 (reverse order - B step, critical path)
+ if (backwardInputCount1 < forwardCount1 && backwardInputCount1 < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardInput,
+ MicroBatchIndex = backwardInputCount1,
+ VirtualStageIndex = 1,
+ IsWarmup = false,
+ IsCooldown = isCooldown
+ });
+ backwardInputCount1++;
+ }
+
+ // BackwardInput on chunk 0 (after chunk 1's B is done for this microbatch)
+ if (backwardInputCount0 < backwardInputCount1 && backwardInputCount0 < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardInput,
+ MicroBatchIndex = backwardInputCount0,
+ VirtualStageIndex = 0,
+ IsWarmup = false,
+ IsCooldown = isCooldown
+ });
+ backwardInputCount0++;
+ }
+
+ // BackwardWeight (W) - fills bubbles, process whichever chunk has pending W
+ if (backwardWeightCount1 < backwardInputCount1 && backwardWeightCount1 < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardWeight,
+ MicroBatchIndex = backwardWeightCount1,
+ VirtualStageIndex = 1,
+ IsWarmup = false,
+ IsCooldown = isCooldown
+ });
+ backwardWeightCount1++;
+ }
+
+ if (backwardWeightCount0 < backwardInputCount0 && backwardWeightCount0 < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardWeight,
+ MicroBatchIndex = backwardWeightCount0,
+ VirtualStageIndex = 0,
+ IsWarmup = false,
+ IsCooldown = isCooldown
+ });
+ backwardWeightCount0++;
+ }
+ }
+
+ // Phase 3: Drain remaining BackwardWeight operations
+ while (backwardWeightCount1 < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardWeight,
+ MicroBatchIndex = backwardWeightCount1,
+ VirtualStageIndex = 1,
+ IsWarmup = false,
+ IsCooldown = true
+ });
+ backwardWeightCount1++;
+ }
+
+ while (backwardWeightCount0 < numMicroBatches)
+ {
+ ops.Add(new PipelineOperation
+ {
+ Type = PipelineOperationType.BackwardWeight,
+ MicroBatchIndex = backwardWeightCount0,
+ VirtualStageIndex = 0,
+ IsWarmup = false,
+ IsCooldown = true
+ });
+ backwardWeightCount0++;
+ }
+
+ return ops;
+ }
+
+ ///
+ public double EstimateBubbleFraction(int numStages, int numMicroBatches)
+ {
+ if (numStages <= 1 || numMicroBatches <= 0)
+ {
+ return 0.0;
+ }
+
+ // ZB-V achieves zero bubble when numMicroBatches >= numStages
+ // Same as ZB-H2 but with 1F1B-equivalent memory
+ if (numMicroBatches >= numStages)
+ {
+ return 0.0;
+ }
+
+ // For insufficient micro-batches, small residual bubble
+ // With V=2 virtual stages, the bubble is reduced compared to ZB-H1
+ return (double)((long)numStages - numMicroBatches) / (6L * numMicroBatches + numStages);
+ }
+}
diff --git a/src/Interfaces/IAiModelBuilder.cs b/src/Interfaces/IAiModelBuilder.cs
index 77189f95f..36fb94ad2 100644
--- a/src/Interfaces/IAiModelBuilder.cs
+++ b/src/Interfaces/IAiModelBuilder.cs
@@ -773,6 +773,28 @@ IAiModelBuilder ConfigureDistributedTraining(
DistributedStrategy strategy = DistributedStrategy.DDP,
IShardingConfiguration? configuration = null);
+ ///
+ /// Configures pipeline-specific options for pipeline parallel training.
+ ///
+ ///
+ /// Call this after with
+ /// DistributedStrategy.PipelineParallel to customize pipeline scheduling,
+ /// partitioning, activation checkpointing, and micro-batch count.
+ /// For Beginners: This method fine-tunes how pipeline parallelism works.
+ /// You only need to call it if you want to change the defaults (GPipe schedule,
+ /// uniform partitioning, no checkpointing, 1 micro-batch).
+ ///
+ /// Pipeline schedule. Null = GPipeSchedule.
+ /// Partition strategy. Null = uniform.
+ /// Activation checkpointing config. Null = disabled.
+ /// Number of micro-batches to split the full batch into. Default: 1.
+ /// This builder instance for method chaining.
+ IAiModelBuilder ConfigurePipelineParallelism(
+ IPipelineSchedule? schedule = null,
+ IPipelinePartitionStrategy? partitionStrategy = null,
+ ActivationCheckpointConfig? checkpointConfig = null,
+ int microBatchCount = 1);
+
///
/// Configures the cross-validation strategy for model evaluation.
///
diff --git a/src/Interfaces/IPipelineDecomposableModel.cs b/src/Interfaces/IPipelineDecomposableModel.cs
new file mode 100644
index 000000000..04b2471d5
--- /dev/null
+++ b/src/Interfaces/IPipelineDecomposableModel.cs
@@ -0,0 +1,66 @@
+namespace AiDotNet.Interfaces;
+
+///
+/// Interface for models that support decomposing the backward pass into separate
+/// activation gradient and weight gradient computations. This enables Zero Bubble
+/// pipeline schedules (ZB-H1, ZB-H2, ZB-V) to overlap weight gradient computation
+/// with other pipeline stages.
+///
+///
+///
+/// Standard backward passes compute both dL/dInput (activation gradients) and dL/dWeights
+/// (weight gradients) together. This interface allows splitting them:
+///
+///
+/// -
+/// BackwardInput (B): Computes dL/dInput - needed by the upstream stage (critical path).
+///
+/// -
+/// BackwardWeight (W): Computes dL/dWeights - can be deferred to fill pipeline bubbles.
+///
+///
+/// For Beginners: Most models compute all gradients at once. This interface lets
+/// advanced pipeline schedules split that work into two parts: one that's urgent (the upstream
+/// stage is waiting for it) and one that can wait (filling idle time in the pipeline).
+///
+/// If your model doesn't implement this interface, pipeline schedules will automatically
+/// fall back to computing both gradient types together (which still works, just can't
+/// fill bubbles as effectively).
+/// Reference: Qi et al., "Zero Bubble Pipeline Parallelism", ICLR 2024 Spotlight.
+/// https://arxiv.org/abs/2401.10241
+///
+/// The numeric type used for calculations.
+/// The input data type.
+/// The output/target data type.
+public interface IPipelineDecomposableModel
+{
+ ///
+ /// Computes only the activation gradients (dL/dInput) for the backward pass.
+ /// This is on the critical path: the upstream pipeline stage needs these gradients
+ /// to continue its own backward pass.
+ ///
+ /// The input data that was used in the forward pass.
+ /// The expected output for loss computation.
+ ///
+ /// A tuple containing:
+ /// - activationGradients: The gradient of the loss with respect to the input (dL/dInput),
+ /// used to send gradients upstream in the pipeline.
+ /// - cachedState: An opaque state object that can be passed to
+ /// to avoid redundant computation. May be null if no caching is needed.
+ ///
+ (Vector activationGradients, object? cachedState) ComputeActivationGradients(
+ TInput input, TOutput target);
+
+ ///
+ /// Computes only the weight gradients (dL/dWeights) for the backward pass.
+ /// This is NOT on the critical path and can be deferred to fill pipeline bubbles.
+ ///
+ /// The input data that was used in the forward pass.
+ /// The expected output for loss computation.
+ ///
+ /// Optional cached state from to avoid
+ /// redundant forward pass computation. If null, the forward pass will be recomputed.
+ ///
+ /// The gradient of the loss with respect to the model's weights (dL/dWeights).
+ Vector ComputeWeightGradients(TInput input, TOutput target, object? cachedState);
+}
diff --git a/src/Interfaces/IPipelinePartitionStrategy.cs b/src/Interfaces/IPipelinePartitionStrategy.cs
new file mode 100644
index 000000000..44407fb74
--- /dev/null
+++ b/src/Interfaces/IPipelinePartitionStrategy.cs
@@ -0,0 +1,33 @@
+namespace AiDotNet.Interfaces;
+
+///
+/// Defines a strategy for partitioning model parameters across pipeline stages.
+///
+///
+/// For Beginners: When splitting a neural network across multiple devices (pipeline parallelism),
+/// you need to decide which layers go on which device. This interface defines that decision.
+///
+/// The default (uniform) strategy just divides parameters evenly, but this can lead to
+/// imbalanced workloads because some layers (like attention) are much more expensive than
+/// others (like layer normalization). A load-balanced strategy can account for this.
+///
+///
+/// The numeric type for operations.
+public interface IPipelinePartitionStrategy
+{
+ ///
+ /// Computes the partition boundaries for the given number of stages.
+ ///
+ ///
+ /// For Beginners: This returns an array describing where each stage's parameters
+ /// start and how many parameters it owns. For example, with 1000 total parameters and 4 stages,
+ /// a uniform partition might return: [(0, 250), (250, 250), (500, 250), (750, 250)].
+ ///
+ /// Total number of parameters in the model.
+ /// Number of pipeline stages to partition across.
+ ///
+ /// An array of (startIndex, size) tuples, one per stage, describing each stage's
+ /// parameter shard boundaries.
+ ///
+ (int StartIndex, int Size)[] ComputePartition(int totalParameters, int numStages);
+}
diff --git a/src/Interfaces/IPipelineSchedule.cs b/src/Interfaces/IPipelineSchedule.cs
new file mode 100644
index 000000000..f26b05e8e
--- /dev/null
+++ b/src/Interfaces/IPipelineSchedule.cs
@@ -0,0 +1,171 @@
+namespace AiDotNet.Interfaces;
+
+///
+/// Defines a scheduling strategy for pipeline parallel training.
+///
+///
+///
+/// Pipeline schedules determine the order in which forward and backward passes execute
+/// across micro-batches and stages. Different schedules trade off memory usage, pipeline
+/// bubble overhead, and implementation complexity.
+///
+///
+/// Schedules fall into two categories:
+/// - Single-stage: Each rank owns one contiguous model chunk (GPipe, 1F1B, ZB-H1, ZB-H2).
+/// - Multi-stage: Each rank owns V non-contiguous chunks ("virtual stages")
+/// (Interleaved 1F1B, Looped BFS, ZB-V).
+///
+/// For Beginners: In pipeline parallelism, multiple stages process data like an
+/// assembly line. A "schedule" decides the order of operations to keep all stages as busy
+/// as possible and minimize idle time ("pipeline bubbles").
+///
+/// Think of it like coordinating workers on an assembly line:
+/// - GPipe: Worker 1 finishes ALL items, then Worker 2 starts ALL items (simple but slow)
+/// - 1F1B: Workers alternate between forward and backward steps (more complex but faster)
+/// - Zero Bubble: Workers split backward into two parts, using the flexible part to fill gaps
+///
+///
+public interface IPipelineSchedule
+{
+ ///
+ /// Gets the name of the scheduling strategy for diagnostics.
+ ///
+ string Name { get; }
+
+ ///
+ /// Gets the number of virtual stages (model chunks) each rank holds.
+ ///
+ ///
+ /// For Beginners: Most schedules assign one chunk of the model to each rank
+ /// (VirtualStagesPerRank = 1). Advanced schedules like Interleaved 1F1B and ZB-V assign
+ /// multiple non-contiguous chunks to each rank to reduce pipeline bubbles.
+ ///
+ int VirtualStagesPerRank { get; }
+
+ ///
+ /// Generates the sequence of operations for a given stage in the pipeline.
+ ///
+ ///
+ /// For Beginners: This returns a list of instructions for a specific stage,
+ /// telling it when to do forward passes, backward passes, and which micro-batch to work on.
+ ///
+ /// The pipeline stage index (0-based).
+ /// Total number of pipeline stages.
+ /// Number of micro-batches per mini-batch.
+ /// Ordered sequence of pipeline operations for this stage.
+ IReadOnlyList GetSchedule(int stageId, int numStages, int numMicroBatches);
+
+ ///
+ /// Estimates the pipeline bubble fraction for this schedule.
+ ///
+ ///
+ /// For Beginners: The bubble fraction is the percentage of time that stages are idle
+ /// (waiting for data). Lower is better. GPipe has ~(numStages-1)/numMicroBatches bubble.
+ /// 1F1B reduces this significantly. Zero Bubble schedules approach 0%.
+ ///
+ /// Total number of pipeline stages.
+ /// Number of micro-batches per mini-batch.
+ /// Estimated fraction of total time spent in pipeline bubbles (0.0 to 1.0).
+ double EstimateBubbleFraction(int numStages, int numMicroBatches);
+}
+
+///
+/// Represents a single operation in the pipeline schedule.
+///
+///
+/// For Beginners: This is one instruction in the schedule, like
+/// "do forward pass on micro-batch #3" or "do backward pass on micro-batch #1".
+///
+/// Zero Bubble schedules split the backward pass into two operations:
+/// BackwardInput (compute activation gradients, on the critical path) and
+/// BackwardWeight (compute weight gradients, can fill bubbles). Traditional
+/// schedules use the combined Backward type.
+///
+///
+public class PipelineOperation
+{
+ ///
+ /// Gets the type of pipeline operation (Forward, Backward, BackwardInput, or BackwardWeight).
+ ///
+ public PipelineOperationType Type { get; init; }
+
+ ///
+ /// Gets the micro-batch index this operation works on.
+ ///
+ public int MicroBatchIndex { get; init; }
+
+ ///
+ /// Gets whether this is a warmup operation (part of pipeline fill phase).
+ ///
+ ///
+ /// For Beginners: During warmup, the pipeline is "filling up" - not all stages
+ /// are busy yet. After warmup, the pipeline runs at full utilization.
+ ///
+ public bool IsWarmup { get; init; }
+
+ ///
+ /// Gets whether this is a cooldown operation (part of pipeline drain phase).
+ ///
+ public bool IsCooldown { get; init; }
+
+ ///
+ /// Gets the virtual stage index for multi-stage schedules (0-based within this rank).
+ ///
+ ///
+ /// For Beginners: In multi-stage schedules like Interleaved 1F1B, each rank
+ /// holds multiple model chunks. This index tells which chunk to run this operation on.
+ /// For single-stage schedules, this is always 0.
+ ///
+ public int VirtualStageIndex { get; init; }
+}
+
+///
+/// Types of pipeline operations.
+///
+///
+///
+/// Traditional schedules (GPipe, 1F1B) use Forward and Backward.
+/// Zero Bubble schedules decompose Backward into BackwardInput + BackwardWeight
+/// to enable filling pipeline bubbles with weight gradient computation.
+///
+/// Reference: Qi et al., "Zero Bubble Pipeline Parallelism", ICLR 2024.
+/// https://arxiv.org/abs/2401.10241
+///
+public enum PipelineOperationType
+{
+ ///
+ /// Forward pass through the stage's layers.
+ ///
+ Forward,
+
+ ///
+ /// Combined backward pass (gradient computation) through the stage's layers.
+ /// Used by traditional schedules (GPipe, 1F1B) that don't split the backward pass.
+ ///
+ Backward,
+
+ ///
+ /// Backward pass computing only activation gradients (dL/dInput).
+ /// This is on the critical path - the upstream stage needs these gradients.
+ /// Used by Zero Bubble schedules (ZB-H1, ZB-H2, ZB-V).
+ ///
+ ///
+ /// For Beginners: This computes how much the loss changes when the input
+ /// to this stage changes. The previous stage needs this information to continue its
+ /// own backward pass, so it must be done promptly.
+ ///
+ BackwardInput,
+
+ ///
+ /// Backward pass computing only weight gradients (dL/dWeights).
+ /// This is NOT on the critical path - no other stage depends on it.
+ /// Can be deferred to fill pipeline bubbles.
+ /// Used by Zero Bubble schedules (ZB-H1, ZB-H2, ZB-V).
+ ///
+ ///
+ /// For Beginners: This computes how much the loss changes when the weights
+ /// of this stage change. Since no other stage needs this information, it can be computed
+ /// later to fill idle time (bubbles) in the pipeline.
+ ///
+ BackwardWeight
+}