from __future__ import print_function, division

import argparse
import os.path

from astropy.io import fits
import matplotlib
matplotlib.rcParams["font.family"] = "serif"
matplotlib.rcParams["font.sans-serif"] = "Times New Roman"
# matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import Arc
from numba import njit
import numpy as np
from scipy.interpolate import interp1d, RegularGridInterpolator
from scipy.optimize import curve_fit

from astrobits.correlation import radialautocorrelation


@njit()
def radialaverage(data, rs, bins):
    data = data.flatten()
    rs = rs.flatten()

    idxs = np.digitize(rs, bins)

    mus = np.zeros(len(bins) - 1)
    mus2 = np.zeros(len(bins) - 1)
    Ns = np.zeros(len(bins) - 1)
    for i, d in zip(idxs, data):
        if not np.isfinite(d) or i == 0 or i == len(bins):
            continue

        mus[i - 1] += d
        mus2[i - 1] += d**2
        Ns[i - 1] += 1

    mus /= Ns
    mus2 /= Ns

    return mus, np.sqrt(mus2 - mus**2), Ns


parser = argparse.ArgumentParser()
parser.add_argument("image")
parser.add_argument("--unit", choices=["Jy/beam", "K", "counts"], required=True)
parser.add_argument("--noupscale", action="store_true")
args = parser.parse_args()

prefix = os.path.splitext(os.path.basename(args.image))[0]
print("Prefix:", prefix)

hdu = fits.open(args.image)[0]
window = np.empty(hdu.data.shape)
window[:] = hdu.data[:]
window = window.T  # Convert to [x, y] ordering, and native numpy array

# Unit
if args.unit == "Jy/beam":
    window *= 1E6  # -> uJy / beam
    ppunit = "$\mu$Jy beam$^{-1}$"
elif args.unit == "K":
    window *= 1E3  # -> mK
    ppunit = "mK"
elif args.unit == "counts":
    window  *= 1
    ppunit = "counts ks$^{-1}$ arcmin$^{-2}$"

if not args.noupscale:
    print("Upscaling...")
    ys, xs = np.meshgrid(np.arange(0, window.shape[0] - 1 + 0.1, 0.5), np.arange(0, window.shape[1] - 1 + 0.1, 0.5))
    xi = np.array([xs.flatten(), ys.flatten()]).T
    window = np.reshape(RegularGridInterpolator((range(window.shape[0]), range(window.shape[1])), window)(xi), xs.shape)
    hdu.header["PEAK1"] *= 2
    hdu.header["PEAK2"] *= 2
    print("Done")

# window = window[:, ::-1]
xcenter, ycenter = window.shape[1] // 2 + 1, window.shape[0] // 2
scalepx = (hdu.header["PEAK2"] - hdu.header["PEAK1"]) / 2

# NaN high
ys, xs = np.meshgrid(range(window.shape[0]), range(window.shape[1]))
rs = np.sqrt((xs - xcenter)**2 + (ys - ycenter)**2)
# idx = rs > 4 * scalepx
# window[idx] = np.nan

# Subtract out background
idx = rs > 2 * scalepx
bkg = np.nanmean(window[idx])
window -= bkg

# ypeak = int(hdu.header["PEAK1"])
# oneD = window[xcenter, ypeak - 200:ypeak + 201]
# plt.plot(oneD)
# plt.plot(oneD[::-1])
# plt.show()

# plt.plot(oneD - oneD[::-1])
# plt.show()
# exit()


# Model peaks
bins = np.arange(-0.5, xcenter, 1)
mids = 0.5 * (bins[:-1] + bins[1:])
modelsum = np.zeros_like(window)
peakregions = np.zeros(window.shape, dtype=bool)

plt.figure("1D model")
plt.axhline(0)
for npeak, ypeak in enumerate([int(hdu.header["PEAK1"]), int(hdu.header["PEAK2"])]):
    print("Modelling peak:", xcenter, ypeak, "Value:", window[xcenter, ypeak] + bkg, "Value (- bkg):", window[xcenter, ypeak])

    print("Actual peak value:", window[xcenter - 100:xcenter + 100, ypeak - 100:ypeak + 100].max() + bkg)

    rs = np.sqrt((xs - xcenter)**2 + (ys - ypeak)**2)
    thetas = np.arctan2(ys - ypeak, xs - xcenter)

    peakregions[rs < 0.2 * scalepx] = True

    if npeak == 0:
        idx = np.any([thetas <= 0, thetas == np.pi], axis=0)
    elif npeak == 1:
        idx = np.any([thetas >= 0, thetas == -np.pi], axis=0)

    exterior = window[idx]
    exteriorrs = rs[idx]

    # window[idx] = 0
    # plt.imshow(window, origin="bottom")
    # plt.show()

    mus, _, Ns = radialaverage(exterior, exteriorrs, bins)
    print(mus[:5])

    # Blank out mus where Ns suddenly begins to decrease
    idx = np.argmax(Ns)
    mus[idx:] = np.nan

    plt.plot(mids, mus)
    # plt.show()

    # Interpolating model...
    interpmodel = interp1d(mids[np.isfinite(mus)], mus[np.isfinite(mus)], bounds_error=False, fill_value=np.nan)
    print("Interp model peak: ", interpmodel(0))

    model = interpmodel(rs)
    model[np.isnan(model)] = 0

    print("Peak ", npeak + 1, " FWHM: ", 2  * np.linspace(0, 1, 10000)[
        np.argmin(
            np.abs(
                interpmodel(np.linspace(0, scalepx, 10000)) - 0.5 * interpmodel(0)
            )
        )
    ])

    modelsum += model

    # # Output for stackpsfs.py
    # if npeak == 0:
    #     xs = np.linspace(0, 2 * scalepx, 570)
    #     model1d = interp1d(mids[np.isfinite(mus)], mus[np.isfinite(mus)], bounds_error=False, fill_value=np.nan)(xs)
    #     np.save("output/model1d-maxdist15Mpc.npy", model1d)
    #     exit()


delta = window - modelsum
deltablanked = delta.copy()
deltablanked[peakregions] = np.nan
#deltablankedzeros = delta.copy()
deltablanked[xcenter - int(1 * scalepx):xcenter + int(1 * scalepx), ycenter - int(1 * scalepx):ycenter + int(1 * scalepx)] = np.nan
deltablanked = deltablanked[xcenter - int(2.5 * scalepx):xcenter + int(2.5 * scalepx), ycenter - int(2.5 * scalepx):ycenter + int(2.5 * scalepx)]
# plt.imshow(deltablanked)
# plt.show()
stderr = np.nanstd(deltablanked)
print("Estimated standard error:", stderr)
print("Peak 1D value:", delta[ycenter, :].max(), "at x =", (delta[ycenter, :].argmax() - xcenter) / scalepx, " or normalised to the noise:", delta[ycenter, :].max() / stderr)

with open("output/" + prefix + "-noise.txt", "w") as f:
    print(stderr, file=f)

np.save("output/" + prefix + "-1Dwindow.npy", window[xcenter, :])
np.save("output/" + prefix + "-1Dmodelsum.npy", modelsum[xcenter, :])

# Plot noise characteristics
plt.figure(figsize=(5, 5))
plt.subplot(2, 1, 1)
bins = np.linspace(-4 * stderr, 4 * stderr, 100)
print("Max and min values in residual:", np.nanmin(delta), np.nanmax(delta))
deltainner = delta[int(ycenter - 1.5 * scalepx):int(ycenter + 1.5 * scalepx), int(xcenter - 1.5 * scalepx):int(xcenter + 1.5 * scalepx)]
rmaxidx = np.array(np.unravel_index(deltainner.argmax(), deltainner.shape)) + [int(ycenter - 1.5 * scalepx), int(xcenter - 1.5 * scalepx)]
print("Max value of inner residuals is ", delta[rmaxidx[0], rmaxidx[1]], " (", delta[rmaxidx[0], rmaxidx[1]] / stderr, " sigma) at ", (rmaxidx - [ycenter, xcenter]) / scalepx)
mids = (bins[:-1] + bins[1:]) / 2

ns, _, _ = plt.hist(deltablanked[np.isfinite(deltablanked)], bins=bins, density=True, color="dodgerblue", alpha=1, log=True)

def gaussian(xs, A, x0, sigma):
    return A * np.exp(-(xs - x0)**2 / (2 * sigma**2))

popt, pcov = curve_fit(gaussian, mids, ns, p0=(max(ns), 0, stderr))
print(popt)
plt.plot(mids, gaussian(mids, *popt), color="black", linestyle="dashed")
plt.xlim([min(mids), max(mids)])
plt.xlabel("Pixel value [%s]" % ppunit)
plt.ylabel("Normalised counts")
plt.ylim(ymin=1E-5)

# Calculate autocorrelation
deltablanked -= np.nanmean(deltablanked)
sigma2 = np.nanstd(deltablanked)**2
Ns = np.isfinite(deltablanked).astype(float)
deltablanked[~np.isfinite(deltablanked)] = 0
bins = np.arange(-0.5, 400, 1)
mids = (bins[:-1] + bins[1:]) / 2
mus, _, _, _ = radialautocorrelation(deltablanked, bins, sigma2=sigma2, Ns=Ns)

hwhm = interp1d(mus, mids)(0.5)
fwhm = (hwhm * 2) / scalepx  # Convert to normalised units
fwhm = np.sqrt(fwhm**2 / 2)  # Gaussian convolved with a Gaussian results in exp(-r^ / (2 * (sigma^2 + sigma^2)))
print("FWHM:", fwhm)

plt.subplot(2, 1, 2)
plt.plot(mids, mus)
xlabels = np.arange(0, 1.01, 0.05)
xvals = xlabels * scalepx
idx = xvals < 400
xvals, xlabels = xvals[idx], xlabels[idx]
plt.xticks(xvals, ["%g" % x for x in xlabels])
plt.vlines(hwhm, 0, 0.5, colors="black", linestyles="dotted")
plt.hlines(0.5, 0, hwhm, colors="black", linestyles="dotted")
plt.xlim([0, 400])
plt.ylim([0, 1])
plt.xlabel("Radial offset [normalised distance]")
plt.ylabel("Autocorrelation")

plt.tight_layout()
plt.savefig("output/" + prefix + "-noise.pdf")
plt.savefig("output/" + prefix + "-noise.png")

plt.figure(figsize=(14, 3.5))
ax = plt.subplot(1, 3, 1)
plt.imshow(window, vmin=-5 * stderr, vmax=np.nanpercentile(window, 99.9), origin="lower", cmap="plasma")
plt.xticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.yticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.grid()

# Draw model arcs
for distance in range(int(scalepx / 2), ycenter * 2, int(scalepx / 2)):
    ax.add_artist(Arc(
        (hdu.header["PEAK2"], ycenter), distance, distance, 0, -90, 90, color="white", alpha=1, linestyle="dashed"
    ))
    ax.add_artist(Arc(
        (hdu.header["PEAK1"], ycenter), distance, distance, 0, 90, 270, color="white", alpha=1, linestyle="dashed"
    ))

plt.xlim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.ylim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.colorbar(label="[%s]" % ppunit)

plt.subplot(1, 3, 2)
plt.imshow(modelsum, vmin=-5 * stderr, vmax=np.nanpercentile(window, 99.9), origin="lower", cmap="plasma")
plt.xticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.yticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.grid()
plt.xlim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.ylim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.colorbar(label="[%s]" % ppunit)

plt.subplot(1, 3, 3)
plt.imshow(delta, vmin=-5 * stderr, vmax=5 * stderr, origin="lower", cmap="plasma")
plt.xticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.yticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.grid()
plt.xlim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.ylim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.colorbar(label="[%s]" % ppunit)

plt.savefig("output/" + prefix + "-examplemodel.pdf")
plt.savefig("output/" + prefix + "-examplemodel.png", transparent=True)

plt.figure(figsize=(9, 7))
plt.subplot(2, 2, 1)
plt.imshow(window + bkg, vmin=np.nanpercentile(window + bkg, 0.01), vmax=np.nanpercentile(window + bkg, 99.99), origin="lower", cmap="plasma")
plt.xticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.yticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
# plt.grid()
# plt.hlines([ycenter - y * scalepx for y in np.linspace(-2.25, 2.25, 9)], xcenter -0.95 * scalepx, xcenter + 0.95 * scalepx, color="green", linestyle="dashed")
# plt.vlines([xcenter - x * scalepx for x in np.linspace(-2.25, 2.25, 9)], ycenter - 0.2 * scalepx, ycenter + 0.2 * scalepx, color="orange", linestyle="dashed")
plt.fill_between([xcenter - 3 * scalepx, xcenter + 3 * scalepx], ycenter - 0.2 * scalepx, ycenter + 0.2 * scalepx, facecolor="none", edgecolor="orange", linewidth=1.25, linestyle="dashed")
plt.fill_betweenx([ycenter - 3 * scalepx, ycenter + 3 * scalepx], xcenter - 0.95 * scalepx, xcenter + 0.95 * scalepx, facecolor="none", edgecolor="limegreen", linewidth=1.25, linestyle="dashed")
plt.xlim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.ylim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.colorbar(label="[%s]" % ppunit)
plt.subplot(2, 2, 2)
plt.imshow(window + bkg, vmin=bkg - 5 * stderr, vmax=bkg + 5 * stderr, origin="lower", cmap="plasma")
plt.xticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.yticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
# plt.grid()
plt.xlim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.ylim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.colorbar(label="[%s]" % ppunit)
plt.subplot(2, 2, 3)
plt.imshow(modelsum + bkg, vmin=bkg - 5 * stderr, vmax=bkg + 5 * stderr, origin="lower", cmap="plasma")
plt.xticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.yticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
# plt.grid()
plt.xlim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.ylim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.colorbar(label="[%s]" % ppunit)
plt.subplot(2, 2, 4)
plt.imshow(delta + bkg, vmin=bkg - 5 * stderr, vmax=bkg + 5 * stderr, origin="lower", cmap="plasma")
plt.xticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.yticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
# plt.grid()
plt.xlim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.ylim([ycenter - 2.5 * scalepx, ycenter + 2.5 * scalepx])
plt.colorbar(label="[%s]" % ppunit)
plt.savefig("output/" + prefix + "-beforeafter.pdf")
plt.savefig("output/" + prefix + "-beforeafter.png", transparent=True)
# plt.show()

plt.figure(figsize=(9, 6))

plt.subplot(4, 1, 1)
plt.plot(window[xcenter, :] + bkg, color="dodgerblue", label="Stacked")
# plt.fill_between(range(len(window[xcenter, :])), window[xcenter, :] - 3 * stderr, window[xcenter, :] + 3 * stderr, color="dodgerblue", alpha=0.3)
plt.plot(modelsum[xcenter, :] + bkg, color="red", label="Model")
# plt.plot(np.load("output/150MHz-z0.00-999.00-convolvedMWA-noise0-webmodel-5mJycleaned-residual-nlimit20000-stacked-1Dwindow.npy"), color="dodgerblue", linestyle="dashed", label="Cosmic Web")
plt.xticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.xlim([ycenter - 3 * scalepx, ycenter + 3 * scalepx])
plt.grid()
plt.legend()
ymin, ymax = plt.ylim()
plt.hlines(ymin + 0.9 * (ymax - ymin), ycenter - 2.9 * scalepx, ycenter - (2.9 - fwhm) * scalepx, color="black")
plt.vlines(ycenter - 2.9 * scalepx, ymin + 0.87 * (ymax - ymin), ymin + 0.93 * (ymax - ymin), color="black")
plt.vlines(ycenter - (2.9  - fwhm) * scalepx, ymin + 0.87 * (ymax - ymin), ymin + 0.93 * (ymax - ymin), color="black")
plt.ylabel("Stacked\n[%s]" % ppunit)
plt.xlabel("x [normalised]")

plt.subplot(4, 1, 2)
plt.plot(delta[xcenter, :] / stderr, color="black")
plt.xticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.xlim([ycenter - 3 * scalepx, ycenter + 3 * scalepx])
plt.grid()
ymin, ymax = plt.ylim()
plt.hlines(ymin + 0.9 * (ymax - ymin), ycenter - 2.9 * scalepx, ycenter - (2.9 - fwhm) * scalepx, color="black")
plt.vlines(ycenter -2.9 * scalepx, ymin + 0.87 * (ymax - ymin), ymin + 0.93 * (ymax - ymin), color="black")
plt.vlines(ycenter - (2.9  - fwhm) * scalepx, ymin + 0.87 * (ymax - ymin), ymin + 0.93 * (ymax - ymin), color="black")
plt.ylabel("Residual $(\Delta / \sigma)$")
plt.xlabel("x [normalised]")

plt.subplot(4, 1, 3)
noisesamples = []
width = 0.2
for yoffset in range(int(-3 * scalepx), int(-1 * scalepx)):
    noisesamples.append(
        np.mean(delta[int(ycenter + yoffset - width * scalepx):int(ycenter + yoffset + width * scalepx), xcenter - int(2 * scalepx):xcenter + int(2 * scalepx)], axis=0)
    )
for yoffset in range(int(1 * scalepx), int(3 * scalepx)):
    noisesamples.append(
        np.mean(delta[int(ycenter + yoffset - width * scalepx):int(ycenter + yoffset + width * scalepx), xcenter - int(2 * scalepx):xcenter + int(2 * scalepx)], axis=0)
    )
widenoise = np.std(noisesamples)
wideprofile = np.mean(delta[int(ycenter - width * scalepx):int(ycenter + width * scalepx), :], axis=0)
print("Widenoise:", widenoise)
print("Wide profile peaks at:", wideprofile.max(), "at x =", (wideprofile.argmax() - xcenter) / scalepx, " or normalised to noise:", wideprofile.max() / widenoise)

tmaxidx = int(xcenter - 1 * scalepx) + wideprofile[int(xcenter - 1 * scalepx):int(xcenter + 1 * scalepx)].argmax()
print("Transverse profile peaks at:", wideprofile[tmaxidx], "at x =", (tmaxidx - ycenter) / scalepx, " or normalised to noise:", wideprofile[tmaxidx] / widenoise)

plt.plot(wideprofile / widenoise, color="orange")
plt.xticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.xlim([ycenter - 3 * scalepx, ycenter + 3 * scalepx])
plt.grid()
ymin, ymax = plt.ylim()
plt.hlines(ymin + 0.9 * (ymax - ymin), ycenter - 2.9 * scalepx, ycenter - (2.9 - fwhm) * scalepx, color="black")
plt.vlines(ycenter -2.9 * scalepx, ymin + 0.87 * (ymax - ymin), ymin + 0.93 * (ymax - ymin), color="black")
plt.vlines(ycenter - (2.9  - fwhm) * scalepx, ymin + 0.87 * (ymax - ymin), ymin + 0.93 * (ymax - ymin), color="black")
plt.ylabel("Transverse\nMean $(\Delta / \sigma)$")
plt.xlabel("x [normalised]")

plt.subplot(4, 1, 4)
longitudinalmean = np.mean(delta[:, int(xcenter - 0.95 * scalepx):int(xcenter + 0.95 * scalepx)], axis=1)
longitudinalnoise = np.std(
    np.concatenate([
        np.mean(delta[:, int(xcenter - 3 * scalepx):int(xcenter - 1.1 * scalepx)], axis=1),
        np.mean(delta[:, int(xcenter + 1.1 * scalepx):int(xcenter + 3 * scalepx)], axis=1),
    ])
)
print("Longitudinalnoise:", longitudinalnoise)
lmaxidx = int(ycenter - 3 * scalepx) + longitudinalmean[int(ycenter - 3 * scalepx):int(ycenter + 3 * scalepx)].argmax()
print("Longitudinal profile peaks at:", longitudinalmean[lmaxidx], "at y =", (lmaxidx - ycenter) / scalepx, " or normalised to noise:", longitudinalmean[lmaxidx] / longitudinalnoise)

plt.plot(longitudinalmean / longitudinalnoise, color="limegreen")
plt.xticks([ycenter + i * scalepx for i in range(-3, 4)], range(-3, 4))
plt.xlim([ycenter - 3 * scalepx, ycenter + 3 * scalepx])
plt.grid()
ymin, ymax = plt.ylim()
plt.hlines(ymin + 0.9 * (ymax - ymin), ycenter - 2.9 * scalepx, ycenter - (2.9 - fwhm) * scalepx, color="black")
plt.vlines(ycenter -2.9 * scalepx, ymin + 0.87 * (ymax - ymin), ymin + 0.93 * (ymax - ymin), color="black")
plt.vlines(ycenter - (2.9  - fwhm) * scalepx, ymin + 0.87 * (ymax - ymin), ymin + 0.93 * (ymax - ymin), color="black")
plt.ylabel("Longitudinal\nMean $(\Delta / \sigma)$")
plt.xlabel("y [normalised]")

plt.tight_layout()
plt.savefig("output/" + prefix + "-modelled.pdf")
plt.savefig("output/" + prefix + "-modelled.png", transparent=False)
plt.show()
