#!/usr/bin/env python

import os
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import fitsio
import kcorrect.kcorrect
import kcorrect.response
import kcorrect.utils

# This code serves as an example of how the Kcorrect class can be
# subclassed and used for specific purposes.

# (1) We instantiate responseDict. When the KcorrectGST object is
# initialized, the necessary responses are loaded into this
# dictionary, which is a singleton. This will give us access to the
# responses functions and their properties.
print("Instantiate responseDict", flush=True)
responseDict = kcorrect.response.ResponseDict()

# (2) We create an instance of KcorrectGST, which is a subclass
# of Kcorrect. If you consult this code, you will see that all that
# it does is set the default responses; but you can set this up to
# have whatever default settings you want, or overload the methods
# if you want to do something more complicated.
print("Initialize KcorrectGST object", flush=True)

k = kcorrect.kcorrect.KcorrectGST()

# Alternatively, just instantiate Kcorrect with the responses as an
# input directly, e.g.:
#
# k = kcorrect.kcorrect.Kcorrect(responses=['galex_FUV', 'galex_NUV',
#     'sdss_u0', 'sdss_g0', 'sdss_r0', 'sdss_i0', 'sdss_z0', 'twomass_J',
#     'twomass_H', 'twomass_Ks'])

# (3) We read in the test file, which has some particular combination
# of GALEX, SDSS, and 2MASS photometry.
print("Reading in file", flush=True)

gstfile = os.path.join(kcorrect.KCORRECT_DIR, 'data', 'test',
                       'gst_tests_small.fits')
gst = fitsio.read(gstfile)

# (4) We calculate the maggies in all based, on an AB system. This
# must be Galactic extinction corrected.
print("Make maggies", flush=True)

maggies = np.zeros((len(gst), 10), dtype=np.float32)
ivar = np.zeros((len(gst), 10), dtype=np.float32)

# (4)(a) Make GALEX maggies. This code follows some instructions from
# the GALEX documentation for these particular quantities. Take care
# to only set good values, and bad values will be left with ivar=0 and
# maggies finite.

fuv_zp = 18.82
fuv_scale = 10.**(- 0.4 * fuv_zp)
nuv_zp = 20.08
nuv_scale = 10.**(- 0.4 * nuv_zp)

ok_fuv = ((gst['FUV_FLUX_AUTO'] != -99.) & (gst['FUV_FLUX_AUTO'] != 0.)
          & (gst['FUV_FLUXERR_AUTO'] != -99.) & (gst['FUV_FLUXERR_AUTO'] != 0.))
maggies[ok_fuv, 0] = gst['FUV_FLUX_AUTO'][ok_fuv] * fuv_scale
ivar[ok_fuv, 0] = 1. / (gst['FUV_FLUXERR_AUTO'][ok_fuv] * fuv_scale)**2
ok_nuv = ((gst['NUV_FLUX_AUTO'] != -99.) & (gst['NUV_FLUX_AUTO'] != 0.)
          & (gst['NUV_FLUXERR_AUTO'] != -99.) & (gst['NUV_FLUXERR_AUTO'] != 0.))
maggies[ok_nuv, 1] = gst['NUV_FLUX_AUTO'][ok_nuv] * nuv_scale
ivar[ok_nuv, 1] = 1. / (gst['NUV_FLUXERR_AUTO'][ok_nuv] * nuv_scale)**2

# (4)(b) Calculate SDSS fluxes in maggies, using a kcorrect utility to
# perform the slight corrections to AB. The quantities in the catalog
# are nanomaggies, which is why we need to scale the results.
sdss_maggies, sdss_ivar = kcorrect.utils.sdss_ab_correct(maggies=gst['MODELFLUX'],
                                                         ivar=gst['MODELFLUX_IVAR'])
maggies[:, 2:7] = sdss_maggies * 1.e-9
ivar[:, 2:7] = sdss_ivar * 1.e+18

# (4)(c) Calculate 2MASS fluxes in AB maggies. Here we use the
# Vega->AB conversion that the responseDict calculates for these
# filters. We make sure ivar=0 for any bad cases.
ok_j = ((gst['J_M_EXT'] != 0.) & (gst['J_MSIG_EXT'] > 0.))
maggies[ok_j, 7] = 10.**(-0.4 * (gst['J_M_EXT'][ok_j] +
                                 responseDict['twomass_J'].vega2ab))
ivar[ok_j, 7] = 1. / (maggies[ok_j, 7] * np.log(10.) / 2.5 *
                      gst['J_MSIG_EXT'][ok_j])**2
ok_h = ((gst['H_M_EXT'] != 0.) & (gst['H_MSIG_EXT'] > 0.))
maggies[ok_h, 8] = 10.**(-0.4 * (gst['H_M_EXT'][ok_h] +
                                 responseDict['twomass_H'].vega2ab))
ivar[ok_h, 8] = 1. / (maggies[ok_h, 8] * np.log(10.) / 2.5 *
                      gst['H_MSIG_EXT'][ok_h])**2
ok_k = ((gst['K_M_EXT'] != 0.) & (gst['K_MSIG_EXT'] > 0.))
maggies[ok_k, 9] = 10.**(-0.4 * (gst['K_M_EXT'][ok_k] +
                                 responseDict['twomass_Ks'].vega2ab))
ivar[ok_k, 9] = 1. / (maggies[ok_k, 9] * np.log(10.) / 2.5 *
                      gst['K_MSIG_EXT'][ok_k])**2

# (4)(d) Account for Galactic extinction. Here we just do something
# very simple using some (old-ish) extinction coefficients, and
# using the SFD extinction for the SDSS bands from the SDSS catalog.
red_fac = np.array([8.29, 8.18, 5.155, 3.793, 2.751,
                    2.086, 1.479, 0.902, 0.576, 0.367], dtype=np.float32)
red_fac = red_fac / red_fac[4]
for i in np.arange(10, dtype=int):
    maggies[:, i] = maggies[:, i] * 10.**(0.4 * gst['EXTINCTION'][:, 2] *
                                          red_fac[i])
    ivar[:, i] = ivar[:, i] * 10.**(- 0.8 * gst['EXTINCTION'][:, 2] *
                                    red_fac[i])

# (4)(e) Use another kcorrect utility to impose a floor on the
# errors in each band because of our calibration uncertainties,
# and our general unwillingless to take seriously photometry
# errors <~ 1% (for the purposes here!).
floor = np.array([0.03, 0.03, 0.02, 0.02, 0.02, 0.02, 0.02,
                  0.04, 0.04, 0.04], dtype=np.float32)
ivar = kcorrect.utils.error_floor(floor=floor, maggies=maggies, ivar=ivar)

# (4)(f) Fix catastrophically bad cases after the fact; it is better
# to catch these earlier. We should have guaranteed that nbad=0 here.
ibad = np.where((np.isfinite(maggies) == False) |
                (np.isfinite(ivar) == False) |
                (ivar < 0))
print("nbad = {n} (should be 0)".format(n=len(ibad[0])), flush=True)
maggies[ibad] = 0.
ivar[ibad] = 0.

# (5) Fit the coefficients of the kcorrect templates. You will find
# a couple of errors reported here "NNLS quitting on iteration count";
# these are explained in the fit_coeffs() docstring, and should be
# harmless.
print("Calculate coefficients", flush=True)
coeffs = k.fit_coeffs(maggies=maggies, ivar=ivar, redshift=gst['Z'])

# (6) Based on coefficients, calculate the K-corrections
print("Calculate K-corrections", flush=True)
kc = k.kcorrect(redshift=gst['Z'], coeffs=coeffs)

# (7) Return the derived parameters and absolute magnitude for each
# galaxy based on its coefficients. Note that the derived quantities
# are just properties of the fit coefficients, but the absolute magnitude
# is calculated using the K-corrections from the apparent magnitudes.
print("Calculate derived parameters", flush=True)
derived = k.derived(redshift=gst['Z'], coeffs=coeffs)
absmag = k.absmag(redshift=gst['Z'], coeffs=coeffs, maggies=maggies, ivar=ivar)

# (8) Synthesize B and V magnitudes using a Kcorrect object. Note that you can
# also do this directly using the Response objects and their "project" methods,
# if you calculate the spectrum from the template coefficients.
print("Calculate B and V related quantities", flush=True)
kbv = kcorrect.kcorrect.Kcorrect(responses_out=['bessell_B', 'bessell_V'],
                                 responses=['sdss_g0', 'sdss_r0'])
derived_bv = kbv.derived(redshift=gst['Z'], coeffs=coeffs)
absmag_bv = kbv.absmag(redshift=gst['Z'], coeffs=coeffs,
                       maggies=maggies[:, 3:5], ivar=ivar[:, 3:5])

# Finally a bunch of plots

print("Make plot", flush=True)
pngfile = 'kcorrect-default-v4.gst.png'

matplotlib.rcParams['figure.figsize'] = [5., 7.]

fig, ax = plt.subplots(5, 2, sharex=True,
                       gridspec_kw={'hspace': 0,
                                    'wspace': 0})

bandranges = {'F': [-0.5, 1.31],
              'N': [-0.5, 1.31],
              'u': [-0.2, 1.51],
              'g': [-0.1, 1.05],
              'r': [-0.12, 0.46],
              'i': [-0.29, 0.32],
              'z': [-0.18, 0.29],
              'J': [-0.18, 0.19],
              'H': [-0.18, 0.19],
              'K': [-0.58, 0.09]}

for i, band in enumerate(bandranges):
    row = i // 2
    col = i % 2
    ax[row, col].scatter(gst['Z'], kc[:, i], s=1, alpha=0.1)
    ax[row, col].set_ylabel('$K_{' + band + '}$')
    ax[row, col].set_ylim(bandranges[band])
    ax[row, col].set_xlim([- 0.01, 0.26])
    if(col == 1):
        ax[row, col].yaxis.set_label_position('right')
        ax[row, col].yaxis.tick_right()
ax[4, 0].set_xlabel('redshift')
ax[4, 1].set_xlabel('redshift')

plt.tight_layout()
plt.savefig(pngfile, dpi=150)

print("Reconstruct maggies", flush=True)
rm = k.reconstruct(redshift=gst['Z'], coeffs=coeffs)

print("Make mag residual plot", flush=True)
pngfile = 'kcorrect-default-v4.gst.deltam.png'

fig, ax = plt.subplots(5, 2, sharex=True,
                       gridspec_kw={'hspace': 0,
                                    'wspace': 0})

bandranges = {'F': [-1.3, 1.3],
              'N': [-1.1, 1.1],
              'u': [-0.4, 0.4],
              'g': [-0.2, 0.2],
              'r': [-0.15, 0.15],
              'i': [-0.15, 0.15],
              'z': [-0.15, 0.15],
              'J': [-0.5, 0.5],
              'H': [-0.5, 0.5],
              'K': [-0.5, 0.5]}

for i, band in enumerate(bandranges):
    ok = (maggies[:, i] > 0.) & (rm[:, i] > 0.)
    diff = - 2.5 * np.log10(maggies[ok, i] / rm[ok, i])
    row = i // 2
    col = i % 2
    ax[row, col].scatter(gst['Z'][ok], diff, s=1, alpha=0.1)
    ax[row, col].set_ylabel('$\\Delta m_{' + band + '}$')
    ax[row, col].set_ylim(bandranges[band])
    ax[row, col].set_xlim([- 0.01, 0.26])
    if(col == 1):
        ax[row, col].yaxis.set_label_position('right')
        ax[row, col].yaxis.tick_right()
ax[4, 0].set_xlabel('redshift')
ax[4, 1].set_xlabel('redshift')

plt.tight_layout()
plt.savefig(pngfile, dpi=150)

print("Make color residual plot", flush=True)
pngfile = 'kcorrect-default-v4.gst.deltac.png'

fig, ax = plt.subplots(5, 2, sharex=True,
                       gridspec_kw={'hspace': 0,
                                    'wspace': 0.})

bandranges = {'F-N': [-1.3, 1.3],
              'N-u': [-1.1, 1.1],
              'u-g': [-0.7, 0.7],
              'g-r': [-0.16, 0.16],
              'r-i': [-0.15, 0.15],
              'i-z': [-0.22, 0.22],
              'z-J': [-0.65, 0.65],
              'J-H': [-0.9, 0.9],
              'H-K': [-1.1, 1.1]}

for i, band in enumerate(bandranges):
    ip = i + 1
    row = ip // 2
    col = ip % 2
    ok = ((maggies[:, i] > 0.) & (maggies[:, i + 1] > 0.) &
          (rm[:, i] > 0.) & (rm[:, i + 1] > 0.))
    oc = - 2.5 * np.log10(maggies[ok, i] / maggies[ok, i + 1])
    rc = - 2.5 * np.log10(rm[ok, i] / rm[ok, i + 1])
    diff = oc - rc
    ax[row, col].scatter(gst['Z'][ok], diff, s=1, alpha=0.1)
    ax[row, col].set_ylabel('$\\Delta (' + band + ')$')
    ax[row, col].set_ylim(bandranges[band])
    ax[row, col].set_xlim([- 0.01, 0.26])
    if(col == 1):
        ax[row, col].yaxis.set_label_position('right')
        ax[row, col].yaxis.tick_right()
ax[4, 0].set_xlabel('redshift')
ax[4, 1].set_xlabel('redshift')
fig.delaxes(ax[0, 0])

plt.tight_layout()
plt.savefig(pngfile, dpi=150)

print("Make M/L versus B-V plot", flush=True)

pngfile = 'kcorrect-default-v4.gst.mtolv.png'

matplotlib.rcParams['figure.figsize'] = [5. , 5.]
fig, ax = plt.subplots()

# Note that to compare to Bell & De Jong (2001) you need to be in Vega!
ok = derived_bv['mtol'][:, 1] > 0.
print(derived_bv['mtol'][:, 1])
ab2vega_factor = 10.**(0.4 * responseDict['bessell_V'].vega2ab)
log_mtolv = np.log10(derived_bv['mtol'][ok, 1] / ab2vega_factor)
bmv = absmag_bv[ok, 0] - absmag_bv[ok, 1]
bmv = bmv - responseDict['bessell_B'].vega2ab + responseDict['bessell_V'].vega2ab

ax.scatter(bmv, log_mtolv, s=1, alpha=0.2)

# Comparison from Bell & De Jong (2001) Table 1
bmv_l = -2. + 4. * np.arange(2, dtype=np.float32)
mtolv_l = -0.734 + 1.404 * bmv_l
ax.plot(bmv_l, mtolv_l, linewidth=3, label='Bell & de Jong (2001)')

ax.set_xlim([0.1, 1.1])
ax.set_ylim([-0.4, 0.5])
ax.set_xlabel('$B-V$')
ax.set_ylabel('$\\log_{10} M / L_V$ (solar units)')

plt.tight_layout()
plt.savefig(pngfile, dpi=150)
