Transient Analysis Tutorial
This tutorial demonstrates how to use the JAX-friendly redback_jax.Transient and redback_jax.Spectrum classes for electromagnetic transient analysis.
Basic Transient Usage
Creating a Transient Object
The most basic way to create a transient is to provide time and flux data:
import numpy as np
import jax.numpy as jnp
from redback_jax import Transient
# Generate synthetic lightcurve data
time = np.linspace(0, 20, 30)
flux = 5 * np.exp(-time/8) + 0.2 * np.random.randn(30)
flux_err = 0.1 * np.abs(flux) + 0.05
# Create transient object
transient = Transient(
time=time,
y=flux,
y_err=flux_err,
data_mode='flux',
name='My Transient',
redshift=0.05
)
The Transient class automatically converts input arrays to JAX arrays for efficient computation.
Data Modes
The transient class supports several data modes:
'flux': Flux measurements (erg cm⁻² s⁻¹)'magnitude': Magnitude measurements'luminosity': Luminosity measurements (erg s⁻¹)'flux_density': Flux density measurements (mJy)'counts': Count measurements
# Magnitude data
magnitude_transient = Transient(
time=time,
y=magnitude,
y_err=magnitude_err,
data_mode='magnitude'
)
Loading Data from Files
You can load transient data directly from CSV or text files:
# Load from CSV file
transient = Transient.from_data_file(
'lightcurve_data.csv',
data_mode='flux',
name='SN2023A'
)
# Specify custom column names
transient = Transient.from_data_file(
'data.txt',
data_mode='magnitude',
time_col='mjd',
y_col='mag',
y_err_col='mag_err'
)
Plotting Data
The transient class provides simple plotting functionality:
import matplotlib.pyplot as plt
# Basic plot
fig, ax = plt.subplots(figsize=(10, 6))
transient.plot_data(axes=ax, color='blue', alpha=0.7)
plt.show()
# Plot without error bars
transient.plot_data(show_errors=False, color='red')
Multi-band Photometry
For multi-band photometric data, include band information:
# Multi-band data
time = np.repeat(np.linspace(0, 10, 10), 3) # 3 bands × 10 epochs
bands = np.tile(['g', 'r', 'i'], 10)
magnitude = np.random.uniform(16, 20, 30) # Random magnitudes
transient = Transient(
time=time,
y=magnitude,
data_mode='magnitude',
bands=bands,
name='Multi-band SN'
)
# Plot different bands with different colors
colors = {'g': 'green', 'r': 'red', 'i': 'orange'}
fig, ax = plt.subplots(figsize=(10, 6))
for band in ['g', 'r', 'i']:
mask = transient.bands == band
time_band = transient.time[mask]
mag_band = transient.y[mask]
ax.scatter(np.asarray(time_band), np.asarray(mag_band),
color=colors[band], label=f'{band} band')
ax.set_xlabel(transient.xlabel)
ax.set_ylabel(transient.ylabel)
ax.invert_yaxis() # Magnitudes increase downward
ax.legend()
Model Fitting and Plotting
You can overplot models on your data:
# Define a simple exponential model
def exponential_model(t, amplitude, decay_time, offset):
return amplitude * jnp.exp(-t / decay_time) + offset
# Model parameters
model_params = {
'amplitude': 3.0,
'decay_time': 8.0,
'offset': 0.5
}
# Plot data with model
fig, ax = plt.subplots(figsize=(10, 6))
transient.plot_data(axes=ax, label='Data')
transient.plot_model(
model_func=exponential_model,
model_params=model_params,
axes=ax,
color='red',
label='Model'
)
Working with JAX Arrays
All data in the transient object are stored as JAX arrays, enabling efficient computation:
# Access data as JAX arrays
peak_time = jnp.argmax(transient.y)
peak_flux = jnp.max(transient.y)
mean_time = jnp.mean(transient.time)
# JAX operations work directly
log_flux = jnp.log10(transient.y)
flux_gradient = jnp.gradient(transient.y)
# Use JAX transformations
import jax
def compute_chi_squared(model_params):
model_flux = exponential_model(transient.time, **model_params)
residuals = (transient.y - model_flux) / transient.y_err
return jnp.sum(residuals**2)
# Get gradients with respect to parameters
chi2_grad = jax.grad(compute_chi_squared)
Spectrum Analysis
The Spectrum class handles spectroscopic data:
Creating a Spectrum
from redback_jax import Spectrum
# Create spectrum data
wavelength = np.linspace(4000, 7000, 200) # Angstroms
flux_density = np.random.random(200) # Flux density
flux_err = 0.1 * flux_density
spectrum = Spectrum(
wavelength=wavelength,
flux_density=flux_density,
flux_density_err=flux_err,
name='SN Spectrum',
time='Day 5'
)
Plotting Spectra
# Plot spectrum
fig, ax = plt.subplots(figsize=(12, 6))
spectrum.plot_data(axes=ax, color='purple', alpha=0.8)
plt.show()
Working with JAX in Spectra
# Spectrum operations with JAX
peak_wavelength = spectrum.wavelength[jnp.argmax(spectrum.flux_density)]
total_flux = jnp.trapz(spectrum.flux_density, spectrum.wavelength)
# Smooth spectrum with JAX
from scipy import ndimage
smoothed_flux = ndimage.gaussian_filter1d(
np.asarray(spectrum.flux_density), sigma=2
)
Advanced Usage
Performance Tips
Use JAX arrays from the start: If your data is already in JAX format, pass it directly to avoid unnecessary conversions.
Batch operations: Process multiple transients using JAX’s vectorization capabilities.
JIT compilation: Use
jax.jitto compile model functions for faster evaluation.
import jax
# JIT-compile model for speed
@jax.jit
def fast_exponential_model(t, amplitude, decay_time, offset):
return amplitude * jnp.exp(-t / decay_time) + offset
Custom Data Processing
Since all data are JAX arrays, you can easily implement custom processing:
def calculate_color(transient_g, transient_r):
"""Calculate g-r color from two transients."""
# Interpolate r-band to g-band times
r_interp = jnp.interp(transient_g.time, transient_r.time, transient_r.y)
color = transient_g.y - r_interp
return color
def phase_fold(transient, period, epoch):
"""Phase-fold a transient lightcurve."""
phase = jnp.mod(transient.time - epoch, period) / period
return phase
Error Handling
The transient classes include validation to catch common errors:
# This will raise a ValueError
try:
bad_transient = Transient(
time=np.array([1, 2, 3]),
y=np.array([1, 2]), # Wrong length!
data_mode='flux'
)
except ValueError as e:
print(f"Error: {e}")
# This will also raise a ValueError
try:
bad_mode = Transient(
time=np.array([1, 2, 3]),
y=np.array([1, 2, 3]),
data_mode='invalid_mode' # Invalid data mode!
)
except ValueError as e:
print(f"Error: {e}")
Next Steps
Explore the API Reference documentation for complete method references
Check out the example scripts in the
examples/directoryLearn about JAX-based model fitting in the models tutorial