from numpy import empty,zeros,ones,where
from flopy.mbase import Package

class ModflowMnw2(Package):
    'Multi-node well 2 package class\n'
    '''
    NOTE: This implementation does not allow well loss parameters {Rw,Rskin,Kskin,B,C,P,CWC,PP} to vary along the length of a given well. It also
    does not currently support data sections 2e, 2f, 2g, 2h, or 4b as defined in the data input instructions for the MNW2 package.    
    '''

    def __init__( self, model, mnwmax=0, iwl2cb=-1, mnwprnt=0, aux=None,
                  wellid=None, nnodes=None, losstype=None, pumploc=0, qlimit=0, ppflag=0, pumpcap=0,
                  lay_row_col=None, ztop_zbotm_row_col=None, rw=0, rskin=0, kskin=0, b=0, c=0, p=0, cwc=0, pp=1,
##                  pumplay_pumprow_pumpcol=0, zpump=0,
##                  hlim=None, qcut=None, qfrcmn=None, qfrcmx=None,
##                  hlift=0, liftq0=0, liftqmax=0, hwtol=0,
##                  liftn=None, qn=None,
                  itmp=0,
                  wellid_qdes=None, capmult=0, cprime=0,
                  extension='mnw2', unitnumber=34 ):
        Package.__init__(self, model, extension, 'MNW2', unitnumber) # Call ancestor's init to set self.parent, extension, name, and unit number
        self.url = 'mnw2.htm'
        self.nper = self.parent.nrow_ncol_nlay_nper[-1]
        self.heading = '# Multi-node well 2 (MNW2) file for MODFLOW, generated by Flopy'
        self.mnwmax = mnwmax            #-maximum number of multi-node wells to be simulated
        self.iwl2cb = iwl2cb            #-flag and unit number
        self.mnwprnt = mnwprnt          #-verbosity flag
        self.aux = aux                  #-list of optional auxilary parameters
        self.wellid = wellid            #-array containing well id's (shape = (MNWMAX))
        self.nnodes = nnodes            #-array containing # of nodes to be simulated for each well (shape = (MNWMAX))
        self.losstype = losstype        #-array containing head loss type for each well (shape = (MNWMAX))
        self.pumploc = pumploc          #-array containing integer flag pertaining to the location of a pump intake (if any) (shape = (MNWMAX))
        self.qlimit = qlimit            #-array containing integer flag indicating if water levels will be used to constrain pumping (shape = (MNWMAX))
        self.ppflag = ppflag            #-array containing integer flag indicating if water levels will be corrected for partial penetration (shape = (MNWMAX))
        self.pumpcap = pumpcap          #-array containing integer flag indicating if discharge from a pumping well is adjusted for changes in lift (shape = (MNWMAX))
        self.lay_row_col = lay_row_col  #-list of arrays containing lay, row, and col for all well nodes [NNODES > 0](shape = (NNODES,3), length = MNWMAX)
        self.ztop_zbotm_row_col = ztop_zbotm_row_col    #-list of arrays containing top and botm elevation of all open intervals [NNODES < 0](shape = (abs(NNODES),2), length = MNWMAX)       
        self.rw = rw                    #-array containing Rw (shape = (MNWMAX))
        self.rskin = rskin              #-array containing Rskin (shape = (MNWMAX))
        self.kskin = kskin              #-array containing Kskin (shape = (MNWMAX))
        self.b = b                      #-array containing B (shape = (MNWMAX))
        self.c = c                      #-array containing C (shape = (MNWMAX))
        self.p = p                      #-array containing P (shape = (MNWMAX))
        self.cwc = cwc                  #-array containing CWC (shape = (MNWMAX))
        self.pp = pp                    #-array containing PP (shape = (MNWMAX))
##        self.pumplay_pumprow_pumpcol = pumplay_pumprow_pumpcol #-array containing lay,row,col of pump intake for each well (if any) (shape = (MNWMAX,3))
##        self.zpump = zpump              #-array containing elevation of pump intake for each well (if any) (shape = (MNWMAX))
##        self.hlim = hlim                #-list of arrays containing limiting water level which constrains flow (shape = MNWMAX, length <= NPER)
##        self.qcut = qcut                #-list of arrays containing integer flag indicating how pumping limits will be specified (shape = MNWMAX, length <= NPER)
##        self.qfrcmn = qfrcmn            #-list of arrays containing mimimum pumping rate or fraction of original pumping rate (shape = MNWMAX, length <= NPER)
##        self.qfrcmx = qfrcmx            #-list of arrays containing mimimum pumping rate which must be exceeded to reactivate well (shape = MNWMAX, length <= NPER)
##        self.hlift = hlift              #-array containing the reference elevation of the discharge point for each well (shape = MNWMAX)
##        self.liftq0 = liftq0            #-array containing the value of lift that exceeds the capacity of the pump (shape = MNWMAX)
##        self.liftqmax = liftqmax        #-array containing the value of lift that corresponds to the maximum pumping rate (shape = MNWMAX)
##        self.hwtol = hwtol              #-array containing the miminum absolute value of change in computed water level between iterations (shape = MNWMAX)
##        self.liftn = liftn              #-list of arrays containing the value of lift that corresponds to a known value of discharge (Qn) (shape = MNWMAX, length = pumpcap)
##        self.qn = qn                    #-list of arrays containing the value of discharge corresponding to LIFTn (shape = MNWMAX, length = pumpcap)
        self.itmp = itmp                #-array containing # of wells to be simulated for each stress period (shape = (NPER))
        self.wellid_qdes = wellid_qdes  #-list of arrays containing desired Q for each well in each stress period (shape = (NPER,MNWMAX,2))
##        self.capmult = capmult          #-array containing CapMult flag for each well in each stress period (shape = (NPER,MNWMAX))
##        self.cprime = cprime            #-array containing Cprime for each well in each stress period (shape = (NPER,MNWMAX))
        
        #-create empty arrays of the correct size
        '''
        NOTE: some arrays are not pre-formatted here as their shapes vary from well to well and from period to period.
        '''
        self.wellid = empty( (self.mnwmax),dtype='S25' )
        self.nnodes = zeros( (self.mnwmax),dtype='int32' )
        self.losstype = empty( (self.mnwmax),dtype='S25' )
        self.pumploc = zeros( (self.mnwmax),dtype='int32' )
        self.qlimit = zeros( (self.mnwmax),dtype='int32' )
        self.ppflag = zeros( (self.mnwmax),dtype='int32' )
        self.pumpcap = zeros( (self.mnwmax),dtype='int32' )
        self.rw = zeros( self.mnwmax,dtype='float32' )
        self.rskin = zeros( self.mnwmax,dtype='float32' )
        self.kskin = zeros( self.mnwmax,dtype='float32' )
        self.b = zeros( self.mnwmax,dtype='float32' )
        self.c = zeros( self.mnwmax,dtype='float32' )
        self.p = zeros( self.mnwmax,dtype='float32' )
        self.cwc = zeros( self.mnwmax,dtype='float32' )
        self.pp = zeros( self.mnwmax,dtype='float32' )
##        self.pumplay_pumprow_pumpcol = empty( (self.mnwmax,3),dtype='int32' )
##        self.zpump = empty( (self.mnwmax),dtype='float32' )
##        self.hlift = empty( (self.mnwmax),dtype='float32' )
##        self.liftq0 = empty( (self.mnwmax),dtype='float32' )
##        self.liftqmax = empty( (self.mnwmax),dtype='float32' )
##        self.hwtol = empty( (self.mnwmax),dtype='float32' )
        self.itmp = zeros( self.nper,dtype='int32' )
##        self.capmult = empty( (self.nper,self.mnwmax),dtype='float32' )
##        self.cprime = empty( (self.nper,self.mnwmax),dtype='float32' )

        #-assign values to arrays        
        self.assignarray_old( self.wellid, wellid )
        self.assignarray_old( self.nnodes, nnodes )
        self.assignarray_old( self.losstype, losstype )
        self.assignarray_old( self.pumploc, pumploc )
        self.assignarray_old( self.qlimit, qlimit )
        self.assignarray_old( self.ppflag, ppflag )
        self.assignarray_old( self.pumpcap, pumpcap )
        self.assignarray_old( self.rw, rw )
        self.assignarray_old( self.rskin, rskin )
        self.assignarray_old( self.kskin, kskin )
        self.assignarray_old( self.b, b )
        self.assignarray_old( self.c, c )
        self.assignarray_old( self.p, p )
        self.assignarray_old( self.cwc, cwc )
        self.assignarray_old( self.pp, pp )
##        self.assignarray_old( self.pumplay_pumprow_pumpcol, pumplay_pumprow_pumpcol )
##        self.assignarray_old( self.zpump, zpump )
##        self.assignarray_old( self.hlift, hlift )
##        self.assignarray_old( self.liftq0, liftq0 )
##        self.assignarray_old( self.liftqmax, liftqmax )
##        self.assignarray_old( self.hwtol, hwtol )
        self.assignarray_old( self.itmp, itmp )
##        self.assignarray_old( self.capmult, capmult )
##        self.assignarray_old( self.cprime, cprime )
        
        #-input format checks:
        lossTypes = ['NONE','THIEM','SKIN','GENERAL','SPECIFYcwc']
        for i in range(mnwmax):
            assert len(self.wellid[i].split(' ')) == 1, 'WELLID (%s) must not contain spaces' % self.wellid[i]
            assert self.losstype[i] in lossTypes, 'LOSSTYPE (%s) must be one of the following: NONE, THIEM, SKIN, GENERAL, or SPECIFYcwc' % self.losstype[i]
        assert self.itmp[0] >= 0, 'ITMP must be greater than or equal to zero for the first time step.'
        assert self.itmp.max() <= self.mnwmax, 'ITMP cannot exceed maximum number of wells to be simulated.'
        
        self.parent.add_package(self)
         
    def write_file( self ):
        
        #-open file for writing
        f_mnw2 = open( self.file_name[0], 'w' )

        #-write header
        f_mnw2.write( '%s\n' % self.heading )

        #-Section 1 - MNWMAX, IWL2CB, MNWPRNT {OPTION}
        auxParamString = ''
        if self.aux != None:
            for param in self.aux:
                auxParamString = auxParamString + 'AUX %s ' % param
        f_mnw2.write( '%10i%10i%10i %s\n' % ( self.mnwmax,
                                              self.iwl2cb,
                                              self.mnwprnt,
                                              auxParamString ) )

        #-Section 2 - Repeat this section MNWMAX times (once for each well)
        for i in range(self.mnwmax):
            #-Section 2a - WELLID, NNODES
            f_mnw2.write( '%s%10i\n' % ( self.wellid[i], self.nnodes[i] ) )
            #-Section 2b - LOSSTYPE, PUMPLOC, Qlimit, PPFLAG, PUMPCAP
            f_mnw2.write( '%s %10i%10i%10i%10i\n' % ( self.losstype[i],
                                                      self.pumploc[i],
                                                      self.qlimit[i],
                                                      self.ppflag[i],
                                                      self.pumpcap[i] ) )
            #-Section 2c - {Rw, Rskin, Kskin, B, C, P, CWC}
            if self.losstype[i] == 'THIEM':
                f_mnw2.write( '%10f\n' % ( self.rw[i] )  )
            elif self.losstype[i] == 'SKIN':
                f_mnw2.write( '%10f %10f %10f\n' % ( self.rw[i],
                                                     self.rskin[i],
                                                     self.kskin[i] ) )
            elif self.losstype[i] == 'GENERAL':
                f_mnw2.write( '%10f %10f %10f %10f\n' % ( self.rw[i],
                                                          self.b[i],
                                                          self.c[i],
                                                          self.p[i] ) )
            elif self.losstype[i] == 'SPECIFYcwc':
                f_mnw2.write( '%10f\n' % ( self.cwc[i] ) )
                
            #-Section 2d - Repeat sections 2d-1 or 2d-2 once for each open interval
            #-Section 2d-1 - NNODES > 0; LAY, ROW, COL {Rw, Rskin, Kskin, B, C, P, CWC, PP}
            absNnodes = abs(self.nnodes[i])
            if self.nnodes[i] > 0:
                for n in range(absNnodes):
                    f_mnw2.write( '%10i%10i%10i\n' % ( self.lay_row_col[i][n,0],
                                                       self.lay_row_col[i][n,1],
                                                       self.lay_row_col[i][n,2] ) )
            #-Section 2d-2 - NNODES < 0; Ztop, Zbotm, ROW, COL {Rw, Rskin, Kskin, B, C, P, CWC, PP}
            elif self.nnodes[i] < 0:
                for n in range(absNnodes):
                    f_mnw2.write( '%10f %10f %10i %10i\n' % ( self.ztop_zbotm_row_col[i][n,0],
                                                              self.ztop_zbotm_row_col[i][n,1],
                                                              self.ztop_zbotm_row_col[i][n,2],
                                                              self.ztop_zbotm_row_col[i][n,3] ) )
##            #-Section 2e - {PUMPLAY PUMPROW PUMPCOL} {ZPUMP}
##            if self.pumploc[i] > 0:
##                 f_mnw2.write( '%10i%10i%10i\n' % ( self.pumplay_pumprow_pumpcol[i,0],
##                                                    self.pumplay_pumprow_pumpcol[i,1],
##                                                    self.pumplay_pumprow_pumpcol[i,2] ) )
##            elif self.pumploc[i] < 0:
##                f_mnw2.write( '%10f\n' % self.zpump[i] )
##
##            #-Section 2f - Hlim QCUT {Qfrcmn Qfrcmx}
##            if self.qlimit[i] > 0:
##                if self.qcut[i][0] == 0:
##                    f_mnw2.write( '%10f%10i\n' % ( self.hlim[0][i],
##                                                   self.qcut[i][0] ) )                
##                else:
##                    f_mnw2.write( '%10f %10i %10f %10f\n' % ( self.hlim[0][i],
##                                                              self.qcut[0][i],
##                                                              self.qfrcmn[0][i],
##                                                              self.qfrcmx[0][i] ) )
##
##            #-Section 2g - Hlift LIFTq0 LIFTqmax HWtol
##            if self.pumpcap[i] > 0:
##                f_mnw2.write( '%10f %10f %10f %10f\n' % ( self.hlift[i],
##                                                          self.liftq0[i],
##                                                          self.liftqmax[i],
##                                                          self.hwtol[i] ) )
##
##            #-Section 2h - LIFTn Qn (Repeat PUMPCAP times)
##            if self.pumpcap[i] > 0:
##                for x in range(self.pumpcap[i]):
##                    f_mnw2.write( '10f %10f\n' % ( self.liftn[i][x],
##                                                   self.qn[i][x] ) )
            
        #-Section 3 - Repeat this section NPER times (once for each stress period)   
        for p in range(self.nper):
            f_mnw2.write( '%10i\n' % self.itmp[p] )

            #-Section 4 - Repeat this section ITMP times (once for each well to be simulated in current stress period)
            if self.itmp[p] > 0:
                '''
                Create an array that will hold well names to be simulated during this stress period and find their corresponding
                index number in the "wellid" array so the right parameters (Hlim Qcut {Qfrcmn Qfrcmx}) are accessed.
                '''
                itmp_wellid_index_array = empty((self.itmp[p],2),dtype='object')
                for well in range(self.itmp[p]):
                    itmp_wellid_index_array[well,0]=self.wellid_qdes[p][well,0]
                    itmp_wellid_index_array[well,1]=where(self.wellid==self.wellid_qdes[p][well,0])
                
                for j in range(self.itmp[p]):
                    #-Section 4a - WELLID Qdes {CapMult} {Cprime} {xyz}
                    assert self.wellid_qdes[p][j,0] in self.wellid,'WELLID for pumping well is not present in "wellid" array'
                    f_mnw2.write( '%s %10f\n' % ( self.wellid_qdes[p][j,0],
                                                  self.wellid_qdes[p][j,1] ) )

##                    #-Section 4b - Hlim QCUT {Qfrcmn Qfrcmx}
##                    if self.qlimit[p] < 0:
##                        if self.qcut[p][j] !=0:
##                            f_mnw2.write( '%10f %10i %10f %10f\n' % ( self.hlim[p][itmp_wellid_index_array[j,1]],
##                                                                      self.qcut[p][itmp_wellid_index_array[j,1]],
##                                                                      self.qfrcmn[p][itmp_wellid_index_array[j,1]],
##                                                                      self.qfrcmx[p][itmp_wellid_index_array[j,1]] ) )
##                        else:
##                            f_mnw2.write( '%10f %10i\n' % (self.hlim[p][itmp_wellid_index_array[j,1]],
##                                                           self.qcut[p][itmp_wellid_index_array[j,1]]) )                        
        f_mnw2.close()

