JAX/Flax implementation of cosmological emulators with automatic JIT compilation.
pip install -e .import jaxace
import jax.numpy as jnp
# Cosmology
cosmo = jaxace.w0waCDMCosmology(
ln10As=3.044, ns=0.9649, h=0.6736,
omega_b=0.02237, omega_c=0.1200,
m_nu=0.06, w0=-1.0, wa=0.0
)
# Background functions
z = jnp.array([0.0, 0.5, 1.0])
growth = jaxace.D_z_from_cosmo(z, cosmo)
distance = jaxace.r_z_from_cosmo(z, cosmo)
# Neural network emulator
emulator = jaxace.init_emulator(nn_dict, weights, jaxace.FlaxEmulator)
output = emulator(input_data) # Auto-JIT + batch detection- Background cosmology (growth, distances, Hubble)
- Neural network emulators with auto-JIT
- Massive neutrinos and dark energy support
- Full JAX integration (grad, vmap, jit)