Skip to content

CosmologicalEmulators/jaxace

Repository files navigation

jaxace

Tests Documentation Documentation codecov

JAX/Flax implementation of cosmological emulators with automatic JIT compilation.

Installation

pip install -e .

Usage

import jaxace
import jax.numpy as jnp

# Cosmology
cosmo = jaxace.w0waCDMCosmology(
    ln10As=3.044, ns=0.9649, h=0.6736,
    omega_b=0.02237, omega_c=0.1200,
    m_nu=0.06, w0=-1.0, wa=0.0
)

# Background functions
z = jnp.array([0.0, 0.5, 1.0])
growth = jaxace.D_z_from_cosmo(z, cosmo)
distance = jaxace.r_z_from_cosmo(z, cosmo)

# Neural network emulator
emulator = jaxace.init_emulator(nn_dict, weights, jaxace.FlaxEmulator)
output = emulator(input_data)  # Auto-JIT + batch detection

Features

  • Background cosmology (growth, distances, Hubble)
  • Neural network emulators with auto-JIT
  • Massive neutrinos and dark energy support
  • Full JAX integration (grad, vmap, jit)

About

No description, website, or topics provided.

Resources

License

Contributing

Stars

Watchers

Forks

Packages

No packages published

Contributors 2

  •  
  •  

Languages