Source code for redback_jax.models.spectra_model

"""
Generic factory for converting any bolometric model into a spectra model.

Usage::

    from redback_jax.models.spectra_model import make_spectra_model
    from redback_jax.models import magnetar_powered_bolometric

    magnetar_powered_spectra = make_spectra_model(magnetar_powered_bolometric)

    out = magnetar_powered_spectra(
        redshift=0.1,
        lum_dist=dl_cm,
        vej=10000.0,
        temperature_floor=3000.0,
        # remaining kwargs forwarded verbatim to the bolometric function:
        p0=2.0, bp=1.0, mass_ns=1.4, theta_pb=0.3,
        mej=1.0, kappa=0.1, kappa_gamma=10.0,
    )

All operations stay in log10 / temperature space for float32 safety.
The bolometric function is called and its output (linear erg/s) is converted
to log10 immediately. If the bolometric function overflows float32 the returned
spectra will contain NaN/Inf, but for typical parameter ranges float32 is safe.
"""

import math as _math
from collections import namedtuple

import jax
import jax.numpy as jnp
from jax import jit

from redback_jax.conversions import calc_kcorrected_properties, lambda_to_nu
from redback_jax.models.sed_features import NO_SED_FEATURES, apply_sed_feature
from redback_jax.photosphere import compute_temperature_floor_log10

# Physical constants as Python floats
_H      = 6.626e-27   # erg s
_C      = 2.998e10    # cm/s
_KB     = 1.381e-16   # erg/K
_C_ANG  = 2.998e18    # Angstrom/s  (speed of light)

_LOG10_2PI_H  = _math.log10(2.0 * _math.pi * _H)
_LOG10_C2     = _math.log10(_C ** 2)
_LOG10_H_OVER_KB = _math.log10(_H / _KB)


[docs] def make_spectra_model(bolometric_fn): """ Wrap a bolometric model function to produce a full spectra model. The returned function has signature:: spectra_model(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs) -> namedtuple(time, lambdas, spectra) Parameters ---------- bolometric_fn : callable Any function ``f(time_days, **kwargs) -> log10_lbol`` (log10 erg/s). ``time_days`` must be its first positional argument. Returns ------- callable A spectra model with the same photosphere/SED pipeline. """ import inspect as _inspect _bolo_accepts_vej = 'vej' in _inspect.signature(bolometric_fn).parameters def spectra_model(redshift, lum_dist, vej, temperature_floor, features=NO_SED_FEATURES, **bolometric_kwargs): return _spectra_model_impl( bolometric_fn, redshift, lum_dist, vej, temperature_floor, features, bolometric_kwargs, _bolo_accepts_vej, ) spectra_model.__doc__ = ( f"Spectra model wrapping ``{bolometric_fn.__name__}``.\\n\\n" "Args:\\n" " redshift: source redshift\\n" " lum_dist: luminosity distance in cm\\n" " vej: ejecta velocity in km/s (photosphere)\\n" " temperature_floor: floor temperature in K\\n" " features: SEDFeatures (default NO_SED_FEATURES)\\n" " **bolometric_kwargs: forwarded to the bolometric function\\n\\n" "Returns:\\n" " namedtuple with fields ``time`` (days), ``lambdas`` (Angstrom), " "``spectra`` (erg/s/cm^2/Angstrom)\\n" ) spectra_model.__name__ = bolometric_fn.__name__ + "_spectra" return spectra_model
def _spectra_model_impl(bolometric_fn, redshift, lum_dist, vej, temperature_floor, features, bolometric_kwargs, bolo_accepts_vej=False): """Inner implementation — log10-space SED pipeline for float32 safety.""" lambda_observer_frame = jnp.geomspace(100.0, 60000.0, 100) time_temp = jnp.geomspace(0.1, 3000.0, 3000) # days time_observer_frame = time_temp * (1.0 + redshift) frequency, time = calc_kcorrected_properties( frequency=lambda_to_nu(lambda_observer_frame), redshift=redshift, time=time_observer_frame, ) # Bolometric luminosity in log10 erg/s (returned directly by all bolometric fns) # If the bolometric function also accepts vej (e.g. arnett_bolometric uses it # for diffusion), forward it — unless the caller already supplied it explicitly. if bolo_accepts_vej: bolometric_kwargs = {'vej': vej, **bolometric_kwargs} log10_lbol = bolometric_fn(time, **bolometric_kwargs) # Temperature and log10(radius) — both float32-safe T_ph, log10_r_ph = compute_temperature_floor_log10( time=time, log10_luminosity=log10_lbol, vej=vej, temperature_floor=temperature_floor, ) # Blackbody flux density in log10 space: # F_nu = 2*pi*h*nu^3 * R^2 / (dl^2 * c^2) / expm1(h*nu / kB*T) # log10(F_nu) = log10(2pi*h) + 3*log10(nu) + 2*log10_r_ph # - 2*log10(dl) - 2*log10(c) - log10(expm1(h*nu/(kB*T))) fp = time.dtype nu = frequency.astype(fp) # (Nfreq,) dl = jnp.asarray(lum_dist, dtype=fp) log10_dl = jnp.log10(jnp.maximum(dl, jnp.array(1.0, dtype=fp))) # x = h*nu / (kB * T) — O(1) quantity, safe in float32 # T_ph: (Ntime,), nu: (Nfreq,) x = (_H / _KB) * nu[None, :] / jnp.maximum(T_ph[:, None], jnp.array(1.0, dtype=fp)) x = jnp.clip(x, jnp.array(1e-10, dtype=fp), jnp.array(80.0, dtype=fp)) log10_Fnu = (jnp.array(_LOG10_2PI_H, dtype=fp) + 3.0 * jnp.log10(nu[None, :]) + 2.0 * log10_r_ph[:, None] - 2.0 * log10_dl - jnp.array(_LOG10_C2, dtype=fp) - jnp.log10(jnp.expm1(x))) # (Ntime, Nfreq) spectral_flux_density = jnp.power(jnp.array(10.0, dtype=fp), log10_Fnu) # erg/s/Hz/cm^2 spectral_flux_density = apply_sed_feature( features, spectral_flux_density, frequency, time) # Convert erg/s/Hz/cm^2 → erg/s/cm^2/Angstrom, then correct for bandwidth stretching lam = lambda_observer_frame.astype(fp) spectra = spectral_flux_density * jnp.array(_C_ANG, dtype=fp) / (lam[None, :] ** 2) spectra = spectra * jnp.asarray(1.0 + redshift, dtype=fp) return namedtuple('output', ['time', 'lambdas', 'spectra'])( time=time_observer_frame, lambdas=lambda_observer_frame, spectra=spectra, )