"""Linear drift rate measurements using a 2D auto-correlation analysis
and Monte Carlo resampling.
"""
import copy
import os
import warnings
import matplotlib.pyplot as plt
import numpy as np
import scipy.optimize
import scipy.signal
from sarjana.signal import find_burst
from dfdt.ac_mc_drift import gauss_2d
def ac_mc_drift(
dedispersed_intensity,
dm_uncertainty,
source,
eventid,
ds,
sub_factor=64,
dm_trials=10,
mc_trials=10,
detection_confidence=99.73,
uncertainty_confidence=68.0,
plot_result=True,
plot_all=False,
peak=None,
width=None,
fdir="./results/",
):
"""Measure linear drift rate with a 2D auto-correlation method and
uncertainties with Monte Carlo resampling.
Parameters
----------
dedispersed_intensity : array_like
Dedispersed waterfall.
dm_uncertainty : float
Statistical uncertainty on DM, in pc cm-3, for resampling.
source : str
Source name, used for plotting purposes.
eventid : str
CHIME/FRB event ID, used for plotting purposes.
ds : :obj:DynamicSpectrum
Object holding intensity data parameters (i.e., time/frequency
resolution).
sub_factor : int
Factor to subband intensity data by, 64 by default.
dm_trials : int
Number of DM trials, 10 by default. If 1, do not resample the
DM at all.
mc_trials : int
Number of Monte Carlo trials, 10 by default. If 1, do not
resample the noise at all.
detection_confidence : float
Confidence interval in percent to calculate for results and to
display on plot, 99.7 (3sigma) by default.
uncertainty_confidence : float
Confidence interval in percent to calculate uncertainty region
for results, 68 (1sigma) by default.
plot_result : bool
Plot analysis results, True by default.
plot_all : bool
Plot all resampled 2D autocorrelations (for debuggin), False by
default.
peak : int, optional
Pulse peak position index. None by default.
width : int, optional
Pulse width, as a factor of sampling time. None by default.
Returns
-------
constrained : bool
Is measurement constrained or not? I.e., do all thetas fall in
the same quadrant.
dfdt_data : float
Linear drift rate from data, in MHz/ms.
dfdt_mc : float
Mean linear drift rate from MC trials, in MHz/ms.
dfdt_mc_low : float
Linear drift rate lower bound on containment interval from MC
trials, in MHz/ms.
dfdt_mc_high : float
Linear drift rate upper bound on containment interval from MC
trials, in MHz/ms.
"""
if not os.path.exists(fdir):
os.makedirs(fdir)
print("{} -- {} -- Analyzing..".format(source, eventid))
# mask out top channel (if not already masked)
dedispersed_intensity[0, ...] = np.nan
# mask out all outliers 3 sigma away from the channel mean
channel_means = np.nanmean(dedispersed_intensity, axis=1)
channel_stds = np.nanstd(dedispersed_intensity, axis=1)
threshold = np.repeat(
channel_means - 3 * channel_stds, dedispersed_intensity.shape[1]
).reshape(dedispersed_intensity.shape)
dedispersed_intensity[dedispersed_intensity < threshold] = np.nan
threshold = np.repeat(
channel_means + 3 * channel_stds, dedispersed_intensity.shape[1]
).reshape(dedispersed_intensity.shape)
dedispersed_intensity[dedispersed_intensity > threshold] = np.nan
# subtract mean (can also try median)
mean = np.nanmean(dedispersed_intensity, axis=1)
mean = np.repeat(mean, dedispersed_intensity.shape[1]).reshape(
dedispersed_intensity.shape
)
dedispersed_intensity = dedispersed_intensity - mean
subbanded_channel_bw = (
ds.df_mhz * (ds.nchan / dedispersed_intensity.shape[0])
)
center_frequencies = (
np.arange(
ds.freq_bottom_mhz,
ds.freq_top_mhz,
subbanded_channel_bw,
)
+ subbanded_channel_bw / 2.0
)
# calculate drift rate from data at best known DM
intensity = copy.deepcopy(dedispersed_intensity)
if peak is None or width is None:
ts = np.nansum(intensity, axis=0)
peak, width, snr = find_burst(ts, width_factor=4)
window = 100
# increase window for wide bursts
while width > 0.5 * window:
window += 100
sub = np.nanmean(
intensity.reshape(-1, sub_factor, intensity.shape[1]), axis=1)
median = np.nanmedian(sub)
sub[sub == 0.0] = median
sub[np.isnan(sub)] = median
waterfall = copy.deepcopy(sub[..., peak - window // 2 : peak + window // 2])
# select noise before (and after) the burst (if necessary)
noise_window = (peak - 3 * window // 2, peak - window // 2)
if noise_window[0] < 0:
difference = abs(noise_window[0])
noise_window = (noise_window[0] + difference,
noise_window[1] + difference)
noise_waterfall = copy.deepcopy(np.roll(
sub, difference, axis=1)[...,noise_window[0]:noise_window[1]]
)
else:
noise_waterfall = copy.deepcopy(
sub[...,noise_window[0] : noise_window[1]]
)
ac2d = scipy.signal.correlate2d(
waterfall, waterfall, mode="full", boundary="fill", fillvalue=0
)
ac2d[ac2d.shape[0] // 2, :] = np.nan
ac2d[:, ac2d.shape[1] // 2] = np.nan
scaled_ac2d = copy.deepcopy(ac2d)
scaling_factor = np.nanmax(scaled_ac2d)
scaled_ac2d = scaled_ac2d / scaling_factor
noise_ac2d = scipy.signal.correlate2d(
noise_waterfall, noise_waterfall, mode="full", boundary="fill",
fillvalue=0
)
noise_ac2d[noise_ac2d.shape[0] // 2, :] = np.nan
noise_ac2d[:, noise_ac2d.shape[1] // 2] = np.nan
scaled_noise_ac2d = copy.deepcopy(noise_ac2d)
scaled_noise_ac2d = scaled_noise_ac2d / scaling_factor
dts = (
np.arange(-ac2d.shape[1] / 2 + 1, ac2d.shape[1] / 2 + 1)
* ds.dt_s * 1e3
)
dfs = (
np.arange(-ac2d.shape[0] / 2 + 1, ac2d.shape[0] / 2 + 1)
* ds.df_mhz
* (ds.nchan / dedispersed_intensity.shape[0])
* sub_factor
)
# construct data model
nanmask = np.isnan(scaled_ac2d)
x, y = [arr.T for arr in np.meshgrid(dfs, dts)]
p0 = np.nanmax(scaled_ac2d), 45, 200, 0.2, np.nanmedian(scaled_noise_ac2d)
try:
p1, pcov = scipy.optimize.curve_fit(
gauss_2d, (x[~nanmask], y[~nanmask]),
scaled_ac2d[~nanmask].flatten(), p0=p0
)
# let theta range from -pi to pi
theta = p1[-2] % (2 * np.pi)
theta_sigma = np.sqrt(np.diag(pcov))[-2]
# rotate theta for drift rate calculation by 90 deg
sigma_x = p1[1]
sigma_y = p1[2]
if sigma_y > sigma_x:
theta -= np.pi / 2.
if theta > np.pi:
theta -= 2 * np.pi
dfdt_data = 1.0 / np.tan(-theta)
except:
# fall back on fit guess if fit did not converge
p1 = p0
theta = np.nan
theta_sigma = np.nan
dfdt_data = np.nan
print("{} -- {} -- df/dt (data) = {:.2f} MHz/ms".format(
source, eventid, dfdt_data))
return dfdt_data