This repository provides an introduction to JAX, a high-performance numerical computing library from Google Research.
It is meant for practitioners that are already familiar with numpy and PyTorch.
-
00_from_zero_to_noob.ipynb: An introductory tutorial notebook that covers the essentials:- Numpy-style array operations
- Automatic differentiation and optimization
- Parallel execution with vmap
- Introduction to PyTrees
- Just-In-Time compilation with jit
- Random numbers
- Introduction to flax
- Recurrent nets with scan
-
01_meta_learning_maml.ipynbMeta learning with MAML, implementation sketch -
02_meta_learning_hypernet.ipynbMeta learning with Hypernets, implementation sketch
Worked out examples are available in the gallery folder.
- Make sure you have Python installed
- Install dependencies:
pip install jax jaxlib jupyter- Launch Jupyter Notebook:
jupyter notebook- Open
00_from_zero_to_noob.ipynbto begin learning JAX