Skip to content

nic-barbara/R2DN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Robust Neural Networks in JAX

This repository contains the code associated with our upcoming paper R2DN: Scalable Parameterization of Contracting and Lipschitz Recurrent Deep Networks (Barbara, Wang, & Manchester, 2025).

Included are JAX implementations of each of the following robust neural models:

Robust neural models are included in the robustnn/ directory. Scripts used to generate the results in the paper are in the examples/ directory.

Installation and Usage

To install the required dependencies and run the code, open a terminal in the root directory of this repository and enter the following commands.

./install.sh
./run.sh

This will create a Python virtual environment and run all the experiments, process the results, and reproduce the figures from the paper.

A Note on Dependencies

All code was tested and developed in Ubuntu 22.04 with CUDA 12.4 and Python 3.10.12.

Requirements were generated with pipreqs. The install.sh script assumes the user is running JAX on an NVIDIA GPU with CUDA 12 already installed. If no GPU is available, simply remove the [cuda12_pip] option when installing JAX. If you have a GPU that is not running CUDA (or a different CUDA version), edit the installation command accordingly.

About

JAX Implementations of Robust Neural Networks

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published