import numpy as np
#from numpy import ones, zeros, empty
from flopy.mbase import Package
from flopy.utils import util_2d,util_3d

class ModflowSwi2(Package):
    'Salt Water Intrusion (SWI2) package class'
    def __init__(self, model, nsrf=1, istrat=1, nobs=0, iswizt=55, iswibd=56, iswiobs=0, fsssopt=False, adaptive=False, \
                 nsolver=1, iprsol=0, mutsol=3, \
                 solver2params = {'mxiter':100, 'iter1':20, 'npcond':1, 'zclose':1e-3, 'rclose':1e-4, 'relax':1.0, 'nbpol':2, 'damp':1.0, 'dampt':1.0}, \
                 toeslope=0.05, tipslope=0.05, alpha=None, beta=0.1, nadptmx=1, nadptmn=1, adptfct=1.0, \
                 nu=0.025, zeta=[], ssz=[], isource=0, \
                 obsnam=[], obslrc=[],
                 extension=['swi2','zta','swb'], unit_number=29, \
                 npln=None):
        name = ['SWI2', 'DATA(BINARY)', 'DATA(BINARY)']
        units = [unit_number,iswizt,iswibd]
        extra = ['','REPLACE','REPLACE']
        if nobs > 0:
            extension = name.append('zobs')
            name = name.append('DATA')
            units = units.append(iswiobs)
            extra = extra.append('')
        #Package.__init__(self, model, ) # Call ancestor's init to set self.parent
        #Package.__init__(self, model, extension, ['SWI', 'DATA(BINARY)', 'DATA(BINARY)'], [unit_number,iswizt,iswibd], extra=['','REPLACE','REPLACE']) # Call ancestor's init to set self.parent, extension, name and unit number
        Package.__init__(self, model, extension=extension, name=name, unit_number=units, extra=extra) # Call ancestor's init to set self.parent, extension, name and unit number
        nrow, ncol, nlay, nper = self.parent.nrow_ncol_nlay_nper
        self.heading = '# Salt Water Intrusion (SWI2) package file for MODFLOW-2005, generated by Flopy.'
        #
        self.fsssopt, self.adaptive = fsssopt, adaptive
        #
        if npln is not None:
            print 'npln keyword is deprecated. use the nsrf keyword'
            nsrf = npln
        self.nsrf, self.istrat, self.nobs, self.iswizt, self.iswibd, self.iswiobs = nsrf, istrat, nobs, iswizt, iswibd, iswiobs
        #
        self.nsolver, self.iprsol, self.mutsol = nsolver, iprsol, mutsol
        #
        self.solver2params = solver2params
        #        
        self.toeslope, self.tipslope, self.alpha, self.beta = toeslope, tipslope, alpha, beta
        self.nadptmx, self.nadptmn, self.adptfct = nadptmx, nadptmn, adptfct
        # Create arrays so that they have the correct size
        if self.istrat == 1:
            #self.nu = empty( self.nsrf+1 )
            self.nu = util_2d(model,(self.nsrf+1,),np.float32,nu,name='nu')
        else:
            #self.nu = empty( self.nsrf+2 )
            self.nu = util_2d(model,(self.nsrf+2,),np.float32,nu,name='nu')
        self.zeta = []
        #for i in range(nlay):
        #    self.zeta.append( empty((nrow, ncol, self.nsrf)) )
        #self.ssz = empty((nrow, ncol, nlay))
        #self.isource = empty((nrow, ncol, nlay),dtype='int32')
        # Set values of arrays
        #self.assignarray_old( self.nu, nu )
        for i in range(self.nsrf):
            #self.assignarray_old( self.zeta[i], zeta[i] )
            self.zeta.append(util_3d(model,(nlay,nrow,ncol),np.float32,zeta[i],name='zeta_'+str(i+1)))
        #self.assignarray_old( self.ssz, ssz )
        self.ssz = util_3d(model,(nlay,nrow,ncol),np.float32,ssz,name='ssz')
        #self.assignarray_old( self.isource, isource )
        self.isource = util_3d(model,(nlay,nrow,ncol),np.int,isource,name='isource')
        #
        self.obsnam = obsnam
        self.obslrc = obslrc
        #
        self.parent.add_package(self)
    def __repr__( self ):
        return 'Salt Water Intrusion (SWI2) package class'
    def write_file(self):
        nrow, ncol, nlay, nper = self.parent.nrow_ncol_nlay_nper
        # Open file for writing
        f_swi = open(self.file_name[0], 'w')
        # First line: heading
        f_swi.write('%s\n' % self.heading)  # Writing heading not allowed in SWI???
        #dataset 1
        f_swi.write( '#--Dataset 1\n' )
        f_swi.write( 6*'%10i' % (self.nsrf, self.istrat, self.nobs, self.iswizt, self.iswibd, self.iswiobs) )        

        if self.fsssopt is True:
            f_swi.write( '    FSSSOPT' )
        if self.adaptive is True:
            f_swi.write( '   ADAPTIVE' )
        f_swi.write( '\n' )
        #dataset 2a
        f_swi.write( '#--Dataset 2a\n' )
        f_swi.write( '%10i%10i%10i\n' % (self.nsolver, self.iprsol, self.mutsol) )
        #dataset 2b
        if self.nsolver == 2:
            f_swi.write( '#--Dataset 2b\n' )
            f_swi.write( '%10i' %self.solver2params['mxiter'] )
            f_swi.write( '%10i' %self.solver2params['iter1'] )
            f_swi.write( '%10i' %self.solver2params['npcond'] )
            f_swi.write( '%14.6e' %self.solver2params['zclose'] )
            f_swi.write( '%14.6e' %self.solver2params['rclose'] )
            f_swi.write( '%14.6e' %self.solver2params['relax'] )
            f_swi.write( '%10i' %self.solver2params['nbpol'] )
            f_swi.write( '%14.6e' %self.solver2params['damp'] )
            f_swi.write( '%14.6e\n' %self.solver2params['dampt'] )
        #dataset 3a
        f_swi.write( '#--Dataset 3a\n' )
        f_swi.write( '%14.6e%14.6e' % (self.toeslope, self.tipslope) )
        if self.alpha is not None:
            f_swi.write( '%14.6e%14.6e' % (self.alpha, self.beta) )
        f_swi.write('\n')
        #dataset 3b
        if self.adaptive is True:
            f_swi.write( '#--Dataset 3b\n' )
            f_swi.write( '%10i%10i%14.6e\n' % (self.nadptmx, self.nadptmn, self.adptfct) )
        #dataset 4
        f_swi.write( '#--Dataset 4\n' )
        #self.parent.write_array_old( f_swi, self.nu, self.unit_number[0], True, 13, 20 )
        f_swi.write(self.nu.get_file_entry())
        #dataset 5
        f_swi.write( '#--Dataset 5\n' )
        for isur in range(self.nsrf):
            for ilay in range(nlay):
                #self.parent.write_array_old( f_swi, self.zeta[ilay][:,:,isur], self.unit_number[0], True, 13, ncol )
                f_swi.write(self.zeta[isur][ilay].get_file_entry())        
        #dataset 6
        f_swi.write( '#--Dataset 6\n' )
        #for ilay in range(nlay):
                #self.parent.write_array_old( f_swi, self.ssz[:,:,ilay], self.unit_number[0], True, 13, ncol )
        f_swi.write(self.ssz.get_file_entry())
        #dataset 7
        f_swi.write( '#--Dataset 7\n' )
        #for ilay in range(nlay):
                #self.parent.write_array_old( f_swi, self.isource[:,:,ilay], self.unit_number[0], True, 13, ncol )
        f_swi.write(self.isource.get_file_entry())
        #dataset 8
        if self.nobs > 0:
            f_swi.write( '#--Dataset 8\n' )
            for i in range(self.nobs):
                f_swi.write( self.obsnam[i] + 3*'%10i' %self.obslrc + '\n' )
        # Close file
        f_swi.close()
    
