Examples
Bolometric inference with BlackJAX
Fit an Arnett model to observed bolometric luminosities using NUTS via BlackJAX.
Because all bolometric models return log10_lbol, the likelihood is computed
in log10 space — this is both float32-safe and numerically well-conditioned.
import jax
jax.config.update("jax_enable_x64", False)
import jax.numpy as jnp
import blackjax
from redback_jax.models.supernova_models import arnett_bolometric
# Simulated observations
time = jnp.linspace(5.0, 60.0, 30, dtype=jnp.float32)
true_params = dict(f_nickel=0.4, mej=1.2, vej=9000.0, kappa=0.1, kappa_gamma=10.0)
log10_lbol_true = arnett_bolometric(time, **true_params)
log10_lbol_obs = log10_lbol_true + 0.05 * jax.random.normal(jax.random.PRNGKey(0), time.shape)
sigma = jnp.full_like(time, 0.05)
# Log-likelihood in log10 space
@jax.jit
def log_likelihood(params):
log10_lbol = arnett_bolometric(time, **params,
vej=9000.0, kappa=0.1, kappa_gamma=10.0)
return -0.5 * jnp.sum(((log10_lbol_obs - log10_lbol) / sigma)**2)
Photometry fitting with the inference API
Use the clean Prior / Likelihood / NestedSampler / MCMCSampler
API for end-to-end Bayesian photometric fitting. The Likelihood class
handles bandflux integration internally and is JIT-safe.
import jax
from redback_jax.inference import Prior, Uniform, Likelihood, NestedSampler, MCMCSampler
from redback_jax.utils import luminosity_distance_cm
REDSHIFT = 0.01
DL_CM = luminosity_distance_cm(REDSHIFT)
prior = Prior([
Uniform(58580, 58620, name='t0'),
Uniform(0.05, 0.30, name='f_nickel'),
Uniform(0.5, 3.0, name='mej'),
Uniform(3000, 12000, name='vej'),
])
# transient.time (MJD), transient.y (AB mag), transient.y_err, transient.bands
likelihood = Likelihood(
model='arnett_spectra',
transient=transient,
fixed_params={
'redshift': REDSHIFT,
'lum_dist': DL_CM,
'temperature_floor': 5000.0,
'kappa': 0.07,
'kappa_gamma': 0.1,
},
)
# Nested sampling
ns_result = NestedSampler(likelihood, prior, outdir='results/').run(jax.random.PRNGKey(0))
ns_result.summary()
# Or MCMC with NUTS
mcmc_result = MCMCSampler(
likelihood, prior, n_warmup=500, n_samples=2000, n_chains=4
).run(jax.random.PRNGKey(1))
mcmc_result.summary()
Kilonova bolometric fitting
The kilonova models use an energy-normalised ODE scan for float32 safety.
They also return log10_lbol:
import jax
jax.config.update("jax_enable_x64", False)
import jax.numpy as jnp
from redback_jax.models.kilonova import metzger_kilonova_bolometric
time = jnp.linspace(0.5, 20.0, 50, dtype=jnp.float32)
log10_lbol = metzger_kilonova_bolometric(
time,
mej=0.05, # solar masses
vej=0.2, # fraction of c
beta=3.0, # velocity profile slope
kappa=1.0, # cm^2/g
)
# Typical range: log10_lbol ~ [39, 42]
Magnetar-boosted kilonova:
from redback_jax.models.kilonova import magnetar_boosted_kilonova_bolometric
log10_lbol = magnetar_boosted_kilonova_bolometric(
time,
mej=0.05, vej=0.2, beta=3.0, kappa=1.0,
p0=1.0, # spin period in ms
bp=1.0, # B-field in units of 1e14 G
mass_ns=1.4, # neutron star mass in solar masses
theta_pb=0.0, # spin-B field angle in radians
)
Shock-powered models
Shock cooling and cocoon models also return log10_lbol. Note that
shock_cooling_bolometric takes log10 inputs for the large physical
quantities:
import jax
jax.config.update("jax_enable_x64", False)
import jax.numpy as jnp
from redback_jax.models.shock_powered_models import (
shock_cooling_bolometric,
shocked_cocoon_bolometric,
)
time = jnp.linspace(0.1, 10.0, 50, dtype=jnp.float32)
# Inputs are log10(mass/Msun), log10(radius/cm), log10(energy/erg)
log10_lbol = shock_cooling_bolometric(
time,
log10_mass=-2.0, # 0.01 solar masses
log10_radius=13.0, # 1e13 cm
log10_energy=51.0, # 1e51 erg
nn=10.0, # outer density power-law slope
delta=1.1, # inner density power-law slope
kappa=0.2, # opacity in cm^2/g
)
log10_lbol = shocked_cocoon_bolometric(
time,
mej=0.01,
vej=0.1, # fraction of c
eta=2.0,
tshock=1.0, # seconds
shocked_fraction=0.5,
cos_theta_cocoon=0.5,
kappa=0.1,
)