"""
This script will concatenate arrays of the era5 variables to take composites later
"""

import numpy as np
import xarray as xr
import pandas as pd


"""
There will be 2 loops, one for each year's dataset, and one for each storm in that 
particular year. Define the sizes of the dimensions according to the year, and use
appropriate conditional statements to keep the lengths and concatenations straight
"""
#AAW define NaN array that will contain number of storms in each year
#nstorms_yrs is one dimensional, length corresponding to the total number of years, which is 1980--2016
nstorms_yrs = np.empty((37)) 
nstorms_yrs[:] = np.nan
yrcount = 0 #AAW start a counter that will corresponding to the indices of nstorms_yrs, starting at 0
#first loop: for y, each year of the datasets we have
for y in range(1980, 2017):
    # print(y)
    #ds = xr.open_dataset(r'/home/cdirkes/era5/varbudget_output/test_varbudget.era5.'+str(y)+'.nc')
    ds = xr.open_dataset(r'/home/awing/era5/test_varbudget.era5.'+str(y)+'.nc') #AAW new data
    #just open one of the 4D arrays to make the sizing correct for each year 
    VAR=ds.h

    #conditional statement to make sure the number of storms is appropriate to start concatenating
    #if it's the first year, define the empty arrays as the original shape   
    if y == 1980:
        #define the sizes of the dimensions
        nlat = np.size(VAR,0)
        nlon = np.size(VAR,1)
        nstorms = np.size(VAR,2)
        ntracks = np.size(VAR,3)
        #create empty arrays of NaNs for all the composites
        h_comp = np.empty((nlat,nlon,nstorms,ntracks)) 
        h_comp[:] = np.nan
        
        SEF_comp = np.empty((nlat,nlon,nstorms,ntracks)) 
        SEF_comp[:] = np.nan
        
        LW_comp = np.empty((nlat,nlon,nstorms,ntracks)) 
        LW_comp[:] = np.nan
        
        SW_comp = np.empty((nlat,nlon,nstorms,ntracks)) 
        SW_comp[:] = np.nan
        
        hSEF_comp = np.empty((nlat,nlon,nstorms,ntracks)) 
        hSEF_comp[:] = np.nan 
        
        hLW_comp = np.empty((nlat,nlon,nstorms,ntracks)) 
        hLW_comp[:] = np.nan
        
        hSW_comp = np.empty((nlat,nlon,nstorms,ntracks)) 
        hSW_comp[:] = np.nan
        
        wind_comp = np.empty((nstorms,ntracks)) 
        wind_comp[:] = np.nan
        
        varh_comp = np.empty((nlat,nlon,nstorms,ntracks)) 
        varh_comp[:] = np.nan

        centerlat_comp = np.empty((nstorms,ntracks))
        centerlat_comp[:] = np.nan
        
        centerlon_comp = np.empty((nstorms,ntracks))
        centerlon_comp[:] = np.nan
        
        latbox_comp = np.empty((nlat,nstorms,ntracks)) 
        latbox_comp[:] = np.nan
        
        lonbox_comp = np.empty((nlon,nstorms,ntracks)) 
        lonbox_comp[:] = np.nan
        
    #elif it's NOT the first year, then start concatenating across the third dimension  
    #to make room for the extra storms 
    elif y != 1980:
        # VAR=ds.h
        nlat = np.size(VAR,0)
        nlon = np.size(VAR,1)
        nstorms = np.size(VAR,2)
        ntracks = np.size(VAR,3)
        
        empty = np.empty((nlat,nlon,nstorms,ntracks))
        empty[:] = np.nan
        
        emptywind = np.empty((nstorms,ntracks))
        emptywind[:] = np.nan
        
        emptylat = np.empty((nlat,nstorms,ntracks))
        emptylat[:] = np.nan

        emptylon = np.empty((nlon,nstorms,ntracks))
        emptylon[:] = np.nan
        
        h_comp = np.concatenate((h_comp,empty),axis=2)
        SEF_comp = np.concatenate((SEF_comp,empty),axis=2)
        LW_comp = np.concatenate((LW_comp,empty),axis=2)
        SW_comp = np.concatenate((SW_comp,empty),axis=2)
        hSEF_comp = np.concatenate((hSEF_comp,empty),axis=2)
        hLW_comp = np.concatenate((hLW_comp,empty),axis=2)
        hSW_comp = np.concatenate((hSW_comp,empty),axis=2)
        varh_comp = np.concatenate((varh_comp,empty),axis=2)
        
        wind_comp = np.concatenate((wind_comp,emptywind),axis=0)
        centerlat_comp = np.concatenate((centerlat_comp,emptywind),axis=0)
        centerlon_comp = np.concatenate((centerlon_comp,emptywind),axis=0)
        
        latbox_comp = np.concatenate((latbox_comp,emptylat),axis=1)
        lonbox_comp = np.concatenate((lonbox_comp,emptylon),axis=1)  
        # print(np.shape(h_comp))
    # print(nstorms)    
    nstorms_yrs[yrcount] = nstorms #AAW store the number of storms in this year in the corresponding index in nstorms_yrs   
    yrcount = yrcount + 1   
    
yrcount = 0 #AAW start a counter that will corresponding to the indices of nstorms_yrs, starting at 0        
for y in range(1980, 2017):
    # print(y)
    #ds = xr.open_dataset(r'/home/cdirkes/era5/varbudget_output/test_varbudget.era5.'+str(y)+'.nc')
    ds = xr.open_dataset(r'/home/awing/era5/test_varbudget.era5.'+str(y)+'.nc')
    # dslm = xr.open_dataset(r'/home/cdirkes/era5/landmask/landmask_output.'+str(y)+'.nc')
    #just open one of the 4D arrays to make the sizing correct for each year 
        
   #i in range of all the storms in that particular year (VAR[i,:] is i'th storm, all tracks)
    for i in range(0,np.size(ds.h,2)):
        wind = ds.wind.isel(nstorms=i)
        #if the first index is NaN, skip the storm
        VARi = np.isnan(wind)
        if VARi[0] == False:
            idxwind = pd.Index(wind)
            windlist = idxwind.tolist()
            searchwind = max(wind)
           
            iLMI = windlist.index(searchwind)
                
            #if it is the first year, fill the comp arrays from the appropriate nstorms indices    
            if y == 1980:    
                h_comp[:,:,i,ntracks-iLMI-1:ntracks] = ds.h.isel(nstorms=i,ntracks=slice(0,iLMI+1))        
                SEF_comp[:,:,i,ntracks-iLMI-1:ntracks] = ds.SEF.isel(nstorms=i,ntracks=slice(0,iLMI+1))  
                LW_comp[:,:,i,ntracks-iLMI-1:ntracks] = ds.LW.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                SW_comp[:,:,i,ntracks-iLMI-1:ntracks] = ds.SW.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                hSEF_comp[:,:,i,ntracks-iLMI-1:ntracks] = ds.hSEF.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                hLW_comp[:,:,i,ntracks-iLMI-1:ntracks] = ds.hLW.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                hSW_comp[:,:,i,ntracks-iLMI-1:ntracks] = ds.hSW.isel(nstorms=i,ntracks=slice(0,iLMI+1))   
                varh_comp[:,:,i,ntracks-iLMI-1:ntracks] = ds.varh.isel(nstorms=i,ntracks=slice(0,iLMI+1))  
                wind_comp[i,ntracks-iLMI-1:ntracks] = ds.wind.isel(nstorms=i,ntracks=slice(0,iLMI+1)) 
                centerlat_comp[i,ntracks-iLMI-1:ntracks] = ds.centerlat.isel(nstorms=i,ntracks=slice(0,iLMI+1)) 
                centerlon_comp[i,ntracks-iLMI-1:ntracks] = ds.centerlon.isel(nstorms=i,ntracks=slice(0,iLMI+1)) 
                latbox_comp[:,i,ntracks-iLMI-1:ntracks] = ds.latitude.isel(nstorms=i,ntracks=slice(0,iLMI+1)) 
                lonbox_comp[:,i,ntracks-iLMI-1:ntracks] = ds.longitude.isel(nstorms=i,ntracks=slice(0,iLMI+1)) 
                
                # landmask_comp[i,ntracks-iLMI-1:ntracks] = dslm.landmask.isel(nstorms=i,ntracks=slice(0,iLMI+1)) 
                #landmask versions
          
                
           #if it's not the first year, fill in starting where the previous nstorms left off     
            elif y != 1980:
                nstorms_prev = int(np.sum(nstorms_yrs[0:yrcount])) #grab the number of storms from the previous year (the index of our current year is yrcount, so the index of the previous year is yrcount-1)
                # print(nstorms_prev+i)
                h_comp[:,:,nstorms_prev + i,ntracks-iLMI-1:ntracks] = ds.h.isel(nstorms=i,ntracks=slice(0,iLMI+1))        
                SEF_comp[:,:,nstorms_prev + i,ntracks-iLMI-1:ntracks] = ds.SEF.isel(nstorms=i,ntracks=slice(0,iLMI+1))  
                LW_comp[:,:,nstorms_prev + i,ntracks-iLMI-1:ntracks] = ds.LW.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                SW_comp[:,:,nstorms_prev + i,ntracks-iLMI-1:ntracks] = ds.SW.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                hSEF_comp[:,:,nstorms_prev + i,ntracks-iLMI-1:ntracks] = ds.hSEF.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                hLW_comp[:,:,nstorms_prev + i,ntracks-iLMI-1:ntracks] = ds.hLW.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                hSW_comp[:,:,nstorms_prev + i,ntracks-iLMI-1:ntracks] = ds.hSW.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                wind_comp[nstorms_prev + i,ntracks-iLMI-1:ntracks] = ds.wind.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                centerlat_comp[nstorms_prev + i,ntracks-iLMI-1:ntracks] = ds.centerlat.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                centerlon_comp[nstorms_prev + i,ntracks-iLMI-1:ntracks] = ds.centerlon.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                varh_comp[:,:,nstorms_prev + i,ntracks-iLMI-1:ntracks] = ds.varh.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                latbox_comp[:,nstorms_prev + i,ntracks-iLMI-1:ntracks] = ds.latitude.isel(nstorms=i,ntracks=slice(0,iLMI+1))
                lonbox_comp[:,nstorms_prev + i,ntracks-iLMI-1:ntracks] = ds.longitude.isel(nstorms=i,ntracks=slice(0,iLMI+1))# landmask_comp[:,:,nstorms_prev + i,ntracks-iLMI-1:ntracks] = dslm.landmask.isel(nstorms=i,ntracks=slice(0,iLMI+1))
            
    #now that we are at the end of the loop over years, advance the index counter by 1
    yrcount = yrcount + 1 #should be *inside* the loop over years but outside (and after) everything else. 
    # print(np.shape(h_comp))
"""
outside the loop we will now take the composite averages over the 3rd dimension.
will be easier to write out 7 variables into a single netCDF file so that I
only have to run the long part of this script once
"""
h_comp = xr.core.dataarray.DataArray(h_comp, dims=['nlat', 'nlon', 'nstorms', 'ntracks'])
SEF_comp = xr.core.dataarray.DataArray(SEF_comp, dims=['nlat', 'nlon', 'nstorms','ntracks'])
LW_comp = xr.core.dataarray.DataArray(LW_comp, dims=['nlat', 'nlon', 'nstorms','ntracks'])
SW_comp = xr.core.dataarray.DataArray(SW_comp, dims=['nlat', 'nlon', 'nstorms','ntracks'])

hSEF_comp = xr.core.dataarray.DataArray(hSEF_comp, dims=['nlat', 'nlon', 'nstorms','ntracks'])
hLW_comp = xr.core.dataarray.DataArray(hLW_comp, dims=['nlat', 'nlon', 'nstorms','ntracks'])
hSW_comp = xr.core.dataarray.DataArray(hSW_comp, dims=['nlat', 'nlon', 'nstorms','ntracks'])
wind_comp = xr.core.dataarray.DataArray(wind_comp, dims=['nstorms','ntracks'])
centerlat_comp = xr.core.dataarray.DataArray(centerlat_comp, dims=['nstorms','ntracks'])
centerlon_comp = xr.core.dataarray.DataArray(centerlon_comp, dims=['nstorms','ntracks'])
latbox_comp =  xr.core.dataarray.DataArray(latbox_comp, dims=['nlat','nstorms','ntracks'])
lonbox_comp = xr.core.dataarray.DataArray(lonbox_comp, dims=['nlon','nstorms','ntracks'])
varh_comp = xr.core.dataarray.DataArray(varh_comp, dims=['nlat', 'nlon', 'nstorms','ntracks'])


#save the variables as their large arrays 
h_comp.attrs['units']='J/m^2'
h_comp.attrs['long_name']= 'concatenated array of h across all storms'
h_comp.attrs['_FillValue']=-9999
h_comp.attrs['GridType']='0.25 x 0.25 deg Grid'

SEF_comp.attrs['units']='J/m^2'
SEF_comp.attrs['long_name']= 'concatenated array of SEF across all storms'
SEF_comp.attrs['_FillValue']=-9999
SEF_comp.attrs['GridType']='0.25 x 0.25 deg Grid'

LW_comp.attrs['units']='J/m^2'
LW_comp.attrs['long_name']= 'concatenated array of LW across all storms'
LW_comp.attrs['_FillValue']=-9999
LW_comp.attrs['GridType']='0.25 x 0.25 deg Grid'

SW_comp.attrs['units']='J/m^2'
SW_comp.attrs['long_name']= 'concatenated array of SW across all storms'
SW_comp.attrs['_FillValue']=-9999
SW_comp.attrs['GridType']='0.25 x 0.25 deg Grid'

hSEF_comp.attrs['units']='J/m^2'
hSEF_comp.attrs['long_name']= 'concatenated array of hSEF across all storms'
hSEF_comp.attrs['_FillValue']=-9999
hSEF_comp.attrs['GridType']='0.25 x 0.25 deg Grid'

hLW_comp.attrs['units']='J/m^2'
hLW_comp.attrs['long_name']= 'concatenated array of hLW across all storms'
hLW_comp.attrs['_FillValue']=-9999
hLW_comp.attrs['GridType']='0.25 x 0.25 deg Grid'

hSW_comp.attrs['units']='J/m^2'
hSW_comp.attrs['long_name']= 'concatenated array of hSW across all storms'
hSW_comp.attrs['_FillValue']=-9999
hSW_comp.attrs['GridType']='0.25 x 0.25 deg Grid'

wind_comp.attrs['units']='J/m^2'
wind_comp.attrs['long_name']= 'concatenated array of wind across all storms'
wind_comp.attrs['_FillValue']=-9999
wind_comp.attrs['GridType']='0.25 x 0.25 deg Grid'

centerlat_comp.attrs['units']='degrees'
centerlat_comp.attrs['long_name']= 'concatenated array of center latitude across all storms'
centerlat_comp.attrs['_FillValue']=-9999
centerlat_comp.attrs['GridType']='0.25 x 0.25 deg Grid'

centerlon_comp.attrs['units']='degrees'
centerlon_comp.attrs['long_name']= 'concatenated array of center longitude across all storms'
centerlon_comp.attrs['_FillValue']=-9999
centerlon_comp.attrs['GridType']='0.25 x 0.25 deg Grid'

varh_comp.attrs['units']='J^2/m^4'
varh_comp.attrs['long_name']= 'concatenated varh + landmask across all storms'
varh_comp.attrs['_FillValue']=-9999
varh_comp.attrs['GridType']='0.25 x 0.25 deg Grid'

latbox_comp.attrs['units']='degrees'
latbox_comp.attrs['long_name']= 'concatenated array of 10 deg latbox across all storms'
latbox_comp.attrs['_FillValue']=-9999
latbox_comp.attrs['GridType']='0.25 x 0.25 deg Grid'

lonbox_comp.attrs['units']='degrees'
lonbox_comp.attrs['long_name']= 'concatenated array of 10 deg lonbox across all storms'
lonbox_comp.attrs['_FillValue']=-9999
lonbox_comp.attrs['GridType']='0.25 x 0.25 deg Grid'

composite_ds = xr.Dataset({'h_concat':h_comp, 'SEF_concat':SEF_comp, 'LW_concat':LW_comp, \
                            'SW_concat':SW_comp, 'hSEF_concat':hSEF_comp, 'hLW_concat':hLW_comp, \
                            'hSW_concat':hSW_comp, 'wind_concat':wind_comp, 'centerlat_concat':centerlat_comp,\
                            'centerlon_concat':centerlon_comp, 'varh_concat':varh_comp, 'latbox_concat':latbox_comp, \
                            'lonbox_concat':lonbox_comp})
#composite_ds.to_netcdf('/home/cdirkes/era5/landmask/vars/test.era5.lmconcats.nc', 'w', format='NETCDF4')
composite_ds.to_netcdf('/home/awing/AAW.era5.lmconcats.nc', 'w', format='NETCDF4')    



