#!/usr/bin/env python

from __future__ import absolute_import, print_function, division
from six.moves import range

import os.path
import sys

import math
import numpy as np
import matplotlib.mlab as mlab
from pylab import figure, show, text, close

from MDPlus.analysis.pca import Pczfile

# some global variables that keep track of where we are:
i      = 0
j      = 1
option = 1
c      = 0
resol  = 1
subset = 0

# some key definitions:
up     = 'i'
down   = 'k'
left   = 'j'
right  = 'l'
plus   = '+'
minus  = '-'
next   = 'o'
prev   = 'u'

helptext=('PCZPLOT v0.2\n\n'
          'Press \'1\' to plot the projection of a PC vs. snapshot number.\n'
          '    - use \'i\' and \'k\' keys to choose PC to plot.\n\n'
          'Press \'2\' to plot the histogram of a PC\'s projection.\n'
          '    - use \'j\' and \'l\' keys to choose PC to plot.\n'
          '    - use \'+\' and \'-\' keys to change the resolution.\n\n'
          'Press \'3\' to plot the track of one PC against another.\n'
          '    - use the \'a\' key to animate the plot.\n\n'
          'Press \'4\' to plot 2D histograms of PCs.\n'
          '    - use \'j\' and \'l\' keys to choose PC to plot on the x-axis.\n'
          '    - use \'i\' and \'k\' keys to choose PC to plot on the y-axis.\n'
          '    - use the \'c\' key to swap between heat map and contours.\n'
          '    - use \'+\' and \'-\' keys to change the resolution.\n\n'
          'Use \'u\' and \'o\' keys to change subset on any screen.\n'
          'Press \'q\' to quit, and \'h\' for this help text again.')

def incr(oldvalue, increment, minvalue=None, maxvalue=None, skipvalue=None):
    newvalue = oldvalue + increment
    if newvalue == skipvalue:
        newvalue = newvalue + increment
    if newvalue > maxvalue:
        newvalue = maxvalue
    if newvalue < minvalue:
        newvalue = minvalue
    if newvalue == skipvalue:
        newvalue = newvalue - increment

    return newvalue

def press(event):
    global i,j,c,option,resol,subset
    if event.key=='1':
        option=1
    elif event.key=='2':
        option=2
        resol=1
    elif event.key=='3':
        option=3
        resol=1
    elif event.key=='4':
        option=4
    elif event.key=='h':
        option=0
    elif event.key=='q':
        close('all')

    if  option==0:
        # show help text
        ax.clear()
        text(0.5,0.5,helptext, horizontalalignment='center',
             verticalalignment='center', transform = ax.transAxes)
        fig.canvas.draw()

    elif option==1:
        # Make a PC vs. snapshot plot
        if event.key==down:
            i = incr(i,-1,0,pcz.nvecs-1)
        elif event.key==up:
            i = incr(i,1,0,pcz.nvecs-1)
        elif event.key==prev:
            subset = incr(subset,-1,0,nsub-1)
        elif event.key==next:
            subset = incr(subset,1,0,nsub-1)
        ax.clear()
        ax.set_aspect('auto')
        proj = pcz.projs[i]
        ax.set_ylim(proj.min(),proj.max())
        ax.set_title(os.path.basename(pcz.filename)+'\n'+'Projection '+str(i)
                     +', Subset '+subid[subset])
        ax.set_xlabel('Snapshot')
        ax.set_ylabel('Proj '+str(i))
        ax.plot(proj[i1[subset]:i2[subset]])
        fig.canvas.draw()

    elif option==4:
        # show a 2D PC vs. PC plot as heat map or contours
        if event.key==left:
            i = incr(i,-1,0,pcz.nvecs-1,j)
        elif event.key==right:
            i = incr(i,1,0,pcz.nvecs-1,j)
        elif event.key==down:
            j = incr(j,-1,0,pcz.nvecs-1,i)
        elif event.key==up:
            j = incr(j,1,0,pcz.nvecs-1,i)
        elif event.key=='c':
            c=c+1
            if c>1:
                c=0
        elif event.key==plus:
            resol = incr(resol,1,1,5)
        elif event.key==minus:
            resol = incr(resol,-1,1,5)

        elif event.key==prev:
            subset = incr(subset,-1,0,nsub-1)
        elif event.key==next:
            subset = incr(subset,1,0,nsub-1)

        ax.clear()
        proj1 = pcz.projs[i][i1[subset]:i2[subset]]
        proj2 = pcz.projs[j][i1[subset]:i2[subset]]
        hrange = [prange[j,:],prange[i,:]]
        nb=15+5*resol
        H, xedges, yedges = np.histogram2d(proj2, proj1, bins=(nb, nb), range=hrange)
        extent = [yedges[0], yedges[-1], xedges[0], xedges[-1]]
        ax.set_title(os.path.basename(pcz.filename)+'\n'+
                     'Proj '+str(i)+' vs. '+str(j)+' for subset '+subid[subset])
        ax.set_xlabel('Proj '+str(i))
        ax.set_ylabel('Proj '+str(j))
        if c==1:
            ax.imshow(H, extent=extent, interpolation='nearest',origin='lower')
        else:
            ax.contour(H, 10, extent=extent, origin='lower')
        ax.set_aspect('auto')
        fig.canvas.draw()

    elif option==2:
        # Make a histogram of a PC
        if event.key==left:
            i = incr(i,-1,0,pcz.nvecs-1)
        elif event.key==right:
            i = incr(i,1,0,pcz.nvecs-1)
        if event.key==minus:
            resol = incr(i,-1,1,5)
        elif event.key==plus:
            resol = incr(i,1,1,5)
        elif event.key==prev:
            subset = incr(subset,-1,0,nsub-1)
        elif event.key==next:
            subset = incr(subset,1,0,nsub-1)
        ax.clear()
        ax.set_aspect('auto')
        ax.set_xlim(prange[i,0],prange[i,1])
        proj = pcz.projs[i][i1[subset]:i2[subset]]
        ax.set_title(os.path.basename(pcz.filename)+'\n'+
                     'Histogram of projection '+str(i))
        ax.set_xlabel('Proj'+str(i)+', Subset '+subid[subset])
        ax.set_ylabel('Frequency')
        nl, bins, patches = ax.hist(proj, bins=10*2**(resol-1), normed=1,
                                  histtype='stepfilled')
        sigma = math.sqrt(pcz.evals[i])
        mu=0.0
        y = mlab.normpdf( bins, mu, sigma)
        l = ax.plot(bins, y, 'r--', linewidth=1)
        ax.set_ylim(0,y.max()*1.1)
        fig.canvas.draw()

    elif option == 3:
        # make a 2D plot of a PC vs. another
        if event.key==left:
            c=0
            i = incr(i,-1,0,pcz.nvecs-1,j)
        elif event.key==right:
            c=0
            i = incr(i,1,0,pcz.nvecs-1,j)
        elif event.key==down:
            c=0
            j = incr(j,-1,0,pcz.nvecs-1,i)
        elif event.key==up:
            c=0
            j = incr(j,1,0,pcz.nvecs-1,i)
        elif event.key==prev:
            subset = incr(subset,-1,0,nsub-1)
        elif event.key==next:
            subset = incr(subset,1,0,nsub-1)
        elif event.key=='a':
            c=1

        ax.clear()
        if j==i:
            j=j+1
        if j>=pcz.nvecs:
            j=j-2
        proj1 = pcz.projs[i][i1[subset]:i2[subset]]
        proj2 = pcz.projs[j][i1[subset]:i2[subset]]
        nf = pcz.nframes
        ax.set_title(os.path.basename(pcz.filename)+'\n'+
                     'Proj '+str(i)+' vs. '+str(j)+' for subset '+subid[subset])
        ax.set_xlabel('Proj '+str(i))
        ax.set_ylabel('Proj '+str(j))
        line, = ax.plot(proj1, proj2)
        line.set_marker('*')
        line.set_markevery(nf-1)
        if c==0:
            fig.canvas.draw()
        else:
            # crude animation method, but good enough
            nfstep = nf // 200
            if nfstep < 1:
                nfstep=1
            for k in np.arange(1, nf, nfstep):
                line.set_xdata(proj1[0:k])
                line.set_ydata(proj2[0:k])
                line.set_markevery(k-1)
                fig.canvas.draw()

if len(sys.argv) < 2 or sys.argv[1] == '-h' or sys.argv[1] == '--help':
    print('usage: pyPczplot pczfile [indexfile]')
    exit(1)

testfile = sys.argv[1]
pcz = Pczfile(testfile)
if len(sys.argv) > 2:
    subfile = sys.argv[2]
    i1,i2 = np.loadtxt(subfile,usecols=(1,2),unpack=True)
    i2 = i2 + 1
    subid=[]
    with open(subfile,'r') as f:
        for line in f:
            subid.append(line.split()[0])

    nsub = len(i1)
else:
    i1 = [(0)]
    i2 = [(pcz.nframes)]
    subid = [str(0)]
    nsub = 1

evals = pcz.evals
prange = np.zeros((pcz.nvecs, 2))
for k in range(pcz.nvecs):
    proj = pcz.projs[k]
    prange[k,0]=proj.min()
    prange[k,1]=proj.max()

fig = figure()
ax = fig.add_subplot(111)
# begin by showing basic info over a plot of the eigenvalues
info = ('PCZPLOT v0.1\n\n'
        'File : '+pcz.filename+'\n'
        +str(pcz.natoms)+' atoms, '+str(pcz.nvecs)+ ' PCs, '
        +str(pcz.nframes)+' frames\n'
        +str(nsub)+' subsets\n\n'
        '(press \'h\' for help from any screen)')
text(0.5,0.5,info, horizontalalignment='center',
     verticalalignment='center', transform = ax.transAxes)
ax.plot(evals)
ax.set_title(os.path.basename(pcz.filename)+'\n'+'Eigenvalues')
ax.set_xlabel('PC')
ax.set_ylabel('Eigenvalue')
fig.canvas.mpl_connect('key_press_event', press)

show()
