"""
This is the new way of selecting latbox and lonbox of the ERA5 
reanalysis data set. Selected by finding the index of the TempestTrack lat and 
lon points and selecting a box that is a certain number of grid points
below and above the center in either direction for a latbox and lonbox
that is approximately 10 degrees.
"""
import numpy as np 
import pandas as pd
import xarray as xr
import sys
import netCDF4

# define the filepath of where to find the tracks
tempest_filepath = r'/home/awing/ibtracs/trajectories.txt.ERA5'

def boxavg(thing,lat,lon):
    """
    Calculate cosine(latitude)-weighted box average 
    e.g., h_avg=boxavg(h,lat)
    Important: h and lat must already be restricted to box of interest
    """
    
#       Define weights, *(idata*0+1) magically repeats it to
#       match dimensions of idata
    # wgt = np.cos(np.deg2rad(lat))*(thing*0+1)
    # thing_avg = (wgt*thing).sum()/wgt.sum() #sum over all grid points in box
    coslat_values=np.transpose(np.tile(np.cos(np.deg2rad(lat)),(len(lon),1)))
    thing1 = thing*coslat_values
    thing2 = thing1/thing1
    average=np.nansum(np.nansum(thing1,0))/np.nansum(np.nansum(coslat_values*thing2,0))

    return average

start_year = 1980 # beginning of most TempestExtremes tracks is 1980
end_year = 2018 # depends on how far the dataset goes
#----------Get the data organized to ultimately group by storm-----------#


df = pd.read_csv(tempest_filepath,sep='\t',header=None,names=\
                 ['flag','num_timesteps','dummy','lon','lat','pressure (hPa)',\
                  'windspeed (m/s)','dummy2','year','month','day','hour'])
# drop columns we don't need
df = df.drop(columns=(['dummy','dummy2'])) 
# add in a flag so that it knows when to 
df_starts = df[df['flag']=='start']
# start with assigning a stormID of 1 for the dataset 
storm_id = 1
# loop through each storm at a time and add in the appropriate storm_ids
for idx, num_steps in zip(df_starts.index, df_starts['num_timesteps'].values):
    # add in a column for the storm ID, arbitrary for each dataset
    df.loc[idx:idx+num_steps+1,'stormid'] = storm_id 
    # add 1 to the stormID each time you get to the end of a particular storm\
    # track to continue the loop
    storm_id += 1   

# identify and remove the flag column and num_timesteps column since we've \
# already added in the stormIDS
df = df[df['flag']!='start']
df = df.drop(columns=(['flag','num_timesteps'])).reset_index(drop=True)

# do some formatting to make some of the column headings formatted the way I want them
df.loc[:,'pressure (hPa)'] = df.loc[:,'pressure (hPa)'] / 100.
df.loc[:,'year'] = df.loc[:,'year'].astype(int).astype(str)
df.loc[:,'month'] = df.loc[:,'month'].astype(int).astype(str)
df.loc[:,'day'] = df.loc[:,'day'].astype(int).astype(str)
df.loc[:,'hour'] = df.loc[:,'hour'].astype(int).astype(str)
df.loc[:,'windspeed (m/s)'] = df.loc[:,'windspeed (m/s)']
df.loc[:,'hour'] = np.where(df['hour'].astype(int)<10,'0'+ df['hour'], df['hour'])

# create a timestamp for each time step so that it is easier to read in later
df.loc[:,'datetime'] = pd.to_datetime(df['year'] + ' ' + df['month'] + ' ' + df['day'] + ' ' + df['hour'])

#------------Now create a list of storm ID's by year for looping through\
#    and matching to the data----------#

ibtracs_by_year = {year: [] for year in range(start_year,end_year+1)} # make \
#an empty array of storm tracks by year, where you add 2 just in case

for storm in range(1,3086): # range of storms to loop through to generate \
#    list of storms by year. Note I manually found the last stormID
    sys.stdout.write(f'\r{storm}/3086') #AAW 3086 storms through 2018
    # show the progress in the command line
    
    # get the list of characteristics for the storm ID you're currently on
    ds_storm = df[df['stormid']==storm]
    # get the unique list of times for that storm
    times = ds_storm['datetime'].values
    # do some conversion to make it easier to work with
    times = pd.to_datetime(times) # convert pandas time to datetimeindex
    times = times.to_pydatetime() # convert to datetime.datetime object
    # only include times for which the dataset is valid
    if (times[0].year < start_year and times[-1].year < start_year) or (\
    times[0].year > end_year and times[-1].year > end_year): 
        continue
    second_part_flag = False
    if times[0].year != times[-1].year: # this is to check if a storm starts in one year and ends in another, testing the first and last time in the storm
        ibtracs_by_year[times[-1].year].append(storm) # append the extra storm to the list to ensure all times are included in the year
    ibtracs_by_year[times[0].year].append(storm) # append the list of storms to the correct year, including ones that have parts of two years

# make a list to loop through each year--not the most efficient but I did this a long time ago haha
year_list=[] # make a list of years
for key in ibtracs_by_year:
    year_list.append(key)

"""-----------------------------------------------------------------------------------
This is where we start our main loops to extract all the data ------------------------
-----------------------------------------------------------------------------------"""
print('\n')
for current_year in range(1980,1990): #range(1980,2017)

    storm_times=[] # create an empty list to append to of times for all storms in the current year
    #outside the storm loop but inside the year loop rewrite the SEF as a function of lat and lon box
    
    #define the NaN arrays for various variable anomalies. don't need averages. also and boxlon
    # save both the variable and the variable anomaly, write 2 files, one w regular vars and one w anomalies
    #will need len(ib(current))
    
    #nlat and nlon will be whatever grid spacing makes a 10x10 degree box
    #for ERA5, lat and lon both have 0.25 degree grid spacing
    nlat = 41
    nlon = 41
    
    # nstorms = number of storms in that year --> in this case it is 41
    nstorms = len(ibtracs_by_year[current_year])
    # print(nstorms)
    
    #ntracks is maximum possible number of "observations" along the track of a storm
    #in this case the max num_timesteps is 86
    ntracks = max(df_starts['num_timesteps'])   
    # print(ntracks)
    
    #We need to make an empty 4D array of NaNs that have the dimensions lat x lon x nstorms x ntracks
    #want to store the regular variable and the box anomaly of net SW, net LW, and , and all those with hanom
    SWanomsave = np.empty((nlat,nlon,nstorms,ntracks)) 
    SWanomsave[:] = np.nan
    LWanomsave = np.empty((nlat,nlon,nstorms,ntracks)) 
    LWanomsave[:] = np.nan
    SEFanomsave = np.empty((nlat,nlon,nstorms,ntracks)) 
    SEFanomsave[:] = np.nan
    SWsave = np.empty((nlat,nlon,nstorms,ntracks)) 
    SWsave[:] = np.nan
    LWsave = np.empty((nlat,nlon,nstorms,ntracks)) 
    LWsave[:] = np.nan
    SEFsave = np.empty((nlat,nlon,nstorms,ntracks)) 
    SEFsave[:] = np.nan
    lhfsave = np.empty((nlat,nlon,nstorms,ntracks)) 
    lhfsave[:] = np.nan
    shfsave = np.empty((nlat,nlon,nstorms,ntracks)) 
    shfsave[:] = np.nan
    
    boxlat = np.empty((nlat,nstorms,ntracks))
    boxlon = np.empty((nlon,nstorms,ntracks))
    boxlat[:] = np.nan
    boxlon[:] = np.nan
    
    yearsave = np.empty((nstorms, ntracks))
    yearsave[:] = np.nan   
    monthsave = np.empty((nstorms, ntracks))
    monthsave[:] = np.nan  
    daysave = np.empty((nstorms, ntracks))
    daysave[:] = np.nan    
    hoursave = np.empty((nstorms, ntracks))
    hoursave[:] = np.nan
    
    windsave = np.empty((nstorms, ntracks))
    windsave[:] = np.nan 
    
    hsave = np.empty((nlat,nlon,nstorms,ntracks)) 
    hsave[:] = np.nan
    
    varhsave = np.empty((nlat,nlon,nstorms,ntracks)) 
    varhsave[:] = np.nan
    
    hSWsave = np.empty((nlat,nlon,nstorms,ntracks)) 
    hSWsave[:] = np.nan
    hLWsave = np.empty((nlat,nlon,nstorms,ntracks)) 
    hLWsave[:] = np.nan   
    hSEFsave = np.empty((nlat,nlon,nstorms,ntracks)) 
    hSEFsave[:] = np.nan
    
    centerlatsave = np.empty((nstorms,ntracks))
    centerlatsave[:] = np.nan
    centerlonsave = np.empty((nstorms,ntracks))
    centerlonsave[:] = np.nan
    
    
    for s, storm in enumerate(ibtracs_by_year[current_year]): # start looping through every storm 
        # print(s,storm)
        # print('current year'+str(current_year))
        sys.stdout.write(f'\r{s}/{nstorms}')
        # generate a list of times valid for the whole year

        # this is a repeat of above where you're just formatting the times and making them how you want to be
        ds_storm = df[df['stormid']==storm]
        time_list = [tstamp.to_pydatetime() for tstamp in pd.to_datetime(ds_storm['datetime'].values)] # make a timelist and convert to datetime.datetime
        time_list = [t.replace(second=0, microsecond=0) for t in time_list]   # Round to nearest minute to avoid weird rounding errors
        storm_times.append(time_list) # append the time_list for each storm to the overall storm_times list altogether

        # Import ibtracs variables
        track_lat = ds_storm['lat'].values
        track_lon = ds_storm['lon'].values
        sid = ds_storm['stormid'].values # import storm ID values from USA dataset
        wind = ds_storm['windspeed (m/s)'].values # import wind speeds from USA dataset
        mslp = ds_storm['pressure (hPa)'].values # import mean sea level pressure from USA dataset in mb

        #################################################################################################

        # Loop through each time and assign the flux variables to the TC
        
        #################################################################################################
     
        time_datasets=[] # empty array of datasets by time        
        used_times = [] # list of times used

        for t, time in enumerate(time_list):
       
            if time_list[t].year != current_year: # don't accumulate precip for time steps in years other than current one
#*** might not need
                continue
            #doesn't matter which file we choose bc they all have the same lat/lon spacing 
            dslat = xr.open_dataset(r'/mars/tank3/parfitt/era5/forecasts/RAD/redownloaded_fluxes/1980-01.nc') #AAW new path
            latvar = np.array(dslat.latitude)
            lonvar = np.array(dslat.longitude)

            idxlat = pd.Index(latvar)
            idxlon = pd.Index(lonvar)
            latlist = idxlat.tolist()
            lonlist = idxlon.tolist()
            
            searchlat = dslat.latitude.sel(latitude = track_lat[t], method = 'nearest')
            ilat = latlist.index(searchlat)
        
            searchlon = dslat.longitude.sel(longitude = track_lon[t], method = 'nearest')
            ilon = lonlist.index(searchlon)

            ilat_up = ilat + 21
            ilat_down = ilat - 20
            ilon_up = ilon + 21
            ilon_down = ilon - 20
            
            max_lon = int(np.size(lonvar)-1)
            
            latbox = np.array(dslat.latitude.isel(latitude = slice(ilat_down,ilat_up)))            
            lonbox = np.array(dslat.longitude.isel(longitude = slice(ilon_down,ilon_up)))

            centerlat = track_lat[t]
            centerlon = track_lon[t]
            
            # get the formats of the dates and stuff the way you want them
            ym = time.strftime('%Y%m') # convert the year month time from datetime.datetime to a string
            ym = str(ym) # make the ym a string to be able to compare to the file name
            year = time.strftime('%Y') # convert the year from datetime.datetime to a string
            month = time.strftime('%m') # convert the month to a string
            month = month.zfill(2) #this adds a floating zero in front of the month
            day = time.strftime('%d')
            hour = time.strftime('%H') # convert the hour to a string
            m_minus1 = int(month)-1
            m_minus1 = str(m_minus1)
            m_minus1 = m_minus1.zfill(2) #this adds a floating zero in front of the month if it is one digit
            y_minus1 = int(year)-1
            y_minus1 = str(y_minus1)
            d_minus1 = int(day)-1
            d_minus1 = str(d_minus1)
            d_minus1 = d_minus1.zfill(2) #adds a floating zero if one digit
            MM = month.zfill(2)
            HH = hour.zfill(2)
            DD = day.zfill(2)
            YYYY = year
            MM_1 = m_minus1
            DD_1 = d_minus1
            YYYY_1 = y_minus1
     

            filepath = r'/mars/tank3/parfitt/era5/forecasts/RAD/redownloaded_fluxes/' #contains rad fluxes #AAW new path
            filepathsef = r'/mars/tank3/parfitt/era5/forecasts/RAD/lhf_shf/' #contains surface fluxes #AAW new path
                
    #Option 1: 00Z Jan 1 NOT 1980            
            if time_list[t].year != 1980 and time_list[t].month == 1 and time_list[t].day == 1 and time_list[t].hour == 0:
                #open the previous year's december file
                filename = '%s-12.nc' % (YYYY)
                dssef = xr.open_dataset(filepathsef + filename) #AAW new path
                ds = xr.open_dataset(filepath + filename) #AAW new path
                
                lhf1 = dssef.slhf.isel(time = -6, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                lhf2 = dssef.slhf.isel(time = -5, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                lhf = (lhf1 + lhf2)/(-3600*2)
                
                shf1 = dssef.sshf.isel(time = -6, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                shf2 = dssef.sshf.isel(time = -5, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                shf = (shf1 + shf2)/(-3600*2)

                ssw1 = ds.ssr.isel(time = -6, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                ssw2 = ds.ssr.isel(time = -5, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                ssw = (ssw1 + ssw2)/(3600*2)

                slw1 = ds.str.isel(time = -6, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                slw2 = ds.str.isel(time = -5, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                slw = (slw1 + slw2)/(3600*2)

                tsw1 = ds.tsr.isel(time = -6, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                tsw2 = ds.tsr.isel(time = -5, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                tsw = (tsw1 + tsw2)/(3600*2)

                tlw1 = ds.ttr.isel(time = -6, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                tlw2 = ds.ttr.isel(time = -5, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                tlw = (tlw1 + tlw2)/(3600*2)                
                
                
    #Option 2: 00Z on the first of the month NOT in January            
            elif time_list[t].month != 1 and time_list[t].day == 1 and time_list[t].hour == 0:
                #open the previous MONTH's file
                filename = '%s-%s.nc' % (YYYY,MM_1)
                dssef = xr.open_dataset(filepathsef + filename) #AAW newpath
                ds = xr.open_dataset(filepath + filename) #AAW new path
                
                lhf1 = dssef.slhf.isel(time = -6, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                lhf2 = dssef.slhf.isel(time = -5, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                lhf = (lhf1 + lhf2)/(-3600*2)
                
                shf1 = dssef.sshf.isel(time = -6, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                shf2 = dssef.sshf.isel(time = -5, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                shf = (shf1 + shf2)/(-3600*2)

                ssw1 = ds.ssr.isel(time = -6, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                ssw2 = ds.ssr.isel(time = -5, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                ssw = (ssw1 + ssw2)/(3600*2)

                slw1 = ds.str.isel(time = -6, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                slw2 = ds.str.isel(time = -5, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                slw = (slw1 + slw2)/(3600*2)

                tsw1 = ds.tsr.isel(time = -6, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                tsw2 = ds.tsr.isel(time = -5, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                tsw = (tsw1 + tsw2)/(3600*2)

                tlw1 = ds.ttr.isel(time = -6, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                tlw2 = ds.ttr.isel(time = -5, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                tlw = (tlw1 + tlw2)/(3600*2)   
                
    #Option 3: 00Z NOT first of the month, 6z, 12z, or 18z
                #open the file from that month and that day 
            elif  (time_list[t].day != 1 and time_list[t].hour == 0) or time_list[t].hour == 6 or \
                time_list[t].hour == 12 or time_list[t].hour == 18:
                filename = '%s-%s.nc' % (YYYY,MM)
                dssef = xr.open_dataset(filepathsef + filename) #AAW new path
                ds = xr.open_dataset(filepath + filename) #AAW new path
                timevar = pd.to_datetime(ds.time.values)
                timevar = timevar.to_pydatetime() 
          
            # select the index where the times from tempest and the dataset match
                for x in range(0,np.size(ds.time)):
                    if time_list[t] == timevar[x]:
                        index1 = x
                
                index2 = index1 + 1

                lhf1 = dssef.slhf.isel(time = index1, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                lhf2 = dssef.slhf.isel(time = index2, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                lhf = (lhf1 + lhf2)/(-3600*2)
                
                shf1 = dssef.sshf.isel(time = index1, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                shf2 = dssef.sshf.isel(time = index2, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                shf = (shf1 + shf2)/(-3600*2)

                ssw1 = ds.ssr.isel(time = index1, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                ssw2 = ds.ssr.isel(time = index2, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                ssw = (ssw1 + ssw2)/(3600*2)

                slw1 = ds.str.isel(time = index1, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                slw2 = ds.str.isel(time = index2, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                slw = (slw1 + slw2)/(3600*2)

                tsw1 = ds.tsr.isel(time = index1, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                tsw2 = ds.tsr.isel(time = index2, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                tsw = (tsw1 + tsw2)/(3600*2)

                tlw1 = ds.ttr.isel(time = index1, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                tlw2 = ds.ttr.isel(time = index2, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))  
                tlw = (tlw1 + tlw2)/(3600*2)   
                
                # print(index1, index2, time_list[t])
                
                # print(tsw-ssw)
            LW = tlw - slw
            SW = tsw - ssw
            SEF = lhf + shf
            # print(SEF)
                
# ------------------------------------------------------------------------------------------------ 
# conditional statements to read in the variables depending on which day, month, and time it is.
#
# ------------------------------------------------------------------------------------------------        
         
            
#--------------------------------------------------------------------------------------#
#                                 THIS ENDS THE CONDITIONALS                           #
#--------------------------------------------------------------------------------------#
                
            # still inside innermost loop over storms but out of the conditonal
            # read in h from huracan 
            # now define anoms with the 4D indices    
   
            hpath = '/huracan/tank2/columbia/reanalysis/era5/2D/h/'
            hfilename = 'era5.h.%s%s%s%s.nc' % (YYYY,MM,DD,HH)
            dsh = xr.open_dataset(hpath + hfilename)
            h = dsh.h.isel(latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))

            #landmask stuff
            dslm = xr.open_dataset(r'/mars/tank3/parfitt/era5/landmask/land-sea_mask.nc')
            lm = dslm.lsm.isel(time = 0)#, latitude = slice(ilat_down,ilat_up), longitude = slice(ilon_down,ilon_up))
            lonlm = dslm.longitude #ead in longitude variable from that dataset
            latlm = dslm.latitude #read in latitude variable from that dataset
            idxlatlm = pd.Index(latlm)
            idxlonlm = pd.Index(lonlm)
            latlistlm = idxlatlm.tolist()
            lonlistlm = idxlonlm.tolist()
            lsm = np.zeros((nlat,nlon))
            for a in range(0,len(latbox)):
                for o in range(0,len(lonbox)):
                    testlat = np.array(latbox[a])
                    testlon = np.array(lonbox[o])
                    
                    searchlatlm = np.array(dslm.latitude.sel(latitude = testlat, method = 'nearest'))
                    # print(a, testlat, np.array(searchlat),np.array(searchlatlm))
                    lmlat = latlistlm.index(searchlatlm)
                    
                    searchlonlm = np.array(dslm.longitude.sel(longitude = testlon, method = 'nearest'))
                    # print(o, testlon, np.array(searchlon),np.array(searchlonlm))
                    lmlon = lonlistlm.index(searchlonlm)
                    if lm[lmlat,lmlon] > 0.2:
                        lsm[a,o] = np.nan   #(20 x nstorms x ntracks)
   
            lsm = np.array(lsm)
            # print(lsm)
            h = np.array(h)
            SEF = np.array(SEF)
            SW = np.array(SW)
            LW = np.array(LW)
            
            if (np.shape(lsm) != np.shape(SW)) or (np.shape(lsm) != np.shape(h)) or (np.shape(lsm) != np.shape(LW)) or (np.shape(lsm) != np.shape(SEF)):
                lsm = lsm[:,0:(np.size(lonbox))]              
            
            havg = boxavg(h, latbox, lonbox)
            SWavg = boxavg(SW, latbox, lonbox)
            LWavg = boxavg(LW, latbox, lonbox)
            SEFavg = boxavg(SEF, latbox, lonbox) 
            
            h = h + lsm
            SEF = SEF + lsm
            SW = SW + lsm
            LW = LW + lsm

            hanom = h - havg
            hsave[0:len(latbox),0:len(lonbox),s,t] = h
            varhsave[0:len(latbox),0:len(lonbox),s,t] = np.multiply(np.array(hanom), np.array(hanom))
                            
            SWanom = SW - SWavg
            SWanomsave[0:len(latbox),0:len(lonbox),s,t] = SW - SWavg
            
            LWanom = LW - LWavg
            LWanomsave[0:len(latbox),0:len(lonbox),s,t] = LW - LWavg
            
            SEFanom = SEF - SEFavg
            SEFanomsave[0:len(latbox),0:len(lonbox),s,t] = SEF - SEFavg
            
            windsave[s,t] = wind[t]

            boxlat[0:len(latbox),s,t] = latbox
            boxlon[0:len(lonbox), s, t] = lonbox
            
            centerlatsave[s,t] = track_lat[t]
            centerlonsave[s,t] = track_lon[t]
              #change len to latbox and longbox    
            SWsave[0:len(latbox),0:len(lonbox),s,t] = SW
            LWsave[0:len(latbox),0:len(lonbox),s,t] = LW
            SEFsave[0:len(latbox),0:len(lonbox),s,t] = SEF
            lhfsave[0:len(latbox),0:len(lonbox),s,t] = lhf
            shfsave[0:len(latbox),0:len(lonbox),s,t] = shf
            
            hSWsave[0:len(latbox),0:len(lonbox),s,t] = np.multiply(np.array(hanom), np.array(SWanom))
            hLWsave[0:len(latbox),0:len(lonbox),s,t] = np.multiply(np.array(hanom), np.array(LWanom))
            hSEFsave[0:len(latbox),0:len(lonbox),s,t] = np.multiply(np.array(hanom), np.array(SEFanom))
                     
            yearsave[s,t] = time_list[t].year
            monthsave[s,t] = time_list[t].month
            daysave[s,t] = time_list[t].day
            hoursave[s,t] = time_list[t].hour
                        
            #convert them all into xarrays
            SWsave = xr.core.dataarray.DataArray(SWsave, dims=['nlat', 'nlon', 'nstorms','ntracks'])
            LWsave = xr.core.dataarray.DataArray(LWsave, dims=['nlat', 'nlon', 'nstorms','ntracks'])
            SEFsave = xr.core.dataarray.DataArray(SEFsave, dims=['nlat', 'nlon', 'nstorms','ntracks'])
            lhfsave = xr.core.dataarray.DataArray(lhfsave, dims=['nlat', 'nlon', 'nstorms','ntracks'])
            shfsave = xr.core.dataarray.DataArray(shfsave, dims=['nlat', 'nlon', 'nstorms','ntracks'])            
            
            SWanomsave = xr.core.dataarray.DataArray(SWanomsave, dims=['nlat', 'nlon', 'nstorms','ntracks'])
            LWanomsave = xr.core.dataarray.DataArray(LWanomsave, dims=['nlat', 'nlon', 'nstorms','ntracks'])
            SEFanomsave = xr.core.dataarray.DataArray(SEFanomsave, dims=['nlat', 'nlon', 'nstorms','ntracks'])
            
            hsave = xr.core.dataarray.DataArray(hsave, dims=['nlat', 'nlon', 'nstorms', 'ntracks'])
            varhsave = xr.core.dataarray.DataArray(varhsave, dims=['nlat', 'nlon', 'nstorms','ntracks'])
            hSWsave = xr.core.dataarray.DataArray(hSWsave, dims=['nlat', 'nlon', 'nstorms','ntracks'])
            hLWsave = xr.core.dataarray.DataArray(hLWsave, dims=['nlat', 'nlon', 'nstorms','ntracks'])
            hSEFsave = xr.core.dataarray.DataArray(hSEFsave, dims=['nlat', 'nlon', 'nstorms','ntracks'])

            
            windsave = xr.core.dataarray.DataArray(windsave, dims=['nstorms', 'ntracks'])

            boxlat = xr.core.dataarray.DataArray(boxlat, dims=['nlat', 'nstorms','ntracks'])
            boxlon = xr.core.dataarray.DataArray(boxlon, dims=['nlon', 'nstorms','ntracks'])
            centerlatsave = xr.core.dataarray.DataArray(centerlatsave, dims=['nstorms','ntracks'])
            centerlonsave = xr.core.dataarray.DataArray(centerlonsave, dims=['nstorms','ntracks'])
            
            yearsave = xr.core.dataarray.DataArray(yearsave, dims=['nstorms', 'ntracks'])
            monthsave = xr.core.dataarray.DataArray(monthsave, dims=['nstorms', 'ntracks'])
            daysave = xr.core.dataarray.DataArray(daysave, dims=['nstorms', 'ntracks'])
            hoursave = xr.core.dataarray.DataArray(hoursave, dims=['nstorms', 'ntracks'])
            
    """
    Write variable (already defined as a data array) to a netcdf file. making sure that it
    is within the first nested loop over all the years but not within the storms or tracks
    Need: SEF, SW, LW, all their anomalies, and lat and lon
    """
    #save all the units and attributes of each variable
    yearsave.attrs['units'] = 'year of specific storm'
    yearsave.attrs['long_name'] = 'year'
    
    monthsave.attrs['units'] = 'month of specific storm'
    monthsave.attrs['long_name'] = 'month'
    
    daysave.attrs['units'] = 'day of specific storm'
    daysave.attrs['long_name'] = 'day'
    
    hoursave.attrs['units'] = '6hr accumulation centered around hour'
    hoursave.attrs['long_name'] = 'hour, UTC'

    SWsave.attrs['units']='W/m^2'
    SWsave.attrs['long_name']='Net shortwave flux'
    SWsave.attrs['_FillValue']=-9999
    SWsave.attrs['GridType']='Lat/lon grid'
    
    LWsave.attrs['units']='W/m^2'
    LWsave.attrs['long_name']='Net longwave flux'
    LWsave.attrs['_FillValue']=-9999
    LWsave.attrs['GridType']='Lat/lon Grid'
  
    SEFsave.attrs['units']='W/m^2'
    SEFsave.attrs['long_name']='Surface enthalpy flux'
    SEFsave.attrs['_FillValue']=-9999
    SEFsave.attrs['GridType']='Lat/lon Grid'
    
    hsave.attrs['units']='J/m^2'
    hsave.attrs['long_name']='column integrated frozen moist static energy'
    hsave.attrs['_FillValue']=-9999
    hsave.attrs['GridType']='Lat/lon Grid'
    
    varhsave.attrs['units']='J^2/m^4'
    varhsave.attrs['long_name']='variance of column integrated FMSE'
    varhsave.attrs['_FillValue']=-9999
    varhsave.attrs['GridType']='Lat/lon Grid'    
    
    hSWsave.attrs['units']='J^2 m^-4 s^-1'
    hSWsave.attrs['long_name']= 'product of hanom and SWanom'
    hSWsave.attrs['_FillValue']=-9999
    hSWsave.attrs['GridType']='Lat/lon Grid'    
    
    hLWsave.attrs['units']='J^2 m^-4 s^-1'
    hLWsave.attrs['long_name']='product of hanom and LWanom'
    hLWsave.attrs['_FillValue']=-9999
    hLWsave.attrs['GridType']='Lat/lon Grid'   
    
    hSEFsave.attrs['units']='J^2 m^-4 s^-1'
    hSEFsave.attrs['long_name']= 'product of hanom and SEFanom'
    hSEFsave.attrs['_FillValue']=-9999
    hSEFsave.attrs['GridType']='Lat/lon Grid'
    
    windsave.attrs['units']='m/s'
    windsave.attrs['long_name']='maximum wind speed'
    windsave.attrs['_FillValue']=-9999
    windsave.attrs['GridType']='Lat/lon Grid'
    
    boxlat.attrs['units']='Degrees'
    boxlat.attrs['long_name']='Latitude'
    boxlat.attrs['_FillValue']=-9999
    boxlat.attrs['GridType']='0.5 deg Latitude Spacing'
   
    boxlon.attrs['units']='Degrees'
    boxlon.attrs['long_name']='Longitude'
    boxlon.attrs['_FillValue']=-9999
    boxlon.attrs['GridType']='0.625 deg Longitude Spacing'
    
    centerlatsave.attrs['units']='Degrees'
    centerlatsave.attrs['long_name']='Center latitude of storm'
    centerlatsave.attrs['_FillValue']=-9999
    centerlatsave.attrs['GridType']='0.5 deg Latitude Spacing'
   
    centerlonsave.attrs['units']='Degrees'
    centerlonsave.attrs['long_name']='Center longitude of storm'
    centerlonsave.attrs['_FillValue']=-9999
    centerlonsave.attrs['GridType']='0.625 Longitude Spacing'
    
    lhfsave.attrs['units']='W/m^2'
    lhfsave.attrs['long_name']='latent heat flux'
    lhfsave.attrs['_FillValue']=-9999
    lhfsave.attrs['GridType']='Lat/lon Grid'
  
    shfsave.attrs['units']='W/m^2'
    shfsave.attrs['long_name']='sensible heat flux'
    shfsave.attrs['_FillValue']=-9999
    shfsave.attrs['GridType']='Lat/lon Grid'
    
    varbudget_ds = xr.Dataset({'SW':SWsave, 'LW':LWsave, 'SEF':SEFsave, 'wind':windsave, 'latitude':boxlat, \
            'longitude':boxlon, 'centerlat':centerlatsave, 'centerlon':centerlonsave, 'year':yearsave, \
            'month':monthsave, 'day':daysave, 'hour':hoursave, 'h':hsave, 'varh':varhsave, 'hSW':hSWsave, \
            'hLW':hLWsave, 'hSEF':hSEFsave, 'lhf':lhfsave, 'shf':shfsave}) 
    #varbudget_ds.to_netcdf('/home/cdirkes/era5/varbudget_output/test_varbudget.era5.'+str(current_year)+'.nc', 'w', format='NETCDF4')
    varbudget_ds.to_netcdf('/home/awing/era5/test_varbudget.era5.'+str(current_year)+'.nc', 'w', format='NETCDF4') #AAW save
    print('\n')

