#%%
from re import S
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt

###NOW WORK WITH BINNING REANALYSIS DATA#################################################

#Functions
def boxavg(thing,lat,lon):
    coslat_values = np.transpose(np.tile(np.cos(np.deg2rad(lat)),(len(lon),1)))
    thing1 = thing*coslat_values
    thing2 = thing1/thing1
    average = np.nansum(np.nansum(thing1,0))/np.nansum(np.nansum(coslat_values*thing2,0))

    return average

#Open the reanalysis data that is already has years concatenated and trimmed after LMI and uncomment the
#reanalysis you want to run through the code
minbin = 0
maxbin = 66

#CFSR
reanalysisdata = xr.open_dataset(r'/home/awing/tc-pod/Dirkes2022/AAW.cfsr.lmconcats.nc',cache=False).load()
save_name = 'CFSR'
latres = 0.5
lonres = 0.5
lats = np.arange(-5,5+latres,latres)
lons = np.arange(-5,5+lonres,lonres)

#ERA5
#reanalysisdata = xr.open_dataset(r'/home/awing/tc-pod/Dirkes2022/AAW.era5.lmconcats.nc',cache=False).load()
#save_name = 'ERA5'
#latres = 0.25
#lonres = 0.25
#lats = np.arange(-5,5+latres,latres)
#lons = np.arange(-5,5+lonres,lonres)

#JRA55
#reanalysisdata = xr.open_dataset(r'/home/awing/tc-pod/Dirkes2022/AAW.native.jra55.lmconcats.nc',cache=False).load()
#save_name = 'JRA55'
#latres = 0.56162167
#lonres = 0.5625
#lats = np.arange(-5,5+latres,latres)
#lons = np.arange(-5,5+lonres,lonres)

#MERRA2
#reanalysisdata = xr.open_dataset(r'/home/awing/tc-pod/Dirkes2022/AAW.merra2.lmconcats.nc',cache=False).load()
#save_name = 'MERRA2'
#latres = 0.5
#lonres = 0.625
#lats = np.arange(-5,5+latres,latres)
#lons = np.arange(-5,5+lonres,lonres)

#ERAINT
#reanalysisdata = xr.open_dataset(r'/home/cdirkes/eraint/landmask/vars/eraint.lmconcats.nc',cache=False).load()
#save_name = 'ERAINT'
#latres = 0.701517
#lonres = 0.703125
#lats = np.array([-4.9122619, -4.2105102, -3.5087585, -2.8070068, -2.1052551, -1.4035034, -0.7017517,\
#                   0, 0.7017517, 1.4035034, 2.1052551, 2.8070068, 3.5087585, 4.2105102, 4.9122619])
#lons = np.array([-4.921875, -4.21875, -3.515625, -2.8125, -2.109375, -1.40625, -0.703125, \
#                   0, 0.703125, 1.40625, 2.109375, 2.8125, 3.515625, 4.21875, 4.921875]) 

#Value needed for the compositing from bins
rawmaxbin = maxbin/3

#Transpose reanalysis data so they match the model dims
reanalysisdata = reanalysisdata.transpose('nstorms','ntracks','nlat','nlon')
reanalysisdata = reanalysisdata.swap_dims({'nstorms':'nstorms','ntracks':'ntracks','nlat':'nlat','nlon':'nlon'})

#Resolution of dataset for getting bin count
count_denom = len(reanalysisdata.latbox_concat[0][0]) * len(reanalysisdata.lonbox_concat[0][0])

#Create the hanom variable
hanom_var = np.ones((len(reanalysisdata.h_concat),len(reanalysisdata.h_concat[0]),len(reanalysisdata.h_concat[0][0]),len(reanalysisdata.h_concat[0][0][0])))*np.nan
for s in range(0,len(reanalysisdata.h_concat)):
    for t in range(0,len(reanalysisdata.h_concat[0])):
        havg = boxavg(np.array(reanalysisdata.h_concat[s][t][:][:]),np.array(reanalysisdata.latbox_concat[s][t][:]),np.array(reanalysisdata.lonbox_concat[s][t][:]))
        hanom_var[s][t][:][:] = np.array(reanalysisdata.h_concat[s][t][:][:]) - havg
reanalysisdata['hanom'] = (['nstorms','ntracks','nlat','nlon'],np.array(hanom_var[:][:][:][:]))

#Get list of vars to loop through to NaN all variables at timesteps where the TC is outside of 30 N/S
vars = list(reanalysisdata.keys())
#Tag and bin the reanalysis data
maxwinds = reanalysisdata.wind_concat
winds_list = []

#Loop through the variables to pick out the feedbacks and add a normalized version of that variable
for var in vars:
    if(var[0:1]=='h' or var[0:4]=='varh'):
        normvar = np.array(reanalysisdata[var])
        boxavrawvar = np.array(reanalysisdata[var])
        boxavvar = np.ones((len(maxwinds),len(maxwinds[0]))) * np.nan
        boxavnormvar = np.ones((len(maxwinds),len(maxwinds[0]))) * np.nan
        for s in range(len(maxwinds)):
            for t in range(len(maxwinds[s])):
                hvar = np.array(reanalysisdata.varh_concat[s][t][:][:])
                boxavghvar = boxavg(hvar,np.array(reanalysisdata.latbox_concat[s][t][:]),np.array(reanalysisdata.lonbox_concat[s][t][:]))
                normvar[s][t][:][:] = normvar[s][t][:][:]/boxavghvar
                boxavvar[s][t] = boxavg(boxavrawvar[s][t][:][:],reanalysisdata.latbox_concat[s][t][:],np.array(reanalysisdata.lonbox_concat[s][t][:]))
                boxavnormvar[s][t] = boxavg(np.array(normvar[s][t][:][:]),np.array(reanalysisdata.latbox_concat[s][t][:]),np.array(reanalysisdata.lonbox_concat[s][t][:]))
        reanalysisdata['norm'+var] = (['nstorms','ntracks','nlat','nlon'],np.array(normvar[:][:][:][:]))
        reanalysisdata['boxav_'+var] = (['nstorms','ntracks'],np.array(boxavvar[:][:]))
        reanalysisdata['boxav_norm_'+var] = (['nstorms','ntracks'],np.array(boxavnormvar[:][:]))

for s in range(0,len(maxwinds)):
    print(s,end='\r')
    vmax_indiv_list = []
    for t in range(0,len(maxwinds[s])):
        #First check and NaN all variables at timesteps where TC center is outside 30 N/S 
        if(reanalysisdata.centerlat_concat[s][t]>30 or reanalysisdata.centerlat_concat[s][t]<-30):
            for var in vars:
                if(reanalysisdata[var].ndim==2):
                    reanalysisdata[var][s,t] = np.nan
                elif(reanalysisdata[var].ndim==3):
                    reanalysisdata[var][s,t,:] = np.nan
                else:
                    reanalysisdata[var][s,t,:,:] = np.nan
        #Get max wind at specific step to tag the steps for binning snapshot
        vmax_sel = maxwinds[s,t].values
        vmax = xr.full_like(reanalysisdata.h_concat[s,t],float(vmax_sel)).rename('vmax')
        vmax_indiv_list.append(vmax)
    vmax_indiv_array = xr.concat(vmax_indiv_list,dim='ntracks')
    #Create the vmax tag variable
    winds_list.append(vmax_indiv_array)

#Update reanalysis data with the vmax tag created above
winds_array = xr.concat(winds_list,dim='nstorms')
reanalysisupdated = xr.merge([reanalysisdata,winds_array])
#Stretch the boxav variables to 1 dimension and make a new stretched windmax variable
newvars = list(reanalysisupdated.keys())
for var in newvars:
    if(var[0:5]=='boxav'):
        reanalysisupdated['new_'+var] = (['newsteps'],np.squeeze(np.reshape(np.array(reanalysisupdated[var]),(len(reanalysisdata.nstorms)*len(reanalysisdata.ntracks)))))

reanalysisupdated['new_maxwind'] = (['newsteps'],np.squeeze(np.reshape(np.array(reanalysisupdated['wind_concat']),(len(reanalysisdata.nstorms)*len(reanalysisdata.ntracks)))))

#Bin snapshots according to max wind speed bins
bins = np.arange(minbin,maxbin,3)
#Set a count array to gather the sample size for each bin and all bins
bins_count = np.zeros(len(bins))
vmax2 = reanalysisupdated.vmax.copy(deep=True)
onedvmax = reanalysisupdated.new_maxwind.copy(deep=True)
for b, bin in enumerate(bins):
    print(bin)
    upperbin = bin+3
    #Variable to get the number of samples for the current bin (divide by the resolution dims multiplied together)
    count = (len(np.where((reanalysisupdated.vmax>=bin)&(reanalysisupdated.vmax<upperbin))[0])/count_denom)
    bins_count[b] = count
    vmax2 = (xr.where((reanalysisupdated.vmax>=bin)&(reanalysisupdated.vmax<upperbin), b, vmax2))
    onedvmax = (xr.where((reanalysisupdated.new_maxwind>=bin)&(reanalysisupdated.new_maxwind<upperbin), b, onedvmax))
bin_ds = xr.Dataset(data_vars=dict(bins=(['nstorms','ntracks','nlat','nlon'],vmax2.values)))
onedbin_ds = xr.Dataset(data_vars=dict(newbins=(['newsteps'],onedvmax.values)))
ds = xr.merge([reanalysisupdated,bin_ds,onedbin_ds])
ds = ds.set_coords(['bins'])
ds = ds.set_coords(['newbins'])

#Get the mean of each bin for one composite image
#Get the mean of each bin for one composite image
bins = np.arange(0,22,1)
binlabels = np.arange(1.5,66,3)
dsbins = ds['bins'].values
dsnewbins = ds['newbins'].values
binmeans = {}
binboxavstdevs = {}
for var_name,values in ds.items():
    print(var_name)
    dvar = ds[var_name].values
    if(len(np.shape(dvar))==4):
        avg_bin_list = []
        for b,bin in enumerate(bins):
            avg_bin_list.append(np.nanmean(np.where(dsbins==bin,np.array(dvar),np.nan),axis=(0,1)))
        binmeans[var_name] = (['bin','lat','lon'],avg_bin_list)
    if(len(np.shape(dvar))==1):
        avg_bin_list = []
        stdev_bin_list = []
        for b,bin in enumerate(bins):
            avg_bin_list.append(np.nanmean(np.where(dsnewbins==bin,np.array(dvar),np.nan),axis=(0)))
            stdev_bin_list.append(np.nanstd(np.where(dsnewbins==bin,np.array(dvar),np.nan),axis=(0)))
        binmeans[var_name] = (['bin'],avg_bin_list)
        binboxavstdevs[var_name] = (['bin'],stdev_bin_list)

#Bin means
binmeans['bin'] = (['bin'],binlabels)
binmeans['lat'] = (['lat'],lats)
binmeans['lon'] = (['lon'],lons)
binmeans['bincounts'] = ('bin',bins_count)
#Add relevant raw variables that have not been binned or composited
binmeans['maxwind'] = (['nstorms','ntracks'],np.array(reanalysisdata['wind_concat']))
#binmeans['minSLP'] = (['nstorms','ntracks'],np.array(reanalysisdata['minSLP']))

#Binned boxav stdevs
binboxavstdevs['bin'] = (['bin'],binlabels)
binboxavstdevs['bincounts'] = ('bin',bins_count)

reanalysisbinavgdata = xr.Dataset(data_vars=binmeans, attrs={'description':'Mean Binned Data and Original Variables (Wind and Min. MSLP)'})
reanalysisbinboxavstdevdata = xr.Dataset(data_vars=binboxavstdevs, attrs={'description':'Standard Deviations of Binned Box Averaged Variables'})

reanalysisbinavgdata.to_netcdf('/home/jstarr2/TC-MSE-Var_POD/obsdata/'+save_name+'_Binned_Composites_with_normFeedbacks.nc')
reanalysisbinboxavstdevdata.to_netcdf('/home/jstarr2/TC-MSE-Var_POD/obsdata/'+save_name+'_Binned_STDEVS_of_BoxAvgs.nc')