Source code for ligo.skymap.bayestar

#
# Copyright (C) 2013-2025  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Rapid sky localization with BAYESTAR [1]_.

References
----------
.. [1] Singer & Price, 2016. "Rapid Bayesian position reconstruction for
       gravitational-wave transients." PRD, 93, 024013.
       :doi:`10.1103/PhysRevD.93.024013`

"""

import inspect
import logging
import os
import sys
from textwrap import wrap

import lal
import lalsimulation
import numpy as np
from astropy import units as u
from astropy.table import Column, Table

from .. import core, distance, healpix_tree, moc, version
from ..core import antenna_factor, signal_amplitude_model
from ..core import log_posterior_toa_phoa_snr as _log_posterior_toa_phoa_snr
from ..io.events.base import Event
from ..io.fits import metadata_for_version_module
from ..io.hdf5 import write_samples
from ..kde import Clustered2Plus1DSkyKDE
from ..util.numpy import require_contiguous_aligned
from ..util.stopwatch import Stopwatch
from . import filter
from .ez_emcee import ez_emcee

__all__ = (
    "derasterize",
    "localize",
    "rasterize",
    "antenna_factor",
    "signal_amplitude_model",
)

log = logging.getLogger("BAYESTAR")

_RESCALE_LOGLIKELIHOOD = 0.83

signal_amplitude_model = require_contiguous_aligned(signal_amplitude_model)
_log_posterior_toa_phoa_snr = require_contiguous_aligned(_log_posterior_toa_phoa_snr)


# Wrap so that ufunc parameter names are known
def log_posterior_toa_phoa_snr(
    ra,
    sin_dec,
    distance,
    u,
    twopsi,
    t,
    min_distance,
    max_distance,
    prior_distance_power,
    cosmology,
    gmst,
    sample_rate,
    epochs,
    snrs,
    responses,
    locations,
    horizons,
    rescale_loglikelihood=_RESCALE_LOGLIKELIHOOD,
):
    return _log_posterior_toa_phoa_snr(
        ra,
        sin_dec,
        distance,
        u,
        twopsi,
        t,
        min_distance,
        max_distance,
        prior_distance_power,
        cosmology,
        gmst,
        sample_rate,
        epochs,
        snrs,
        responses,
        locations,
        horizons,
        rescale_loglikelihood,
    )


# Wrap so that fixed parameter values are pulled from keyword arguments
def log_post(params, *args, **kwargs):
    # Names of parameters
    keys = ("ra", "sin_dec", "distance", "u", "twopsi", "t")

    params = list(params.T)
    kwargs = dict(kwargs)

    return log_posterior_toa_phoa_snr(
        *(kwargs.pop(key) if key in kwargs else params.pop(0) for key in keys),
        *args,
        **kwargs,
    )


def localize_emcee(args, xmin, xmax, chain_dump=None):
    # Gather posterior samples
    chain = ez_emcee(log_post, xmin, xmax, args=args, vectorize=True)

    # Transform back from sin_dec to dec and cos_inclination to inclination
    chain[:, 1] = np.arcsin(chain[:, 1])
    chain[:, 3] = np.arccos(chain[:, 3])

    # Optionally save posterior sample chain to file.
    if chain_dump:
        _, ndim = chain.shape
        names = "ra dec distance inclination twopsi time".split()[:ndim]
        write_samples(
            Table(rows=chain, names=names, copy=False),
            chain_dump,
            path="/bayestar/posterior_samples",
            overwrite=True,
        )

    # Pass a random subset of 1000 points to the KDE, to save time.
    pts = np.random.permutation(chain)[:1000, :3]
    ckde = Clustered2Plus1DSkyKDE(pts)
    return ckde.as_healpix()


def condition(
    event,
    waveform="o2-uberbank",
    f_low=30.0,
    enable_snr_series=True,
    f_high_truncate=0.95,
):
    if len(event.singles) == 0:
        raise ValueError("Cannot localize an event with zero detectors.")

    singles = event.singles
    if not enable_snr_series:
        singles = [single for single in singles if single.snr is not None]

    ifos = [single.detector for single in singles]

    # Extract SNRs from table.
    snrs = np.ma.asarray(
        [np.ma.masked if single.snr is None else single.snr for single in singles]
    )

    # Look up physical parameters for detector.
    detectors = [lalsimulation.DetectorPrefixToLALDetector(str(ifo)) for ifo in ifos]
    responses = np.asarray([det.response for det in detectors])
    locations = np.asarray([det.location for det in detectors]) / lal.C_SI

    # Power spectra for each detector.
    psds = [single.psd for single in singles]
    psds = [
        filter.InterpolatedPSD(
            filter.abscissa(psd), psd.data.data, f_high_truncate=f_high_truncate
        )
        for psd in psds
    ]

    log.debug("calculating templates")
    H = filter.sngl_inspiral_psd(waveform, f_min=f_low, **event.template_args)

    log.debug("calculating noise PSDs")
    HS = [filter.signal_psd_series(H, S) for S in psds]

    # Signal models for each detector.
    log.debug("calculating Fisher matrix elements")
    signal_models = [filter.SignalModel(_) for _ in HS]

    # Get SNR=1 horizon distances for each detector.
    horizons = np.asarray(
        [signal_model.get_horizon_distance() for signal_model in signal_models]
    )

    weights = np.ma.asarray(
        [
            1 / np.square(signal_model.get_crb_toa_uncert(snr))
            for signal_model, snr in zip(signal_models, snrs)
        ]
    )

    # Center detector array.
    locations -= np.sum(locations * weights.reshape(-1, 1), axis=0) / np.sum(weights)

    if enable_snr_series:
        snr_series = [single.snr_series for single in singles]
        if all(s is None for s in snr_series):
            snr_series = None
    else:
        snr_series = None

    # Maximum barycentered arrival time error:
    # |distance from array barycenter to furthest detector| / c + 5 ms.
    # For LHO+LLO, this is 15.0 ms.
    # For an arbitrary terrestrial detector network, the maximum is 26.3 ms.
    max_abs_t = np.max(np.sqrt(np.sum(np.square(locations), axis=1))) + 0.005

    # Calculate a minimum arrival time prior for low-bandwidth signals.
    # This is important for early-warning events, for which the light travel
    # time between detectors may be only a tiny fraction of the template
    # autocorrelation time.
    #
    # The period of the autocorrelation function is 2 pi times the timing
    # uncertainty at SNR=1 (see Eq. (24) of [1]_). We want a quarter period:
    # there is a factor of a half because we want the first zero crossing,
    # and another factor of a half because the SNR time series will be cropped
    # to a two-sided interval.
    max_acor_t = max(
        0.5 * np.pi * signal_model.get_crb_toa_uncert(1)
        for signal_model in signal_models
    )
    max_abs_t = max(max_abs_t, max_acor_t)

    if snr_series is None:
        log.warning(
            "No SNR time series found, so we are creating a "
            "zero-noise SNR time series from the whitened template's "
            "autocorrelation sequence. The sky localization "
            "uncertainty may be underestimated."
        )

        acors, sample_rates = zip(*[filter.autocorrelation(_, max_abs_t) for _ in HS])
        sample_rate = sample_rates[0]
        deltaT = 1 / sample_rate
        nsamples = len(acors[0])
        assert all(sample_rate == _ for _ in sample_rates)
        assert all(nsamples == len(_) for _ in acors)
        nsamples = nsamples * 2 - 1

        snr_series = []
        for acor, single in zip(acors, singles):
            series = lal.CreateCOMPLEX8TimeSeries(
                "fake SNR", 0, 0, deltaT, lal.StrainUnit, nsamples
            )
            series.epoch = single.time - 0.5 * (nsamples - 1) * deltaT
            acor = np.concatenate((np.conj(acor[:0:-1]), acor))
            series.data.data = single.snr * filter.exp_i(single.phase) * acor
            snr_series.append(series)

    # Ensure that all of the SNR time series have the same sample rate.
    # FIXME: for now, the Python wrapper expects all of the SNR time series to
    # also be the same length.
    deltaT = snr_series[0].deltaT
    sample_rate = 1 / deltaT
    if any(deltaT != series.deltaT for series in snr_series):
        raise ValueError(
            "BAYESTAR does not yet support SNR time series with mixed sample rates"
        )

    # Ensure that all of the SNR time series have odd lengths.
    if any(len(series.data.data) % 2 == 0 for series in snr_series):
        raise ValueError("SNR time series must have odd lengths")

    # Trim time series to the desired length.
    max_abs_n = int(np.ceil(max_abs_t * sample_rate))
    desired_length = 2 * max_abs_n - 1
    for i, series in enumerate(snr_series):
        length = len(series.data.data)
        if length > desired_length:
            snr_series[i] = lal.CutCOMPLEX8TimeSeries(
                series, length // 2 + 1 - max_abs_n, desired_length
            )

    # FIXME: for now, the Python wrapper expects all of the SNR time sries to
    # also be the same length.
    nsamples = len(snr_series[0].data.data)
    if any(nsamples != len(series.data.data) for series in snr_series):
        raise ValueError(
            "BAYESTAR does not yet support SNR time series of mixed lengths"
        )

    # Perform sanity checks that the middle sample of the SNR time series match
    # the sngl_inspiral records to the nearest sample (plus the smallest
    # representable LIGOTimeGPS difference of 1 nanosecond).
    for ifo, single, series in zip(ifos, singles, snr_series):
        shift = np.abs(
            0.5 * (nsamples - 1) * series.deltaT + float(series.epoch - single.time)
        )
        if shift >= deltaT + 1e-8:
            raise ValueError(
                "BAYESTAR expects the SNR time series to be "
                "centered on the single-detector trigger times, "
                "but {} was off by {} s".format(ifo, shift)
            )

    # Extract the TOAs in GPS nanoseconds from the SNR time series, assuming
    # that the trigger happened in the middle.
    toas_ns = [
        series.epoch.ns() + 1e9 * 0.5 * (len(series.data.data) - 1) * series.deltaT
        for series in snr_series
    ]

    # Collect all of the SNR series in one array.
    snr_series = np.vstack([series.data.data for series in snr_series])

    # Center times of arrival and compute GMST at mean arrival time.
    # Pre-center in integer nanoseconds to preserve precision of
    # initial datatype.
    epoch = sum(toas_ns) // len(toas_ns)
    toas = 1e-9 * (np.asarray(toas_ns) - epoch)
    mean_toa = np.average(toas, weights=weights)
    toas -= mean_toa
    epoch += int(np.round(1e9 * mean_toa))
    epoch = lal.LIGOTimeGPS(0, int(epoch))

    # Translate SNR time series back to time of first sample.
    toas -= 0.5 * (nsamples - 1) * deltaT

    # Convert complex SNRS to amplitude and phase
    snrs_abs = np.abs(snr_series)
    snrs_arg = filter.unwrap(np.angle(snr_series))
    snrs = np.stack((snrs_abs, snrs_arg), axis=-1)

    return epoch, sample_rate, toas, snrs, responses, locations, horizons


def condition_prior(
    horizons,
    min_distance=None,
    max_distance=None,
    prior_distance_power=None,
    cosmology=False,
):
    if cosmology:
        log.warning("Enabling cosmological prior. This feature is UNREVIEWED.")

    # If minimum distance is not specified, then default to 0 Mpc.
    if min_distance is None:
        min_distance = 0

    # If maximum distance is not specified, then default to the SNR=4
    # horizon distance of the most sensitive detector.
    if max_distance is None:
        max_distance = max(horizons) / 4

    # If prior_distance_power is not specified, then default to 2
    # (p(r) ~ r^2, uniform in volume).
    if prior_distance_power is None:
        prior_distance_power = 2

    # Raise an exception if 0 Mpc is the minimum effective distance and the
    # prior is of the form r**k for k<0
    if min_distance == 0 and prior_distance_power < 0:
        raise ValueError(
            ("Prior is a power law r^k with k={}, undefined at min_distance=0").format(
                prior_distance_power
            )
        )

    return min_distance, max_distance, prior_distance_power, cosmology


[docs] def localize( event, waveform="o2-uberbank", f_low=30.0, min_distance=None, max_distance=None, prior_distance_power=None, cosmology=False, mcmc=False, chain_dump=None, enable_snr_series=True, f_high_truncate=0.95, rescale_loglikelihood=_RESCALE_LOGLIKELIHOOD, enable_jax=False, ): """Localize a compact binary signal using the BAYESTAR algorithm. Parameters ---------- event : `ligo.skymap.io.events.Event` The event candidate. waveform : str, optional The name of the waveform approximant. f_low : float, optional The low frequency cutoff. min_distance, max_distance : float, optional The limits of integration over luminosity distance, in Mpc (default: determine automatically from detector sensitivity). prior_distance_power : int, optional The power of distance that appears in the prior (default: 2, uniform in volume). cosmology: bool, optional Set to enable a uniform in comoving volume prior (default: false). mcmc : bool, optional Set to use MCMC sampling rather than more accurate Gaussian quadrature. chain_dump : str, optional Save posterior samples to this filename if `mcmc` is set. enable_snr_series : bool, optional Set to False to disable SNR time series. f_high_truncate : float, optional Truncate the noise power spectral densities at this factor times the highest sampled frequency to suppress artifacts caused by incorrect PSD conditioning by some matched filter pipelines. Returns ------- skymap : `astropy.table.Table` A 3D sky map in multi-order HEALPix format. """ # Hide event parameters, but show all other arguments def formatvalue(value): if isinstance(value, Event): return "=..." else: return "=" + repr(value) frame = inspect.currentframe() argstr = inspect.formatargvalues( *inspect.getargvalues(frame), formatvalue=formatvalue ) stopwatch = Stopwatch() stopwatch.start() epoch, sample_rate, toas, snrs, responses, locations, horizons = condition( event, waveform=waveform, f_low=f_low, enable_snr_series=enable_snr_series, f_high_truncate=f_high_truncate, ) min_distance, max_distance, prior_distance_power, cosmology = condition_prior( horizons, min_distance, max_distance, prior_distance_power, cosmology ) gmst = lal.GreenwichMeanSiderealTime(epoch) # Time and run sky localization. log.debug("starting computationally-intensive section") args = ( min_distance, max_distance, prior_distance_power, cosmology, gmst, sample_rate, toas, snrs, responses, locations, horizons, rescale_loglikelihood, ) if mcmc: max_abs_t = 2 * snrs.data.shape[1] / sample_rate skymap = localize_emcee( args=args, xmin=[0, -1, min_distance, -1, 0, 0], xmax=[2 * np.pi, 1, max_distance, 1, 2 * np.pi, 2 * max_abs_t], chain_dump=chain_dump, ) elif enable_jax: # Import JAX localize from ..jaxcore.local import _MAX_NIFOS, _MAX_NSAMPLES, bsm_jax # Add padding to use the compiled function nifos = len(toas) nsamples = snrs.shape[1] snrs_padded = np.zeros((_MAX_NIFOS, _MAX_NSAMPLES, 2), dtype=np.float32) snrs_padded[:nifos, :nsamples, :] = snrs epochs_padded = np.zeros((_MAX_NIFOS,), dtype=np.float32) epochs_padded[:nifos] = toas responses_padded = np.zeros((_MAX_NIFOS, 3, 3), dtype=np.float32) responses_padded[:nifos] = responses locations_padded = np.zeros((_MAX_NIFOS, 3), dtype=np.float32) locations_padded[:nifos] = locations horizons_padded = np.zeros((_MAX_NIFOS,), dtype=np.float32) horizons_padded[:nifos] = horizons # Perform sky map calculation skymap, log_bci, log_bsn = bsm_jax( min_distance, max_distance, prior_distance_power, cosmology, gmst, nifos, nsamples, sample_rate, epochs_padded, snrs_padded, responses_padded, locations_padded, horizons_padded, rescale_loglikelihood, ) # Handle NumPy conversion from JAX skymap = np.asarray(skymap) log_bci, log_bsn = float(log_bci), float(log_bsn) dtype = [ ("UNIQ", "i8"), ("PROBDENSITY", "f8"), ("DISTMEAN", "f8"), ("DISTSTD", "f8"), ] structured = np.zeros(len(skymap), dtype=dtype) structured["UNIQ"] = skymap[:, 0].astype(np.int64) structured["PROBDENSITY"] = skymap[:, 1] structured["DISTMEAN"] = skymap[:, 2] structured["DISTSTD"] = skymap[:, 3] skymap = structured skymap = Table(skymap, copy=False) skymap.meta["log_bci"] = log_bci skymap.meta["log_bsn"] = log_bsn else: skymap, log_bci, log_bsn = core.toa_phoa_snr(*args) skymap = Table(skymap, copy=False) skymap.meta["log_bci"] = log_bci skymap.meta["log_bsn"] = log_bsn # Convert distance moments to parameters try: distmean = skymap.columns.pop("DISTMEAN") diststd = skymap.columns.pop("DISTSTD") except KeyError: distmean, diststd, _ = distance.parameters_to_moments( skymap["DISTMU"], skymap["DISTSIGMA"] ) else: skymap["DISTMU"], skymap["DISTSIGMA"], skymap["DISTNORM"] = ( distance.moments_to_parameters(distmean, diststd) ) # Add marginal distance moments good = np.isfinite(distmean) & np.isfinite(diststd) prob = (moc.uniq2pixarea(skymap["UNIQ"]) * skymap["PROBDENSITY"])[good] distmean = distmean[good] diststd = diststd[good] rbar = (prob * distmean).sum() r2bar = (prob * (np.square(diststd) + np.square(distmean))).sum() skymap.meta["distmean"] = rbar skymap.meta["diststd"] = np.sqrt(r2bar - np.square(rbar)) stopwatch.stop() end_time = lal.GPSTimeNow() log.info("finished computationally-intensive section in %s", stopwatch) # Fill in metadata and return. program, _ = os.path.splitext(os.path.basename(sys.argv[0])) skymap.meta.update(metadata_for_version_module(version)) skymap.meta["creator"] = "BAYESTAR" skymap.meta["origin"] = "LIGO/Virgo/KAGRA" skymap.meta["gps_time"] = float(epoch) skymap.meta["runtime"] = stopwatch.real skymap.meta["instruments"] = {single.detector for single in event.singles} skymap.meta["gps_creation_time"] = end_time skymap.meta["history"] = [ "", "Generated by calling the following Python function:", *wrap("{}.{}{}".format(__name__, frame.f_code.co_name, argstr), 72), "", "This was the command line that started the program:", *wrap(" ".join([program] + sys.argv[1:]), 72), ] return skymap
def rasterize(skymap, order=None): orig_order, _ = moc.uniq2nest(skymap["UNIQ"].max()) # Determine whether we need to do nontrivial downsampling. downsampling = ( order is not None and 0 <= order < orig_order and "DISTMU" in skymap.dtype.fields.keys() ) # If we are downsampling, then convert from distance parameters to # distance moments times probability density so that the averaging # that is automatically done by moc.rasterize() correctly marginalizes # the moments. if downsampling: skymap = Table(skymap, copy=True, meta=skymap.meta) probdensity = skymap["PROBDENSITY"] distmu = skymap.columns.pop("DISTMU") distsigma = skymap.columns.pop("DISTSIGMA") bad = ~(np.isfinite(distmu) & np.isfinite(distsigma)) distmean, diststd, _ = distance.parameters_to_moments(distmu, distsigma) distmean[bad] = np.nan diststd[bad] = np.nan skymap["DISTMEAN"] = probdensity * distmean skymap["DISTVAR"] = probdensity * (np.square(diststd) + np.square(distmean)) skymap = Table(moc.rasterize(skymap, order=order), meta=skymap.meta, copy=False) # If we are downsampling, then convert back to distance parameters. if downsampling: distmean = skymap.columns.pop("DISTMEAN") / skymap["PROBDENSITY"] diststd = np.sqrt( skymap.columns.pop("DISTVAR") / skymap["PROBDENSITY"] - np.square(distmean) ) skymap["DISTMU"], skymap["DISTSIGMA"], skymap["DISTNORM"] = ( distance.moments_to_parameters(distmean, diststd) ) skymap.rename_column("PROBDENSITY", "PROB") skymap["PROB"] *= 4 * np.pi / len(skymap) skymap["PROB"].unit = u.pixel**-1 return skymap def derasterize(skymap): skymap.rename_column("PROB", "PROBDENSITY") skymap["PROBDENSITY"] *= len(skymap) / (4 * np.pi) skymap["PROBDENSITY"].unit = u.steradian**-1 nside, _, ipix, _, _, value = zip(*healpix_tree.reconstruct_nested(skymap)) nside = np.asarray(nside) ipix = np.asarray(ipix) value = np.stack(value) uniq = 4 * np.square(nside) + ipix old_units = [column.unit for column in skymap.columns.values()] skymap = Table(value, meta=skymap.meta, copy=False) for old_unit, column in zip(old_units, skymap.columns.values()): column.unit = old_unit skymap.add_column(Column(uniq, name="UNIQ"), 0) skymap.sort("UNIQ") return skymap def test(): """Run BAYESTAR C unit tests. Examples -------- >>> test() 0 """ return int(core.test()) del require_contiguous_aligned