API Reference
Main Package
Redback-JAX: A lightweight JAX-only version of the redback electromagnetic transient analysis package.
This package provides JAX-based implementations for electromagnetic transient modeling and Bayesian inference, focusing on performance and automatic differentiation capabilities.
Models
JAX-based transient models for electromagnetic counterparts.
- redback_jax.models.arnett_bolometric(time, f_nickel, mej, *, vej=None, kappa=None, kappa_gamma=None)[source]
Bolometric Arnett (1982) light curve with Ni/Co decay engine + diffusion.
- Parameters:
time – time in days
f_nickel – fraction of nickel mass
mej – total ejecta mass in solar masses
kappa – opacity in cm^2/g (required)
kappa_gamma – gamma-ray opacity in cm^2/g (required)
vej – ejecta velocity in km/s (required)
- Returns:
log10 of bolometric luminosity in erg/s
- redback_jax.models.arnett_with_features_cosmology(f_nickel, mej, *, redshift=0.0, cosmo_H0=PLANCK18_H0, cosmo_Om0=PLANCK18_OM0, vej=None, kappa=None, kappa_gamma=None, temperature_floor=None, features=NO_SED_FEATURES)[source]
Arnett model with cosmological luminosity distance calculation.
- Parameters:
redshift – source redshift
f_nickel – fraction of nickel mass
mej – total ejecta mass in solar masses
cosmo_H0 – Hubble constant (km/s/Mpc)
cosmo_Om0 – matter density parameter
kappa – opacity in cm^2/g (required)
kappa_gamma – gamma-ray opacity in cm^2/g (required)
vej – ejecta velocity in km/s (required)
temperature_floor – floor temperature in K
features – SEDFeatures object
- Returns:
namedtuple(time, lambdas, spectra)
- redback_jax.models.blackbody_to_flux_density(temperature, r_photosphere, dl, frequency)[source]
A general blackbody_to_flux_density formula
- Parameters:
temperature – effective temperature in kelvin
r_photosphere – photosphere radius in cm
dl – luminosity_distance in cm
frequency – frequency to calculate in Hz
- Returns:
flux_density in erg/s/Hz/cm^2
- redback_jax.models.magnetar_powered_bolometric(time, p0, bp, mass_ns, theta_pb, mej, kappa, kappa_gamma, vej)[source]
Bolometric light curve of a magnetar-powered supernova (Arnett diffusion).
- Parameters:
time – source-frame time in days
p0 – initial spin period in milliseconds
bp – polar B-field in units of 10^14 G
mass_ns – NS mass in solar masses
theta_pb – spin–B-field angle in radians
mej – ejecta mass in solar masses
kappa – optical opacity in cm^2/g
kappa_gamma – gamma-ray opacity in cm^2/g
vej – ejecta velocity in km/s
- Returns:
log10 of bolometric luminosity in erg/s
- redback_jax.models.magnetar_nickel_bolometric(time, f_nickel, mej, p0, bp, mass_ns, theta_pb, kappa, kappa_gamma, vej)[source]
Bolometric light curve powered by both a magnetar and Ni/Co radioactive decay (Arnett diffusion). The two luminosity sources are added before diffusion.
Reference: Gomez et al. 2018 (https://ui.adsabs.harvard.edu/abs/2018ApJS..236….6G/abstract)
- Parameters:
time – source-frame time in days
f_nickel – nickel mass fraction (M_Ni = f_nickel * mej)
mej – total ejecta mass in solar masses
p0 – initial spin period in milliseconds
bp – polar B-field in units of 10^14 G
mass_ns – NS mass in solar masses
theta_pb – spin–B-field angle in radians
kappa – optical opacity in cm^2/g
kappa_gamma – gamma-ray opacity in cm^2/g
vej – ejecta velocity in km/s
- Returns:
log10 of bolometric luminosity in erg/s
- redback_jax.models.csm_interaction_bolometric(time, mej, csm_mass, vej, eta, rho, kappa, r0, nn=12, delta=1, efficiency=0.5)[source]
Bolometric CSM-interaction light curve (Chevalier 1982 shocks + diffusion).
- Parameters:
time – source-frame time in days
mej – ejecta mass in solar masses
csm_mass – CSM mass in solar masses
vej – ejecta velocity in km/s
eta – CSM density profile exponent
rho – CSM density amplitude in g/cm^3
kappa – opacity in cm^2/g
r0 – inner CSM radius in AU
nn – ejecta density power-law slope (default 12)
delta – inner ejecta density slope (default 1)
efficiency – kinetic-to-luminosity efficiency (default 0.5)
- Returns:
log10 of bolometric luminosity in erg/s
- redback_jax.models.shock_cooling_bolometric(time, log10_mass, log10_radius, log10_energy, nn, delta, kappa)[source]
Bolometric shock-cooling light curve following Piro (2021).
All large quantities (mass, radius, energy) are passed as log10 to stay float32-safe throughout; all intermediate computations stay in log10 space.
- Parameters:
time – source-frame time in days
log10_mass – log10 envelope mass in solar masses
log10_radius – log10 envelope radius in cm
log10_energy – log10 explosion energy in erg
nn – outer density power-law slope
delta – inner density power-law slope
kappa – opacity in cm^2/g
- Returns:
log10 of bolometric luminosity in erg/s
- redback_jax.models.shocked_cocoon_bolometric(time, mej, vej, eta, tshock, shocked_fraction, cos_theta_cocoon, kappa)[source]
Bolometric light curve of a shocked jet cocoon (Piro & Kollmeier 2018). All large intermediate quantities computed in log10 space for float32 safety.
- Parameters:
time – source-frame time in days
mej – ejecta mass in solar masses
vej – ejecta velocity in units of c (speed of light)
eta – ejecta density power-law slope
tshock – shock time in seconds
shocked_fraction – fraction of ejecta mass that is shocked
cos_theta_cocoon – cosine of cocoon opening half-angle
kappa – gray opacity in cm^2/g
- Returns:
log10 of bolometric luminosity in erg/s
- redback_jax.models.shock_cooling_and_arnett_bolometric(time, log10_mass, log10_radius, log10_energy, nn, delta, f_nickel, mej, vej, kappa, kappa_gamma)[source]
Combined bolometric light curve: shock cooling (Piro 2021) + Ni/Co decay (Arnett 1982).
The two components are summed in luminosity space (log-sum-exp), which is float32-safe.
- Parameters:
time – source-frame time in days
log10_mass – log10 envelope mass in solar masses
log10_radius – log10 envelope radius in cm
log10_energy – log10 explosion energy in erg
nn – outer density power-law slope
delta – inner density power-law slope
f_nickel – nickel mass fraction
mej – total ejecta mass in solar masses
vej – ejecta velocity in km/s
kappa – optical opacity in cm^2/g
kappa_gamma – gamma-ray opacity in cm^2/g
- Returns:
log10 of bolometric luminosity in erg/s
- redback_jax.models.tde_analytical_bolometric(time, log10_l0, t_0_turn, mej, vej, kappa, kappa_gamma)[source]
Bolometric TDE light curve: t^{-5/3} fallback engine + Arnett diffusion.
- Parameters:
time – source-frame time in days
log10_l0 – log10 of bolometric luminosity at 1 second in erg/s
t_0_turn – turn-on time in days
mej – ejecta mass in solar masses
vej – ejecta velocity in km/s
kappa – optical opacity in cm^2/g
kappa_gamma – gamma-ray opacity in cm^2/g
- Returns:
log10 of bolometric luminosity in erg/s
- redback_jax.models.metzger_kilonova_bolometric(time, mej, vej, beta, kappa, vmax=0.7, neutron_precursor=True)[source]
Bolometric kilonova light curve (Metzger 2017) with 200 shells and Barnes+16 thermalisation, solved via a sequential Euler ODE with jax.lax.scan.
- Parameters:
time – source-frame time in days (must be strictly increasing, ≥2 points)
mej – ejecta mass in solar masses
vej – minimum ejecta velocity in units of c
beta – velocity power-law slope (M ∝ v^{-beta})
kappa – gray opacity in cm^2/g
vmax – maximum ejecta velocity in units of c (default 0.7)
neutron_precursor – include neutron precursor emission (default True)
- Returns:
log10 of bolometric luminosity in erg/s
- redback_jax.models.magnetar_boosted_kilonova_bolometric(time, mej, vej, beta, kappa, p0, bp, mass_ns, theta_pb, thermalisation_efficiency=1.0, vmax=0.7, neutron_precursor=True)[source]
Bolometric kilonova light curve with magnetar spin-down energy injection.
- Parameters:
time – source-frame time in days (strictly increasing, ≥2 points)
mej – ejecta mass in solar masses
vej – minimum ejecta velocity in units of c
beta – velocity power-law slope
kappa – gray opacity in cm^2/g
p0 – initial spin period in milliseconds
bp – polar B-field in units of 10^14 G
mass_ns – neutron star mass in solar masses
theta_pb – angle between spin and B-field axes in radians
thermalisation_efficiency – fraction of magnetar luminosity thermalised (default 1.0)
vmax – maximum ejecta velocity in units of c (default 0.7)
neutron_precursor – include neutron precursor emission (default True)
- Returns:
log10 of bolometric luminosity in erg/s
- redback_jax.models.make_spectra_model(bolometric_fn)[source]
Wrap a bolometric model function to produce a full spectra model.
The returned function has signature:
spectra_model(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs) -> namedtuple(time, lambdas, spectra)
- Parameters:
bolometric_fn (callable) – Any function
f(time_days, **kwargs) -> log10_lbol(log10 erg/s).time_daysmust be its first positional argument.- Returns:
A spectra model with the same photosphere/SED pipeline.
- Return type:
callable
- redback_jax.models.arnett_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)
Spectra model wrapping
arnett_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fieldstime(days),lambdas(Angstrom),spectra(erg/s/cm^2/Angstrom)n
- redback_jax.models.magnetar_powered_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)
Spectra model wrapping
magnetar_powered_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fieldstime(days),lambdas(Angstrom),spectra(erg/s/cm^2/Angstrom)n
- redback_jax.models.magnetar_nickel_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)
Spectra model wrapping
magnetar_nickel_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fieldstime(days),lambdas(Angstrom),spectra(erg/s/cm^2/Angstrom)n
- redback_jax.models.csm_interaction_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)
Spectra model wrapping
csm_interaction_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fieldstime(days),lambdas(Angstrom),spectra(erg/s/cm^2/Angstrom)n
- redback_jax.models.shock_cooling_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)
Spectra model wrapping
shock_cooling_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fieldstime(days),lambdas(Angstrom),spectra(erg/s/cm^2/Angstrom)n
- redback_jax.models.shocked_cocoon_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)
Spectra model wrapping
shocked_cocoon_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fieldstime(days),lambdas(Angstrom),spectra(erg/s/cm^2/Angstrom)n
- redback_jax.models.shock_cooling_and_arnett_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)
Spectra model wrapping
shock_cooling_and_arnett_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fieldstime(days),lambdas(Angstrom),spectra(erg/s/cm^2/Angstrom)n
- redback_jax.models.tde_analytical_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)
Spectra model wrapping
tde_analytical_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fieldstime(days),lambdas(Angstrom),spectra(erg/s/cm^2/Angstrom)n
- redback_jax.models.metzger_kilonova_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)
Spectra model wrapping
metzger_kilonova_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fieldstime(days),lambdas(Angstrom),spectra(erg/s/cm^2/Angstrom)n
- redback_jax.models.magnetar_boosted_kilonova_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)
Spectra model wrapping
magnetar_boosted_kilonova_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fieldstime(days),lambdas(Angstrom),spectra(erg/s/cm^2/Angstrom)n
- class redback_jax.models.SEDFeatures(rest_wavelengths, sigmas, amplitudes, t_starts, t_ends, rise_times=2.0 * 24.0 * 3600.0, fall_times=5.0 * 24.0 * 3600.0)[source]
Bases:
objectA class representing a spectral feature(s) in the SED. This implements the PyTree interface for JAX, so it can be used in JIT-compiled functions.
- Parameters:
rest_wavelengths – Central wavelengths in Angstroms
sigmas – Gaussian widths in Angstroms
amplitudes – Amplitudes (negative=absorption, positive=emission), percentage of continuum (e.g., -0.4 = 40% absorption)
t_starts – Start time in seconds
t_ends – End time in seconds
rise_times – Rise times in seconds (default: 2 days)
fall_times – Fall times in seconds (default: 5 days)
- redback_jax.models.apply_sed_feature(features, base_flux, frequency, time)[source]
Apply spectral features completely vectorized.
- Parameters:
features – SEDFeatures object
:param base_flux as a 2-d array (time, wavelength) in erg/s/Hz/cm^2 :param frequency: frequency to calculate in Hz - Must be same length as time array or a single number.
In source frame.
- Parameters:
time – time array in seconds (source frame).
- Returns:
modified flux_density as a 2-d array (time, wavelength) in erg/s/Hz/cm^2
- redback_jax.models.register_model(name, fn)[source]
Register a model function under name.
- Parameters:
name (str)
Utilities
Utility functions for redback-jax.
Inference
JAX-based Bayesian inference tools using BlackJAX.
Clean API
from redback_jax.inference import Prior, Uniform, LogUniform, Gaussian from redback_jax.inference import SpectralLikelihood, NestedSampler, NSResult
Legacy API (kept for backward compatibility)
- from redback_jax.inference import (
SamplerResult, create_uniform_prior, create_gaussian_likelihood, run_nested_sampling, run_mcmc, fit_transient, HAS_BLACKJAX,
)
- class redback_jax.inference.Uniform(minimum, maximum, *, name)[source]
Bases:
_DistributionUniform prior between minimum and maximum.
- Parameters:
Examples
>>> p = Uniform(0.05, 0.2, name='f_nickel') >>> p.log_prob(jnp.array(0.1)) # log(1 / (0.2 - 0.05))
- class redback_jax.inference.LogUniform(minimum, maximum, *, name)[source]
Bases:
_DistributionLog-uniform (Jeffreys) prior between minimum and maximum.
The density is proportional to 1/x, normalised over [minimum, maximum].
- Parameters:
Examples
>>> p = LogUniform(1e-2, 1e2, name='kappa')
- class redback_jax.inference.Gaussian(mu, sigma, *, name, minimum=-jnp.inf, maximum=jnp.inf)[source]
Bases:
_DistributionGaussian prior (truncated to finite support by the sampler’s hard bounds).
- Parameters:
- class redback_jax.inference.Prior(distributions)[source]
Bases:
objectA composite prior built from a list of 1-D distributions.
- Parameters:
distributions (list of _Distribution) – One distribution per free parameter. The list order determines the column order of the parameter vector passed to the likelihood.
Examples
>>> prior = Prior([ ... Uniform(58580, 58620, name='t0'), ... Uniform(0.05, 0.20, name='f_nickel'), ... Uniform(0.8, 2.0, name='mej'), ... Uniform(3000, 8000, name='vej'), ... ]) >>> particles = prior.sample_n(jax.random.PRNGKey(0), 100) # (100, 4) >>> log_p = prior.log_prob(particles[0]) # scalar
- log_prob(params)[source]
Evaluate the joint log-prior for a parameter vector of shape
(n_params,).This is JAX-traceable and can be used inside
@jax.jit.
- class redback_jax.inference.Likelihood(model, transient, fixed_params, t0_key='t0')[source]
Bases:
objectGaussian photometric likelihood using a spectra model pipeline.
- Parameters:
model (str or callable) – Model name (e.g.
'arnett_spectra') or a callable with signaturef(redshift, lum_dist, vej, temperature_floor, **kwargs)returning a namedtuple(time, lambdas, spectra).transient (Transient) – Data container with
.time,.y,.y_err,.bands.fixed_params (dict) – Parameters held fixed during inference — everything the model needs that is not a free parameter in the prior.
t0_key (str or None, optional) – Name of the MJD explosion-time free parameter (default
't0'). When present in the prior the likelihood convertstransient.timefrom observer-frame MJD to source-frame days automatically. Set toNoneif times are already in source-frame days.
- class redback_jax.inference.NestedSampler(likelihood, prior, outdir='results/', n_live=125, n_delete=20, num_mcmc_steps_multiplier=5, termination_dlogz=-3.0, verbose=True)[source]
Bases:
objectBlackJAX nested sampler with a clean redback-style interface.
- Parameters:
likelihood (Likelihood) – A
Likelihoodinstance.prior (Prior) – Composite prior object.
outdir (str, optional) – Directory for output files. Created if it does not exist. Set to
Noneto disable file output.n_live (int, optional) – Number of live points (default 125).
n_delete (int, optional) – Number of points to remove per iteration (default 20).
num_mcmc_steps_multiplier (int, optional) – MCMC steps per iteration =
n_params × multiplier(default 5).termination_dlogz (float, optional) – Stop when
logZ_live - logZ < termination_dlogz(default -3).verbose (bool, optional) – Show a tqdm progress bar (default True).
Examples
>>> sampler = NestedSampler(likelihood, prior, outdir='results/') >>> result = sampler.run(jax.random.PRNGKey(42)) >>> result.summary()
- class redback_jax.inference.NSResult(logZ, samples, dead, log_weights, param_names)[source]
Bases:
objectContainer for nested sampling results.
- log_weights
Log importance weights (shape
(n_dead,)).- Type:
jnp.ndarray
- class redback_jax.inference.MCMCSampler(likelihood, prior, n_warmup=500, n_samples=2000, n_chains=4, step_size=0.05, verbose=True)[source]
Bases:
objectBlackJAX NUTS sampler with a clean redback-style interface.
The log-posterior is
log_likelihood + log_prior, evaluated in the original parameter space. A reflected boundary is used to keep samples inside the prior support.- Parameters:
likelihood (Likelihood) – A
Likelihoodinstance.prior (Prior) – Composite prior object.
n_warmup (int, optional) – Number of warmup (adaptation) steps per chain (default 500).
n_samples (int, optional) – Number of post-warmup samples per chain (default 2000).
n_chains (int, optional) – Number of independent chains (default 4).
step_size (float, optional) – Initial NUTS step size (default 0.05).
verbose (bool, optional) – Show a progress bar (default True).
Examples
>>> sampler = MCMCSampler(likelihood, prior, n_warmup=500, n_samples=2000) >>> result = sampler.run(jax.random.PRNGKey(0)) >>> result.summary()
- class redback_jax.inference.MCMCResult(samples_per_chain, param_names)[source]
Bases:
objectContainer for MCMC results.
- samples_per_chain
Samples per chain as
{name: jnp.ndarray}— shape(n_chains, n_samples).- Type:
- class redback_jax.inference.SamplerResult(samples, log_likelihoods, log_weights, log_evidence, log_evidence_error, metadata)[source]
Bases:
NamedTupleResults from nested sampling run.
- Parameters:
- log_likelihoods
Log likelihood values for each sample
- Type:
jnp.ndarray
- log_weights
Log weights for each sample (for nested sampling)
- Type:
jnp.ndarray
- redback_jax.inference.create_uniform_prior(prior_bounds)[source]
Create a uniform prior function from parameter bounds.
- Parameters:
prior_bounds (dict) – Dictionary mapping parameter names to (low, high) bounds
- Returns:
Function that transforms unit hypercube to parameter space
- Return type:
callable
- redback_jax.inference.create_gaussian_likelihood(model_fn, observed_data, errors, reduce_fn=None)[source]
Create a Gaussian likelihood function.
- Parameters:
model_fn (callable) – Function that takes parameter dict and returns model predictions
observed_data (jnp.ndarray) – Observed data array
errors (jnp.ndarray) – Error array (standard deviations)
reduce_fn (callable, optional) – Function to reduce data (e.g., for rescaling errors)
- Returns:
Log-likelihood function
- Return type:
callable
- redback_jax.inference.run_nested_sampling(loglikelihood_fn, prior_bounds, n_particles=500, num_mcmc_steps=20, max_iterations=100, rng_key=None, verbose=True)[source]
Run Sequential Monte Carlo (SMC) sampling using BlackJAX.
This uses adaptive tempered SMC which provides evidence estimates similar to nested sampling. The algorithm gradually increases the temperature from the prior to the posterior while tracking the normalizing constant.
- Parameters:
loglikelihood_fn (callable) – Log-likelihood function that takes parameter dict
prior_bounds (dict) – Dictionary mapping parameter names to (low, high) bounds
n_particles (int, optional) – Number of particles (default: 500)
num_mcmc_steps (int, optional) – Number of MCMC steps per iteration (default: 20)
max_iterations (int, optional) – Maximum number of temperature steps (default: 100)
rng_key (jax.Array, optional) – JAX random key (default: None, will create one)
verbose (bool, optional) – Print progress information (default: True)
- Returns:
Results from the SMC sampling run with evidence estimate
- Return type:
- Raises:
ImportError – If blackjax is not installed
- redback_jax.inference.run_mcmc(loglikelihood_fn, prior_bounds, n_samples=10000, n_warmup=1000, n_chains=4, step_size=0.01, rng_key=None, verbose=True)[source]
Run MCMC sampling using BlackJAX’s NUTS sampler.
- Parameters:
loglikelihood_fn (callable) – Log-likelihood function that takes parameter dict
prior_bounds (dict) – Dictionary mapping parameter names to (low, high) bounds
n_samples (int, optional) – Number of samples to draw (default: 10000)
n_warmup (int, optional) – Number of warmup/burnin steps (default: 1000)
n_chains (int, optional) – Number of parallel chains (default: 4)
step_size (float, optional) – Initial step size for NUTS (default: 0.01)
rng_key (jax.Array, optional) – JAX random key (default: None, will create one)
verbose (bool, optional) – Print progress information (default: True)
- Returns:
Results from the MCMC run
- Return type:
- Raises:
ImportError – If blackjax is not installed
- redback_jax.inference.fit_transient(transient, model_fn, prior_bounds, sampler='nested', sampler_kwargs=None, rng_key=None, verbose=True)[source]
Fit a transient model to observational data.
This is the main high-level interface for parameter inference, following the redback API style.
- Parameters:
transient (Transient) – Transient object with observational data
model_fn (callable) – Model function that takes parameter dict and returns model predictions. Should be compatible with the transient’s data structure.
prior_bounds (dict) – Dictionary mapping parameter names to (low, high) bounds
sampler (str, optional) – Sampler to use: “nested” for nested sampling, “mcmc” for NUTS (default: “nested”)
sampler_kwargs (dict, optional) – Additional keyword arguments for the sampler
rng_key (jax.Array, optional) – JAX random key
verbose (bool, optional) – Print progress information
- Returns:
Results from the sampling run
- Return type:
Examples
>>> from redback_jax import Transient >>> from redback_jax.sources import PrecomputedSpectraSource >>> import jax.numpy as jnp >>> >>> # Create transient data >>> transient = Transient( ... name='test_sn', ... times=jnp.array([0, 5, 10, 15, 20]), ... magnitudes=jnp.array([18.0, 17.5, 17.0, 17.5, 18.0]), ... magnitude_errors=jnp.array([0.1, 0.1, 0.1, 0.1, 0.1]), ... bands=['g'] * 5 ... ) >>> >>> # Create model function >>> source = PrecomputedSpectraSource.from_arnett_model(...) >>> bridges, band_to_idx = source.prepare_bridges(['g']) >>> band_indices = jnp.array([0, 0, 0, 0, 0]) >>> >>> def model_fn(params): ... return source.bandmag(params, None, transient.times, ... band_indices=band_indices, ... bridges=bridges, ... unique_bands=['g']) >>> >>> # Define priors >>> prior_bounds = { ... 'amplitude': (0.1, 10.0), ... } >>> >>> # Run inference >>> result = fit_transient(transient, model_fn, prior_bounds) >>> print(f"Log evidence: {result.log_evidence:.2f}")
- redback_jax.inference.to_anesthetic_samples(result)[source]
Convert SamplerResult to anesthetic NestedSamples.
- Parameters:
result (SamplerResult) – Results from nested sampling
- Returns:
Anesthetic samples object with posterior samples
- Return type:
anesthetic.NestedSamples or anesthetic.Samples
- Raises:
ImportError – If anesthetic is not installed
- redback_jax.inference.summarize_result(result)[source]
Summarize sampling results with posterior statistics.
- Parameters:
result (SamplerResult) – Results from sampling
- Returns:
Dictionary with parameter statistics (mean, std, median, etc.)
- Return type: