Skip to content

Aneeshers/NNX-Control

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 

Repository files navigation

NNX-Control

High-performance JAX control environments with end-to-end PPO training in a single file.

Features

  • Pure JAX Implementation: Fully JITtable, vmapable, and scanable control environments with no Gym/Gymnax dependencies
  • JAX-Native Rendering: Hardware-accelerated rendering using pure JAX operations
  • NNX-Based PPO: Proximal Policy Optimization implemented with Flax NNX (not Flax Linen) for modern JAX neural networks
  • One-File Implementation: Complete environment, training loop, and rendering in a single, self-contained file
  • Parallel Training: Vectorized rollouts across multiple parallel environments for maximum throughput

Supported Environments

Currently supports:

  • CartPole-v1: Classic cart-pole balancing task with pure-JAX physics simulation

Performance

The implementation leverages JAX's compilation and parallelization capabilities:

  • JIT compilation for fast environment steps and policy updates
  • vmap for batched environment rollouts (64 parallel envs)
  • scan for efficient sequential operations in training loops
  • Pure-JAX rendering with hardware acceleration

Training Results

Training curve showing average reward per episode

Trained policy evaluation on 4 parallel environments

Usage

import jax
from cartpole import CartPole, env

# Create environment
env = CartPole()
params = env.default_params

# Reset environment
key = jax.random.PRNGKey(0)
obs, state = env.reset(key, params)

# Step environment
action = 1  # 0 = left, 1 = right
obs_next, state_next, reward, done, info = env.step(key, state, action, params)

# Render state (returns RGB array)
rgb_array = env.render(state, params)

Training is fully integrated - just run:

python cartpole.py

Related Projects

For a similar high-performance gridworld environment implementation, check out NNX-Gridworld - a super-fast JITtable, vmapped gridworld with JAX-based rendering and one-file PPO for both vision (with JAX-rendered observations) and state-based tasks.

Technical Details

Environment

  • Pure functional API with no mutable state
  • Physics simulation using Euler integration
  • Auto-reset on episode termination for continuous training
  • Compatible with standard RL benchmarking protocols

PPO Implementation

  • Actor-critic architecture with separate networks
  • Generalized Advantage Estimation (GAE)
  • Clipped surrogate objective (ε=0.2)
  • Value function loss with coefficient weighting
  • 4 update iterations per rollout
  • Normalized advantages and returns for training stability

Architecture

  • Feed-forward networks: 4 → 64 → 64 → {2, 1} (actor/critic outputs)
  • ReLU activations
  • Adam optimizer (lr=3e-4)
  • Batch size: 64 parallel environments × 500 timesteps

Requirements

  • JAX
  • Flax (NNX)
  • Optax
  • Matplotlib (for plotting)
  • Imageio (for GIF generation)

About

Extremely fast jitted RL control environments with jax rendering and fast NNX policies

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages