API Reference

Main Package

Redback-JAX: A lightweight JAX-only version of the redback electromagnetic transient analysis package.

This package provides JAX-based implementations for electromagnetic transient modeling and Bayesian inference, focusing on performance and automatic differentiation capabilities.

redback_jax.__getattr__(name)[source]

Lazy imports to avoid enabling JAX x64 at package load time.

Models

JAX-based transient models for electromagnetic counterparts.

redback_jax.models.arnett_bolometric(time, f_nickel, mej, *, vej=None, kappa=None, kappa_gamma=None)[source]

Bolometric Arnett (1982) light curve with Ni/Co decay engine + diffusion.

Parameters:
  • time – time in days

  • f_nickel – fraction of nickel mass

  • mej – total ejecta mass in solar masses

  • kappa – opacity in cm^2/g (required)

  • kappa_gamma – gamma-ray opacity in cm^2/g (required)

  • vej – ejecta velocity in km/s (required)

Returns:

log10 of bolometric luminosity in erg/s

redback_jax.models.arnett_with_features_cosmology(f_nickel, mej, *, redshift=0.0, cosmo_H0=PLANCK18_H0, cosmo_Om0=PLANCK18_OM0, vej=None, kappa=None, kappa_gamma=None, temperature_floor=None, features=NO_SED_FEATURES)[source]

Arnett model with cosmological luminosity distance calculation.

Parameters:
  • redshift – source redshift

  • f_nickel – fraction of nickel mass

  • mej – total ejecta mass in solar masses

  • cosmo_H0 – Hubble constant (km/s/Mpc)

  • cosmo_Om0 – matter density parameter

  • kappa – opacity in cm^2/g (required)

  • kappa_gamma – gamma-ray opacity in cm^2/g (required)

  • vej – ejecta velocity in km/s (required)

  • temperature_floor – floor temperature in K

  • features – SEDFeatures object

Returns:

namedtuple(time, lambdas, spectra)

redback_jax.models.blackbody_to_flux_density(temperature, r_photosphere, dl, frequency)[source]

A general blackbody_to_flux_density formula

Parameters:
  • temperature – effective temperature in kelvin

  • r_photosphere – photosphere radius in cm

  • dl – luminosity_distance in cm

  • frequency – frequency to calculate in Hz

Returns:

flux_density in erg/s/Hz/cm^2

redback_jax.models.magnetar_powered_bolometric(time, p0, bp, mass_ns, theta_pb, mej, kappa, kappa_gamma, vej)[source]

Bolometric light curve of a magnetar-powered supernova (Arnett diffusion).

Parameters:
  • time – source-frame time in days

  • p0 – initial spin period in milliseconds

  • bp – polar B-field in units of 10^14 G

  • mass_ns – NS mass in solar masses

  • theta_pb – spin–B-field angle in radians

  • mej – ejecta mass in solar masses

  • kappa – optical opacity in cm^2/g

  • kappa_gamma – gamma-ray opacity in cm^2/g

  • vej – ejecta velocity in km/s

Returns:

log10 of bolometric luminosity in erg/s

redback_jax.models.magnetar_nickel_bolometric(time, f_nickel, mej, p0, bp, mass_ns, theta_pb, kappa, kappa_gamma, vej)[source]

Bolometric light curve powered by both a magnetar and Ni/Co radioactive decay (Arnett diffusion). The two luminosity sources are added before diffusion.

Reference: Gomez et al. 2018 (https://ui.adsabs.harvard.edu/abs/2018ApJS..236….6G/abstract)

Parameters:
  • time – source-frame time in days

  • f_nickel – nickel mass fraction (M_Ni = f_nickel * mej)

  • mej – total ejecta mass in solar masses

  • p0 – initial spin period in milliseconds

  • bp – polar B-field in units of 10^14 G

  • mass_ns – NS mass in solar masses

  • theta_pb – spin–B-field angle in radians

  • kappa – optical opacity in cm^2/g

  • kappa_gamma – gamma-ray opacity in cm^2/g

  • vej – ejecta velocity in km/s

Returns:

log10 of bolometric luminosity in erg/s

redback_jax.models.csm_interaction_bolometric(time, mej, csm_mass, vej, eta, rho, kappa, r0, nn=12, delta=1, efficiency=0.5)[source]

Bolometric CSM-interaction light curve (Chevalier 1982 shocks + diffusion).

Parameters:
  • time – source-frame time in days

  • mej – ejecta mass in solar masses

  • csm_mass – CSM mass in solar masses

  • vej – ejecta velocity in km/s

  • eta – CSM density profile exponent

  • rho – CSM density amplitude in g/cm^3

  • kappa – opacity in cm^2/g

  • r0 – inner CSM radius in AU

  • nn – ejecta density power-law slope (default 12)

  • delta – inner ejecta density slope (default 1)

  • efficiency – kinetic-to-luminosity efficiency (default 0.5)

Returns:

log10 of bolometric luminosity in erg/s

redback_jax.models.shock_cooling_bolometric(time, log10_mass, log10_radius, log10_energy, nn, delta, kappa)[source]

Bolometric shock-cooling light curve following Piro (2021).

All large quantities (mass, radius, energy) are passed as log10 to stay float32-safe throughout; all intermediate computations stay in log10 space.

Parameters:
  • time – source-frame time in days

  • log10_mass – log10 envelope mass in solar masses

  • log10_radius – log10 envelope radius in cm

  • log10_energy – log10 explosion energy in erg

  • nn – outer density power-law slope

  • delta – inner density power-law slope

  • kappa – opacity in cm^2/g

Returns:

log10 of bolometric luminosity in erg/s

redback_jax.models.shocked_cocoon_bolometric(time, mej, vej, eta, tshock, shocked_fraction, cos_theta_cocoon, kappa)[source]

Bolometric light curve of a shocked jet cocoon (Piro & Kollmeier 2018). All large intermediate quantities computed in log10 space for float32 safety.

Parameters:
  • time – source-frame time in days

  • mej – ejecta mass in solar masses

  • vej – ejecta velocity in units of c (speed of light)

  • eta – ejecta density power-law slope

  • tshock – shock time in seconds

  • shocked_fraction – fraction of ejecta mass that is shocked

  • cos_theta_cocoon – cosine of cocoon opening half-angle

  • kappa – gray opacity in cm^2/g

Returns:

log10 of bolometric luminosity in erg/s

redback_jax.models.shock_cooling_and_arnett_bolometric(time, log10_mass, log10_radius, log10_energy, nn, delta, f_nickel, mej, vej, kappa, kappa_gamma)[source]

Combined bolometric light curve: shock cooling (Piro 2021) + Ni/Co decay (Arnett 1982).

The two components are summed in luminosity space (log-sum-exp), which is float32-safe.

Parameters:
  • time – source-frame time in days

  • log10_mass – log10 envelope mass in solar masses

  • log10_radius – log10 envelope radius in cm

  • log10_energy – log10 explosion energy in erg

  • nn – outer density power-law slope

  • delta – inner density power-law slope

  • f_nickel – nickel mass fraction

  • mej – total ejecta mass in solar masses

  • vej – ejecta velocity in km/s

  • kappa – optical opacity in cm^2/g

  • kappa_gamma – gamma-ray opacity in cm^2/g

Returns:

log10 of bolometric luminosity in erg/s

redback_jax.models.tde_analytical_bolometric(time, log10_l0, t_0_turn, mej, vej, kappa, kappa_gamma)[source]

Bolometric TDE light curve: t^{-5/3} fallback engine + Arnett diffusion.

Parameters:
  • time – source-frame time in days

  • log10_l0 – log10 of bolometric luminosity at 1 second in erg/s

  • t_0_turn – turn-on time in days

  • mej – ejecta mass in solar masses

  • vej – ejecta velocity in km/s

  • kappa – optical opacity in cm^2/g

  • kappa_gamma – gamma-ray opacity in cm^2/g

Returns:

log10 of bolometric luminosity in erg/s

redback_jax.models.metzger_kilonova_bolometric(time, mej, vej, beta, kappa, vmax=0.7, neutron_precursor=True)[source]

Bolometric kilonova light curve (Metzger 2017) with 200 shells and Barnes+16 thermalisation, solved via a sequential Euler ODE with jax.lax.scan.

Parameters:
  • time – source-frame time in days (must be strictly increasing, ≥2 points)

  • mej – ejecta mass in solar masses

  • vej – minimum ejecta velocity in units of c

  • beta – velocity power-law slope (M ∝ v^{-beta})

  • kappa – gray opacity in cm^2/g

  • vmax – maximum ejecta velocity in units of c (default 0.7)

  • neutron_precursor – include neutron precursor emission (default True)

Returns:

log10 of bolometric luminosity in erg/s

redback_jax.models.magnetar_boosted_kilonova_bolometric(time, mej, vej, beta, kappa, p0, bp, mass_ns, theta_pb, thermalisation_efficiency=1.0, vmax=0.7, neutron_precursor=True)[source]

Bolometric kilonova light curve with magnetar spin-down energy injection.

Parameters:
  • time – source-frame time in days (strictly increasing, ≥2 points)

  • mej – ejecta mass in solar masses

  • vej – minimum ejecta velocity in units of c

  • beta – velocity power-law slope

  • kappa – gray opacity in cm^2/g

  • p0 – initial spin period in milliseconds

  • bp – polar B-field in units of 10^14 G

  • mass_ns – neutron star mass in solar masses

  • theta_pb – angle between spin and B-field axes in radians

  • thermalisation_efficiency – fraction of magnetar luminosity thermalised (default 1.0)

  • vmax – maximum ejecta velocity in units of c (default 0.7)

  • neutron_precursor – include neutron precursor emission (default True)

Returns:

log10 of bolometric luminosity in erg/s

redback_jax.models.make_spectra_model(bolometric_fn)[source]

Wrap a bolometric model function to produce a full spectra model.

The returned function has signature:

spectra_model(redshift, lum_dist, vej, temperature_floor,
              features=NO_SED_FEATURES, **bolometric_kwargs)
-> namedtuple(time, lambdas, spectra)
Parameters:

bolometric_fn (callable) – Any function f(time_days, **kwargs) -> log10_lbol (log10 erg/s). time_days must be its first positional argument.

Returns:

A spectra model with the same photosphere/SED pipeline.

Return type:

callable

redback_jax.models.arnett_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)

Spectra model wrapping arnett_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fields time (days), lambdas (Angstrom), spectra (erg/s/cm^2/Angstrom)n

redback_jax.models.magnetar_powered_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)

Spectra model wrapping magnetar_powered_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fields time (days), lambdas (Angstrom), spectra (erg/s/cm^2/Angstrom)n

redback_jax.models.magnetar_nickel_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)

Spectra model wrapping magnetar_nickel_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fields time (days), lambdas (Angstrom), spectra (erg/s/cm^2/Angstrom)n

redback_jax.models.csm_interaction_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)

Spectra model wrapping csm_interaction_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fields time (days), lambdas (Angstrom), spectra (erg/s/cm^2/Angstrom)n

redback_jax.models.shock_cooling_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)

Spectra model wrapping shock_cooling_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fields time (days), lambdas (Angstrom), spectra (erg/s/cm^2/Angstrom)n

redback_jax.models.shocked_cocoon_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)

Spectra model wrapping shocked_cocoon_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fields time (days), lambdas (Angstrom), spectra (erg/s/cm^2/Angstrom)n

redback_jax.models.shock_cooling_and_arnett_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)

Spectra model wrapping shock_cooling_and_arnett_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fields time (days), lambdas (Angstrom), spectra (erg/s/cm^2/Angstrom)n

redback_jax.models.tde_analytical_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)

Spectra model wrapping tde_analytical_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fields time (days), lambdas (Angstrom), spectra (erg/s/cm^2/Angstrom)n

redback_jax.models.metzger_kilonova_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)

Spectra model wrapping metzger_kilonova_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fields time (days), lambdas (Angstrom), spectra (erg/s/cm^2/Angstrom)n

redback_jax.models.magnetar_boosted_kilonova_spectra(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs)

Spectra model wrapping magnetar_boosted_kilonova_bolometric.nnArgs:n redshift: source redshiftn lum_dist: luminosity distance in cmn vej: ejecta velocity in km/s (photosphere)n temperature_floor: floor temperature in Kn features: SEDFeatures (default NO_SED_FEATURES)n **bolometric_kwargs: forwarded to the bolometric functionnnReturns:n namedtuple with fields time (days), lambdas (Angstrom), spectra (erg/s/cm^2/Angstrom)n

class redback_jax.models.SEDFeatures(rest_wavelengths, sigmas, amplitudes, t_starts, t_ends, rise_times=2.0 * 24.0 * 3600.0, fall_times=5.0 * 24.0 * 3600.0)[source]

Bases: object

A class representing a spectral feature(s) in the SED. This implements the PyTree interface for JAX, so it can be used in JIT-compiled functions.

Parameters:
  • rest_wavelengths – Central wavelengths in Angstroms

  • sigmas – Gaussian widths in Angstroms

  • amplitudes – Amplitudes (negative=absorption, positive=emission), percentage of continuum (e.g., -0.4 = 40% absorption)

  • t_starts – Start time in seconds

  • t_ends – End time in seconds

  • rise_times – Rise times in seconds (default: 2 days)

  • fall_times – Fall times in seconds (default: 5 days)

classmethod from_feature_list(feature_list)[source]

Create SEDFeatures from a list of feature dictionaries.

Parameters:

feature_list – List of dictionaries, each with keys: ‘rest_wavelength’, ‘sigma’, ‘amplitude’, ‘t_start’, ‘t_end’, ‘rise_time’, ‘fall_time’

Returns:

SEDFeatures object

tree_flatten()[source]
classmethod tree_unflatten(aux_data, children)[source]
calculate_smooth_evolution(time)[source]

Calculate smooth transitions for a set of features.

Parameters:

time – Time array in seconds (source frame)

Returns:

time_factors array in [0, 1] representing the evolution of the feature over time

redback_jax.models.apply_sed_feature(features, base_flux, frequency, time)[source]

Apply spectral features completely vectorized.

Parameters:

features – SEDFeatures object

:param base_flux as a 2-d array (time, wavelength) in erg/s/Hz/cm^2 :param frequency: frequency to calculate in Hz - Must be same length as time array or a single number.

In source frame.

Parameters:

time – time array in seconds (source frame).

Returns:

modified flux_density as a 2-d array (time, wavelength) in erg/s/Hz/cm^2

redback_jax.models.register_model(name, fn)[source]

Register a model function under name.

Parameters:

name (str)

redback_jax.models.get_model(name)[source]

Retrieve a registered model by name, or raise KeyError.

Parameters:

name (str)

redback_jax.models.load_plugins()[source]

Discover and load all redback_jax.models entry-point plugins.

Utilities

Utility functions for redback-jax.

redback_jax.utils.luminosity_distance_cm(redshift, H0=PLANCK18_H0, Om0=PLANCK18_OM0)[source]

Luminosity distance in cm.

Parameters:
  • redshift (float) – Source redshift.

  • H0 (float, optional) – Hubble constant in km/s/Mpc (default: Planck18).

  • Om0 (float, optional) – Matter density parameter (default: Planck18).

Returns:

Luminosity distance in cm.

Return type:

float

Examples

>>> dl_cm = luminosity_distance_cm(0.01)

Inference

JAX-based Bayesian inference tools using BlackJAX.

Clean API

from redback_jax.inference import Prior, Uniform, LogUniform, Gaussian from redback_jax.inference import SpectralLikelihood, NestedSampler, NSResult

Legacy API (kept for backward compatibility)

from redback_jax.inference import (

SamplerResult, create_uniform_prior, create_gaussian_likelihood, run_nested_sampling, run_mcmc, fit_transient, HAS_BLACKJAX,

)

class redback_jax.inference.Uniform(minimum, maximum, *, name)[source]

Bases: _Distribution

Uniform prior between minimum and maximum.

Parameters:
  • minimum (float) – Support of the distribution.

  • maximum (float) – Support of the distribution.

  • name (str) – Parameter name (used as dict key in likelihood calls).

Examples

>>> p = Uniform(0.05, 0.2, name='f_nickel')
>>> p.log_prob(jnp.array(0.1))   # log(1 / (0.2 - 0.05))
property low: float
property high: float
sample(key)[source]
Parameters:

key (Array)

Return type:

Array

log_prob(value)[source]
Parameters:

value (Array)

Return type:

Array

class redback_jax.inference.LogUniform(minimum, maximum, *, name)[source]

Bases: _Distribution

Log-uniform (Jeffreys) prior between minimum and maximum.

The density is proportional to 1/x, normalised over [minimum, maximum].

Parameters:
  • minimum (float) – Support of the distribution (must be > 0).

  • maximum (float) – Support of the distribution (must be > 0).

  • name (str) – Parameter name.

Examples

>>> p = LogUniform(1e-2, 1e2, name='kappa')
property low: float
property high: float
sample(key)[source]
Parameters:

key (Array)

Return type:

Array

log_prob(value)[source]
Parameters:

value (Array)

Return type:

Array

class redback_jax.inference.Gaussian(mu, sigma, *, name, minimum=-jnp.inf, maximum=jnp.inf)[source]

Bases: _Distribution

Gaussian prior (truncated to finite support by the sampler’s hard bounds).

Parameters:
  • mu (float) – Mean and standard deviation.

  • sigma (float) – Mean and standard deviation.

  • name (str) – Parameter name.

  • minimum (float)

  • maximum (float)

property low: float
property high: float
sample(key)[source]
Parameters:

key (Array)

Return type:

Array

log_prob(value)[source]
Parameters:

value (Array)

Return type:

Array

class redback_jax.inference.Prior(distributions)[source]

Bases: object

A composite prior built from a list of 1-D distributions.

Parameters:

distributions (list of _Distribution) – One distribution per free parameter. The list order determines the column order of the parameter vector passed to the likelihood.

Examples

>>> prior = Prior([
...     Uniform(58580, 58620, name='t0'),
...     Uniform(0.05, 0.20,   name='f_nickel'),
...     Uniform(0.8,  2.0,    name='mej'),
...     Uniform(3000, 8000,   name='vej'),
... ])
>>> particles = prior.sample_n(jax.random.PRNGKey(0), 100)  # (100, 4)
>>> log_p = prior.log_prob(particles[0])                    # scalar
sample(key)[source]

Draw one sample; returns a dict {name: scalar}.

Parameters:

key (Array)

Return type:

dict

sample_n(key, n)[source]

Draw n samples; returns an array of shape (n, n_params).

Parameters:
Return type:

Array

log_prob(params)[source]

Evaluate the joint log-prior for a parameter vector of shape (n_params,).

This is JAX-traceable and can be used inside @jax.jit.

Parameters:

params (Array)

Return type:

Array

log_prob_fn()[source]

Return a pure JAX-traceable function params -> log_prior.

params_to_dict(params)[source]

Convert a parameter vector (n_params,) to a name-keyed dict.

Parameters:

params (Array)

Return type:

dict

dict_to_params(d)[source]

Convert a name-keyed dict to a parameter vector.

Parameters:

d (dict)

Return type:

Array

class redback_jax.inference.Likelihood(model, transient, fixed_params, t0_key='t0')[source]

Bases: object

Gaussian photometric likelihood using a spectra model pipeline.

Parameters:
  • model (str or callable) – Model name (e.g. 'arnett_spectra') or a callable with signature f(redshift, lum_dist, vej, temperature_floor, **kwargs) returning a namedtuple (time, lambdas, spectra).

  • transient (Transient) – Data container with .time, .y, .y_err, .bands.

  • fixed_params (dict) – Parameters held fixed during inference — everything the model needs that is not a free parameter in the prior.

  • t0_key (str or None, optional) – Name of the MJD explosion-time free parameter (default 't0'). When present in the prior the likelihood converts transient.time from observer-frame MJD to source-frame days automatically. Set to None if times are already in source-frame days.

class redback_jax.inference.NestedSampler(likelihood, prior, outdir='results/', n_live=125, n_delete=20, num_mcmc_steps_multiplier=5, termination_dlogz=-3.0, verbose=True)[source]

Bases: object

BlackJAX nested sampler with a clean redback-style interface.

Parameters:
  • likelihood (Likelihood) – A Likelihood instance.

  • prior (Prior) – Composite prior object.

  • outdir (str, optional) – Directory for output files. Created if it does not exist. Set to None to disable file output.

  • n_live (int, optional) – Number of live points (default 125).

  • n_delete (int, optional) – Number of points to remove per iteration (default 20).

  • num_mcmc_steps_multiplier (int, optional) – MCMC steps per iteration = n_params × multiplier (default 5).

  • termination_dlogz (float, optional) – Stop when logZ_live - logZ < termination_dlogz (default -3).

  • verbose (bool, optional) – Show a tqdm progress bar (default True).

Examples

>>> sampler = NestedSampler(likelihood, prior, outdir='results/')
>>> result  = sampler.run(jax.random.PRNGKey(42))
>>> result.summary()
run(key)[source]

Run nested sampling.

Parameters:

key (jax.Array) – JAX random key.

Returns:

Posterior samples and evidence estimate.

Return type:

NSResult

plot_corner(result, truth=None, filename=None, **kwargs)[source]

Make a corner plot using anesthetic.

Parameters:
  • result (NSResult) – Output of run().

  • truth (dict, optional) – True parameter values to mark on the plot.

  • filename (str, optional) – Path to save the figure. Defaults to {outdir}/corner.png.

class redback_jax.inference.NSResult(logZ, samples, dead, log_weights, param_names)[source]

Bases: object

Container for nested sampling results.

logZ

Log evidence estimate.

Type:

float

samples

Posterior samples as {name: jnp.ndarray}.

Type:

dict

dead

Raw dead-point pytree from BlackJAX (for expert use).

Type:

object

log_weights

Log importance weights (shape (n_dead,)).

Type:

jnp.ndarray

param_names

Ordered parameter names.

Type:

list of str

summary()[source]

Print a parameter summary table.

class redback_jax.inference.MCMCSampler(likelihood, prior, n_warmup=500, n_samples=2000, n_chains=4, step_size=0.05, verbose=True)[source]

Bases: object

BlackJAX NUTS sampler with a clean redback-style interface.

The log-posterior is log_likelihood + log_prior, evaluated in the original parameter space. A reflected boundary is used to keep samples inside the prior support.

Parameters:
  • likelihood (Likelihood) – A Likelihood instance.

  • prior (Prior) – Composite prior object.

  • n_warmup (int, optional) – Number of warmup (adaptation) steps per chain (default 500).

  • n_samples (int, optional) – Number of post-warmup samples per chain (default 2000).

  • n_chains (int, optional) – Number of independent chains (default 4).

  • step_size (float, optional) – Initial NUTS step size (default 0.05).

  • verbose (bool, optional) – Show a progress bar (default True).

Examples

>>> sampler = MCMCSampler(likelihood, prior, n_warmup=500, n_samples=2000)
>>> result  = sampler.run(jax.random.PRNGKey(0))
>>> result.summary()
run(key)[source]

Run MCMC.

Parameters:

key (jax.Array) – JAX random key.

Return type:

MCMCResult

class redback_jax.inference.MCMCResult(samples_per_chain, param_names)[source]

Bases: object

Container for MCMC results.

samples

Posterior samples as {name: jnp.ndarray} — shape (n_chains * n_samples,).

Type:

dict

samples_per_chain

Samples per chain as {name: jnp.ndarray} — shape (n_chains, n_samples).

Type:

dict

param_names

Ordered parameter names.

Type:

list of str

n_chains

Number of chains.

Type:

int

n_samples

Number of post-warmup samples per chain.

Type:

int

summary()[source]

Print a parameter summary table.

class redback_jax.inference.SamplerResult(samples, log_likelihoods, log_weights, log_evidence, log_evidence_error, metadata)[source]

Bases: NamedTuple

Results from nested sampling run.

Parameters:
samples

Dictionary mapping parameter names to sample arrays

Type:

dict

log_likelihoods

Log likelihood values for each sample

Type:

jnp.ndarray

log_weights

Log weights for each sample (for nested sampling)

Type:

jnp.ndarray

log_evidence

Log evidence estimate

Type:

float

log_evidence_error

Error on log evidence estimate

Type:

float

metadata

Additional metadata from the sampling run

Type:

dict

samples: Dict[str, Array]

Alias for field number 0

log_likelihoods: Array

Alias for field number 1

log_weights: Array

Alias for field number 2

log_evidence: float

Alias for field number 3

log_evidence_error: float

Alias for field number 4

metadata: Dict[str, Any]

Alias for field number 5

redback_jax.inference.create_uniform_prior(prior_bounds)[source]

Create a uniform prior function from parameter bounds.

Parameters:

prior_bounds (dict) – Dictionary mapping parameter names to (low, high) bounds

Returns:

Function that transforms unit hypercube to parameter space

Return type:

callable

redback_jax.inference.create_gaussian_likelihood(model_fn, observed_data, errors, reduce_fn=None)[source]

Create a Gaussian likelihood function.

Parameters:
  • model_fn (callable) – Function that takes parameter dict and returns model predictions

  • observed_data (jnp.ndarray) – Observed data array

  • errors (jnp.ndarray) – Error array (standard deviations)

  • reduce_fn (callable, optional) – Function to reduce data (e.g., for rescaling errors)

Returns:

Log-likelihood function

Return type:

callable

redback_jax.inference.run_nested_sampling(loglikelihood_fn, prior_bounds, n_particles=500, num_mcmc_steps=20, max_iterations=100, rng_key=None, verbose=True)[source]

Run Sequential Monte Carlo (SMC) sampling using BlackJAX.

This uses adaptive tempered SMC which provides evidence estimates similar to nested sampling. The algorithm gradually increases the temperature from the prior to the posterior while tracking the normalizing constant.

Parameters:
  • loglikelihood_fn (callable) – Log-likelihood function that takes parameter dict

  • prior_bounds (dict) – Dictionary mapping parameter names to (low, high) bounds

  • n_particles (int, optional) – Number of particles (default: 500)

  • num_mcmc_steps (int, optional) – Number of MCMC steps per iteration (default: 20)

  • max_iterations (int, optional) – Maximum number of temperature steps (default: 100)

  • rng_key (jax.Array, optional) – JAX random key (default: None, will create one)

  • verbose (bool, optional) – Print progress information (default: True)

Returns:

Results from the SMC sampling run with evidence estimate

Return type:

SamplerResult

Raises:

ImportError – If blackjax is not installed

redback_jax.inference.run_mcmc(loglikelihood_fn, prior_bounds, n_samples=10000, n_warmup=1000, n_chains=4, step_size=0.01, rng_key=None, verbose=True)[source]

Run MCMC sampling using BlackJAX’s NUTS sampler.

Parameters:
  • loglikelihood_fn (callable) – Log-likelihood function that takes parameter dict

  • prior_bounds (dict) – Dictionary mapping parameter names to (low, high) bounds

  • n_samples (int, optional) – Number of samples to draw (default: 10000)

  • n_warmup (int, optional) – Number of warmup/burnin steps (default: 1000)

  • n_chains (int, optional) – Number of parallel chains (default: 4)

  • step_size (float, optional) – Initial step size for NUTS (default: 0.01)

  • rng_key (jax.Array, optional) – JAX random key (default: None, will create one)

  • verbose (bool, optional) – Print progress information (default: True)

Returns:

Results from the MCMC run

Return type:

SamplerResult

Raises:

ImportError – If blackjax is not installed

redback_jax.inference.fit_transient(transient, model_fn, prior_bounds, sampler='nested', sampler_kwargs=None, rng_key=None, verbose=True)[source]

Fit a transient model to observational data.

This is the main high-level interface for parameter inference, following the redback API style.

Parameters:
  • transient (Transient) – Transient object with observational data

  • model_fn (callable) – Model function that takes parameter dict and returns model predictions. Should be compatible with the transient’s data structure.

  • prior_bounds (dict) – Dictionary mapping parameter names to (low, high) bounds

  • sampler (str, optional) – Sampler to use: “nested” for nested sampling, “mcmc” for NUTS (default: “nested”)

  • sampler_kwargs (dict, optional) – Additional keyword arguments for the sampler

  • rng_key (jax.Array, optional) – JAX random key

  • verbose (bool, optional) – Print progress information

Returns:

Results from the sampling run

Return type:

SamplerResult

Examples

>>> from redback_jax import Transient
>>> from redback_jax.sources import PrecomputedSpectraSource
>>> import jax.numpy as jnp
>>>
>>> # Create transient data
>>> transient = Transient(
...     name='test_sn',
...     times=jnp.array([0, 5, 10, 15, 20]),
...     magnitudes=jnp.array([18.0, 17.5, 17.0, 17.5, 18.0]),
...     magnitude_errors=jnp.array([0.1, 0.1, 0.1, 0.1, 0.1]),
...     bands=['g'] * 5
... )
>>>
>>> # Create model function
>>> source = PrecomputedSpectraSource.from_arnett_model(...)
>>> bridges, band_to_idx = source.prepare_bridges(['g'])
>>> band_indices = jnp.array([0, 0, 0, 0, 0])
>>>
>>> def model_fn(params):
...     return source.bandmag(params, None, transient.times,
...                           band_indices=band_indices,
...                           bridges=bridges,
...                           unique_bands=['g'])
>>>
>>> # Define priors
>>> prior_bounds = {
...     'amplitude': (0.1, 10.0),
... }
>>>
>>> # Run inference
>>> result = fit_transient(transient, model_fn, prior_bounds)
>>> print(f"Log evidence: {result.log_evidence:.2f}")
redback_jax.inference.to_anesthetic_samples(result)[source]

Convert SamplerResult to anesthetic NestedSamples.

Parameters:

result (SamplerResult) – Results from nested sampling

Returns:

Anesthetic samples object with posterior samples

Return type:

anesthetic.NestedSamples or anesthetic.Samples

Raises:

ImportError – If anesthetic is not installed

redback_jax.inference.summarize_result(result)[source]

Summarize sampling results with posterior statistics.

Parameters:

result (SamplerResult) – Results from sampling

Returns:

Dictionary with parameter statistics (mean, std, median, etc.)

Return type:

dict