#! /usr/bin/env python
import sys
import string
import subprocess
from getpass import getuser
import numpy as np
import netCDF4 as netcdf4
import xarray as xr

def mprint(mstr):
    vnum=sys.version_info[0]
    if vnum == 3:
        print(mstr)
    if vnum == 2:
        print mstr

'''
#------------------------------------------------------------------#
#---------------------  Instructions  -----------------------------#
#------------------------------------------------------------------#
After creating a single point surface data file from a global 
surface data file, use this script to overwrite some fields with 
site-specific data.  
'''

myname = getuser()
#--  Create single point CLM surface data file
create_surfdata = True

#--  specify site from which to extract data  ----------
sitename='US-Ha1'

#--  Site level data directory  ------------------------
site_dir='../PTCLM/PTCLM_sitedata/'

#--  Specify original file
fsurf = '/glade/scratch/'+myname+'/single_point/surfdata_0.9x1.25_16pfts_CMIP6_simyr1850_287.8_42.5_c170706.nc'

#--  Output directory  ---------------------------------
dir_output='/glade/scratch/'+myname+'/single_point/'

#-- Specify new file name  -----------------------------
command='date "+%y%m%d"'
x2=subprocess.Popen(command,stdout=subprocess.PIPE,shell='True')
x=x2.communicate()
timetag = x[0].strip()
fsurf2   = dir_output + 'surfdata_0.9x1.25_16pfts_CMIP6_simyr2000_'+sitename+'_site.c'+timetag+'.nc'

#== End User Mods  =====================================

#site_code,pft_f1,pft_c1,pft_f2,pft_c2,pft_f3,pft_c3,pft_f4,pft_c4,pft_f5,pft_c5
pftfile =site_dir+'PTCLMDATA_pftdata.txt'
#site_code,name,state,lon,lat,elev,startyear,endyear,alignyear,timestep,campaign
sitefile=site_dir+'PTCLMDATA_sitedata.txt'
#site_code,soil_depth,n_layers,layer_depth,layer_sand%,layer_clay%
soilfile=site_dir+'PTCLMDATA_soildata.txt'

#--  Raw datafiles  ------------------------
rawdatafile = '../mksurfdata_map/mksurfdata_map.namelist'

mstr='Checking for data for site: '+sitename
mprint(mstr)

#--  Read raw datafiles  ------------------------
with open(rawdatafile, 'r') as t1:
    for tmp in t1:
       x=tmp.split('=')
       if x[0].strip() == 'mksrf_fsoitex':
          fsoitex = x[1].strip()
          fsoitex = fsoitex.strip("'")
       if x[0].strip() == 'mksrf_forganic':
          forganic = x[1].strip()
          forganic = forganic.strip("'")
       if x[0].strip() == 'mksrf_flakwat':
          flakwat= x[1].strip()
          flakwat= flakwat.strip("'")
       if x[0].strip() == 'mksrf_fwetlnd':
          fwetlnd = x[1].strip()
          fwetlnd = fwetlnd.strip("'")
       if x[0].strip() == 'mksrf_fmax':
          fmax = x[1].strip()
          fmax = fmax.strip("'")
       if x[0].strip() == 'mksrf_fglacier':
          fglacier= x[1].strip()
          fglacier= fglacier.strip("'")
       if x[0].strip() == 'mksrf_fglacierregion':
          fglacierregion= x[1].strip()
          fglacierregion= fglacierregion.strip("'")
       if x[0].strip() == 'mksrf_fvocef':
          fvocef= x[1].strip()
          fvocef= fvocef.strip("'")
       if x[0].strip() == 'mksrf_furbtopo':
          furbtopo = x[1].strip()
          furbtopo = furbtopo.strip("'")
       if x[0].strip() == 'mksrf_fgdp':
          fgdp = x[1].strip()
          fgdp = fgdp.strip("'")
       if x[0].strip() == 'mksrf_fpeat':
          fpeat = x[1].strip()
          fpeat = fpeat.strip("'")
       if x[0].strip() == 'mksrf_fsoildepth':
          fsoildepth= x[1].strip()
          fsoildepth= fsoildepth.strip("'")
       if x[0].strip() == 'mksrf_fabm':
          fabm = x[1].strip()
          fabm = fabm.strip("'")
       if x[0].strip() == 'mksrf_ftopostats':
          ftopostats = x[1].strip()
          ftopostats = ftopostats.strip("'")
       if x[0].strip() == 'mksrf_fvic':
          fvic = x[1].strip()
          fvic = fvic.strip("'")
       if x[0].strip() == 'mksrf_fch4':
          fch4 = x[1].strip()
          fch4 = fch4.strip("'")
       if x[0].strip() == 'mksrf_furban':
          furban = x[1].strip()
          furban = furban.strip("'")
       if x[0].strip() == 'mksrf_fvegtyp':
          fvegtyp= x[1].strip()
          fvegtyp= fvegtyp.strip("'")
       if x[0].strip() == 'mksrf_fhrvtyp':
          fhrvtyp= x[1].strip()
          fhrvtyp= fhrvtyp.strip("'")
       if x[0].strip() == 'mksrf_fsoicol':
          fsoicol = x[1].strip()
          fsoicol = fsoicol.strip("'")
       if x[0].strip() == 'mksrf_flai':
          flai = x[1].strip()
          flai = flai.strip("'")
       
#Open site file and extract data
site_found = False
with open(sitefile, 'r') as t1:
    for tmp in t1:
       x=tmp.split(',')
       if x[0] == sitename:
          site_found = True
          site_header = tmp
          name     = x[1]
          state    = x[2]
          plon     = np.float32(x[3])
          plat     = np.float32(x[4])
          elev     = np.float32(x[5])
          startyear= np.int32(x[6])
          endyear  = np.int32(x[7])
          alignyear= np.int32(x[8])
          timestep = np.float32(x[9])
          campaign = x[10]
          exit

if not site_found:
   mprint('No site matching '+sitename+' was found')
   stop
else:
   mprint(site_header)

# convert longitude to east if needed
if plon < 0:
   plon+=360.0

#--  Open pft file and extract data  ---------------------------------
site_found = False
with open(pftfile, 'r') as t1:
    for tmp in t1:
       x=tmp.split(',')
       if x[0] == sitename:
          site_found = True
          pft_f1= np.int32(x[1])
          pft_c1= np.int32(x[2])
          pft_f2= np.int32(x[3])
          pft_c2= np.int32(x[4])
          pft_f3= np.int32(x[5])
          pft_c3= np.int32(x[6])
          pft_f4= np.int32(x[7])
          pft_c4= np.int32(x[8])
          pft_f5= np.int32(x[9])
          pft_c5= np.int32(x[10])
          exit


#--  Open soil file and extract data  ---------------------------------
site_found = False
with open(soilfile, 'r') as t1:
    for tmp in t1:
       x=tmp.split(',')
       if x[0] == sitename:
          site_found = True
          soil_depth     = np.float32(x[1])
          n_layers       = np.int32(x[2])
          layer_depth    = np.float32(x[3])
          layer_sand_pct = np.float32(x[4])
          layer_clay_pct = np.float32(x[5])
          exit

#--  create surface data file  --------------------------------------
##2
if create_surfdata:
   f1  = xr.open_dataset(fsurf)
   # expand dimensions
   #f1 = f1.expand_dims(['lsmlat','lsmlon'])

   # create 1d variables
   lon0=np.asarray(f1['LONGXY'][0,:])
   lon=xr.DataArray(lon0,name='lon',dims='lsmlon',coords={'lsmlon':lon0})
   lat0=np.asarray(f1['LATIXY'][:,0])
   lat=xr.DataArray(lat0,name='lat',dims='lsmlat',coords={'lsmlat':lat0})
   #f2=f1.assign({'lon':lon,'lat':lat})
   # the above doesn't work now
   f2=f1.assign()
   f2['lon'] = lon
   f2['lat'] = lat

   # make gridcell entirely natural vegetation *or* entirely crop
   new_pct_natveg = 0.
   new_pct_crop    = 0.
   if np.logical_and(
      np.any([pft_f1,pft_f2,pft_f3,pft_f4,pft_f5]),
      np.any([(pft_c1 < 15),(pft_c2 < 15),(pft_c3 < 15),(pft_c4 < 15),(pft_c5 < 15)])):
      new_pct_natveg = 100.
   if np.logical_and(
      np.any([pft_f1,pft_f2,pft_f3,pft_f4,pft_f5]),
      np.any([(pft_c1 >= 15),(pft_c2 >= 15),(pft_c3 >= 15),(pft_c4 >= 15),(pft_c5 >= 15)])):
      new_pct_crop = 100.

   if new_pct_natveg == new_pct_crop:
      print 'both natveg and crop set to 100, exiting'
      stop

   #--  Remove non-vegetated landunits  ---------------------------------
   f2['PCT_NATVEG']  = new_pct_natveg
   f2['PCT_CROP']    = new_pct_crop
   f2['PCT_LAKE']    = 0.
   f2['PCT_WETLAND'] = 0.
   f2['PCT_URBAN']   = 0.
   f2['PCT_GLACIER'] = 0.

   #--  Overwrite global data with raw data  ----------------------------
   f2['LONGXY'] = plon
   f2['LATIXY'] = plat

   #--  Soil texture
   r1  = xr.open_dataset(fsoitex)
   plonc = plon 
   if plonc > 180.0:
      plonc -= 360.0

   # set coordinates (seems to require 'lon' and 'lat' to recognize...
   r1=r1.rename({'LON':'lon','LAT':'lat'})
   r1.set_coords(['lon','lat'],inplace=True)
   # extract gridcell closest to plonc/plat
   r2  = r1.sel(lon=plonc,lat=plat,method='nearest')
   # extract sand/clay profiles for given mapunit
   mapunit = np.int32(r2['MAPUNITS'])

   f2['PCT_SAND'][:,0,0] = np.asarray(r2['PCT_SAND'][:,mapunit])
   f2['PCT_CLAY'][:,0,0] = np.asarray(r2['PCT_CLAY'][:,mapunit])

   r1.close() ; r2.close()

   #--  Organic
   r1  = xr.open_dataset(forganic)
   r1=r1.rename({'LON':'lon','LAT':'lat'})
   r1.set_coords(['lon','lat'],inplace=True)
   # extract gridcell closest to plonc/plat
   r2  = r1.sel(lon=plonc,lat=plat,method='nearest')
   f2['ORGANIC'][:,0,0] =  np.asarray(r2['ORGANIC'])
   r1.close() ; r2.close()

   #--  Fmax
   r1  = xr.open_dataset(fmax)
   # extract gridcell closest to plonc/plat
   r2  = r1.sel(lon=plonc,lat=plat,method='nearest')
   f2['FMAX'] =  np.asarray(r2['FMAX'])
   r1.close() ; r2.close()

   #--  VOC
   r1  = xr.open_dataset(fvocef)
   # extract gridcell closest to plonc/plat
   r2  = r1.sel(lon=plonc,lat=plat,method='nearest')
   f2['EF1_BTR'] =  np.asarray(r2['ef_btr'])
   f2['EF1_CRP'] =  np.asarray(r2['ef_crp'])
   f2['EF1_FDT'] =  np.asarray(r2['ef_fdt'])
   f2['EF1_FET'] =  np.asarray(r2['ef_fet'])
   f2['EF1_GRS'] =  np.asarray(r2['ef_grs'])
   f2['EF1_SHR'] =  np.asarray(r2['ef_shr'])
   r1.close() ; r2.close()

   #--  GDP
   r1  = xr.open_dataset(fgdp)
   # extract gridcell closest to plonc/plat
   r2  = r1.sel(lon=plonc,lat=plat,method='nearest')
   f2['gdp'] =  np.asarray(r2['gdp'])
   r1.close() ; r2.close()

   #--  Peat
   r1  = xr.open_dataset(fpeat)
   # extract gridcell closest to plonc/plat
   r2  = r1.sel(lon=plonc,lat=plat,method='nearest')
   f2['peatf'] =  np.asarray(r2['peatf'])
   r1.close() ; r2.close()

   #--  Soil Depth
   r1  = xr.open_dataset(fsoildepth)
   # create 1d variables
   lon0=np.asarray(r1['LONGXY'][0,:])
   lon=xr.DataArray(lon0,name='lon',dims='lon',coords={'lon':lon0})
   lat0=np.asarray(r1['LATIXY'][:,0])
   lat=xr.DataArray(lat0,name='lat',dims='lat',coords={'lat':lat0})
   r1['lon'] = lon
   r1['lat'] = lat
   r1=r1.rename({'lsmlon':'lon','lsmlat':'lat'})
   # extract gridcell closest to plonc/plat
   r2  = r1.sel(lon=plonc,lat=plat,method='nearest')
   f2['zbedrock'] =  np.asarray(r2['Avg_Depth_Median'])
   #f2['zbedrock'] =  np.asarray(r2['Avg_Depth_Mean'])
   r1.close() ; r2.close()

   #--  ABM
   r1  = xr.open_dataset(fabm)
   # extract gridcell closest to plonc/plat
   r2  = r1.sel(lon=plonc,lat=plat,method='nearest')
   f2['abm'] =  np.asarray(r2['abm'])
   r1.close() ; r2.close()

   #--  SLOPE
   r1  = xr.open_dataset(ftopostats)
   rlon=np.asarray(r1['LONGITUDE'])
   rlat=np.asarray(r1['LATITUDE'])
   # extract gridcell closest to plon/plat (data are 1-d) (lon [0,360])
   k1 = np.argmin(np.power(rlon - plon,2)+np.power(rlat - plat,2))
   f2['SLOPE'] = np.asarray(r1['SLOPE'][k1])
   r1.close() ; r2.close()

   #--  VIC
   r1  = xr.open_dataset(fvic)
   r1=r1.rename({'lsmlon':'lon','lsmlat':'lat'})
   # extract gridcell closest to plonc/plat
   r2  = r1.sel(lon=plonc,lat=plat,method='nearest')
   f2['Ws']    =  np.asarray(r2['Ws'])
   f2['Dsmax'] =  np.asarray(r2['Dsmax'])
   f2['Ds']    =  np.asarray(r2['Ds'])
   r1.close() ; r2.close()

  #--  Methane
   r1  = xr.open_dataset(fch4)
   # create 1d variables
   lon0=np.asarray(r1['LONGXY'][0,:])
   lon=xr.DataArray(lon0,name='lon',dims='lon',coords={'lon':lon0})
   lat0=np.asarray(r1['LATIXY'][:,0])
   lat=xr.DataArray(lat0,name='lat',dims='lat',coords={'lat':lat0})
   r1['lon'] = lon
   r1['lat'] = lat
   r1=r1.rename({'lsmlon':'lon','lsmlat':'lat'})
   # extract gridcell closest to plon/plat (this file is [0,360]
   r2  = r1.sel(lon=plon,lat=plat,method='nearest')
   f2['P3']   =  np.asarray(r2['P3'])
   f2['ZWT0'] =  np.asarray(r2['ZWT0'])
   f2['F0']   =  np.asarray(r2['F0'])
   r1.close() ; r2.close()

   #--  Soil Color
   r1  = xr.open_dataset(fsoicol)
   r1=r1.rename({'LON':'lon','LAT':'lat'})
   r1.set_coords(['lon','lat'],inplace=True)
   # extract gridcell closest to plonc/plat
   r2  = r1.sel(lon=plonc,lat=plat,method='nearest')
   f2['SOIL_COLOR'] =  np.asarray(r2['SOIL_COLOR'])
   r1.close() ; r2.close()

   #--  LAI / Height 
   r1  = xr.open_dataset(flai)
   r1=r1.rename({'LON':'lon','LAT':'lat'})
   r1.set_coords(['lon','lat'],inplace=True)
   # extract gridcell closest to plonc/plat
   r2  = r1.sel(lon=plonc,lat=plat,method='nearest')

   # ignoring crop, i.e. excluding index 15
   f2['MONTHLY_HEIGHT_BOT'][:,0:15,0,0] =  np.asarray(r2['MONTHLY_HEIGHT_BOT'][:,0:15])
   f2['MONTHLY_HEIGHT_TOP'][:,0:15,0,0] =  np.asarray(r2['MONTHLY_HEIGHT_TOP'][:,0:15])
   f2['MONTHLY_LAI'][:,0:15,0,0] =  np.asarray(r2['MONTHLY_LAI'][:,0:15])
   f2['MONTHLY_SAI'][:,0:15,0,0] =  np.asarray(r2['MONTHLY_SAI'][:,0:15])
   r1.close() ; r2.close()

   #--  Close output file
   # mode 'w' overwrites file
   f2.to_netcdf(path=fsurf2, mode='w')
   mprint('created file '+fsurf2)
   f1.close(); f2.close(); f2.close()
   mprint('successfully created surface data file: '+fsurf2)

