import os
from pathlib import Path

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

import dfdt

from sarjana.handlers import ParquetWaterfall

burstpath = Path(os.getenv('DATAPATH'), '23891929_DM348.8_waterfall.npy')
"""Linear drift rate measurements using a 2D auto-correlation analysis
and Monte Carlo resampling.

"""

import copy
import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
import scipy.optimize
import scipy.signal

from sarjana.signal import find_burst
from dfdt.ac_mc_drift import gauss_2d

def ac_mc_drift(
    dedispersed_intensity,
    dm_uncertainty,
    source,
    eventid,
    ds,
    sub_factor=64,
    dm_trials=10,
    mc_trials=10,
    detection_confidence=99.73,
    uncertainty_confidence=68.0,
    plot_result=True,
    plot_all=False,
    peak=None,
    width=None,
    fdir="./results/",
):
    """Measure linear drift rate with a 2D auto-correlation method and
    uncertainties with Monte Carlo resampling.

    Parameters
    ----------
    dedispersed_intensity : array_like
        Dedispersed waterfall.
    dm_uncertainty : float
        Statistical uncertainty on DM, in pc cm-3, for resampling.
    source : str
        Source name, used for plotting purposes.
    eventid : str
        CHIME/FRB event ID, used for plotting purposes.
    ds : :obj:DynamicSpectrum
        Object holding intensity data parameters (i.e., time/frequency
        resolution).
    sub_factor : int
        Factor to subband intensity data by, 64 by default.
    dm_trials : int
        Number of DM trials, 10 by default. If 1, do not resample the
        DM at all.
    mc_trials : int
        Number of Monte Carlo trials, 10 by default. If 1, do not
        resample the noise at all.
    detection_confidence : float
        Confidence interval in percent to calculate for results and to
        display on plot, 99.7 (3sigma) by default.
    uncertainty_confidence : float
        Confidence interval in percent to calculate uncertainty region
        for results, 68 (1sigma) by default.
    plot_result : bool
        Plot analysis results, True by default.
    plot_all : bool
        Plot all resampled 2D autocorrelations (for debuggin), False by
        default.
    peak : int, optional
        Pulse peak position index. None by default.
    width : int, optional
        Pulse width, as a factor of sampling time. None by default.

    Returns
    -------
    constrained : bool
        Is measurement constrained or not? I.e., do all thetas fall in
        the same quadrant.
    dfdt_data : float
        Linear drift rate from data, in MHz/ms.
    dfdt_mc : float
        Mean linear drift rate from MC trials, in MHz/ms.
    dfdt_mc_low : float
        Linear drift rate lower bound on containment interval from MC
        trials, in MHz/ms.
    dfdt_mc_high : float
        Linear drift rate upper bound on containment interval from MC
        trials, in MHz/ms.

    """
    if not os.path.exists(fdir):
        os.makedirs(fdir)

    print("{} -- {} -- Analyzing..".format(source, eventid))

    # mask out top channel (if not already masked)
    dedispersed_intensity[0, ...] = np.nan

    # mask out all outliers 3 sigma away from the channel mean
    channel_means = np.nanmean(dedispersed_intensity, axis=1)
    channel_stds = np.nanstd(dedispersed_intensity, axis=1)
    threshold = np.repeat(
        channel_means - 3 * channel_stds, dedispersed_intensity.shape[1]
    ).reshape(dedispersed_intensity.shape)
    dedispersed_intensity[dedispersed_intensity < threshold] = np.nan
    threshold = np.repeat(
        channel_means + 3 * channel_stds, dedispersed_intensity.shape[1]
    ).reshape(dedispersed_intensity.shape)
    dedispersed_intensity[dedispersed_intensity > threshold] = np.nan

    # subtract mean (can also try median)
    mean = np.nanmean(dedispersed_intensity, axis=1)
    mean = np.repeat(mean, dedispersed_intensity.shape[1]).reshape(
        dedispersed_intensity.shape
    )

    dedispersed_intensity = dedispersed_intensity - mean

    subbanded_channel_bw = (
        ds.df_mhz * (ds.nchan / dedispersed_intensity.shape[0])
    )
    center_frequencies = (
        np.arange(
            ds.freq_bottom_mhz,
            ds.freq_top_mhz,
            subbanded_channel_bw,
        )
        + subbanded_channel_bw / 2.0
    )

    # calculate drift rate from data at best known DM
    intensity = copy.deepcopy(dedispersed_intensity)

    if peak is None or width is None:
        ts = np.nansum(intensity, axis=0)
        peak, width, snr = find_burst(ts, width_factor=4)

    window = 100

    # increase window for wide bursts
    while width > 0.5 * window:
        window += 100

    sub = np.nanmean(
        intensity.reshape(-1, sub_factor, intensity.shape[1]), axis=1)
    median = np.nanmedian(sub)
    sub[sub == 0.0] = median
    sub[np.isnan(sub)] = median

    waterfall = copy.deepcopy(sub[..., peak - window // 2 : peak + window // 2])

    # select noise before (and after) the burst (if necessary)
    noise_window = (peak - 3 * window // 2, peak - window // 2)

    if noise_window[0] < 0:
        difference = abs(noise_window[0])
        noise_window = (noise_window[0] + difference,
                        noise_window[1] + difference)
        noise_waterfall = copy.deepcopy(np.roll(
            sub, difference, axis=1)[...,noise_window[0]:noise_window[1]]
        )
    else:
        noise_waterfall = copy.deepcopy(
            sub[...,noise_window[0] : noise_window[1]]
        )

    ac2d = scipy.signal.correlate2d(
        waterfall, waterfall, mode="full", boundary="fill", fillvalue=0
    )

    ac2d[ac2d.shape[0] // 2, :] = np.nan
    ac2d[:, ac2d.shape[1] // 2] = np.nan

    scaled_ac2d = copy.deepcopy(ac2d)

    scaling_factor = np.nanmax(scaled_ac2d)
    scaled_ac2d = scaled_ac2d / scaling_factor

    noise_ac2d = scipy.signal.correlate2d(
        noise_waterfall, noise_waterfall, mode="full", boundary="fill",
        fillvalue=0
    )

    noise_ac2d[noise_ac2d.shape[0] // 2, :] = np.nan
    noise_ac2d[:, noise_ac2d.shape[1] // 2] = np.nan

    scaled_noise_ac2d = copy.deepcopy(noise_ac2d)

    scaled_noise_ac2d = scaled_noise_ac2d / scaling_factor

    dts = (
        np.arange(-ac2d.shape[1] / 2 + 1, ac2d.shape[1] / 2 + 1)
        * ds.dt_s * 1e3
    )
    dfs = (
        np.arange(-ac2d.shape[0] / 2 + 1, ac2d.shape[0] / 2 + 1)
        * ds.df_mhz
        * (ds.nchan / dedispersed_intensity.shape[0])
        * sub_factor
    )

    # construct data model
    nanmask = np.isnan(scaled_ac2d)
    x, y = [arr.T for arr in np.meshgrid(dfs, dts)]

    p0 = np.nanmax(scaled_ac2d), 45, 200, 0.2, np.nanmedian(scaled_noise_ac2d)
    try:
        p1, pcov = scipy.optimize.curve_fit(
            gauss_2d, (x[~nanmask], y[~nanmask]),
            scaled_ac2d[~nanmask].flatten(), p0=p0
        )

        # let theta range from -pi to pi
        theta = p1[-2] % (2 * np.pi)
        theta_sigma = np.sqrt(np.diag(pcov))[-2]

        # rotate theta for drift rate calculation by 90 deg
        sigma_x = p1[1]
        sigma_y = p1[2]
        if sigma_y > sigma_x:
            theta -= np.pi / 2.

        if theta > np.pi:
            theta -= 2 * np.pi
        dfdt_data = 1.0 / np.tan(-theta)
    except:
        # fall back on fit guess if fit did not converge
        p1 = p0

        theta = np.nan
        theta_sigma = np.nan
        dfdt_data = np.nan

    print("{} -- {} -- df/dt (data) = {:.2f} MHz/ms".format(
        source, eventid, dfdt_data))

    return dfdt_data
newburstpath = Path(os.getenv('DATAPATH'), 'raw', 'wfall', 'FRB20181224E_waterfall.h5.parquet')

dedispersed_intensity = np.load(burstpath)
burst = ParquetWaterfall(newburstpath)
wfall = burst.wfall

# burst parameters
dm_uncertainty = 0.2  # pc cm-3
source = "R3"
eventid = "23891929"

# instrument parameters
dt_s = 0.00098304
df_mhz = 0.0244140625
nchan = 16384
freq_bottom_mhz = 400.1953125
freq_top_mhz = 800.1953125

ds = dfdt.DynamicSpectrum(dt_s, df_mhz, nchan, freq_bottom_mhz, freq_top_mhz)
DS = dfdt.DynamicSpectrum(burst.dt, np.diff(burst.plot_freq)[0], burst.wfall.shape[1], burst.plot_freq.min(), burst.plot_freq.max())
dfdt_data = dfdt.ac_mc_drift(
    wfall, dm_uncertainty, burst.eventname, burst.eventname, DS,
    dm_trials=100, mc_trials=100
)
FRB20181224E -- FRB20181224E -- Analyzing..
ValueError: All-NaN slice encountered