Source code for redback_jax.models.kilonova

"""
JAX-based kilonova light-curve models.

References:
    Metzger 2017: https://ui.adsabs.harvard.edu/abs/2017LRR....20....3M/abstract
    Barnes & Kasen 2013/2016: thermalisation efficiency
    Kasen & Bildsten 2010, Yu et al. 2017: magnetar-boosted kilonova
"""

import math as _math

import jax
import jax.numpy as jnp
from jax import jit

from redback_jax.utils.citation_wrapper import citation_wrapper
from redback_jax.interaction_processes import barnes_kasen_16_thermalisation

# Physical constants as Python floats (avoids astropy float64 promotion)
_SOLAR_MASS     = 1.989e33   # g
_SPEED_OF_LIGHT = 2.998e10   # cm/s
_DAY_TO_S       = 86400.0    # s/day

# Magnetar log10 constants
_LOG10_EROT_COEFF = _math.log10(2.6e52)
_LOG10_TP_COEFF   = _math.log10(1.3e5)
_LOG10_2_FLOAT    = _math.log10(2.0)
_LOG10_MSUN       = _math.log10(_SOLAR_MASS)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

@jit
def _electron_fraction_from_kappa(kappa):
    """Approximate electron fraction from gray opacity."""
    return jnp.where(kappa < 1.0, 0.4,
           jnp.where(kappa < 5.0, 0.35,
           jnp.where(kappa < 20.0, 0.25, 0.1)))


@jit
def _rprocess_heating_rate(t_days, e_th):
    """r-process heating rate per unit mass (erg/s/g)."""
    t_sec  = t_days * _DAY_TO_S
    t0     = 1.3    # seconds
    sig    = 0.11   # seconds
    edotr_late  = 2.1e10 * e_th * jnp.maximum(t_days, 1e-10) ** (-1.3)
    edotr_early = (4.0e18
                   * jnp.power(0.5 - jnp.arctan((t_sec - t0) / sig) / jnp.pi, 1.3)
                   * e_th)
    return jnp.where(t_sec > t0, edotr_late, edotr_early)


@jit
def _magnetar_log10_lbol(t_sec, log10_p0_ms, log10_bp, mass_ns, theta_pb):
    """
    Dipole spin-down: returns log10(L) in erg/s (float32-safe).

    :param t_sec: time in seconds
    :param log10_p0_ms: log10 of spin period in milliseconds
    :param log10_bp: log10 of B-field in units of 10^14 G
    :param mass_ns: NS mass in solar masses
    :param theta_pb: spin–B-field angle in radians
    :return: log10 of luminosity in erg/s
    """
    log10_mass_ratio = jnp.log10(mass_ns / 1.4)
    log10_erot = (_LOG10_EROT_COEFF + 1.5 * log10_mass_ratio - 2.0 * log10_p0_ms)
    log10_tp   = (_LOG10_TP_COEFF - 2.0 * log10_bp + 2.0 * log10_p0_ms
                  + 1.5 * log10_mass_ratio
                  - jnp.log10(jnp.maximum(jnp.sin(theta_pb) ** 2, 1e-10)))
    tp = jnp.power(10.0, log10_tp)

    log10_L = (_LOG10_2_FLOAT + log10_erot - log10_tp
               - 2.0 * jnp.log10(1.0 + t_sec / tp))
    return log10_L


# ---------------------------------------------------------------------------
# Metzger kilonova (200-shell ODE via jax.lax.scan)
# Reference: Metzger 2017, redback _metzger_kilonova_model
# ---------------------------------------------------------------------------

[docs] @citation_wrapper('https://ui.adsabs.harvard.edu/abs/2017LRR....20....3M/abstract') def metzger_kilonova_bolometric(time, mej, vej, beta, kappa, vmax=0.7, neutron_precursor=True): """ Bolometric kilonova light curve (Metzger 2017) with 200 shells and Barnes+16 thermalisation, solved via a sequential Euler ODE with jax.lax.scan. :param time: source-frame time in days (must be strictly increasing, ≥2 points) :param mej: ejecta mass in solar masses :param vej: minimum ejecta velocity in units of c :param beta: velocity power-law slope (M ∝ v^{-beta}) :param kappa: gray opacity in cm^2/g :param vmax: maximum ejecta velocity in units of c (default 0.7) :param neutron_precursor: include neutron precursor emission (default True) :return: log10 of bolometric luminosity in erg/s """ return _metzger_kilonova_scan( time, mej, vej, beta, kappa, vmax, neutron_precursor)
def _metzger_kilonova_scan(time, mej, vej, beta, kappa, vmax, neutron_precursor): mass_len = 200 fp = time.dtype t_sec = time * _DAY_TO_S dt = jnp.diff(t_sec) # (T-1,) av, bv, dv = barnes_kasen_16_thermalisation(mej, vej) e_th = 0.36 * (jnp.exp(-av * time) + jnp.log1p(2.0 * bv * jnp.maximum(time, 1e-10) ** dv) / (2.0 * bv * jnp.maximum(time, 1e-10) ** dv)) edotr = _rprocess_heating_rate(time, e_th) # (T,) erg/s/g vel = jnp.linspace(vej, vmax, mass_len) m_arr = mej * (vel / vej) ** (-beta) # solar masses, (S,) v_m = vel * _SPEED_OF_LIGHT # cm/s, (S,) dm = jnp.abs(jnp.diff(m_arr)) # (S-1,) tau_neutron = 900.0 # seconds if neutron_precursor: Ye = _electron_fraction_from_kappa(kappa) # Metzger 2014 eq. 7: transition mass m_n ~ 1e-4 Msun separates # neutron-dominated outer layers (m << m_n) from r-process inner layers neutron_mass = 1e-4 * _SOLAR_MASS Xn0 = ((1.0 - 2.0 * Ye) * 2.0 * jnp.arctan(neutron_mass / (m_arr * _SOLAR_MASS)) / jnp.pi) Xr = 1.0 - Xn0 # E0 in Msun*(cm/s)^2 (safe: ~1e17, fits float32) E0 = 0.5 * m_arr * v_m ** 2 # (S,) # Normalize by E_scale to keep intermediates O(1) throughout scan # E_scale chosen as max(E0) in erg: E0_erg = E0 * solar_mass # log10(E_scale) = log10(max(E0)) + LOG10_MSUN log10_E_scale = jnp.maximum( jnp.log10(jnp.maximum(jnp.max(E0), jnp.array(1e-30, dtype=fp))) + _LOG10_MSUN, jnp.array(30.0, dtype=fp)) E_scale = jnp.power(jnp.array(10.0, dtype=fp), log10_E_scale) # erg # msun_per_E = solar_mass / E_scale (for converting Msun*(cm/s)^2 → normalized) msun_per_E = jnp.power(jnp.array(10.0, dtype=fp), jnp.array(_LOG10_MSUN, dtype=fp) - log10_E_scale) E0_n = E0 * msun_per_E # normalized: O(1) def _step(E_n, inputs): t_i, dt_i, edotr_i = inputs if neutron_precursor: Xn_t = Xn0 * jnp.exp(-t_i / tau_neutron) # Metzger 2014 eq.: ėn = 3.2e14 * Xn (linear, not quadratic) edotn = 3.2e14 * Xn_t # kappa_n: e-scattering from protons produced by neutron decay # (fraction that was Xn0 but has since decayed = Xn0 - Xn(t)) kappa_n = 0.4 * (1.0 - Xn_t - Xr) kap = kappa_n + kappa * Xr else: edotn = jnp.zeros(mass_len) kap = kappa * jnp.ones(mass_len) # Diffusion timescale per shell td_v = (kap[:-1] * m_arr[:-1] * _SOLAR_MASS * 3.0 / (4.0 * jnp.pi * v_m[:-1] * _SPEED_OF_LIGHT * t_i * beta)) # lum_n = E_n / (td_v + t_i*v/c) [normalized lum, erg/s / E_scale] lum_n = E_n[:-1] / (td_v + t_i * v_m[:-1] / _SPEED_OF_LIGHT) # heat in normalized units: edotr [erg/s/g] * dm [Msun] * msun [g/Msun] / E_scale heat_n = (edotr_i + edotn[:-1]) * dm * msun_per_E E_new_inner = E_n[:-1] + (heat_n - E_n[:-1] / t_i - lum_n) * dt_i E_new = jnp.concatenate([jnp.maximum(E_new_inner, jnp.array(0.0, dtype=fp)), E_n[-1:]]) # L_total = sum(lum_n * dm * msun_per_E * E_scale) = sum(lum_n * dm) * solar_mass # = sum(lum * dm) * solar_mass where lum = lum_n * E_scale # Normalized: L_n_total = sum(lum_n * dm) (to convert: * solar_mass / msun_per_E = E_scale) # But we can just store lum_n_total and multiply by E_scale at the end L_n_total = jnp.sum(lum_n * dm) # in 1/s (normalized by E_scale) tau = (m_arr[:-1] * _SOLAR_MASS * kap[:-1] / (4.0 * jnp.pi * (t_i * v_m[:-1]) ** 2)) tau_full = jnp.concatenate([tau, tau[-1:]]) pig = jnp.argmin(jnp.abs(tau_full - 1.0)) R_ph = v_m[pig] * t_i return E_new, (L_n_total, R_ph) _, (L_n_arr, R_arr) = jax.lax.scan( _step, E0_n, (t_sec[:-1], dt, edotr[:-1])) # log10(L) = log10(L_n) + log10_E_scale (no exponentiation, no float32 overflow) log10_L_n = jnp.log10(jnp.maximum(L_n_arr, jnp.array(1e-30, dtype=fp))) log10_L = log10_L_n + log10_E_scale log10_L = jnp.concatenate([log10_L, log10_L[-1:]]) return log10_L # --------------------------------------------------------------------------- # Magnetar-boosted kilonova (200-shell ODE + dipole spin-down injection) # Reference: Yu et al. 2013 # ---------------------------------------------------------------------------
[docs] @citation_wrapper('https://ui.adsabs.harvard.edu/abs/2013ApJ...776L..40Y/abstract') def magnetar_boosted_kilonova_bolometric(time, mej, vej, beta, kappa, p0, bp, mass_ns, theta_pb, thermalisation_efficiency=1.0, vmax=0.7, neutron_precursor=True): """ Bolometric kilonova light curve with magnetar spin-down energy injection. :param time: source-frame time in days (strictly increasing, ≥2 points) :param mej: ejecta mass in solar masses :param vej: minimum ejecta velocity in units of c :param beta: velocity power-law slope :param kappa: gray opacity in cm^2/g :param p0: initial spin period in milliseconds :param bp: polar B-field in units of 10^14 G :param mass_ns: neutron star mass in solar masses :param theta_pb: angle between spin and B-field axes in radians :param thermalisation_efficiency: fraction of magnetar luminosity thermalised (default 1.0) :param vmax: maximum ejecta velocity in units of c (default 0.7) :param neutron_precursor: include neutron precursor emission (default True) :return: log10 of bolometric luminosity in erg/s """ return _magnetar_kilonova_scan( time, mej, vej, beta, kappa, p0, bp, mass_ns, theta_pb, thermalisation_efficiency, vmax, neutron_precursor)
def _magnetar_kilonova_scan(time, mej, vej, beta, kappa, p0, bp, mass_ns, theta_pb, th_eff, vmax, neutron_precursor): mass_len = 200 fp = time.dtype t_sec = time * _DAY_TO_S dt = jnp.diff(t_sec) av, bv, dv = barnes_kasen_16_thermalisation(mej, vej) e_th = 0.36 * (jnp.exp(-av * time) + jnp.log1p(2.0 * bv * jnp.maximum(time, 1e-10) ** dv) / (2.0 * bv * jnp.maximum(time, 1e-10) ** dv)) edotr = _rprocess_heating_rate(time, e_th) # Magnetar luminosity in log10 (float32-safe) log10_p0 = jnp.log10(jnp.maximum(p0, jnp.array(1e-10, dtype=fp))) log10_bp = jnp.log10(jnp.maximum(bp, jnp.array(1e-10, dtype=fp))) log10_L_mag = _magnetar_log10_lbol(t_sec, log10_p0, log10_bp, mass_ns, theta_pb) vel = jnp.linspace(vej, vmax, mass_len) m_arr = mej * (vel / vej) ** (-beta) v_m = vel * _SPEED_OF_LIGHT dm = jnp.abs(jnp.diff(m_arr)) tau_neutron = 900.0 if neutron_precursor: Ye = _electron_fraction_from_kappa(kappa) neutron_mass = 1e-4 * _SOLAR_MASS Xn0 = ((1.0 - 2.0 * Ye) * 2.0 * jnp.arctan(neutron_mass / (m_arr * _SOLAR_MASS)) / jnp.pi) Xr = 1.0 - Xn0 E0 = 0.5 * m_arr * v_m ** 2 # Msun*(cm/s)^2, safe in float32 # Use log10_L_mag[0] as energy scale (magnetar dominates) log10_E_scale = jnp.maximum(log10_L_mag[0], jnp.array(30.0, dtype=fp)) E_scale = jnp.power(jnp.array(10.0, dtype=fp), log10_E_scale) msun_per_E = jnp.power(jnp.array(10.0, dtype=fp), jnp.array(_LOG10_MSUN, dtype=fp) - log10_E_scale) E0_n = E0 * msun_per_E def _step(E_n, inputs): t_i, dt_i, edotr_i, log10_L_mag_i = inputs if neutron_precursor: Xn_t = Xn0 * jnp.exp(-t_i / tau_neutron) edotn = 3.2e14 * Xn_t kappa_n = 0.4 * (1.0 - Xn_t - Xr) kap = kappa_n + kappa * Xr else: edotn = jnp.zeros(mass_len) kap = kappa * jnp.ones(mass_len) td_v = (kap[:-1] * m_arr[:-1] * _SOLAR_MASS * 3.0 / (4.0 * jnp.pi * v_m[:-1] * _SPEED_OF_LIGHT * t_i * beta)) lum_n = E_n[:-1] / (td_v + t_i * v_m[:-1] / _SPEED_OF_LIGHT) heat_n = (edotr_i + edotn[:-1]) * dm * msun_per_E # Magnetar injection (normalized): L_mag / E_scale = 10^(log10_L_mag - log10_E_scale) L_mag_n = jnp.power(jnp.array(10.0, dtype=fp), log10_L_mag_i - log10_E_scale) mag_n = jnp.concatenate([ jnp.array([th_eff * L_mag_n], dtype=fp), jnp.zeros(mass_len - 2, dtype=fp), ]) E_new_inner = E_n[:-1] + (heat_n + mag_n - E_n[:-1] / t_i - lum_n) * dt_i E_new = jnp.concatenate([jnp.maximum(E_new_inner, jnp.array(0.0, dtype=fp)), E_n[-1:]]) L_n_total = jnp.sum(lum_n * dm) tau = (m_arr[:-1] * _SOLAR_MASS * kap[:-1] / (4.0 * jnp.pi * (t_i * v_m[:-1]) ** 2)) tau_full = jnp.concatenate([tau, tau[-1:]]) pig = jnp.argmin(jnp.abs(tau_full - 1.0)) R_ph = v_m[pig] * t_i return E_new, (L_n_total, R_ph) _, (L_n_arr, _) = jax.lax.scan( _step, E0_n, (t_sec[:-1], dt, edotr[:-1], log10_L_mag[:-1])) log10_L_n = jnp.log10(jnp.maximum(L_n_arr, jnp.array(1e-30, dtype=fp))) log10_L = log10_L_n + log10_E_scale log10_L = jnp.concatenate([log10_L, log10_L[-1:]]) return log10_L