-
-
Notifications
You must be signed in to change notification settings - Fork 8
Reinforcement Learning
franklinic edited this page Jan 19, 2026
·
1 revision
AiDotNet provides 80+ reinforcement learning agents for training AI in interactive environments.
Reinforcement learning (RL) trains agents to make decisions by interacting with environments and learning from rewards.
Agent → Action → Environment → State, Reward → Agent (learns)
using AiDotNet.ReinforcementLearning;
// Create environment
var env = new CartPoleEnvironment<float>();
// Create agent
var agent = new DQNAgent<float>(new DQNConfig<float>
{
StateSize = env.ObservationSpace.Shape[0],
ActionSize = env.ActionSpace.N,
HiddenLayers = new[] { 128, 128 },
LearningRate = 1e-3f,
Gamma = 0.99f,
EpsilonStart = 1.0f,
EpsilonEnd = 0.01f,
EpsilonDecay = 0.995f,
ReplayBufferSize = 100000,
BatchSize = 64
});
// Training loop
for (int episode = 0; episode < 500; episode++)
{
var state = env.Reset();
float totalReward = 0;
while (!env.Done)
{
var action = agent.SelectAction(state);
var (nextState, reward, done, info) = env.Step(action);
agent.Remember(state, action, reward, nextState, done);
agent.Learn();
state = nextState;
totalReward += reward;
}
Console.WriteLine($"Episode {episode}: Reward = {totalReward}");
}var config = new DQNConfig<float>
{
StateSize = 4,
ActionSize = 2,
HiddenLayers = new[] { 128, 128 },
// Learning parameters
LearningRate = 1e-3f,
Gamma = 0.99f, // Discount factor
// Exploration
EpsilonStart = 1.0f,
EpsilonEnd = 0.01f,
EpsilonDecay = 0.995f,
// Experience replay
ReplayBufferSize = 100000,
BatchSize = 64,
// Target network
TargetUpdateFrequency = 100
};
var agent = new DQNAgent<float>(config);var agent = new DoubleDQNAgent<float>(new DoubleDQNConfig<float>
{
StateSize = 4,
ActionSize = 2,
// ... same as DQN
});var agent = new DuelingDQNAgent<float>(new DuelingDQNConfig<float>
{
StateSize = 4,
ActionSize = 2,
ValueStreamLayers = new[] { 128 },
AdvantageStreamLayers = new[] { 128 },
// ... other config
});Combines all DQN improvements:
var agent = new RainbowAgent<float>(new RainbowConfig<float>
{
StateSize = 4,
ActionSize = 2,
// Distributional RL
NumAtoms = 51,
VMin = -10,
VMax = 10,
// Prioritized Experience Replay
PriorityAlpha = 0.6f,
PriorityBetaStart = 0.4f,
// Noisy Networks (replaces epsilon-greedy)
NoisyNetSigma = 0.5f,
// N-step returns
NSteps = 3
});var agent = new REINFORCEAgent<float>(new REINFORCEConfig<float>
{
StateSize = 4,
ActionSize = 2,
HiddenLayers = new[] { 128 },
LearningRate = 1e-3f,
Gamma = 0.99f,
EntropyCoefficient = 0.01f
});Most popular policy gradient method:
var agent = new PPOAgent<float>(new PPOConfig<float>
{
StateSize = 4,
ActionSize = 2,
HiddenLayers = new[] { 256, 256 },
// PPO-specific
ClipRatio = 0.2f,
ValueCoefficient = 0.5f,
EntropyCoefficient = 0.01f,
// GAE
Gamma = 0.99f,
Lambda = 0.95f,
// Training
LearningRate = 3e-4f,
NumEpochs = 10,
MiniBatchSize = 64,
NumEnvs = 8 // Parallel environments
});// Advantage Actor-Critic
var agent = new A2CAgent<float>(new A2CConfig<float>
{
StateSize = 4,
ActionSize = 2,
HiddenLayers = new[] { 256, 256 },
LearningRate = 7e-4f,
ValueCoefficient = 0.5f,
EntropyCoefficient = 0.01f,
MaxGradNorm = 0.5f,
NumEnvs = 8
});
// Asynchronous (multi-threaded)
var agent = new A3CAgent<float>(new A3CConfig<float>
{
// ... same config
NumWorkers = 8
});var agent = new TRPOAgent<float>(new TRPOConfig<float>
{
StateSize = 4,
ActionSize = 2,
MaxKL = 0.01f, // KL divergence constraint
Damping = 0.1f,
LineSearchSteps = 10
});State-of-the-art for continuous control:
var agent = new SACAgent<float>(new SACConfig<float>
{
StateSize = 8,
ActionSize = 4, // Continuous actions
ActionBounds = (-1f, 1f),
HiddenLayers = new[] { 256, 256 },
// SAC-specific
Alpha = 0.2f, // Temperature (entropy coefficient)
AutoTuneAlpha = true, // Automatic temperature tuning
TargetEntropy = -4f, // -dim(A)
// Learning
LearningRate = 3e-4f,
Gamma = 0.99f,
Tau = 0.005f, // Soft target update
BatchSize = 256,
ReplayBufferSize = 1000000
});var agent = new TD3Agent<float>(new TD3Config<float>
{
StateSize = 8,
ActionSize = 4,
ActionBounds = (-1f, 1f),
HiddenLayers = new[] { 256, 256 },
// TD3-specific
PolicyDelay = 2, // Update policy every 2 critic updates
TargetNoise = 0.2f,
NoiseClip = 0.5f,
ExplorationNoise = 0.1f,
LearningRate = 3e-4f,
Gamma = 0.99f,
Tau = 0.005f
});var agent = new DDPGAgent<float>(new DDPGConfig<float>
{
StateSize = 8,
ActionSize = 4,
ActionBounds = (-1f, 1f),
ActorLearningRate = 1e-4f,
CriticLearningRate = 1e-3f,
OUNoiseTheta = 0.15f,
OUNoiseSigma = 0.2f
});var worldModel = new WorldModel<float>(new WorldModelConfig<float>
{
StateSize = 8,
ActionSize = 4,
LatentSize = 32,
HiddenSize = 256,
SequenceLength = 16
});
// Train world model
await worldModel.TrainAsync(trajectories);
// Plan using learned model
var agent = new ModelPredictiveControlAgent<float>(
worldModel: worldModel,
planningHorizon: 15,
numSimulations: 1000);var agent = new DreamerAgent<float>(new DreamerConfig<float>
{
StateSize = 64,
ActionSize = 4,
ImageObservation = true,
// World model
LatentSize = 30,
RecurrentStateSize = 200,
// Imagination
ImaginationHorizon = 15,
NumImaginations = 100,
// Training
ModelLearningRate = 6e-4f,
ActorLearningRate = 8e-5f,
CriticLearningRate = 8e-5f
});var agents = new MADDPGAgents<float>(new MADDPGConfig<float>
{
NumAgents = 3,
StateSize = 8,
ActionSize = 2,
SharedCritic = true,
LearningRate = 1e-3f
});
// Training with multiple agents
var states = env.Reset();
while (!env.Done)
{
var actions = agents.SelectActions(states);
var (nextStates, rewards, dones, infos) = env.Step(actions);
agents.Remember(states, actions, rewards, nextStates, dones);
agents.Learn();
states = nextStates;
}var agent = new QMIXAgent<float>(new QMIXConfig<float>
{
NumAgents = 5,
StateSize = 10,
ActionSize = 5,
MixingEmbedSize = 32,
HypernetEmbedSize = 64
});// Classic control
var cartPole = new CartPoleEnvironment<float>();
var mountainCar = new MountainCarEnvironment<float>();
var pendulum = new PendulumEnvironment<float>();
var acrobot = new AcrobotEnvironment<float>();
// Continuous control
var halfCheetah = new HalfCheetahEnvironment<float>();
var humanoid = new HumanoidEnvironment<float>();
var ant = new AntEnvironment<float>();
// Atari (requires ROM)
var breakout = new AtariEnvironment<float>("Breakout-v4");
var pong = new AtariEnvironment<float>("Pong-v4");public class MyEnvironment : Environment<float>
{
public override Space ObservationSpace => new BoxSpace(low: -10, high: 10, shape: new[] { 4 });
public override Space ActionSpace => new DiscreteSpace(3);
public override float[] Reset()
{
// Reset environment and return initial state
return new float[] { 0, 0, 0, 0 };
}
public override (float[] state, float reward, bool done, Dictionary<string, object> info) Step(int action)
{
// Execute action and return next state, reward, done flag
var nextState = ComputeNextState(action);
var reward = ComputeReward();
var done = IsTerminal();
return (nextState, reward, done, new Dictionary<string, object>());
}
}// Run multiple environments in parallel
var envs = new VectorizedEnvironment<float>(
makeEnv: () => new CartPoleEnvironment<float>(),
numEnvs: 8);
var states = envs.Reset(); // Shape: [8, 4]
var (nextStates, rewards, dones, infos) = envs.Step(actions); // actions shape: [8]var callbacks = new CallbackList<float>
{
new EpisodeLoggerCallback(logEvery: 10),
new TensorBoardCallback(logDir: "./logs"),
new CheckpointCallback(saveEvery: 100, path: "./checkpoints"),
new EarlyStoppingCallback(patience: 50, minDelta: 1.0f)
};
await agent.TrainAsync(env, numEpisodes: 1000, callbacks: callbacks);var shapedEnv = new RewardShapingWrapper<float>(
env: baseEnv,
shapingFunction: (state, action, nextState, reward) =>
{
// Add potential-based shaping
var shaping = Potential(nextState) - Potential(state);
return reward + shaping;
});var curriculum = new CurriculumLearning<float>(
initialDifficulty: 0.1f,
maxDifficulty: 1.0f,
successThreshold: 0.8f,
windowSize: 100);
// Automatically increases difficulty as agent improves
var env = curriculum.WrapEnvironment(baseEnv);// Save agent
await agent.SaveAsync("./agent_checkpoint");
// Load agent
var loadedAgent = await DQNAgent<float>.LoadAsync("./agent_checkpoint");
// Save just the policy network
await agent.SavePolicyAsync("./policy.aidotnet");| Algorithm | Type | Action Space | Sample Efficiency | Stability |
|---|---|---|---|---|
| DQN | Value | Discrete | Medium | High |
| Rainbow | Value | Discrete | High | High |
| PPO | Policy | Both | Low | Very High |
| SAC | Actor-Critic | Continuous | High | High |
| TD3 | Actor-Critic | Continuous | High | High |
| DDPG | Actor-Critic | Continuous | Medium | Medium |
| A2C | Policy | Both | Low | Medium |
- Neural Networks - Build custom network architectures
- Distributed Training - Scale training across GPUs
- Optimizers - Training optimization
Getting Started
Core Concepts
Reference
Community