A minimal, easy-to-use neural network library for Rust with optional GPU acceleration. Designed for simplicity and performance.
- Simple API: Builder pattern for intuitive network construction
- Backend Agnostic: Generic over CPU/GPU backends - switch hardware without changing code
- GPU Acceleration: Optional wgpu backend (Vulkan, Metal, DX12, WebGPU) for 10-100x speedups
- Zero-Copy Training: Backend owns all memory, eliminating CPU-GPU copies in hot paths
- Flexible I/O: Trait-based input/output for type-safe data handling
- Custom Loss Functions: Transformer trait for domain-specific loss computation
- Batched GPU Training: Automatic command batching for maximum GPU utilization
use iax::prelude::*;
// 1. Define network architecture
let shape = NeuralShape::input(784)
.layer(256, Activation::ReLU)
.layer(128, Activation::ReLU)
.output(10, Activation::Softmax);
// 2. Create network
let mut nn = NeuralNetwork::new(shape);
// 3. Prepare training data
let data = TrainData::from_pairs(&[
(&[0.0, 0.0], &[0.0]),
(&[0.0, 1.0], &[1.0]),
(&[1.0, 0.0], &[1.0]),
(&[1.0, 1.0], &[0.0]),
]);
// 4. Configure and train
let config = TrainConfig::new()
.learning_rate(0.01)
.epochs(100)
.batch_size(32)
.momentum(0.9)
.verbose(10);
let history = nn.train(&data, config);
// 5. Make predictions
let output = nn.predict(&[0.0, 1.0]);Enable GPU support with the gpu feature:
[dependencies]
iax = { version = "0.1", features = ["gpu"] }Then use the GPU backend:
use iax::prelude::*;
use iax::gpu::WgpuBackend;
// Initialize GPU backend
if let Some(backend) = WgpuBackend::new() {
let mut nn = NeuralNetwork::with_backend(shape, backend);
// Use train_batched for optimal GPU performance
use iax::train::Loss;
let history = nn.train_batched(&data, &config, &Loss::CrossEntropy);
}GPU Performance: Expect 10-100x speedups on large networks. Small networks (<100 neurons) may be slower due to GPU overhead.
Add to your Cargo.toml:
[dependencies]
iax = "0.1"
# Optional: GPU acceleration
[features]
default = []
gpu = ["iax/gpu"]- Windows: DirectX 12 compatible GPU
- macOS: Metal compatible GPU
- Linux: Vulkan drivers (e.g., Mesa for Intel/AMD, NVIDIA proprietary drivers)
- Web: WebGPU compatible browser
Use the builder pattern to define your network:
let shape = NeuralShape::input(784) // Input layer: 784 features
.layer(256, Activation::ReLU) // Hidden layer 1: 256 neurons, ReLU
.layer(128, Activation::ReLU) // Hidden layer 2: 128 neurons, ReLU
.output(10, Activation::Softmax); // Output layer: 10 classes, SoftmaxActivation::Linear- Identity functionActivation::ReLU- Rectified Linear UnitActivation::LeakyReLU(alpha)- Leaky ReLU with configurable alphaActivation::Sigmoid- Sigmoid functionActivation::Tanh- Hyperbolic tangentActivation::Softmax- Softmax (for classification output layers)
let config = TrainConfig::new()
.learning_rate(0.01) // Learning rate (default: 0.001)
.epochs(100) // Number of training epochs
.batch_size(32) // Mini-batch size (default: full batch)
.full_batch() // Use all samples per update
.momentum(0.9) // Momentum coefficient (0.0 = no momentum)
.l2(0.0001) // L2 regularization strength
.validation_split(0.1) // Fraction of data for validation
.verbose(10); // Print progress every N epochsCPU Training (standard):
let history = nn.train(&data, config); // MSE loss
let history = nn.train_classifier(&data, config); // Cross-entropy loss
let history = nn.train_with_loss(&data, config, Loss::MSE); // Custom lossGPU Batched Training (optimal for GPU):
use iax::train::Loss;
let history = nn.train_batched(&data, &config, &Loss::CrossEntropy);With Custom Transformer:
let transformer = MyCustomTransformer;
let history = nn.train_with_transformer(&data, config, &transformer);Loss::MSE- Mean Squared Error (for regression)Loss::CrossEntropy- Cross-Entropy (for classification with softmax)Loss::BinaryCrossEntropy- Binary Cross-Entropy (for binary classification)
let data = TrainData::from_pairs(&[
(&[0.0, 0.0], &[0.0]),
(&[0.0, 1.0], &[1.0]),
(&[1.0, 0.0], &[1.0]),
(&[1.0, 1.0], &[0.0]),
]);let inputs = vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0]; // Flattened
let targets = vec![0.0, 1.0, 1.0, 0.0]; // Flattened
let data = TrainData::new(inputs, targets, 2, 1); // (inputs, targets, input_size, output_size)See examples/mnist_csv.rs for CSV loading examples.
GPU backends can estimate maximum batch size based on available memory:
if let Some(backend) = WgpuBackend::new() {
let max_batch = backend.estimate_max_batch_size(&shape);
println!("Maximum batch size: {}", max_batch);
let limits = backend.limits();
println!("GPU buffer limit: {}MB", limits.max_buffer_size / 1024 / 1024);
}GPU operations are automatically batched for optimal performance. The backend accumulates multiple operations into a single command buffer, reducing CPU-GPU synchronization overhead.
cargo run --example xorDemonstrates training a simple network to learn the XOR function.
Synthetic Data:
cargo run --example mnist --features gpuReal CSV Data:
# First, download MNIST CSV files to examples/data/
cargo run --example mnist_csv --features gpu --releaseEmbedded Quick Test:
cargo run --example mnist_embedded --features gpucargo run --example benchmark --features gpu --releaseCompares CPU vs GPU performance on various network sizes.
cargo run --example typed_ioDemonstrates type-safe input/output using the InputSpec/OutputSpec traits.
Implement the Transformer trait for custom loss computation:
use iax::io::Transformer;
struct MyTransformer;
impl Transformer for MyTransformer {
fn loss(&self, predicted: &[f32], target: &[f32]) -> f32 {
// Custom loss computation
predicted.iter()
.zip(target.iter())
.map(|(p, t)| (p - t).abs())
.sum()
}
fn gradient(&self, predicted: &[f32], target: &[f32]) -> Vec<f32> {
// Custom gradient computation
predicted.iter()
.zip(target.iter())
.map(|(p, t)| if p > *t { 1.0 } else { -1.0 })
.collect()
}
}Use InputSpec and OutputSpec for compile-time type checking:
use iax::io::{InputSpec, OutputSpec, InputData, OutputData};
#[derive(Clone)]
struct MyInput {
features: Vec<f32>,
}
impl InputData for MyInput {
fn to_f32(&self) -> Vec<f32> {
self.features.clone()
}
}
impl InputSpec for MyInput {
type Data = MyInput;
}
// Use with network
let input = MyInput { features: vec![0.0, 1.0] };
let output: Vec<f32> = nn.predict_typed(&input);// Get weights for a specific layer (downloads to CPU)
if let Some(weights) = nn.get_layer_weights(1) {
println!("Layer 1 weights: {:?}", weights);
}
// Set a specific weight
nn.set_weight(1, 0, 0, 0.5)?; // layer, from_neuron, to_neuron, value
// Save/load weights
let bytes = nn.save_weights();
nn.load_weights(&bytes)?;- Use GPU for Large Networks: Networks with >1000 neurons benefit significantly from GPU acceleration
- Batch Size: Larger batches improve GPU utilization but require more memory
- Full Batch Training: On GPU,
train_batched()with full batch is often faster than mini-batches - Learning Rate: Start with 0.01-0.1 for full-batch training, 0.001-0.01 for mini-batch
- Momentum: Use 0.9 for most cases, helps with convergence
- Backend Trait: Abstract computation operations (
Backend) - NeuralNetwork: Generic over backend type
- Zero-Copy Design: Backend owns all buffers, no CPU-GPU copies during training
- Command Batching: GPU operations are batched for optimal throughput
MIT
Contributions welcome! Please open an issue or pull request.