Source code for redback_jax.models.sed_features

"""
JAX-friendly classes for supernova modeling.
"""

from jax import jit
import jax.numpy as jnp
from jax.tree_util import register_pytree_node

from redback_jax.constants import *


[docs] class SEDFeatures: """A class representing a spectral feature(s) in the SED. This implements the PyTree interface for JAX, so it can be used in JIT-compiled functions. :param rest_wavelengths: Central wavelengths in Angstroms :param sigmas: Gaussian widths in Angstroms :param amplitudes: Amplitudes (negative=absorption, positive=emission), percentage of continuum (e.g., -0.4 = 40% absorption) :param t_starts: Start time in seconds :param t_ends: End time in seconds :param rise_times: Rise times in seconds (default: 2 days) :param fall_times: Fall times in seconds (default: 5 days) """ def __init__( self, rest_wavelengths, sigmas, amplitudes, t_starts, t_ends, rise_times=2.0 * 24.0 * 3600.0, fall_times=5.0 * 24.0 * 3600.0, ): self.rest_wavelengths = jnp.atleast_1d(rest_wavelengths) self.sigmas = jnp.atleast_1d(sigmas) self.amplitudes = jnp.atleast_1d(amplitudes) self.t_starts = jnp.atleast_1d(t_starts) self.t_ends = jnp.atleast_1d(t_ends) self.rise_times = jnp.atleast_1d(rise_times) self.fall_times = jnp.atleast_1d(fall_times)
[docs] @classmethod def from_feature_list(cls, feature_list): """Create SEDFeatures from a list of feature dictionaries. :param feature_list: List of dictionaries, each with keys: 'rest_wavelength', 'sigma', 'amplitude', 't_start', 't_end', 'rise_time', 'fall_time' :return: SEDFeatures object """ rest_wavelengths = [] sigmas = [] amplitudes = [] t_starts = [] t_ends = [] rise_times = [] fall_times = [] for feature in feature_list: rest_wavelengths.append(feature["rest_wavelength"]) sigmas.append(feature["sigma"]) amplitudes.append(feature["amplitude"]) t_starts.append(feature["t_start"]) t_ends.append(feature["t_end"]) rise_times.append(feature.get("rise_time", 2.0 * 24.0 * 3600.0)) fall_times.append(feature.get("fall_time", 5.0 * 24.0 * 3600.0)) return cls( jnp.array(rest_wavelengths), jnp.array(sigmas), jnp.array(amplitudes), jnp.array(t_starts), jnp.array(t_ends), jnp.array(rise_times), jnp.array(fall_times), )
[docs] def tree_flatten(self): children = ( self.rest_wavelengths, self.sigmas, self.amplitudes, self.t_starts, self.t_ends, self.rise_times, self.fall_times, ) aux_data = None return children, aux_data
[docs] @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children)
[docs] @jit def calculate_smooth_evolution(self, time): """Calculate smooth transitions for a set of features. :param time: Time array in seconds (source frame) :return: time_factors array in [0, 1] representing the evolution of the feature over time """ t_starts = self.t_starts t_ends = self.t_ends rise_times = self.rise_times fall_times = self.fall_times # Broadcast time array for vectorized operations time_grid = time[None, :] # Shape: (1, n_times) # Compute the masks for each phase. in_rise = (time_grid >= t_starts[:, None]) & (time_grid < (t_starts + rise_times)[:, None]) in_plateau = (time_grid >= (t_starts + rise_times)[:, None]) & (time_grid < (t_ends - fall_times)[:, None]) in_fall = (time_grid >= (t_ends - fall_times)[:, None]) & (time_grid < t_ends[:, None]) # Calculate smooth transitions # Rise phase x_rise = (time_grid - t_starts[:, None]) / rise_times[:, None] rise_factors = 0.5 * (1 + jnp.tanh(6 * (x_rise - 0.5))) # Fall phase x_fall = (t_ends[:, None] - time_grid) / fall_times[:, None] fall_factors = 0.5 * (1 + jnp.tanh(6 * (x_fall - 0.5))) # Combine all phases. This will use 0.0 outside the time ranges. time_factors = ( in_rise.astype(float) * rise_factors + in_plateau.astype(float) * 1.0 + in_fall.astype(float) * fall_factors ) return time_factors
register_pytree_node(SEDFeatures, SEDFeatures.tree_flatten, SEDFeatures.tree_unflatten) # A Constant Non-feature object. The only setting that really matters is that amplitude is 0.0. NO_SED_FEATURES = SEDFeatures(100.0, 1.0, 0.0, 0.0, 10.0, 1.0, 2.0)
[docs] @jit def apply_sed_feature(features, base_flux, frequency, time): """Apply spectral features completely vectorized. :param features: SEDFeatures object :param base_flux as a 2-d array (time, wavelength) in erg/s/Hz/cm^2 :param frequency: frequency to calculate in Hz - Must be same length as time array or a single number. In source frame. :param time: time array in seconds (source frame). :return: modified flux_density as a 2-d array (time, wavelength) in erg/s/Hz/cm^2 """ # Convert frequency to wavelength freq_for_wavelength = jnp.atleast_1d(frequency) wavelength_angstrom = speed_of_light / freq_for_wavelength * 1e8 # Calculate the Gaussian profile. wl_diff = wavelength_angstrom[None, :] - features.rest_wavelengths[:, None] gaussian_profiles = jnp.exp(-0.5 * (wl_diff / features.sigmas[:, None]) ** 2) # Calculate the time factors for this feature time_factors = features.calculate_smooth_evolution(time) # flux is (time, freq) # Broadcast to (n_features, n_times, n_freq) feature_contributions = ( features.amplitudes[:, None, None] * time_factors[:, :, None] * gaussian_profiles[:, None, :] ) total_feature_factor = 1.0 + jnp.sum(feature_contributions, axis=0) return base_flux * total_feature_factor