#!/usr/bin/env python

import os
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import fitsio
import kcorrect

print("Initialize KcorrectSDSS object", flush=True)
k = kcorrect.kcorrect.KcorrectSDSS()

print("Reading in file", flush=True)
gstfile = os.path.join(kcorrect.KCORRECT_DIR,
                       'data', 'test', 'gst_tests_small.fits')
gst = fitsio.read(gstfile)

print("Fit coefficients", flush=True)
maggies = gst['PETROFLUX'] * 10.**(0.4 * gst['EXTINCTION'])
ivar = gst['PETROFLUX_IVAR'] * 10.**(- 0.8 * gst['EXTINCTION'])
coeffs = k.fit_coeffs(maggies=maggies, ivar=ivar,
                      redshift=gst['Z'])

print("Calculate K-corrections", flush=True)
kc = k.kcorrect(redshift=gst['Z'], coeffs=coeffs)

print("Make plot", flush=True)
pngfile = 'kcorrect-default-v4.sdss.png'

matplotlib.rcParams['figure.figsize'] = [5., 7.]

fig, ax = plt.subplots(5, 1, sharex=True,
                       gridspec_kw={'hspace': 0})

bandranges = {'u': [-0.2, 1.01],
              'g': [-0.1, 0.85],
              'r': [-0.12, 0.36],
              'i': [-0.29, 0.29],
              'z': [-0.18, 0.29]}

for i, band in enumerate(bandranges):
    ax[i].scatter(gst['Z'], kc[:, i], s=1, alpha=0.5)
    ax[i].set_ylabel('$K_{' + band + '}$')
    ax[i].set_ylim(bandranges[band])
    ax[i].set_xlim([- 0.01, 0.21])
ax[4].set_xlabel('redshift')

plt.savefig(pngfile, dpi=150)

print("Calculate K-corrections to 0.1", flush=True)
kc = k.kcorrect(redshift=gst['Z'], coeffs=coeffs, band_shift=0.1)

print("Make plot", flush=True)
pngfile = 'kcorrect-default-v4.sdss.0.1.png'

matplotlib.rcParams['figure.figsize'] = [5., 7.]

fig, ax = plt.subplots(5, 1, sharex=True,
                       gridspec_kw={'hspace': 0})

bandranges = {'^{0.1}u': [-0.44, 0.46],
              '^{0.1}g': [-0.42, 0.42],
              '^{0.1}r': [-0.33, 0.12],
              '^{0.1}i': [-0.24, 0.29],
              '^{0.1}z': [-0.28, 0.06]}

for i, band in enumerate(bandranges):
    ax[i].scatter(gst['Z'], kc[:, i], s=1, alpha=0.5)
    ax[i].set_ylabel('$K_{' + band + '}$')
    ax[i].set_ylim(bandranges[band])
    ax[i].set_xlim([- 0.01, 0.21])
ax[4].set_xlabel('redshift')

plt.savefig(pngfile, dpi=150)

print("Reconstruct maggies", flush=True)
rm = k.reconstruct(redshift=gst['Z'], coeffs=coeffs)

print("Make mag residual plot", flush=True)
pngfile = 'kcorrect-default-v4.sdss.deltam.png'

fig, ax = plt.subplots(5, 1, sharex=True,
                       gridspec_kw={'hspace': 0})

bandranges = {'u': [-0.1, 0.1],
              'g': [-0.1, 0.1],
              'r': [-0.1, 0.1],
              'i': [-0.1, 0.1],
              'z': [-0.1, 0.1]}

om = kcorrect.utils.sdss_ab_correct(maggies=maggies)

for i, band in enumerate(bandranges):
    diff = - 2.5 * np.log10(om[:, i] / rm[:, i])
    ax[i].scatter(gst['Z'], diff, s=1, alpha=0.5)
    ax[i].set_ylabel('$\\Delta m_{' + band + '}$')
    ax[i].set_ylim(bandranges[band])
    ax[i].set_xlim([- 0.01, 0.21])
ax[4].set_xlabel('redshift')

plt.savefig(pngfile, dpi=150)

print("Make color residual plot", flush=True)
pngfile = 'kcorrect-default-v4.sdss.deltac.png'

fig, ax = plt.subplots(4, 1, sharex=True,
                       gridspec_kw={'hspace': 0})

bandranges = {'u-g': [-0.1, 0.1],
              'g-r': [-0.1, 0.1],
              'r-i': [-0.1, 0.1],
              'i-z': [-0.1, 0.1]}

om = kcorrect.utils.sdss_ab_correct(maggies=maggies)

for i, band in enumerate(bandranges):
    oc = - 2.5 * np.log10(om[:, i] / om[:, i + 1])
    rc = - 2.5 * np.log10(rm[:, i] / rm[:, i + 1])
    diff = oc - rc
    ax[i].scatter(gst['Z'], diff, s=1, alpha=0.5)
    ax[i].set_ylabel('$\\Delta (' + band + ')$')
    ax[i].set_ylim(bandranges[band])
    ax[i].set_xlim([- 0.01, 0.21])
ax[3].set_xlabel('redshift')

plt.savefig(pngfile, dpi=150)
