#!/usr/bin/env python

import os
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import fitsio
import kcorrect.kcorrect
import kcorrect.response
import kcorrect.utils

# This code tests against the M/L figure from Blanton & Roweis

# (1) We instantiate responseDict. When the KcorrectGST object is
# initialized, the necessary responses are loaded into this
# dictionary, which is a singleton. This will give us access to the
# responses functions and their properties.
print("Instantiate responseDict", flush=True)
responseDict = kcorrect.response.ResponseDict()

# (2) We create an instance of KcorrectSDSS, which is a subclass
# of Kcorrect.  We set abcorrect to False, because we will do that
# ourselves below
print("Initialize KcorrectSDSS object", flush=True)
k = kcorrect.kcorrect.KcorrectSDSS(abcorrect=False)

# (3) We read in the test files from the LSS samples in DR4. Restrict
# to z<0.2 because that is what the 2007 paper did.
print("Reading in file", flush=True)

postfile = os.path.join(kcorrect.KCORRECT_DIR, 'data', 'test',
                        'post_catalog.dr4bsafe1.fits')
post = fitsio.read(postfile)
imfile = os.path.join(kcorrect.KCORRECT_DIR, 'data', 'test',
                      'object_sdss_imaging.fits')
im = fitsio.read(imfile)

ikeep = np.where(post['Z'] < 0.2)[0]
#np.random.shuffle(ikeep)
ikeep = ikeep[0:10000]
post = post[ikeep]
im = im[post['OBJECT_POSITION']]

# (4)(a) Calculate SDSS fluxes in maggies, using a kcorrect utility to
# perform the slight corrections to AB. The quantities in the catalog
# are nanomaggies, which is why we need to scale the results.
maggies, ivar = kcorrect.utils.sdss_ab_correct(maggies=im['PETROFLUX'],
                                               ivar=im['PETROFLUX_IVAR'])
maggies = maggies * 1.e-9
ivar = ivar * 1.e+18

# (4)(b) Account for Galactic extinction. Here we just do something
# very simple using some (old-ish) extinction coefficients, and using
# the SFD extinction for the SDSS bands from the SDSS catalog.
maggies = maggies * 10.**(0.4 * im['EXTINCTION'])
ivar = ivar * 10.**(- 0.8 * im['EXTINCTION'])

# (4)(c) Use another kcorrect utility to impose a floor on the
# errors in each band because of our calibration uncertainties,
# and our general unwillingless to take seriously photometry
# errors <~ 1% (for the purposes here!).
floor = np.array([0.05, 0.02, 0.02, 0.02, 0.03], dtype=np.float32)
ivar = kcorrect.utils.error_floor(floor=floor, maggies=maggies, ivar=ivar)

# (4)(f) Fix catastrophically bad cases after the fact; it is better
# to catch these earlier. We should have guaranteed that nbad=0 here.
ibad = np.where((np.isfinite(maggies) == False) |
                (np.isfinite(ivar) == False) |
                (ivar < 0))
print("nbad = {n} (should be 0)".format(n=len(ibad[0])), flush=True)
maggies[ibad] = 0.
ivar[ibad] = 0.

# (5) Fit the coefficients of the kcorrect templates. You will find
# a couple of errors reported here "NNLS quitting on iteration count";
# these are explained in the fit_coeffs() docstring, and should be
# harmless.
print("Calculate coefficients", flush=True)
coeffs = k.fit_coeffs(maggies=maggies, ivar=ivar, redshift=post['Z'])

# (8) Synthesize B and V magnitudes using a Kcorrect object. Note that you can
# also do this directly using the Response objects and their "project" methods,
# if you calculate the spectrum from the template coefficients.
print("Calculate B and V related quantities", flush=True)
kbv = kcorrect.kcorrect.Kcorrect(responses_out=['bessell_B', 'bessell_V'],
                                 responses=['sdss_g0', 'sdss_r0'])
derived_bv = kbv.derived(redshift=post['Z'], coeffs=coeffs)
absmag_bv = kbv.absmag(redshift=post['Z'], coeffs=coeffs,
                       maggies=maggies[:, 1:3], ivar=ivar[:, 1:3])

kc = kbv.kcorrect(redshift=post['Z'], coeffs=coeffs)
print(post['Z'][0:10])
print(im['PETROFLUX'][0:10])
print(im['PETROFLUX_IVAR'][0:10])
print(derived_bv['mtol'][0:10, 1])
print(coeffs[0:10, :])
print(kc[0:10, :])

# Finally a bunch of plots
print("Make M/L versus B-V plot", flush=True)

pngfile = 'kcorrect-default-v4.lss.mtolv.png'

matplotlib.rcParams['figure.figsize'] = [5., 5.]
fig, ax = plt.subplots()

# Note that to compare to Bell & De Jong (2001) you need to be in Vega!
ok = derived_bv['mtol'][:, 1] > 0.
ab2vega_factor = 10.**(0.4 * responseDict['bessell_V'].vega2ab)
log_mtolv = np.log10(derived_bv['mtol'][ok, 1] / ab2vega_factor)
bmv = absmag_bv[ok, 0] - absmag_bv[ok, 1]
bmv = bmv - responseDict['bessell_B'].vega2ab + responseDict['bessell_V'].vega2ab

ax.scatter(bmv, log_mtolv, s=1, alpha=0.2)

# Comparison from Bell & De Jong (2001) Table 1
bmv_l = -2. + 4. * np.arange(2, dtype=np.float32)
mtolv_l = -0.734 + 1.404 * bmv_l
ax.plot(bmv_l, mtolv_l, linewidth=3, label='Bell & de Jong (2001)',
        color='black')

ax.set_xlim([0.1, 1.1])
ax.set_ylim([-0.4, 0.5])
ax.set_xlabel('$B-V$')
ax.set_ylabel('$\\log_{10} M / L_V$ (solar units)')

plt.tight_layout()
plt.savefig(pngfile, dpi=150)
plt.clf()

pngfile ='kcorrect-default-v4.lss.masses.png'

fig, ax = plt.subplots()

# Now just plot masses vs absolute mag and color
log_stellar_mass = np.log10(derived_bv['mremain'][ok])
ax.scatter(absmag_bv[ok, 1], log_stellar_mass, s=1, c=bmv)
ax.set_xlabel('$M_V$')
ax.set_ylabel('$\\log_{10} M_\\ast$ (solar masses)')
#fig.colorbar(ax=ax)

plt.tight_layout()
plt.savefig(pngfile, dpi=150)
plt.clf()
