High-performance JAX control environments with end-to-end PPO training in a single file.
- Pure JAX Implementation: Fully JITtable,
vmapable, andscanable control environments with no Gym/Gymnax dependencies - JAX-Native Rendering: Hardware-accelerated rendering using pure JAX operations
- NNX-Based PPO: Proximal Policy Optimization implemented with Flax NNX (not Flax Linen) for modern JAX neural networks
- One-File Implementation: Complete environment, training loop, and rendering in a single, self-contained file
- Parallel Training: Vectorized rollouts across multiple parallel environments for maximum throughput
Currently supports:
- CartPole-v1: Classic cart-pole balancing task with pure-JAX physics simulation
The implementation leverages JAX's compilation and parallelization capabilities:
- JIT compilation for fast environment steps and policy updates
vmapfor batched environment rollouts (64 parallel envs)scanfor efficient sequential operations in training loops- Pure-JAX rendering with hardware acceleration
import jax
from cartpole import CartPole, env
# Create environment
env = CartPole()
params = env.default_params
# Reset environment
key = jax.random.PRNGKey(0)
obs, state = env.reset(key, params)
# Step environment
action = 1 # 0 = left, 1 = right
obs_next, state_next, reward, done, info = env.step(key, state, action, params)
# Render state (returns RGB array)
rgb_array = env.render(state, params)Training is fully integrated - just run:
python cartpole.pyFor a similar high-performance gridworld environment implementation, check out NNX-Gridworld - a super-fast JITtable, vmapped gridworld with JAX-based rendering and one-file PPO for both vision (with JAX-rendered observations) and state-based tasks.
- Pure functional API with no mutable state
- Physics simulation using Euler integration
- Auto-reset on episode termination for continuous training
- Compatible with standard RL benchmarking protocols
- Actor-critic architecture with separate networks
- Generalized Advantage Estimation (GAE)
- Clipped surrogate objective (ε=0.2)
- Value function loss with coefficient weighting
- 4 update iterations per rollout
- Normalized advantages and returns for training stability
- Feed-forward networks: 4 → 64 → 64 → {2, 1} (actor/critic outputs)
- ReLU activations
- Adam optimizer (lr=3e-4)
- Batch size: 64 parallel environments × 500 timesteps
- JAX
- Flax (NNX)
- Optax
- Matplotlib (for plotting)
- Imageio (for GIF generation)

