Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions src/AiDotNet.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -99,31 +99,6 @@
<Folder Include="WaveletFunctions\" />
</ItemGroup>

<!-- Source generator for auto-generating YAML configuration mappings -->
<ItemGroup>
<ProjectReference Include="AiDotNet.Generators\AiDotNet.Generators.csproj"
OutputItemType="Analyzer"
ReferenceOutputAssembly="false" />
</ItemGroup>

<!-- Exclude the source generator project files from this project's compilation -->
<ItemGroup>
<Compile Remove="AiDotNet.Generators\**\*.cs" />
<EmbeddedResource Remove="AiDotNet.Generators\**\*" />
<None Remove="AiDotNet.Generators\**\*" />
</ItemGroup>

<!-- Emit generated source files to disk for inspection/debugging -->
<PropertyGroup>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
<CompilerGeneratedFilesOutputPath>Generated</CompilerGeneratedFilesOutputPath>
</PropertyGroup>

<!-- Exclude emitted generated files from compilation (they're already in-memory from the generator) -->
<ItemGroup>
<Compile Remove="Generated\**\*.cs" />
</ItemGroup>

<!-- Exclude the AiDotNet.Serving project files from this project -->
<ItemGroup>
<Compile Remove="AiDotNet.Serving\**\*.cs" />
Expand Down Expand Up @@ -151,6 +126,32 @@
<Compile Remove="Polyfills\LanguageFeaturePolyfills.cs" />
</ItemGroup>

<!-- Source generator for auto-generating YAML configuration mappings -->
<ItemGroup>
<ProjectReference Include="AiDotNet.Generators\AiDotNet.Generators.csproj"
OutputItemType="Analyzer"
ReferenceOutputAssembly="false"
SetTargetFramework="TargetFramework=netstandard2.0" />
</ItemGroup>

<!-- Exclude the source generator project files from this project's compilation -->
<ItemGroup>
<Compile Remove="AiDotNet.Generators\**\*.cs" />
<EmbeddedResource Remove="AiDotNet.Generators\**\*" />
<None Remove="AiDotNet.Generators\**\*" />
</ItemGroup>

<!-- Emit generated source files to disk for inspection/debugging (Debug only) -->
<PropertyGroup Condition="'$(Configuration)'=='Debug'">
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
<CompilerGeneratedFilesOutputPath>Generated</CompilerGeneratedFilesOutputPath>
</PropertyGroup>

<!-- Exclude emitted generated files from compilation (they're already in-memory from the generator) -->
<ItemGroup>
<Compile Remove="Generated\**\*.cs" />
</ItemGroup>

<!-- AiDotNet.Tensors NuGet package (spun out to separate repo) -->
<ItemGroup>
<PackageReference Include="AiDotNet.Tensors" Version="0.7.0" />
Expand Down
81 changes: 80 additions & 1 deletion src/AiModelBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ public partial class AiModelBuilder<T, TInput, TOutput> : IAiModelBuilder<T, TIn
private ICommunicationBackend<T>? _distributedBackend;
private DistributedStrategy _distributedStrategy = DistributedStrategy.DDP;
private IShardingConfiguration<T>? _distributedConfiguration;
private IPipelinePartitionStrategy<T>? _pipelinePartitionStrategy;
private IPipelineSchedule? _pipelineSchedule;
private ActivationCheckpointConfig? _pipelineCheckpointConfig;
private int _pipelineMicroBatchSize = 1;
private ICrossValidator<T, TInput, TOutput>? _crossValidator;
private AgentConfiguration<T>? _agentConfig;
private AgentAssistanceOptions _agentOptions = AgentAssistanceOptions.Default;
Expand Down Expand Up @@ -1762,7 +1766,12 @@ private async Task<AiModelResult<T, TInput, TOutput>> BuildSupervisedInternalAsy
new DistributedTraining.ZeRO3Model<T, TInput, TOutput>(_model, shardingConfig),
new DistributedTraining.ZeRO3Optimizer<T, TInput, TOutput>(optimizer, shardingConfig)),
DistributedStrategy.PipelineParallel => CreateDistributedPair(
new DistributedTraining.PipelineParallelModel<T, TInput, TOutput>(_model, shardingConfig),
new DistributedTraining.PipelineParallelModel<T, TInput, TOutput>(
_model, shardingConfig,
microBatchSize: _pipelineMicroBatchSize,
partitionStrategy: _pipelinePartitionStrategy,
schedule: _pipelineSchedule,
checkpointConfig: _pipelineCheckpointConfig),
new DistributedTraining.PipelineParallelOptimizer<T, TInput, TOutput>(optimizer, shardingConfig)),
DistributedStrategy.TensorParallel => CreateDistributedPair(
new DistributedTraining.TensorParallelModel<T, TInput, TOutput>(_model, shardingConfig),
Expand Down Expand Up @@ -3790,6 +3799,10 @@ public IAiModelBuilder<T, TInput, TOutput> ConfigureMetaLearning(IMetaLearner<T,
///
/// You just train as normal - the distributed magic happens behind the scenes!
/// </para>
/// <para>
/// For pipeline parallelism, call <see cref="ConfigurePipelineParallelism"/> after this method
/// to customize scheduling, partitioning, and activation checkpointing.
/// </para>
/// </remarks>
public IAiModelBuilder<T, TInput, TOutput> ConfigureDistributedTraining(
ICommunicationBackend<T>? backend = null,
Expand All @@ -3802,6 +3815,72 @@ public IAiModelBuilder<T, TInput, TOutput> ConfigureDistributedTraining(
return this;
}

/// <summary>
/// Configures pipeline-specific options for pipeline parallel training.
/// </summary>
/// <param name="schedule">
/// Pipeline execution schedule. If null, uses GPipeSchedule.
/// Use <see cref="DistributedTraining.OneForwardOneBackwardSchedule"/> for reduced pipeline bubble (~12-15% vs ~50%).
/// </param>
/// <param name="partitionStrategy">
/// Strategy for partitioning layers across pipeline stages.
/// If null, uses uniform partitioning. Use <see cref="DistributedTraining.LoadBalancedPartitionStrategy{T}"/>
/// to balance computational cost across stages.
/// </param>
/// <param name="checkpointConfig">
/// Activation checkpointing configuration.
/// If null, checkpointing is disabled. Enable to reduce memory from O(L) to O(sqrt(L)).
/// </param>
/// <param name="microBatchCount">
/// Number of micro-batches to split the full batch into for pipeline execution.
/// Higher values reduce pipeline bubble but increase memory. Default: 1.
/// </param>
/// <returns>This builder instance for method chaining.</returns>
/// <remarks>
/// <para>
/// Call this after <see cref="ConfigureDistributedTraining"/> with
/// <c>DistributedStrategy.PipelineParallel</c> to customize pipeline scheduling,
/// partitioning, activation checkpointing, and micro-batch count.
/// </para>
/// <para>
/// <b>For Beginners:</b> This method fine-tunes how pipeline parallelism works.
/// You only need to call it if you want to change the defaults (GPipe schedule,
/// uniform partitioning, no checkpointing, 1 micro-batch).
/// </para>
/// <para>
/// <b>Example:</b>
/// <code>
/// var result = builder
/// .ConfigureModel(myModel)
/// .ConfigureDistributedTraining(strategy: DistributedStrategy.PipelineParallel)
/// .ConfigurePipelineParallelism(
/// schedule: new OneForwardOneBackwardSchedule(),
/// partitionStrategy: new LoadBalancedPartitionStrategy&lt;double&gt;(estimatedLayerSize: 1024),
/// checkpointConfig: new ActivationCheckpointConfig { Enabled = true },
/// microBatchCount: 8)
/// .Build(xTrain, yTrain);
/// </code>
/// </para>
/// </remarks>
public IAiModelBuilder<T, TInput, TOutput> ConfigurePipelineParallelism(
IPipelineSchedule? schedule = null,
IPipelinePartitionStrategy<T>? partitionStrategy = null,
ActivationCheckpointConfig? checkpointConfig = null,
int microBatchCount = 1)
{
if (microBatchCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(microBatchCount),
$"Micro-batch count must be at least 1, but was {microBatchCount}.");
}

_pipelineSchedule = schedule;
_pipelinePartitionStrategy = partitionStrategy;
_pipelineCheckpointConfig = checkpointConfig;
_pipelineMicroBatchSize = microBatchCount;
return this;
}

/// <summary>
/// Enables AI agent assistance during the model building process.
/// </summary>
Expand Down
140 changes: 140 additions & 0 deletions src/DistributedTraining/ActivationCheckpointConfig.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
namespace AiDotNet.DistributedTraining;

/// <summary>
/// Configuration for activation checkpointing in pipeline parallel training.
/// </summary>
/// <remarks>
/// <para>
/// Activation checkpointing (also called gradient checkpointing) trades compute for memory
/// by only storing activations at checkpoint layers during the forward pass. Intermediate
/// activations are recomputed from the nearest checkpoint during the backward pass.
/// </para>
/// <para><b>For Beginners:</b> During training, the forward pass must save intermediate results
/// (activations) so the backward pass can compute gradients. For very deep models, storing all
/// these activations uses enormous amounts of memory.
///
/// Activation checkpointing is like taking notes at chapter boundaries instead of every page:
/// - Without checkpointing: Save every activation (lots of memory, no recomputation)
/// - With checkpointing: Save every Nth activation, recompute the rest (less memory, more compute)
///
/// Memory savings: O(L) → O(sqrt(L)) where L = number of layers.
/// For 100 layers, this reduces memory from 100 activations to ~10 activations.
///
/// The trade-off is ~33% more compute time, but this enables training models that otherwise
/// wouldn't fit in memory.
/// </para>
/// <para><b>Reference:</b> Chen et al., "Training Deep Nets with Sublinear Memory Cost", 2016.
/// https://arxiv.org/abs/1604.06174</para>
/// </remarks>
public class ActivationCheckpointConfig
{
private int _checkpointEveryNLayers = 10;
private int _maxActivationsInMemory;

/// <summary>
/// Gets or sets whether activation checkpointing is enabled.
/// </summary>
/// <remarks>
/// <para><b>For Beginners:</b> Set this to true to enable memory savings. Default is false
/// (no checkpointing, standard behavior).</para>
/// </remarks>
public bool Enabled { get; set; }

/// <summary>
/// Gets or sets how often to save a checkpoint (every N layers).
/// </summary>
/// <remarks>
/// <para><b>For Beginners:</b> Lower values save more activations (more memory, less recomputation).
/// Higher values save fewer (less memory, more recomputation).
///
/// Optimal value is approximately sqrt(total_layers) for minimum total cost.
/// For a 100-layer model, checkpointing every 10 layers is a good default.
///
/// Default: 10 layers between checkpoints.</para>
/// </remarks>
/// <exception cref="ArgumentOutOfRangeException">Thrown when value is less than 1.</exception>
public int CheckpointEveryNLayers
{
get => _checkpointEveryNLayers;
set
{
if (value < 1)
{
throw new ArgumentOutOfRangeException(nameof(CheckpointEveryNLayers),
$"CheckpointEveryNLayers must be at least 1, but was {value}. " +
"A value of 0 would cause division-by-zero in interval-based checkpointing.");
}
_checkpointEveryNLayers = value;
}
}

/// <summary>
/// Gets or sets the recomputation strategy to use during the backward pass.
/// </summary>
/// <remarks>
/// <para><b>For Beginners:</b>
/// - Selective: Only recompute activations that are needed and not checkpointed (recommended)
/// - Full: Recompute all non-checkpointed activations from the previous checkpoint
/// - None: Don't recompute, equivalent to no checkpointing (for testing/debugging)
/// </para>
/// </remarks>
public RecomputeStrategy RecomputeStrategy { get; set; } = RecomputeStrategy.Selective;

/// <summary>
/// Gets or sets the maximum number of activations to keep in memory simultaneously.
/// </summary>
/// <remarks>
/// <para><b>For Beginners:</b> This caps how many activations are stored at once.
/// Set to 0 for no limit (uses CheckpointEveryNLayers to determine storage).
/// A non-zero value overrides CheckpointEveryNLayers by dynamically adjusting
/// the checkpoint frequency to stay within the memory budget.</para>
/// </remarks>
/// <exception cref="ArgumentOutOfRangeException">Thrown when value is negative.</exception>
public int MaxActivationsInMemory
{
get => _maxActivationsInMemory;
set
{
if (value < 0)
{
throw new ArgumentOutOfRangeException(nameof(MaxActivationsInMemory),
$"MaxActivationsInMemory must be non-negative, but was {value}. " +
"Use 0 for no limit.");
}
_maxActivationsInMemory = value;
}
}

/// <summary>
/// Gets or sets whether to checkpoint the very first layer's input.
/// </summary>
/// <remarks>
/// <para><b>For Beginners:</b> The first layer's input is always needed for the backward pass.
/// If true, it's saved as a checkpoint. If false, the caller must ensure the input is
/// available during the backward pass (which is usually the case).</para>
/// </remarks>
public bool CheckpointFirstLayer { get; set; } = true;
}

/// <summary>
/// Strategy for recomputing activations during the backward pass.
/// </summary>
public enum RecomputeStrategy
{
/// <summary>
/// Only recompute activations that are needed for the current backward step.
/// This is the most memory-efficient but requires careful bookkeeping.
/// </summary>
Selective,

/// <summary>
/// Recompute all activations between the two nearest checkpoints during backward.
/// Simpler implementation but may do slightly more work than necessary.
/// </summary>
Full,

/// <summary>
/// No recomputation. Equivalent to disabled checkpointing. Useful for debugging.
/// </summary>
None
}
Loading
Loading