Transient Analysis Tutorial

This tutorial demonstrates how to use the JAX-friendly redback_jax.Transient and redback_jax.Spectrum classes for electromagnetic transient analysis.

Basic Transient Usage

Creating a Transient Object

The most basic way to create a transient is to provide time and flux data:

import numpy as np
import jax.numpy as jnp
from redback_jax import Transient

# Generate synthetic lightcurve data
time = np.linspace(0, 20, 30)
flux = 5 * np.exp(-time/8) + 0.2 * np.random.randn(30)
flux_err = 0.1 * np.abs(flux) + 0.05

# Create transient object
transient = Transient(
    time=time,
    y=flux,
    y_err=flux_err,
    data_mode='flux',
    name='My Transient',
    redshift=0.05
)

The Transient class automatically converts input arrays to JAX arrays for efficient computation.

Data Modes

The transient class supports several data modes:

  • 'flux': Flux measurements (erg cm⁻² s⁻¹)

  • 'magnitude': Magnitude measurements

  • 'luminosity': Luminosity measurements (erg s⁻¹)

  • 'flux_density': Flux density measurements (mJy)

  • 'counts': Count measurements

# Magnitude data
magnitude_transient = Transient(
    time=time,
    y=magnitude,
    y_err=magnitude_err,
    data_mode='magnitude'
)

Loading Data from Files

You can load transient data directly from CSV or text files:

# Load from CSV file
transient = Transient.from_data_file(
    'lightcurve_data.csv',
    data_mode='flux',
    name='SN2023A'
)

# Specify custom column names
transient = Transient.from_data_file(
    'data.txt',
    data_mode='magnitude',
    time_col='mjd',
    y_col='mag',
    y_err_col='mag_err'
)

Plotting Data

The transient class provides simple plotting functionality:

import matplotlib.pyplot as plt

# Basic plot
fig, ax = plt.subplots(figsize=(10, 6))
transient.plot_data(axes=ax, color='blue', alpha=0.7)
plt.show()

# Plot without error bars
transient.plot_data(show_errors=False, color='red')

Multi-band Photometry

For multi-band photometric data, include band information:

# Multi-band data
time = np.repeat(np.linspace(0, 10, 10), 3)  # 3 bands × 10 epochs
bands = np.tile(['g', 'r', 'i'], 10)
magnitude = np.random.uniform(16, 20, 30)  # Random magnitudes

transient = Transient(
    time=time,
    y=magnitude,
    data_mode='magnitude',
    bands=bands,
    name='Multi-band SN'
)

# Plot different bands with different colors
colors = {'g': 'green', 'r': 'red', 'i': 'orange'}
fig, ax = plt.subplots(figsize=(10, 6))

for band in ['g', 'r', 'i']:
    mask = transient.bands == band
    time_band = transient.time[mask]
    mag_band = transient.y[mask]

    ax.scatter(np.asarray(time_band), np.asarray(mag_band),
              color=colors[band], label=f'{band} band')

ax.set_xlabel(transient.xlabel)
ax.set_ylabel(transient.ylabel)
ax.invert_yaxis()  # Magnitudes increase downward
ax.legend()

Model Fitting and Plotting

You can overplot models on your data:

# Define a simple exponential model
def exponential_model(t, amplitude, decay_time, offset):
    return amplitude * jnp.exp(-t / decay_time) + offset

# Model parameters
model_params = {
    'amplitude': 3.0,
    'decay_time': 8.0,
    'offset': 0.5
}

# Plot data with model
fig, ax = plt.subplots(figsize=(10, 6))
transient.plot_data(axes=ax, label='Data')
transient.plot_model(
    model_func=exponential_model,
    model_params=model_params,
    axes=ax,
    color='red',
    label='Model'
)

Working with JAX Arrays

All data in the transient object are stored as JAX arrays, enabling efficient computation:

# Access data as JAX arrays
peak_time = jnp.argmax(transient.y)
peak_flux = jnp.max(transient.y)
mean_time = jnp.mean(transient.time)

# JAX operations work directly
log_flux = jnp.log10(transient.y)
flux_gradient = jnp.gradient(transient.y)

# Use JAX transformations
import jax

def compute_chi_squared(model_params):
    model_flux = exponential_model(transient.time, **model_params)
    residuals = (transient.y - model_flux) / transient.y_err
    return jnp.sum(residuals**2)

# Get gradients with respect to parameters
chi2_grad = jax.grad(compute_chi_squared)

Spectrum Analysis

The Spectrum class handles spectroscopic data:

Creating a Spectrum

from redback_jax import Spectrum

# Create spectrum data
wavelength = np.linspace(4000, 7000, 200)  # Angstroms
flux_density = np.random.random(200)  # Flux density
flux_err = 0.1 * flux_density

spectrum = Spectrum(
    wavelength=wavelength,
    flux_density=flux_density,
    flux_density_err=flux_err,
    name='SN Spectrum',
    time='Day 5'
)

Plotting Spectra

# Plot spectrum
fig, ax = plt.subplots(figsize=(12, 6))
spectrum.plot_data(axes=ax, color='purple', alpha=0.8)
plt.show()

Working with JAX in Spectra

# Spectrum operations with JAX
peak_wavelength = spectrum.wavelength[jnp.argmax(spectrum.flux_density)]
total_flux = jnp.trapz(spectrum.flux_density, spectrum.wavelength)

# Smooth spectrum with JAX
from scipy import ndimage
smoothed_flux = ndimage.gaussian_filter1d(
    np.asarray(spectrum.flux_density), sigma=2
)

Advanced Usage

Performance Tips

  1. Use JAX arrays from the start: If your data is already in JAX format, pass it directly to avoid unnecessary conversions.

  2. Batch operations: Process multiple transients using JAX’s vectorization capabilities.

  3. JIT compilation: Use jax.jit to compile model functions for faster evaluation.

import jax

# JIT-compile model for speed
@jax.jit
def fast_exponential_model(t, amplitude, decay_time, offset):
    return amplitude * jnp.exp(-t / decay_time) + offset

Custom Data Processing

Since all data are JAX arrays, you can easily implement custom processing:

def calculate_color(transient_g, transient_r):
    """Calculate g-r color from two transients."""
    # Interpolate r-band to g-band times
    r_interp = jnp.interp(transient_g.time, transient_r.time, transient_r.y)
    color = transient_g.y - r_interp
    return color

def phase_fold(transient, period, epoch):
    """Phase-fold a transient lightcurve."""
    phase = jnp.mod(transient.time - epoch, period) / period
    return phase

Error Handling

The transient classes include validation to catch common errors:

# This will raise a ValueError
try:
    bad_transient = Transient(
        time=np.array([1, 2, 3]),
        y=np.array([1, 2]),  # Wrong length!
        data_mode='flux'
    )
except ValueError as e:
    print(f"Error: {e}")

# This will also raise a ValueError
try:
    bad_mode = Transient(
        time=np.array([1, 2, 3]),
        y=np.array([1, 2, 3]),
        data_mode='invalid_mode'  # Invalid data mode!
    )
except ValueError as e:
    print(f"Error: {e}")

Next Steps

  • Explore the API Reference documentation for complete method references

  • Check out the example scripts in the examples/ directory

  • Learn about JAX-based model fitting in the models tutorial