"""
JAX-friendly classes for supernova modeling.
"""
import math as _math
import os as _os
from collections import namedtuple
import numpy as _np
from jax import jit
import jax.numpy as jnp
from scipy.interpolate import RegularGridInterpolator as _RGI
from wcosmo import wcosmo
from redback_jax.utils.citation_wrapper import citation_wrapper
from redback_jax.utils.cosmology import PLANCK18_H0, PLANCK18_OM0, MPC_TO_CM
from redback_jax.conversions import calc_kcorrected_properties, lambda_to_nu
from redback_jax.interaction_processes import (
diffusion_convert_luminosity,
csm_diffusion_convert_luminosity,
)
from redback_jax.models.sed_features import NO_SED_FEATURES, apply_sed_feature
from redback_jax.photosphere import compute_temperature_floor_log10
# ---------------------------------------------------------------------------
# Physical constants — Python floats (not astropy, avoids float64 promotion)
# ---------------------------------------------------------------------------
_SOLAR_MASS = 1.989e33 # g
_SPEED_OF_LIGHT = 2.998e10 # cm/s
_KM_CGS = 1.0e5 # cm/km
_DAY_TO_S = 86400.0 # s/day
_AU_CGS = 1.496e13 # cm/AU
_MPC_TO_CM = MPC_TO_CM # cm/Mpc (from redback_jax.utils.cosmology)
# Log10 of key constants (float32-safe pre-computation)
_LOG10_MSUN = _math.log10(_SOLAR_MASS)
_LOG10_CCGS = _math.log10(_SPEED_OF_LIGHT)
_LOG10_KM_CGS = _math.log10(_KM_CGS)
# Magnetar log10 constants
# erot = 2.6e52 * (mass_ns/1.4)^1.5 * p0^-2 [erg]
# tp = 1.3e5 * bp^-2 * p0^2 * (mass_ns/1.4)^1.5 / sin^2(theta_pb) [s]
_LOG10_EROT_COEFF = _math.log10(2.6e52)
_LOG10_TP_COEFF = _math.log10(1.3e5)
_LOG10_2_FLOAT = _math.log10(2.0)
# ---------------------------------------------------------------------------
# CSM table — loaded once at import time as static numpy arrays.
# Columns: eta, nn, Bf, Br, AA (300 rows = 10 eta × 30 nn)
# ---------------------------------------------------------------------------
_CSM_TABLE_PATH = _os.path.join(_os.path.dirname(_os.path.dirname(__file__)),
'tables', 'csm_table.txt')
_csm_eta_raw, _csm_nn_raw, _csm_bf_raw, _csm_br_raw, _csm_aa_raw = _np.loadtxt(
_CSM_TABLE_PATH, delimiter=',', unpack=True)
_csm_eta_unique = _np.unique(_csm_eta_raw) # 10 values 0–2
_csm_nn_unique = _np.unique(_csm_nn_raw) # 30 values 6–14
# Redback reshape: (10,30).T → (30,10); grid axes are (nn, eta)
_csm_AA_grid = _np.reshape(_csm_aa_raw, (10, 30)).T # (30, 10)
_csm_Bf_grid = _np.reshape(_csm_bf_raw, (10, 30)).T # (30, 10)
_csm_Br_grid = _np.reshape(_csm_br_raw, (10, 30)).T # (30, 10)
_csm_AA_interp = _RGI((_csm_nn_unique, _csm_eta_unique), _csm_AA_grid,
bounds_error=False, fill_value=None)
_csm_Bf_interp = _RGI((_csm_nn_unique, _csm_eta_unique), _csm_Bf_grid,
bounds_error=False, fill_value=None)
_csm_Br_interp = _RGI((_csm_nn_unique, _csm_eta_unique), _csm_Br_grid,
bounds_error=False, fill_value=None)
[docs]
def blackbody_to_flux_density(temperature, r_photosphere, dl, frequency):
"""
A general blackbody_to_flux_density formula
:param temperature: effective temperature in kelvin
:param r_photosphere: photosphere radius in cm
:param dl: luminosity_distance in cm
:param frequency: frequency to calculate in Hz
:return: flux_density in erg/s/Hz/cm^2
"""
# Use Python float constants to avoid astropy float64 promotion
_h = 6.626e-27 # erg s
_c = 2.998e10 # cm/s
_kB = 1.381e-16 # erg/K
num = 2.0 * jnp.pi * _h * frequency ** 3 * r_photosphere ** 2
denom = dl ** 2 * _c ** 2
frac = 1.0 / jnp.expm1((_h * frequency) / (_kB * temperature))
return num / denom * frac
@jit
def _nickelcobalt_log10_engine(time, f_nickel, mej):
"""Ni/Co decay engine — returns log10(L) in erg/s (float32-safe).
:param time: time in days
:param f_nickel: fraction of nickel mass
:param mej: total ejecta mass in solar masses
:return: log10 of bolometric luminosity in erg/s
"""
_log10_ni = _math.log10(6.45e43)
_log10_co = _math.log10(1.45e43)
ni56_life = 8.8 # days
co56_life = 111.3 # days
fp = time.dtype
log10_mni = jnp.log10(jnp.maximum(f_nickel * mej, jnp.array(1e-30, dtype=fp)))
log10_a = jnp.array(_log10_ni, dtype=fp) + (-time / ni56_life) * jnp.array(_math.log10(_math.e), dtype=fp)
log10_b = jnp.array(_log10_co, dtype=fp) + (-time / co56_life) * jnp.array(_math.log10(_math.e), dtype=fp)
log10_max = jnp.maximum(log10_a, log10_b)
log10_sum = log10_max + jnp.log10(
jnp.power(jnp.array(10.0, dtype=fp), log10_a - log10_max)
+ jnp.power(jnp.array(10.0, dtype=fp), log10_b - log10_max))
return log10_mni + log10_sum
[docs]
@citation_wrapper('https://ui.adsabs.harvard.edu/abs/1982ApJ...253..785A/abstract')
@jit
def arnett_bolometric(time, f_nickel, mej, *, vej=None, kappa=None, kappa_gamma=None):
"""
Bolometric Arnett (1982) light curve with Ni/Co decay engine + diffusion.
:param time: time in days
:param f_nickel: fraction of nickel mass
:param mej: total ejecta mass in solar masses
:param kappa: opacity in cm^2/g (required)
:param kappa_gamma: gamma-ray opacity in cm^2/g (required)
:param vej: ejecta velocity in km/s (required)
:return: log10 of bolometric luminosity in erg/s
"""
dense_times = jnp.linspace(0.01, time[-1] + 100.0, 1000)
log10_dense_lbols = _nickelcobalt_log10_engine(dense_times, f_nickel, mej)
_, log10_lum = diffusion_convert_luminosity(
time=time, dense_times=dense_times, log10_luminosity=log10_dense_lbols,
mej=mej, kappa=kappa, kappa_gamma=kappa_gamma, vej=vej)
return log10_lum
@citation_wrapper('https://ui.adsabs.harvard.edu/abs/1982ApJ...253..785A/abstract')
@jit
def arnett_with_features_lum_dist(
f_nickel, mej, *, redshift=0.0, lum_dist=None,
vej=None, kappa=None, kappa_gamma=None,
temperature_floor=None, features=NO_SED_FEATURES,
):
"""
Arnett model with spectra — SED has time-evolving spectral features.
:param redshift: source redshift
:param f_nickel: fraction of nickel mass
:param mej: total ejecta mass in solar masses
:param lum_dist: luminosity distance in cm
:param kappa: opacity in cm^2/g (required)
:param kappa_gamma: gamma-ray opacity in cm^2/g (required)
:param vej: ejecta velocity in km/s (required)
:param temperature_floor: floor temperature in K
:param features: SEDFeatures object
:return: namedtuple(time, lambdas, spectra)
"""
lambda_observer_frame = jnp.geomspace(100.0, 60000.0, 100)
time_temp = jnp.geomspace(0.1, 3000.0, 3000) # days
time_observer_frame = time_temp * (1.0 + redshift)
frequency, time = calc_kcorrected_properties(
frequency=lambda_to_nu(lambda_observer_frame),
redshift=redshift,
time=time_observer_frame,
)
# Compute log10(lbol) directly in log10 space (float32-safe)
dense_times = jnp.linspace(0.01, time[-1] + 100.0, 1000)
log10_ni = _math.log10(6.45e43)
log10_co = _math.log10(1.45e43)
ni56_life = 8.8; co56_life = 111.3
log10_a = log10_ni + (-dense_times / ni56_life) * _math.log10(_math.e)
log10_b = log10_co + (-dense_times / co56_life) * _math.log10(_math.e)
log10_max_ab = jnp.maximum(log10_a, log10_b)
log10_sum = log10_max_ab + jnp.log10(
jnp.power(10.0, log10_a - log10_max_ab)
+ jnp.power(10.0, log10_b - log10_max_ab))
log10_mni = jnp.log10(jnp.maximum(f_nickel * mej, 1e-30))
log10_dense = log10_mni + log10_sum
_, log10_lbol = diffusion_convert_luminosity(
time=time, dense_times=dense_times, log10_luminosity=log10_dense,
mej=mej, kappa=kappa, kappa_gamma=kappa_gamma, vej=vej)
T_ph, log10_r_ph = compute_temperature_floor_log10(
time=time, log10_luminosity=log10_lbol,
vej=vej, temperature_floor=temperature_floor)
fp = time.dtype
nu = frequency.astype(fp)
dl = jnp.asarray(lum_dist, dtype=fp)
log10_dl = jnp.log10(jnp.maximum(dl, jnp.array(1.0, dtype=fp)))
_H = 6.626e-27; _KB = 1.381e-16; _C = 2.998e10
x = (_H / _KB) * nu[None, :] / jnp.maximum(T_ph[:, None], jnp.array(1.0, dtype=fp))
x = jnp.clip(x, jnp.array(1e-10, dtype=fp), jnp.array(80.0, dtype=fp))
log10_2pi_h = jnp.array(_math.log10(2.0 * _math.pi * _H), dtype=fp)
log10_c2 = jnp.array(_math.log10(_C ** 2), dtype=fp)
log10_Fnu = (log10_2pi_h
+ 3.0 * jnp.log10(nu[None, :])
+ 2.0 * log10_r_ph[:, None]
- 2.0 * log10_dl
- log10_c2
- jnp.log10(jnp.expm1(x)))
spectral_flux_density = jnp.power(jnp.array(10.0, dtype=fp), log10_Fnu)
spectral_flux_density = apply_sed_feature(
features, spectral_flux_density, frequency, time)
lam = lambda_observer_frame.astype(fp)
spectra = spectral_flux_density * jnp.array(2.998e18, dtype=fp) / (lam[None, :] ** 2)
spectra = spectra * jnp.asarray(1.0 + redshift, dtype=fp)
return namedtuple('output', ['time', 'lambdas', 'spectra'])(
time=time_observer_frame, lambdas=lambda_observer_frame, spectra=spectra)
[docs]
@citation_wrapper('https://ui.adsabs.harvard.edu/abs/1982ApJ...253..785A/abstract')
def arnett_with_features_cosmology(
f_nickel, mej, *, redshift=0.0, cosmo_H0=PLANCK18_H0, cosmo_Om0=PLANCK18_OM0,
vej=None, kappa=None, kappa_gamma=None,
temperature_floor=None, features=NO_SED_FEATURES,
):
"""
Arnett model with cosmological luminosity distance calculation.
:param redshift: source redshift
:param f_nickel: fraction of nickel mass
:param mej: total ejecta mass in solar masses
:param cosmo_H0: Hubble constant (km/s/Mpc)
:param cosmo_Om0: matter density parameter
:param kappa: opacity in cm^2/g (required)
:param kappa_gamma: gamma-ray opacity in cm^2/g (required)
:param vej: ejecta velocity in km/s (required)
:param temperature_floor: floor temperature in K
:param features: SEDFeatures object
:return: namedtuple(time, lambdas, spectra)
"""
dl = wcosmo.luminosity_distance(redshift, cosmo_H0, cosmo_Om0).value * _MPC_TO_CM
return arnett_with_features_lum_dist(
f_nickel=f_nickel, mej=mej, redshift=redshift, lum_dist=dl,
vej=vej, kappa=kappa, kappa_gamma=kappa_gamma,
temperature_floor=temperature_floor, features=features)
# ---------------------------------------------------------------------------
# Magnetar-powered supernova
# Reference: Kasen & Bildsten 2010, Inserra et al. 2013, Yu et al. 2017
# ---------------------------------------------------------------------------
@jit
def _magnetar_log10_lbol(time_days, log10_p0_ms, log10_bp, mass_ns, theta_pb):
"""
Dipole spin-down log10 luminosity (float32-safe). Returns log10(L) in erg/s.
:param time_days: source-frame time in days
:param log10_p0_ms: log10 of initial spin period in milliseconds
:param log10_bp: log10 of polar B-field in units of 10^14 G
:param mass_ns: NS mass in solar masses
:param theta_pb: spin–B-field angle in radians
:return: log10 of luminosity in erg/s
"""
t_s = time_days * _DAY_TO_S
log10_mass_ratio = jnp.log10(mass_ns / 1.4)
log10_erot = (_LOG10_EROT_COEFF + 1.5 * log10_mass_ratio - 2.0 * log10_p0_ms)
log10_tp = (_LOG10_TP_COEFF - 2.0 * log10_bp + 2.0 * log10_p0_ms
+ 1.5 * log10_mass_ratio
- jnp.log10(jnp.maximum(jnp.sin(theta_pb) ** 2, 1e-10)))
tp = jnp.power(10.0, log10_tp)
log10_L = (_LOG10_2_FLOAT + log10_erot - log10_tp
- 2.0 * jnp.log10(1.0 + 2.0 * t_s / tp))
return log10_L
@citation_wrapper('https://ui.adsabs.harvard.edu/abs/2006ApJ...648L..51S/abstract')
@jit
def _basic_magnetar_engine(time_days, p0, bp, mass_ns, theta_pb):
"""
Dipole spin-down — returns log10(L) in erg/s.
:param time_days: source-frame time in days
:param p0: initial spin period in milliseconds
:param bp: polar B-field in units of 10^14 G
:param mass_ns: NS mass in solar masses
:param theta_pb: spin–B-field angle in radians
:return: log10 of luminosity in erg/s
"""
return _magnetar_log10_lbol(
time_days,
jnp.log10(jnp.maximum(p0, 1e-10)),
jnp.log10(jnp.maximum(bp, 1e-10)),
mass_ns, theta_pb)
[docs]
@citation_wrapper('https://ui.adsabs.harvard.edu/abs/2017ApJ...850...55N/abstract')
@jit
def magnetar_powered_bolometric(time, p0, bp, mass_ns, theta_pb,
mej, kappa, kappa_gamma, vej):
"""
Bolometric light curve of a magnetar-powered supernova (Arnett diffusion).
:param time: source-frame time in days
:param p0: initial spin period in milliseconds
:param bp: polar B-field in units of 10^14 G
:param mass_ns: NS mass in solar masses
:param theta_pb: spin–B-field angle in radians
:param mej: ejecta mass in solar masses
:param kappa: optical opacity in cm^2/g
:param kappa_gamma: gamma-ray opacity in cm^2/g
:param vej: ejecta velocity in km/s
:return: log10 of bolometric luminosity in erg/s
"""
dense_times = jnp.linspace(0.01, time[-1] + 100.0, 1000)
log10_p0 = jnp.log10(jnp.maximum(p0, 1e-10))
log10_bp = jnp.log10(jnp.maximum(bp, 1e-10))
log10_dense_lbols = _magnetar_log10_lbol(dense_times, log10_p0, log10_bp, mass_ns, theta_pb)
_, log10_lbol = diffusion_convert_luminosity(
time=time, dense_times=dense_times, log10_luminosity=log10_dense_lbols,
kappa=kappa, kappa_gamma=kappa_gamma, mej=mej, vej=vej)
return log10_lbol
# ---------------------------------------------------------------------------
# CSM interaction
# Reference: Chevalier & Fransson 1994, Chatzopoulos et al. 2013,
# Villar et al. 2017, Jacobson-Galan et al. 2020
# ---------------------------------------------------------------------------
def _get_csm_coefficients(nn, eta):
"""
Lookup AA, Bf, Br from the pre-loaded CSM table via scipy interpolation.
Runs *outside* JIT (called with concrete Python/numpy scalars).
"""
pt = _np.array([[nn, eta]])
AA = float(_csm_AA_interp(pt)[0])
Bf = float(_csm_Bf_interp(pt)[0])
Br = float(_csm_Br_interp(pt)[0])
return AA, Bf, Br
@jit
def _csm_engine(time, mej, csm_mass, vej, eta, rho, kappa, r0, nn, AA, Bf, Br,
delta, efficiency):
"""
JAX CSM interaction engine (Chevalier 1982 forward/reverse shocks).
nn, AA, Bf, Br are concrete floats — passed in from the static table lookup.
Uses log10 arithmetic for Esn and g_n to stay float32-safe.
:param time: source-frame time in days
:param mej: ejecta mass in solar masses
:param csm_mass: CSM mass in solar masses
:param vej: ejecta velocity in km/s
:param eta: CSM density profile exponent
:param rho: CSM density amplitude in g/cm^3
:param kappa: opacity in cm^2/g
:param r0: inner CSM radius in AU
:param nn: ejecta density power-law slope (concrete float)
:param AA, Bf, Br: CSM shock coefficients (concrete floats from table)
:param delta: inner ejecta density slope
:param efficiency: kinetic-to-luminosity conversion efficiency
:return: (lbol, r_photosphere, mass_csm_threshold)
"""
mej_g = mej * _SOLAR_MASS
csm_mass_g = csm_mass * _SOLAR_MASS
r0_cm = r0 * _AU_CGS
vej_cms = vej * _KM_CGS
# Esn = 3 * vej^2 * mej / 10 [erg] — computed in log10 to avoid overflow
log10_Esn = (jnp.log10(3.0 / 10.0) + 2.0 * jnp.log10(vej_cms)
+ jnp.log10(mej_g))
Esn = jnp.power(10.0, log10_Esn)
ti = 1.0 # seconds offset
qq = rho * r0_cm ** eta
radius_csm = ((3.0 - eta) / (4.0 * jnp.pi * qq) * csm_mass_g
+ r0_cm ** (3.0 - eta)) ** (1.0 / (3.0 - eta))
r_photosphere = jnp.abs(
(-2.0 * (1.0 - eta) / (3.0 * kappa * qq)
+ radius_csm ** (1.0 - eta)) ** (1.0 / (1.0 - eta)))
mass_csm_threshold = jnp.abs(
4.0 * jnp.pi * qq / (3.0 - eta)
* (r_photosphere ** (3.0 - eta) - r0_cm ** (3.0 - eta)))
# g_n in log10 to avoid overflow
# g_n = 1/(4pi*(nn-delta)) * [2*(5-delta)*(nn-5)*Esn]^((nn-3)/2) / [(3-delta)*(nn-3)*mej_g]^((nn-5)/2)
log10_g_n = (- jnp.log10(4.0 * jnp.pi * (nn - delta))
+ ((nn - 3.0) / 2.0) * jnp.log10(jnp.maximum(
2.0 * (5.0 - delta) * (nn - 5.0) * Esn, 1e-300))
- ((nn - 5.0) / 2.0) * jnp.log10(
(3.0 - delta) * (nn - 3.0) * mej_g))
g_n = jnp.power(10.0, log10_g_n)
t_FS = (jnp.abs(
(3.0 - eta) * qq ** ((3.0 - nn) / (nn - eta))
* (AA * g_n) ** ((eta - 3.0) / (nn - eta))
/ (4.0 * jnp.pi * Bf ** (3.0 - eta))
) ** ((nn - eta) / ((nn - 3.0) * (3.0 - eta)))
* mass_csm_threshold ** ((nn - eta) / ((nn - 3.0) * (3.0 - eta))))
t_RS = (vej_cms / (Br * (AA * g_n / qq) ** (1.0 / (nn - eta)))
* (1.0 - (3.0 - nn) * mej_g
/ (4.0 * jnp.pi * vej_cms ** (3.0 - nn) * g_n))
** (1.0 / (3.0 - nn))) ** ((nn - eta) / (eta - 3.0))
t_s = time * _DAY_TO_S + ti
exp_FS = (2.0 * nn + 6.0 * eta - nn * eta - 15.0) / (nn - eta)
lbol_FS = (2.0 * jnp.pi / (nn - eta) ** 3
* g_n ** ((5.0 - eta) / (nn - eta))
* qq ** ((nn - 5.0) / (nn - eta))
* (nn - 3.0) ** 2 * (nn - 5.0)
* Bf ** (5.0 - eta)
* AA ** ((5.0 - eta) / (nn - eta))
* t_s ** exp_FS)
lbol_RS = (2.0 * jnp.pi
* (AA * g_n / qq) ** ((5.0 - nn) / (nn - eta))
* Br ** (5.0 - nn) * g_n
* ((3.0 - eta) / (nn - eta)) ** 3
* t_s ** exp_FS)
lbol_FS = jnp.where(t_FS - t_s > 0, lbol_FS, 0.0)
lbol_RS = jnp.where(t_RS - t_s > 0, lbol_RS, 0.0)
lbol = efficiency * (lbol_FS + lbol_RS)
return lbol, r_photosphere, mass_csm_threshold
[docs]
@citation_wrapper('https://ui.adsabs.harvard.edu/abs/2018ApJS..236....6G/abstract')
@jit
def magnetar_nickel_bolometric(time, f_nickel, mej, p0, bp, mass_ns, theta_pb,
kappa, kappa_gamma, vej):
"""
Bolometric light curve powered by both a magnetar and Ni/Co radioactive decay
(Arnett diffusion). The two luminosity sources are added before diffusion.
Reference: Gomez et al. 2018 (https://ui.adsabs.harvard.edu/abs/2018ApJS..236....6G/abstract)
:param time: source-frame time in days
:param f_nickel: nickel mass fraction (M_Ni = f_nickel * mej)
:param mej: total ejecta mass in solar masses
:param p0: initial spin period in milliseconds
:param bp: polar B-field in units of 10^14 G
:param mass_ns: NS mass in solar masses
:param theta_pb: spin–B-field angle in radians
:param kappa: optical opacity in cm^2/g
:param kappa_gamma: gamma-ray opacity in cm^2/g
:param vej: ejecta velocity in km/s
:return: log10 of bolometric luminosity in erg/s
"""
dense_times = jnp.linspace(0.01, time[-1] + 100.0, 1000)
# Ni/Co decay engine in log10 space
log10_nickel = _nickelcobalt_log10_engine(dense_times, f_nickel, mej)
# Magnetar spin-down engine in log10 space
log10_p0 = jnp.log10(jnp.maximum(p0, jnp.array(1e-10, dtype=time.dtype)))
log10_bp = jnp.log10(jnp.maximum(bp, jnp.array(1e-10, dtype=time.dtype)))
log10_mag = _magnetar_log10_lbol(dense_times, log10_p0, log10_bp, mass_ns, theta_pb)
# Add the two luminosity sources in log10 space (logsumexp-style, float32-safe)
fp = time.dtype
log10_max = jnp.maximum(log10_nickel, log10_mag)
log10_combined = log10_max + jnp.log10(
jnp.power(jnp.array(10.0, dtype=fp), log10_nickel - log10_max)
+ jnp.power(jnp.array(10.0, dtype=fp), log10_mag - log10_max))
_, log10_lbol = diffusion_convert_luminosity(
time=time, dense_times=dense_times, log10_luminosity=log10_combined,
kappa=kappa, kappa_gamma=kappa_gamma, mej=mej, vej=vej)
return log10_lbol
[docs]
@citation_wrapper('https://ui.adsabs.harvard.edu/abs/2013ApJ...773...76C/abstract,'
'https://ui.adsabs.harvard.edu/abs/2017ApJ...849...70V/abstract,'
'https://ui.adsabs.harvard.edu/abs/2020RNAAS...4...16J/abstract')
def csm_interaction_bolometric(time, mej, csm_mass, vej, eta, rho, kappa, r0,
nn=12, delta=1, efficiency=0.5):
"""
Bolometric CSM-interaction light curve (Chevalier 1982 shocks + diffusion).
:param time: source-frame time in days
:param mej: ejecta mass in solar masses
:param csm_mass: CSM mass in solar masses
:param vej: ejecta velocity in km/s
:param eta: CSM density profile exponent
:param rho: CSM density amplitude in g/cm^3
:param kappa: opacity in cm^2/g
:param r0: inner CSM radius in AU
:param nn: ejecta density power-law slope (default 12)
:param delta: inner ejecta density slope (default 1)
:param efficiency: kinetic-to-luminosity efficiency (default 0.5)
:return: log10 of bolometric luminosity in erg/s
"""
AA, Bf, Br = _get_csm_coefficients(nn, eta)
dense_times_jnp = jnp.linspace(0.1, time[-1] + 100.0, 1000)
_nn, _AA, _Bf, _Br, _delta, _eff = float(nn), AA, Bf, Br, float(delta), float(efficiency)
@jit
def _engine_and_diffuse(time, dense_times):
dense_lbols, r_phot, mass_csm_thresh = _csm_engine(
dense_times, mej, csm_mass, vej, eta, rho, kappa, r0,
_nn, _AA, _Bf, _Br, _delta, _eff)
log10_dense = jnp.log10(jnp.maximum(dense_lbols, jnp.array(1e-30, dtype=dense_lbols.dtype)))
return csm_diffusion_convert_luminosity(
time=time, dense_times=dense_times, log10_luminosity=log10_dense,
kappa=kappa, r_photosphere=r_phot, mass_csm_threshold=mass_csm_thresh)
return _engine_and_diffuse(time, dense_times_jnp)