Quick Start
Installation
git clone https://github.com/nikhil-sarin/redback-jax.git
cd redback-jax
pip install -e .
Float32 mode (recommended for GPU)
Redback-JAX is designed to run in float32. Disable x64 before importing any JAX code:
import jax
jax.config.update("jax_enable_x64", False)
Do not import jax_supernovae (jax-bandflux) before this call — it
enables x64 at module level. Redback-JAX lazy-imports bandflux components to
avoid this.
Bolometric light curves
All bolometric functions return log10(L) in erg/s — not linear luminosity.
This is deliberate: physical luminosities (~10³⁸–10⁴⁵ erg/s) exceed the
float32 maximum of ~3.4×10³⁸, so working in log10 space is the only way to
stay float32-safe on GPU.
import jax
jax.config.update("jax_enable_x64", False)
import jax.numpy as jnp
from redback_jax.models.supernova_models import arnett_bolometric
time = jnp.linspace(1.0, 100.0, 200, dtype=jnp.float32)
log10_lbol = arnett_bolometric(
time,
f_nickel=0.5,
mej=1.0, # solar masses
vej=10000.0, # km/s
kappa=0.1, # cm^2/g
kappa_gamma=10.0,
)
# log10_lbol ~ [41, 43] (physical range, float32-safe)
Fitting bolometric data
Compare model and data in log10 space:
import jax.numpy as jnp
from redback_jax.models.supernova_models import arnett_bolometric
# Observed bolometric luminosities
log10_lbol_obs = jnp.log10(observed_lbol)
# Propagate fractional errors: sigma_{log10 L} = sigma_L / (L * ln10)
log10_lbol_err = sigma_lbol / (observed_lbol * jnp.log(10.0))
def log_likelihood(params):
log10_lbol_model = arnett_bolometric(time, **params)
return -0.5 * jnp.sum(((log10_lbol_obs - log10_lbol_model) / log10_lbol_err)**2)
Spectra and photometry
make_spectra_model wraps any bolometric model into a full SED pipeline
(photosphere → blackbody → observer-frame flux density):
import jax
jax.config.update("jax_enable_x64", False)
import jax.numpy as jnp
from redback_jax.models.supernova_models import arnett_bolometric
from redback_jax.models.spectra_model import make_spectra_model
arnett_spectra = make_spectra_model(arnett_bolometric)
output = arnett_spectra(
redshift=0.05,
lum_dist=7e26, # cm (~230 Mpc)
temperature_floor=3000.0,
# bolometric kwargs:
f_nickel=0.5, mej=1.0,
vej=10000.0, kappa=0.1, kappa_gamma=10.0,
)
output.time # observer-frame times (days)
output.lambdas # wavelengths (Angstrom)
output.spectra # (n_times, n_lambda) erg/s/cm^2/Angstrom
For bandflux fitting, pass the spectra grid to PrecomputedSpectraSource:
from redback_jax.sources import PrecomputedSpectraSource
source = PrecomputedSpectraSource(
phases=output.time,
wavelengths=output.lambdas,
flux_grid=output.spectra,
)
bridges, band_to_idx = source.prepare_bridges(['ztfg', 'ztfr'])
band_indices = jnp.array([band_to_idx[b] for b in observed_bands])
model_fluxes = source.bandflux(
{'amplitude': 1.0}, None, obs_times,
band_indices=band_indices, bridges=bridges,
unique_bands=['ztfg', 'ztfr'],
)
Available models
Supernovae
arnett_bolometric— Ni/Co decay + Arnett diffusion (Arnett 1982)magnetar_powered_bolometric— dipole spin-down + Arnett diffusioncsm_interaction_bolometric— forward/reverse shocks + CSM diffusion
TDE
tde_analytical_bolometric— t⁻⁵/³ fallback + Arnett diffusionNote: parameter is
log10_l0(notl0), because the linear value (~10⁴³ erg/s) overflows float32.
Shock-powered
shock_cooling_bolometric— Piro 2021Parameters
log10_mass,log10_radius,log10_energy(log10 of solar masses, cm, erg respectively), plusnn(outer density slope),delta(inner density slope),kappa(opacity in cm²/g).shocked_cocoon_bolometric— Piro & Kollmeier 2018
Kilonova
metzger_kilonova_bolometric— r-process ODE, 200-shell (Metzger 2017)magnetar_boosted_kilonova_bolometric— r-process + magnetar injection
All models are @jax.jit compiled and support jax.grad and jax.vmap.
Next steps
See Examples for complete inference examples
See API Reference for full function signatures