"""
JAX-friendly shock-powered transient models.
References:
Piro 2021 (shock cooling): https://ui.adsabs.harvard.edu/abs/2021ApJ...909..209P/abstract
Piro & Kollmeier 2018 (shocked cocoon): https://ui.adsabs.harvard.edu/abs/2018ApJ...855..103P/abstract
"""
import math as _math
import jax.numpy as jnp
from jax import jit
from redback_jax.utils.citation_wrapper import citation_wrapper
# Physical constants as Python floats (avoids astropy float64 promotion)
_SOLAR_MASS = 1.989e33 # g
_SPEED_OF_LIGHT = 2.998e10 # cm/s
_KM_CGS = 1.0e5 # cm/km
_DAY_TO_S = 86400.0 # s/day
# Diffusion constant for shocked cocoon (Python float, no JAX promotion):
# diff_const = Msun / (4*pi * c * km_cgs)
_DIFF_CONST = _SOLAR_MASS / (4.0 * _math.pi * _SPEED_OF_LIGHT * _KM_CGS)
_LOG10_DIFF_CONST = _math.log10(_DIFF_CONST)
_LOG10_SOLAR_MASS = _math.log10(_SOLAR_MASS)
_LOG10_KM_CGS = _math.log10(_KM_CGS)
_LOG10_DAY_TO_S = _math.log10(_DAY_TO_S)
_LOG10_C_CGS = _math.log10(_SPEED_OF_LIGHT)
_LOG10_E = _math.log10(_math.e)
# ---------------------------------------------------------------------------
# Shock cooling (Piro 2021)
# ---------------------------------------------------------------------------
[docs]
@citation_wrapper('https://ui.adsabs.harvard.edu/abs/2021ApJ...909..209P/abstract')
@jit
def shock_cooling_bolometric(time, log10_mass, log10_radius, log10_energy, nn, delta, kappa):
"""
Bolometric shock-cooling light curve following Piro (2021).
All large quantities (mass, radius, energy) are passed as log10 to stay
float32-safe throughout; all intermediate computations stay in log10 space.
:param time: source-frame time in days
:param log10_mass: log10 envelope mass in solar masses
:param log10_radius: log10 envelope radius in cm
:param log10_energy: log10 explosion energy in erg
:param nn: outer density power-law slope
:param delta: inner density power-law slope
:param kappa: opacity in cm^2/g
:return: log10 of bolometric luminosity in erg/s
"""
fp = time.dtype
n = nn
kk_pow = (n - 3.0) * (3.0 - delta) / (4.0 * _math.pi * (n - delta))
log10_kappa = jnp.log10(jnp.maximum(jnp.asarray(kappa, dtype=fp), jnp.array(1e-30, dtype=fp)))
# log10(mass in grams)
log10_mass_g = log10_mass + jnp.array(_LOG10_SOLAR_MASS, dtype=fp)
# log10(vt): vt^2 = coeff * E/m => log10(vt) = 0.5*(log10(coeff) + log10_E - log10_m)
log10_vt_coeff = jnp.log10(jnp.array(
abs((n - 5.0) * (5.0 - delta) / ((n - 3.0) * (3.0 - delta)) * 2.0), dtype=fp))
log10_vt = 0.5 * (log10_vt_coeff + log10_energy - log10_mass_g)
# log10(td in days): td^2 = (3*kappa*kk_pow*m) / ((n-1)*vt*c)
log10_kk_pow = jnp.log10(jnp.array(abs(kk_pow), dtype=fp))
log10_td_const = (jnp.array(_math.log10(3.0), dtype=fp) + log10_kappa + log10_kk_pow
- jnp.log10(jnp.array(abs(n - 1.0), dtype=fp))
- jnp.array(_LOG10_C_CGS, dtype=fp)
- jnp.array(_LOG10_DAY_TO_S, dtype=fp))
log10_td = 0.5 * (log10_td_const + log10_mass_g - log10_vt)
td = jnp.power(jnp.array(10.0, dtype=fp), log10_td) # days
t = jnp.maximum(time, jnp.array(1.0 / _DAY_TO_S, dtype=fp))
# log10_prefactor = log10(pi*(n-1)/(3*|n-5|) * c / kappa) + log10_R + 2*log10_vt
log10_prefactor = (jnp.log10(jnp.array(_math.pi * abs(n - 1.0) / (3.0 * abs(n - 5.0)), dtype=fp))
+ jnp.array(_LOG10_C_CGS, dtype=fp)
- log10_kappa
+ log10_radius
+ 2.0 * log10_vt)
log10_t = jnp.log10(jnp.maximum(t, jnp.array(1e-30, dtype=fp)))
log10_td_val = jnp.log10(jnp.maximum(td, jnp.array(1e-30, dtype=fp)))
# Pre-peak: lbol = prefactor * (td/t)^(4/(n-2))
log10_lbol_pre = log10_prefactor + (4.0 / (n - 2.0)) * (log10_td_val - log10_t)
# Post-peak: lbol = prefactor * exp(-0.5*(t^2/td^2 - 1))
exponent = jnp.clip(-0.5 * (t ** 2 / jnp.maximum(td, jnp.array(1e-30, dtype=fp)) ** 2 - 1.0),
jnp.array(-87.0, dtype=fp), jnp.array(87.0, dtype=fp))
log10_lbol_post = log10_prefactor + exponent * jnp.array(_LOG10_E, dtype=fp)
return jnp.where(t < td, log10_lbol_pre, log10_lbol_post)
# ---------------------------------------------------------------------------
# Shocked cocoon (Piro & Kollmeier 2018) — full log10-space rewrite
# ---------------------------------------------------------------------------
[docs]
@citation_wrapper('https://ui.adsabs.harvard.edu/abs/2018ApJ...855..103P/abstract')
@jit
def shocked_cocoon_bolometric(time, mej, vej, eta, tshock,
shocked_fraction, cos_theta_cocoon, kappa):
"""
Bolometric light curve of a shocked jet cocoon (Piro & Kollmeier 2018).
All large intermediate quantities computed in log10 space for float32 safety.
:param time: source-frame time in days
:param mej: ejecta mass in solar masses
:param vej: ejecta velocity in units of c (speed of light)
:param eta: ejecta density power-law slope
:param tshock: shock time in seconds
:param shocked_fraction: fraction of ejecta mass that is shocked
:param cos_theta_cocoon: cosine of cocoon opening half-angle
:param kappa: gray opacity in cm^2/g
:return: log10 of bolometric luminosity in erg/s
"""
fp = time.dtype
# Cast scalar inputs to fp
mej_f = jnp.asarray(mej, dtype=fp)
vej_f = jnp.asarray(vej, dtype=fp)
eta_f = jnp.asarray(eta, dtype=fp)
ts_f = jnp.asarray(tshock, dtype=fp)
sf_f = jnp.asarray(shocked_fraction, dtype=fp)
ctc_f = jnp.asarray(cos_theta_cocoon, dtype=fp)
kappa_f = jnp.asarray(kappa, dtype=fp)
# vej_kms = vej * c / km_cgs (km/s)
# log10(vej_kms) = log10(vej) + log10(c/km_cgs)
_log10_c_over_km = _LOG10_C_CGS - _LOG10_KM_CGS
log10_vej_kms = jnp.log10(jnp.maximum(vej_f, jnp.array(1e-30, dtype=fp))) + jnp.array(_log10_c_over_km, dtype=fp)
# shocked_mass = mej * shocked_fraction (solar masses)
log10_shocked_mass = (jnp.log10(jnp.maximum(mej_f, jnp.array(1e-30, dtype=fp)))
+ jnp.log10(jnp.maximum(sf_f, jnp.array(1e-30, dtype=fp))))
# tau_diff^2 = _DIFF_CONST * kappa * shocked_mass / vej_kms (days^2 after /day_s^2)
# log10(tau_diff) = 0.5*(log10_DIFF_CONST + log10_kappa + log10_shocked_mass - log10_vej_kms - 2*log10_day_s)
log10_tau_diff = 0.5 * (jnp.array(_LOG10_DIFF_CONST, dtype=fp)
+ jnp.log10(jnp.maximum(kappa_f, jnp.array(1e-30, dtype=fp)))
+ log10_shocked_mass
- log10_vej_kms
- jnp.array(2.0 * _LOG10_DAY_TO_S, dtype=fp))
tau_diff = jnp.power(jnp.array(10.0, dtype=fp), log10_tau_diff) # days
# t_thin = sqrt(c_kms / vej_kms) * tau_diff
# log10(t_thin) = 0.5*(log10_c_over_km - log10_vej_kms) + log10_tau_diff
log10_t_thin = (0.5 * (jnp.array(_log10_c_over_km, dtype=fp) - log10_vej_kms)
+ log10_tau_diff)
t_thin = jnp.power(jnp.array(10.0, dtype=fp), log10_t_thin) # days
# rshock = tshock * c (cm)
# log10(rshock) = log10(tshock) + log10(c)
log10_rshock = (jnp.log10(jnp.maximum(ts_f, jnp.array(1e-30, dtype=fp)))
+ jnp.array(_LOG10_C_CGS, dtype=fp))
# theta = arccos(cos_theta_cocoon)
theta = jnp.arccos(jnp.clip(ctc_f, jnp.array(-1.0, dtype=fp), jnp.array(1.0, dtype=fp)))
# l0 = (theta^2/2)^(1/3) * shocked_mass_g * vej_cm_s * rshock / tau_diff_s^2
# Compute log10(l0) to avoid overflow
# log10((theta^2/2)^(1/3)) = (2*log10(theta) - log10(2)) / 3
log10_theta_fac = (2.0 * jnp.log10(jnp.maximum(theta, jnp.array(1e-10, dtype=fp)))
- jnp.array(_math.log10(2.0), dtype=fp)) / 3.0
# shocked_mass in grams: log10 = log10_shocked_mass + log10_SOLAR_MASS
log10_shocked_mass_g = log10_shocked_mass + jnp.array(_LOG10_SOLAR_MASS, dtype=fp)
# vej_kms in cm/s: log10 = log10_vej_kms + log10_KM_CGS
log10_vej_cms = log10_vej_kms + jnp.array(_LOG10_KM_CGS, dtype=fp)
# tau_diff in seconds: log10 = log10_tau_diff + log10_DAY_TO_S
log10_tau_diff_s = log10_tau_diff + jnp.array(_LOG10_DAY_TO_S, dtype=fp)
log10_l0 = (log10_theta_fac
+ log10_shocked_mass_g
+ log10_vej_cms
+ log10_rshock
- 2.0 * log10_tau_diff_s)
# lbol = l0 * (t/tau_diff)^(-4/(eta+2)) * (1 + tanh(t_thin - t)) / 2
# log10_lbol = log10_l0 + (-4/(eta+2)) * (log10_t - log10_tau_diff)
# + log10((1 + tanh(t_thin - t))/2) -- this last factor is < 1, safe in linear
t_safe = jnp.maximum(time, jnp.array(1e-30, dtype=fp))
log10_t = jnp.log10(t_safe)
power_exp = -4.0 / (eta_f + 2.0)
log10_lbol = log10_l0 + power_exp * (log10_t - log10_tau_diff)
# Apply thin-shell tapering in log10 space: log10(lbol * taper) = log10_lbol + log10(taper)
# taper = (1 + tanh(t_thin - t)) / 2 is in (0, 1], safe to log10
taper = (1.0 + jnp.tanh(t_thin - time)) / 2.0
log10_taper = jnp.log10(jnp.maximum(taper, jnp.array(1e-30, dtype=fp)))
return log10_lbol + log10_taper
# ---------------------------------------------------------------------------
# Shock cooling + Arnett (Ni/Co decay) combined bolometric model
# ---------------------------------------------------------------------------
[docs]
@citation_wrapper('https://ui.adsabs.harvard.edu/abs/2021ApJ...909..209P/abstract, '
'https://ui.adsabs.harvard.edu/abs/1982ApJ...253..785A/abstract')
@jit
def shock_cooling_and_arnett_bolometric(time, log10_mass, log10_radius, log10_energy,
nn, delta, f_nickel, mej, vej,
kappa, kappa_gamma):
"""
Combined bolometric light curve: shock cooling (Piro 2021) + Ni/Co decay (Arnett 1982).
The two components are summed in luminosity space (log-sum-exp), which is float32-safe.
:param time: source-frame time in days
:param log10_mass: log10 envelope mass in solar masses
:param log10_radius: log10 envelope radius in cm
:param log10_energy: log10 explosion energy in erg
:param nn: outer density power-law slope
:param delta: inner density power-law slope
:param f_nickel: nickel mass fraction
:param mej: total ejecta mass in solar masses
:param vej: ejecta velocity in km/s
:param kappa: optical opacity in cm^2/g
:param kappa_gamma: gamma-ray opacity in cm^2/g
:return: log10 of bolometric luminosity in erg/s
"""
from redback_jax.models.supernova_models import arnett_bolometric
fp = time.dtype
log10_sc = shock_cooling_bolometric(time, log10_mass, log10_radius, log10_energy,
nn, delta, kappa)
log10_ar = arnett_bolometric(time, f_nickel=f_nickel, mej=mej, vej=vej,
kappa=kappa, kappa_gamma=kappa_gamma)
# log10(L1 + L2) = log10_max + log10(1 + 10^(log10_min - log10_max))
log10_max = jnp.maximum(log10_sc, log10_ar)
log10_min = jnp.minimum(log10_sc, log10_ar)
return log10_max + jnp.log10(
jnp.array(1.0, dtype=fp)
+ jnp.power(jnp.array(10.0, dtype=fp), log10_min - log10_max)
)