Source code for redback_jax.inference.mcmc_sampler

"""
MCMC sampler for redback-jax using BlackJAX NUTS.

Usage::

    from redback_jax.inference import Prior, Uniform, Likelihood, MCMCSampler
    import jax

    prior = Prior([
        Uniform(58580, 58620, name='t0'),
        Uniform(0.05,  0.20,  name='f_nickel'),
        Uniform(0.8,   2.0,   name='mej'),
        Uniform(3000,  8000,  name='vej'),
    ])

    likelihood = Likelihood(
        model='arnett_spectra',
        transient=transient,
        fixed_params=fixed,
    )

    sampler = MCMCSampler(likelihood, prior, n_warmup=500, n_samples=2000)
    result  = sampler.run(jax.random.PRNGKey(0))
    result.summary()
"""

import jax
import jax.numpy as jnp
import numpy as np

try:
    import blackjax
    HAS_BLACKJAX = True
except ImportError:
    HAS_BLACKJAX = False

try:
    import tqdm as _tqdm
    HAS_TQDM = True
except ImportError:
    HAS_TQDM = False


[docs] class MCMCResult: """Container for MCMC results. Attributes ---------- samples : dict Posterior samples as ``{name: jnp.ndarray}`` — shape ``(n_chains * n_samples,)``. samples_per_chain : dict Samples per chain as ``{name: jnp.ndarray}`` — shape ``(n_chains, n_samples)``. param_names : list of str Ordered parameter names. n_chains : int Number of chains. n_samples : int Number of post-warmup samples per chain. """ def __init__(self, samples_per_chain, param_names): self.param_names = param_names self.n_chains, self.n_samples = next(iter(samples_per_chain.values())).shape self.samples_per_chain = samples_per_chain self.samples = { name: arr.reshape(-1) for name, arr in samples_per_chain.items() }
[docs] def summary(self): """Print a parameter summary table.""" print(f"\n{'Param':<14} {'Mean':>12} {'Std':>10} {'q16':>10} {'q84':>10}") print("-" * 58) for name in self.param_names: s = self.samples[name] mu = float(jnp.mean(s)) std = float(jnp.std(s)) q16 = float(jnp.percentile(s, 16)) q84 = float(jnp.percentile(s, 84)) print(f"{name:<14} {mu:>12.4f} {std:>10.4f} {q16:>10.4f} {q84:>10.4f}")
def __repr__(self) -> str: n = next(iter(self.samples.values())).shape[0] return f"MCMCResult(n_chains={self.n_chains}, n_samples_per_chain={self.n_samples}, total={n})"
[docs] class MCMCSampler: """BlackJAX NUTS sampler with a clean redback-style interface. The log-posterior is ``log_likelihood + log_prior``, evaluated in the original parameter space. A reflected boundary is used to keep samples inside the prior support. Parameters ---------- likelihood : Likelihood A :class:`~redback_jax.inference.Likelihood` instance. prior : Prior Composite prior object. n_warmup : int, optional Number of warmup (adaptation) steps per chain (default 500). n_samples : int, optional Number of post-warmup samples per chain (default 2000). n_chains : int, optional Number of independent chains (default 4). step_size : float, optional Initial NUTS step size (default 0.05). verbose : bool, optional Show a progress bar (default True). Examples -------- >>> sampler = MCMCSampler(likelihood, prior, n_warmup=500, n_samples=2000) >>> result = sampler.run(jax.random.PRNGKey(0)) >>> result.summary() """ def __init__( self, likelihood, prior, n_warmup: int = 500, n_samples: int = 2000, n_chains: int = 4, step_size: float = 0.05, verbose: bool = True, ): if not HAS_BLACKJAX: raise ImportError( "blackjax is required for MCMC sampling.\n" "Install with: pip install blackjax" ) self.likelihood = likelihood self.prior = prior self.n_warmup = n_warmup self.n_samples = n_samples self.n_chains = n_chains self.step_size = step_size self.verbose = verbose log_prior_fn = prior.log_prob_fn() log_like_fn = likelihood._make_log_likelihood(prior) # Log-posterior: returns -inf outside prior support automatically # because log_prior_fn returns -inf there. def _log_posterior(params): return log_like_fn(params) + log_prior_fn(params) self._log_posterior = _log_posterior # ------------------------------------------------------------------
[docs] def run(self, key: jax.Array) -> MCMCResult: """Run MCMC. Parameters ---------- key : jax.Array JAX random key. Returns ------- MCMCResult """ inverse_mass_matrix = jnp.ones(self.prior.n_params) nuts = blackjax.nuts(self._log_posterior, self.step_size, inverse_mass_matrix) # Draw initial positions from the prior key, init_key = jax.random.split(key) init_positions = self.prior.sample_n(init_key, self.n_chains) # (n_chains, n_params) @jax.jit def one_step(state, step_key): new_state, info = nuts.step(step_key, state) return new_state, new_state.position if self.verbose: print(f"MCMC: {self.n_chains} chains, {self.n_warmup} warmup + " f"{self.n_samples} samples, device: {jax.devices()[0]}") all_chains = [] for chain_idx in range(self.n_chains): key, chain_key = jax.random.split(key) state = nuts.init(init_positions[chain_idx]) # Warmup warmup_keys = jax.random.split(chain_key, self.n_warmup) for wk in warmup_keys: state, _ = one_step(state, wk) # Sample via lax.scan for efficiency key, sample_key = jax.random.split(key) sample_keys = jax.random.split(sample_key, self.n_samples) final_state, positions = jax.lax.scan( lambda s, k: one_step(s, k), state, sample_keys, ) all_chains.append(positions) # (n_samples, n_params) if self.verbose: print(f" Chain {chain_idx + 1}/{self.n_chains} done") # all_chains: list of (n_samples, n_params) → stack to (n_chains, n_samples, n_params) stacked = jnp.stack(all_chains, axis=0) # (n_chains, n_samples, n_params) samples_per_chain = { name: stacked[:, :, i] for i, name in enumerate(self.prior.names) } return MCMCResult(samples_per_chain, self.prior.names)