Source code for redback_jax.inference.likelihood

"""
Likelihood for redback-jax inference.

Usage::

    from redback_jax.inference import Likelihood, Prior, Uniform, NestedSampler

    prior = Prior([
        Uniform(58580, 58620, name='t0'),
        Uniform(0.05,  0.20,  name='f_nickel'),
        Uniform(0.8,   2.0,   name='mej'),
        Uniform(3000,  8000,  name='vej'),
    ])

    likelihood = Likelihood(
        model        = 'arnett_spectra',
        transient    = transient,
        fixed_params = {
            'redshift':          0.01,
            'lum_dist':          dl_cm,
            'temperature_floor': 5000.0,
            'kappa':             0.07,
            'kappa_gamma':       0.1,
        },
    )

    result = NestedSampler(likelihood, prior, outdir='results/').run(key)

The ``transient`` object must have ``.time``, ``.y``, ``.y_err``, ``.bands``.

``fixed_params`` must supply everything the model needs that is *not* in the
prior.  Free parameters automatically take precedence over fixed ones.

If ``'t0'`` is a free parameter, it is treated as an MJD explosion time and
used to shift ``transient.time`` into source-frame days automatically.
"""

import jax
import jax.numpy as jnp
from typing import Dict, Optional

from jax_supernovae.bandpasses import register_all_bandpasses
from jax_supernovae.timeseries import timeseries_multiband_flux

from redback_jax.models import get_model


[docs] class Likelihood: """Gaussian photometric likelihood using a spectra model pipeline. Parameters ---------- model : str or callable Model name (e.g. ``'arnett_spectra'``) or a callable with signature ``f(redshift, lum_dist, vej, temperature_floor, **kwargs)`` returning a namedtuple ``(time, lambdas, spectra)``. transient : Transient Data container with ``.time``, ``.y``, ``.y_err``, ``.bands``. fixed_params : dict Parameters held fixed during inference — everything the model needs that is not a free parameter in the prior. t0_key : str or None, optional Name of the MJD explosion-time free parameter (default ``'t0'``). When present in the prior the likelihood converts ``transient.time`` from observer-frame MJD to source-frame days automatically. Set to ``None`` if times are already in source-frame days. """ def __init__( self, model, transient, fixed_params: Dict, t0_key: Optional[str] = 't0', ): register_all_bandpasses() if isinstance(model, str): self._model_fn = get_model(model) self.model_name = model else: self._model_fn = model self.model_name = getattr(model, '__name__', repr(model)) self.transient = transient self.fixed_params = dict(fixed_params) self.t0_key = t0_key self._obs_times = jnp.asarray(transient.time) self._obs_mags = jnp.asarray(transient.y) self._obs_errs = jnp.asarray(transient.y_err) bands_raw = list(transient.bands) self._unique_bands = list(dict.fromkeys(bands_raw)) self._obs_band_idx = None # built lazily in _build_bridges self._bridges = None self._band_to_idx = None self._redshift_const = float(self.fixed_params.get('redshift', 0.0)) self._bands_raw = bands_raw def _build_bridges(self, prior): """Precompute bandpass bridges using prior midpoints for free params.""" from jax_supernovae.bandpasses import get_bandpass from jax_supernovae.salt3 import precompute_bandflux_bridge # Fill in free param midpoints so the dummy model call has all args dummy_kwargs = dict(self.fixed_params) for d in prior.distributions: if d.name != self.t0_key: dummy_kwargs.setdefault(d.name, 0.5 * (d.low + d.high)) self._dummy_out = self._model_fn(**dummy_kwargs) self._bridges = tuple( precompute_bandflux_bridge(get_bandpass(b)) for b in self._unique_bands ) band_to_idx = {b: i for i, b in enumerate(self._unique_bands)} self._obs_band_idx = jnp.array( [band_to_idx[b] for b in self._bands_raw] ) def _make_log_likelihood(self, prior): """Return a JIT-compiled log-likelihood function ``(params,) -> scalar``.""" self._build_bridges(prior) model_fn = self._model_fn fixed_params = self.fixed_params obs_times = self._obs_times obs_mags = self._obs_mags obs_errs = self._obs_errs obs_band_idx = self._obs_band_idx bridges = self._bridges unique_bands = self._unique_bands redshift = self._redshift_const t0_key = self.t0_key names = prior.names # Static scalars from the dummy run _zero_before = True _minphase = float(self._dummy_out.time[0]) # zp=0 per observation: timeseries_multiband_flux normalises each flux # by the per-band AB zpbandflux so that -2.5*log10(result) = AB mag. _zps = jnp.zeros(len(self._obs_mags)) @jax.jit def _log_like(params: jnp.ndarray) -> jnp.ndarray: param_dict = {n: params[i] for i, n in enumerate(names)} if t0_key is not None and t0_key in param_dict: t0 = param_dict.pop(t0_key) t_source = (obs_times - t0) / (1.0 + redshift) else: t_source = obs_times model_kwargs = {**fixed_params, **param_dict} out = model_fn(**model_kwargs) # timeseries_multiband_flux with zps=0 returns flux/zpbandflux per # band, so -2.5*log10 gives the correct AB magnitude directly. norm_fluxes = timeseries_multiband_flux( t_source, bridges, obs_band_idx, out.time, out.lambdas, out.spectra, 1.0, _zero_before, _minphase, time_degree=1, zps=_zps, zpsys='ab', ) model_mags = -2.5 * jnp.log10(norm_fluxes + 1e-100) return -0.5 * jnp.sum(((obs_mags - model_mags) / obs_errs) ** 2) return _log_like def __repr__(self) -> str: return ( f"Likelihood(model={self.model_name!r}, " f"n_obs={len(self._obs_mags)}, " f"bands={self._unique_bands})" )