# -*- coding: utf-8 -*-
"""
FDMNES class "classes_fdmnes.py"
 functions for generating fdmnes files

By Dan Porter, PhD
Diamond
2018

Version 0.9
Last updated: 17/04/18

Version History:
17/04/18 0.9    Program created

@author: DGPorter
"""

import os, re
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # 3D plotting

from Dans_Diffraction import functions_general as fg
from Dans_Diffraction import functions_crystallography as fc


class Fdmnes:
    """
    FDMNES Create files and run program
    """

    def __init__(self, xtl):
        """
        initialise
        :param xtl: object
        """
        self.xtl = xtl

        # Options
        self.exe_path = 'c:/Program Files/FDMNES/fdmnes_64.exe'
        self.output_path = 'c:/Program Files/FDMNES/sim/new'
        self.output_name = 'out'
        self.comment = ''
        self.radius = 4.0
        self.edge = 'K'
        self.absorber = None
        self.azi_ref = [1, 0, 0]
        self.hkl_reflections = [[1, 0, 0]]

    def setup(self, exe_path=None, output_path=None, output_name=None, comment=None,
              radius=None, edge=None, absorber=None, azi_ref=None, hkl_reflections=None):
        """
        Set FDMNES Parameters
        :param exe_path: Location of FDMNES executable, e.g. 'c:\FDMNES\fdmnes_win64.exe'
        :param output_path: Specify the output path
        :param output_name: Name of output file
        :param comment: A comment written in the input file
        :param radius:
        :param edge:
        :param absorber:
        :param azi_ref:
        :param hkl_reflections:
        :return:
        """
        if exe_path is not None:
            self.exe_path = exe_path

        if output_path is not None:
            self.output_path = output_path

        if output_name is not None:
            self.output_name = output_name

        if comment is not None:
            self.comment = comment

        if radius is not None:
            self.radius = radius

        if edge is not None:
            self.edge = edge

        if absorber is not None:
            self.absorber = absorber

        if azi_ref is not None:
            self.azi_ref = azi_ref

        if hkl_reflections is not None:
            self.hkl_reflections = np.asarray(hkl_reflections).reshape(-1,3)

        print('FDMNES Options')
        print()


    def write_runfile(self):
        "Write a basic FDMNES input file for editing"

        UV = self.xtl.Cell.UV()
        avUV=self.xtl.Cell.UV()
        uvw, type, label, occupancy, uiso, mxmymz = self.xtl.Structure.get()

        noat = len(uvw)

        # Lattice parameters
        a,b,c,alpha,beta,gamma = self.xtl.Cell.lp()

        # element types
        types,typ_idx = np.unique(type, return_inverse=True)
        Z = fc.atom_properties(types,'Z')
        if absorber is None:
            absorber = types[0]

        absorber_idx = np.where(type == absorber)[0]
        nonabsorber_idx = np.where(type != absorber)[0]

        fdm_ar = self.azimuthal_reference(azi_ref)

        if directory is None:
            directory = propdir

        # Create cell file
        fname = os.path.join(directory,'FDMNES_' + self.xtl.name + '.txt')
        fdfile = open(fname,'w')

        # Write top matter
        fdfile.write('! FDMNES indata file\n')
        fdfile.write('! {}\n'.format(self.xtl.name))
        fdfile.write('! indata file generated by Dans_Diffraction.classes_fdmnes\n')
        fdfile.write('! By Dan Porter, PhD\n')
        fdfile.write('\n')
        fdfile.write(' Filout\n')
        fdfile.write('   {}\out\n\n'.format(directory))
        fdfile.write('  Range                        ! Energy range of calculation (eV). Energy of photoelectron relative to Fermi level.\n')
        fdfile.write(' -19. 0.1 31. \n\n')
        fdfile.write(' Radius                       ! Radius of the cluster where final state calculation is performed\n')
        fdfile.write('   {:3.1f}                        ! For a good calculation, this radius must be increased up to 6 or 7 Angstroems\n\n'.format(radius))
        fdfile.write(' Edge                         ! Threshold type\n')
        fdfile.write('  {}\n\n'.format(edge))
        fdfile.write(' SCF                          ! Self consistent solution\n')
        fdfile.write(' Green                        ! Muffin tin potential - faster\n')
        fdfile.write('! Quadrupole                   ! Allows quadrupolar E1E2 terms\n')
        fdfile.write('! magnetism                    ! performs magnetic calculations\n')
        fdfile.write(' Density                      ! Outputs the density of states as _sd1.txt\n')
        fdfile.write(' Spherical                    ! Outputs the spherical tensors as _sph_.txt\n')
        fdfile.write(' energpho                     ! output the energies in real terms\n')
        fdfile.write(' Convolution                  ! Performs the convolution\n\n')

        fdfile.write(' Zero_azim                    ! Define basis vector for zero psi angle\n')
        fdfile.write('  {:8.6f} {:8.6f} {:8.6f}  '.format(fdm_ar[0],fdm_ar[1],fdm_ar[2]))
        fdfile.write('! Same as I16, Reciprocal ({} {} {}) in units of real SL. '.format(azi_ref[0],azi_ref[1],azi_ref[2]))
        fdfile.write('(would be ({} {} {})\n\n'.format(sl_ar[0],sl_ar[1],sl_ar[2]))

        fdfile.write(' rxs                          ! Resonant x-ray scattering at various peaks, peak given by: h k l sigma pi azimuth.\n')
        for hkl in HKL:
            fdfile.write(' {} {} {}    1 1                 ! ({} {} {}) sigma-sigma\n'.format(hkl[0],hkl[1],hkl[2],hkl[0],hkl[1],hkl[2]))
            fdfile.write(' {} {} {}    1 2                 ! ({} {} {}) sigma-pi\n'.format(hkl[0],hkl[1],hkl[2],hkl[0],hkl[1],hkl[2]))
        fdfile.write(' \n')

        fdfile.write(' Atom ! s=0,p=1,d=2,f=3, must be neutral, get d states right by moving e to 2s and 2p sites\n')
        for n in range(len(types)):
            fdfile.write(' {:3.0f} 0 ! {}\n'.format(Z[n],types[n]))
        fdfile.write(' \n')

        fdfile.write(' Crystal                      ! Periodic material description (unit cell)\n')
        fdfile.write(' {:9.5f} {:9.5f} {:9.5f} {:9.5f} {:9.5f} {:9.5f}\n'.format(a,b,c,alpha,beta,gamma))
        fdfile.write('! Coordinates - 1st atom is the absorber\n')
        # Write atomic coordinates
        for nn in range(len(absorber_idx)):
            n = absorber_idx[nn]
            fdfile.write('{0:2.0f} {1:20.15f} {2:20.15f} {3:20.15f} ! {4:-3.0f} {5:2s}\n'.format(typ_idx[n]+1,uvw[n,0],uvw[n,1],uvw[n,2],n,element[n]))
        for nn in range(len(nonabsorber_idx)):
            n = nonabsorber_idx[nn]
            fdfile.write('{0:2.0f} {1:20.15f} {2:20.15f} {3:20.15f} ! {4:-3.0f} {5:2s}\n'.format(typ_idx[n]+1,uvw[n,0],uvw[n,1],uvw[n,2],n,element[n]))
        fdfile.write('\n')

        # Write end matter
        fdfile.write(' End')

        # Close file
        fdfile.close()
        print "FDMNES file written to {}".format(fname)

    def azimuthal_reference(self, hkl=[1, 0, 0]):
        """
        Generate the azimuthal reference
        :param hkl: (1*3) array [h,k,l]
        :return: None
        """

        UV = self.xtl.Cell.UV()
        UVs = self.xtl.Cell.UVstar()

        sl_ar = np.dot(np.dot(hkl, UVs), np.linalg.inv(UVs))  # Q*/UV*
        fdm_ar = np.dot(np.dot(hkl, UVs), np.linalg.inv(UV))  # Q*/UV
        fdm_ar = fdm_ar / np.sqrt(np.sum(fdm_ar ** 2))  # normalise length to 1
        return fdm_ar

    def run_fdmnes(self):
        """
        Run the fdmnes
        :return: None
        """
        pass


class FdmnesAnalysis:
    """
    Create fdmnes object from *_scan_conv.txt file

    Usage:
        fdm = fdmnes('dir\out_scan_conv.txt#')

        fdm contains all calculated reflections and xanes spectra
    """

    def __init__(self, filename='out_scan_conv.txt'):
        energy, angle, intensity = read_scan_conv(filename)

        filename = filename.replace('\\', '/')  # convert windows directories
        dirname, filetitle = os.path.split(filename)  # calculation directory
        calcname = dirname.split('/')[-1]  # calculation name
        outname = filetitle[:-14]
        convname = outname + '_conv.txt'

        enxanes, Ixanes = read_conv(dirname + '/' + convname)
        self.xanes = xanes(enxanes, Ixanes, calcname)

        self.energy = energy
        self.angle = angle
        self.reflections = intensity
        self.reflist = intensity.keys()

        for ref in self.reflist:
            refname = ref.replace('(', '').replace(')', '')
            refobj = reflection(self.energy, self.angle, self.reflections[ref], ref, calcname)
            setattr(self, refname, refobj)


class reflection:
    """
    Reflection object used in class fdmnes

    Contains energy and angle information for each reflection,
    plus plotting functions
    """

    def __init__(self, energy, angle, intensity, refname, calcname):
        self.energy = energy
        self.angle = angle
        self.intensity = intensity
        self.refname = refname
        self.calcname = calcname

    def azi_cut(self, cutenergy=None):
        cutintensity = azi_cut(self.energy, self.intensity, cutenergy)
        return cutintensity

    def eng_cut(self, cutangle=None):
        cutintensity = eng_cut(self.angle, self.intensity, cutangle)
        return cutintensity

    def plot3D(self):
        # 3D figure
        fig = plt.figure(figsize=[12, 10])
        ax = fig.add_subplot(111, projection='3d')

        XX, YY = np.meshgrid(self.angle, self.energy)
        ax.plot_surface(XX, YY, self.intensity, rstride=3, cstride=3, cmap=plt.cm.coolwarm,
                        linewidth=0, antialiased=False)

        # Axis labels
        ax.set_xlabel('Angle (DEG)', fontsize=18)
        ax.set_ylabel('Energy (eV)', fontsize=18)
        ax.set_zlabel('Intensity', fontsize=18)
        plt.suptitle('{}\n{}'.format(self.calcname, self.refname), fontsize=21, fontweight='bold')

    def plot_azi(self, cutenergy='max'):
        cutintensity = azi_cut(self.energy, self.intensity, cutenergy)

        plt.figure(figsize=[12, 10])
        plt.plot(self.angle, cutintensity)
        plt.xlabel('Angle (DEG)', fontsize=18)
        plt.ylabel('Intensity', fontsize=18)
        plt.title('{}\n{} {} eV'.format(self.calcname, self.refname, cutenergy), fontsize=21, fontweight='bold')

    def plot_eng(self, cutangle='max'):
        cutintensity = eng_cut(self.angle, self.intensity, cutangle)

        plt.figure(figsize=[12, 10])
        plt.plot(self.energy, cutintensity)
        plt.xlabel('Energy (eV)', fontsize=18)
        plt.ylabel('Intensity', fontsize=18)
        plt.title('{}\n{} {} Deg'.format(self.calcname, self.refname, cutangle), fontsize=21, fontweight='bold')


class xanes():
    def __init__(self, energy, intensity, calcname):
        self.energy = energy
        self.intensity = intensity
        self.calcname

    def plot(self):
        plt.figure(figsize=[12, 10])
        plt.plot(self.energy, self.intensity, lw=3)
        plt.title(self.calcname, fontsize=26, fontweight='bold', fontname='Times New Roman')
        plt.xlabel('Energy [eV]', fontsize=28, fontname='Times New Roman')
        plt.ylabel('Intensity [arb. units]', fontsize=28, fontname='Times New Roman')
        plt.xticks(fontsize=25, fontname='Times New Roman')
        plt.yticks(fontsize=25, fontname='Times New Roman')


############## FUNCTIONS ########################

def read_conv(filename='out_conv.txt', plot=False):
    """
    Reads fdmnes output file out_conv.txt, that gives the XANES spectra
      energy,intensity = read_conv(filename)
    """

    filename = filename.replace('\\', '/')  # convert windows directories
    dirname, filetitle = os.path.split(filename)  # calculation directory
    calcname = dirname.split('/')[-1]  # calculation nam

    data = np.loadtxt(filename, skiprows=1)
    energy = data[:, 0]
    xanes = data[:, 1]

    if plot:
        plt.figure(figsize=[12, 10])
        plt.plot(energy, xanes, lw=3)
        plt.title(calcname, fontsize=26, fontweight='bold', fontname='Times New Roman')
        plt.xlabel('Energy [eV]', fontsize=28, fontname='Times New Roman')
        plt.ylabel('Intensity [arb. units]', fontsize=28, fontname='Times New Roman')
        plt.xticks(fontsize=25, fontname='Times New Roman')
        plt.yticks(fontsize=25, fontname='Times New Roman')
    return energy, xanes


def read_scan_conv(filename='out_scan_conv.txt'):
    """
    Read FDMNES _scan_conv.txt files, return simulated azimuthal and energy values
    energy,angle,intensity = read_scan_conv(filename)
        filename = directory and name of _scan_conv.txt file
        energy = [nx1] array of energy values
        angle = [mx1] array on angle values
        intensity = {'I(100)ss'}[nxm] dict of arrays of simulated intensities for each reflection

    You can see all the available reflections with intensity.keys()
    """

    # Open file
    file = open(filename)

    # Determine reflections in file
    filetext = file.read()  # generate string
    reftext = re.findall('I\(\-?\d+-?\d+?-?\d+?\)\w\w', filetext)  # find reflection strings

    refs = np.unique(reftext)  # remove duplicates
    Npeak = len(refs)
    Nenergy = len(reftext) / Npeak
    Nangle = 180

    file.seek(0)  # return to start of file

    # pre-define arrays
    storevals = {}
    for ref in refs:
        storevals[ref] = np.zeros([Nenergy, Nangle])

    storeeng = np.zeros(Nenergy)
    storeang = np.zeros(Nangle)

    # Read file, line by line
    for E in range(Nenergy):
        file.readline()  # blank line
        storeeng[E] = float(file.readline().strip())  # energy
        for P in range(Npeak):
            peak = file.readline().strip()  # current reflection
            # read angle,Intensity lines
            vals = np.zeros([Nangle, 2])
            for m in range(Nangle):
                vals[m, :] = [float(x) for x in file.readline().split()]

            storeang = vals[:, 0]  # store angle values
            storevals[peak][E, :] = vals[:, 1]  # store intensity values

    file.close()
    return storeeng, storeang, storevals


def azi_cut(storeeng, intensities, cutenergy=None):
    """
    Generate azimuthal cut at a particular energy
    cutintensity = azi_cut(storeeng,intensities,cutenergy)
        storeeng = [nx1] array of energies from read_scan_conv
        intensities = [nxm] array of simulated intensities for a single reflection (e.g. storevals['I(100)sp'])
        cutenergy = energy to take the cut at, will take value closest to cutenergy. In eV.
        cutenergy = 'max' - take the cut energy at the maximum intensity.
        cutintensity = [mx1] array of simulated intensity at this energy

    e.g.
     energy,angle,intensities = read_scan_conv(filename)
     cutintensity = azi_cut(energy,intensities['I(100)sp'],cutenergy='max')
    """

    if cutenergy == 'max':
        i, j = np.unravel_index(np.argmax(intensities), intensities.shape)
        print(' Highest value = {} at {} eV [{},{}]'.format(intensities[i, j], storeeng[i], i, j))
        cutenergy = storeeng[i]

    enpos = np.argmin(abs(storeeng - cutenergy))
    if np.abs(storeeng[enpos] - cutenergy) > 5:
        print "You havent choosen the right cutenergy. enpos = {}".format(enpos, storeeng[enpos])

    return intensities[enpos, :]


def eng_cut(storeang, intensities, cutangle=None):
    """
    Generate energy cut at a particular azimuthal angle
    cutintensity = eng_cut(storeang,intensities,cutangle)
        storeang = [mx1] array of angles from read_scan_conv
        intensities = [nxm] array of simulated intensities for a single reflection (e.g. storevals['I(100)sp'])
        cutangle = angle to take the cut at, will take value closest to cutenergy. In Deg.
        cutenergy = 'max' - take the cut angle at the maximum intensity.
        cutintensity = [nx1] array of simulated intensity at this angle

    e.g.
     energies,angles,intensities = read_scan_conv(filename)
     cutintensity = eng_cut(angles,intensities['I(100)sp'],cutangle=180)
    """

    if cutangle == 'max':
        i, j = np.unravel_index(np.argmax(intensities), intensities.shape)
        print(' Highest value = {} at {} Deg [{},{}]'.format(intensities[i, j], storeang[j], i, j))
        cutangle = storeang[j]

    angpos = np.argmin(abs(storeang - cutangle))
    if np.abs(storeang[angpos] - cutangle) > 5:
        print "You havent choosen the right cutangle. angpos = {} [{}]".format(enpos, storeang[angpos])
    return intensities[:, angpos]
