#!/usr/bin/env python
# coding: utf-8

# # Code to calculate column-integrated moist static energy from 3D data and write out to 2D netcdf file

# In[1]:


import sys
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from cartopy import config
import cartopy.feature as cfeature
from cartopy.vector_transform import vector_scalar_to_grid
from matplotlib.axes import Axes
import scipy as sp
import xarray as xr
import math
import datetime as dt
import cfgrib


# In[2]:


def read_data_box(latbox,lonbox):
    """
    Read in the temperature, humidity, and geopotential data (+ dimensions) from a 3D file at a single time, over a given region
    e.g., tabs,qv,gz,plev,lat,lon = read_data(latbox,lonbox)
    """    
    #Read in dimensions
    lat = dst['latitude'].sel(latitude=latbox)
    lon = dst['longitude'].sel(longitude=lonbox)
    
    #Read in variables
    tabs = dst['t'].sel(latitude=latbox,longitude=lonbox,time=targettime) #temperature (K)
    qv = dsq['q'].sel(latitude=latbox,longitude=lonbox,time=targettime) #specific humidity (kg/kg)
    z = dsz['gh'].sel(latitude=latbox,longitude=lonbox,time=targettime) #geopotential height (gpm)
    ps = dsp['sp'].sel(latitude=latbox,longitude=lonbox,time=targettime) #surface pressure (Pa)
    dst.close()
    dsq.close()
    dsz.close()
    dsp.close()
    
    return tabs,qv,z,ps,lat,lon


# In[3]:


def read_data():
    """
    Read in the temperature, humidity, and geopotential data (+ dimensions) from a 3D file at a single time, over the whole domain
    e.g., tabs,qv,gz,plev,lat,lon = read_data
    """   
    #Read in dimensions
    lat = dst['latitude']
    lon = dst['longitude']
    
    #Read in variables
    tabs = dst['t'].sel(time=targettime) #temperature (K)
    qv = dsq['q'].sel(time=targettime) #specific humidity (kg/kg)
    z = dsz['gh'].sel(time=targettime) #geopotential height (gpm)
    ps = dsp['sp'].sel(time=targettime) #surface pressure (Pa)
    dst.close()
    dsq.close()
    dsz.close()
    dsp.close()
    
    return tabs,qv,z,ps,lat,lon


# In[4]:


def calc_pressure(ps):
    """
    Calculate pressure on hybrid levels. Need to have first run read_data to get ps  
    """    
    #Read in hybrid coefficients
    hya = [0.000000000000000000,0.000000000000000000,0.000000000000000000,0.000000000000000000,0.000000000000000000,
           0.000000000000000000,0.000000000000000000,0.000000000000000000,133.051011276943000000,364.904148871589000000,
           634.602716447362000000,959.797167291774000000,1347.680041655150000000,1790.907395951100000000,2294.841689948500000000,
           2847.484777711760000000,3468.871488118640000000,4162.956462969160000000,4891.880832504910000000,5671.824239804080000000,
           6476.712996385320000000,7297.469894720490000000,8122.159791249150000000,8914.082201062340000000,9656.181910501640000000,
           10329.436177774600000000,10912.638444238700000000,11369.647830843200000000,11695.371597470000000000,
           11861.253087394800000000,11855.434316349300000000,11663.355365580300000000,11285.404064494200000000,
           10729.949405567900000000,10014.615053510700000000,9167.247035833100000000,8226.244907704420000000,
           7201.568980298280000000,6088.673008533920000000,4950.000000000000000000,4000.000000000000000000,
           3230.000000000000000000,2610.000000000000000000,2105.000000000000000000,1700.000000000000000000,
           1370.000000000000000000,1105.000000000000000000,893.000000000000000000,720.000000000000000000,581.000000000000000000,
           469.000000000000000000,377.000000000000000000,301.000000000000000000,237.000000000000000000,182.000000000000000000,
           136.000000000000000000,97.000000000000000000,65.000000000000000000,39.000000000000000000,20.000000000000000000,
           0.000000000000000000]
    hyb = [1.000000000000000000,0.997000000000000000,0.994000000000000000,0.989000000000000000,0.982000000000000000,
           0.972000000000000000,0.960000000000000000,0.946000000000000000,0.926669489887231000,0.904350958511284000,
           0.879653972835526000,0.851402028327082000,0.819523199583449000,0.785090926040489000,0.748051583100515000,
           0.709525152222882000,0.668311285118814000,0.624370435370308000,0.580081191674951000,0.534281757601959000,
           0.488232870036147000,0.442025301052795000,0.395778402087509000,0.350859177989377000,0.307438180894984000,
           0.265705638222254000,0.225873615557613000,0.189303521691568000,0.155046284025300000,0.124387469126052000,
           0.096445656836507500,0.072366446344196600,0.052145959355057800,0.035700505944321400,0.022853849464893500,
           0.013327529641668900,0.006737550922955820,0.002484310197017220,0.000113269914660783,0.000000000000000000,
           0.000000000000000000,0.000000000000000000,0.000000000000000000,0.000000000000000000,0.000000000000000000,
           0.000000000000000000,0.000000000000000000,0.000000000000000000,0.000000000000000000,0.000000000000000000,
           0.000000000000000000,0.000000000000000000,0.000000000000000000,0.000000000000000000,0.000000000000000000,
           0.000000000000000000,0.000000000000000000,0.000000000000000000,0.000000000000000000,0.000000000000000000,
           0.000000000000000000 ]
    
    #make all the same dimensions lev x lat x lon
    hya = np.transpose(np.tile(hya,(ps.shape[0],ps.shape[1],1)),(2,0,1))
    hyb = np.transpose(np.tile(hyb,(ps.shape[0],ps.shape[1],1)),(2,0,1))
    ps = np.tile(ps,(hya.shape[0],1,1)) 

    #for easier testing
    #ps = np.zeros((np.shape(hya)))
    #ps = ps+1000*100
    #np.shape(ps)

    phalf = hya + hyb*ps
    dphalf = np.diff(phalf,axis=0) 
    dphalf = dphalf*-1 

    #Calculate pressure on full levels
    pfull = np.zeros((np.shape(hya))) #pre-allocate
    for k in range(hya.shape[0]): #0 to 61
        if k==0:
            pfull[k,:,:]=ps[0,:,:]
        elif k==hya.shape[0]-1:
            pfull[k,:,:]=0.5*phalf[k-1,:,:]    
        else:
            pfull[k,:,:]=np.exp((1/dphalf[k-1])*(phalf[k-1]*np.log(phalf[k-1])-phalf[k]*np.log(phalf[k]))-1)
        
    #remove bottom midpoint level (pfull) at surface so matches hybrid levels
    pfull = pfull[1:len(pfull)]

    mlevels = np.arange(1,61,1)
    ilevels = np.arange(1,62,1)

    #rewrite as xarray data array. pfull are midpoint levels (60 x lat x lon). phalf are interface levels (61 x lat x lon)

    plevm = xr.DataArray(pfull,dims=['hybrid','latitude','longitude'],coords={'hybrid':mlevels,'latitude':lat,'longitude':lon},attrs={'units':'hPa'}) #midpoint hybrid levels as data array            
    plevi = xr.DataArray(phalf,dims=['ilev','latitude','longitude'],coords={'ilev':ilevels,'latitude':lat,'longitude':lon},attrs={'units':'hPa'}) #interface levels as data array

    plevm = plevm/100 #convert to hPa
    plevi = plevi/100 #convert to hPa
            
    return plevm,plevi


# In[5]:


### pressure level interpolation (adapted from GEOCAT)

import cf_xarray
import metpy.interpolate
#import numpy as np
#import xarray as xr

__pres_lev_mandatory__ = np.array([
    1000, 925, 850, 700, 500, 400, 300, 250, 200, 150, 100, 70, 50, 30, 20, 10,
    7, 5, 3, 2, 1
]).astype(np.float32)  # Mandatory pressure levels (mb)
#__pres_lev_mandatory__ = __pres_lev_mandatory__ * 100.0  # Convert mb to Pa

def interp_hybrid_to_pressure(data,
                              plevm,
                              new_levels=__pres_lev_mandatory__,
                              lev_dim=None,
                              method='linear'):
    """Interpolate data from hybrid-sigma levels to isobaric levels.
    Notes
    -----
    ACKNOWLEDGEMENT: We'd like to thank to Brian Medeiros (https://github.com/brianpm), Matthew Long
    (https://github.com/matt-long), and Deepak Cherian (https://github.com/dcherian) at NCAR for their
    great contributions since the code implemented here is mostly based on their work.
    Parameters
    ----------
    data : xarray.DataArray
        Multidimensional data array, which holds hybrid-sigma levels and has a `lev_dim` coordinate.
    plevm : xarray.DataArray
        Multidimensional data array which holds pressure on mid-point levels (hybrid-sigma levels). Must have the same shape as data
    new_levels : np.ndarray
        A one-dimensional array of output pressure levels. If not given, the mandatory
        list of 21 pressure levels is used.
    lev_dim : str
        String that is the name of level dimension in data. Defaults to "lev".
    method : str
        String that is the interpolation method; can be either "linear" or "log". Defaults to "linear".
    """

    # Determine the level dimension and then the interpolation axis
    if lev_dim is None:
        try:
            lev_dim = data.cf["vertical"].name
        except Exception:
            raise ValueError(
                "Unable to determine vertical dimension name. Please specify the name via `lev_dim` argument.'"
            )

    interp_axis = data.dims.index(lev_dim)

    # Calculate pressure levels at the hybrid levels
    pressure = plevm 

    # Make pressure shape same as data shape
    pressure = pressure.transpose(*data.dims)

    # Define interpolation function
    if method == 'linear':
        func_interpolate = metpy.interpolate.interpolate_1d
    elif method == 'log':
        func_interpolate = metpy.interpolate.log_interpolate_1d
    else:
        raise ValueError(f'Unknown interpolation method: {method}. '
                         f'Supported methods are: "log" and "linear".')

    def _vertical_remap(data, pressure):
        """Define interpolation function."""

        return func_interpolate(new_levels, pressure, data, axis=interp_axis)
    
    ###############################################################################
    # Workaround
    #
    # For the issue with metpy's xarray interface:
    #
    # `metpy.interpolate.interpolate_1d` had "no implementation found for
    # 'numpy.apply_along_axis'" issue for cases where the input is
    # xarray.Dataarray and has more than 3 dimensions (e.g. 4th dim of `time`).

    # Use dask.array.core.map_blocks instead of xarray.apply_ufunc and
    # auto-chunk input arrays to ensure using only Numpy interface of
    # `metpy.interpolate.interpolate_1d`.

    # # Apply vertical interpolation
    # # Apply Dask parallelization with xarray.apply_ufunc
    # output = xr.apply_ufunc(
    #     _vertical_remap,
    #     data,
    #     pressure,
    #     exclude_dims=set((lev_dim,)),  # Set dimensions allowed to change size
    #     input_core_dims=[[lev_dim], [lev_dim]],  # Set core dimensions
    #     output_core_dims=[["plev"]],  # Specify output dimensions
    #     vectorize=True,  # loop over non-core dims
    #     dask="parallelized",  # Dask parallelization
    #     output_dtypes=[data.dtype],
    #     dask_gufunc_kwargs={"output_sizes": {
    #         "plev": len(new_levels)
    #     }},
    # )

    # If an unchunked Xarray input is given, chunk it just with its dims
    if data.chunks is None:
        data_chunk = dict([
            (k, v) for (k, v) in zip(list(data.dims), list(data.shape))
        ])
        data = data.chunk(data_chunk)

    # Chunk pressure equal to data's chunks
    pressure = pressure.chunk(data.chunks)

    # Output data structure elements
    out_chunks = list(data.chunks)
    out_chunks[interp_axis] = (new_levels.size,)
    out_chunks = tuple(out_chunks)
    # ''' end of boilerplate

    from dask.array.core import map_blocks
    output = map_blocks(
        _vertical_remap,
        data.data,
        pressure.data,
        chunks=out_chunks,
        dtype=data.dtype,
        drop_axis=[interp_axis],
        new_axis=[interp_axis],
    )

    # End of Workaround
    ###############################################################################

    output = xr.DataArray(output)

    # Set output dims and coords
    dims = [
        data.dims[i] if i != interp_axis else "plev" for i in range(data.ndim)
    ]

    # Rename output dims. This is only needed with above workaround block
    dims_dict = {output.dims[i]: dims[i] for i in range(len(output.dims))}
    output = output.rename(dims_dict)

    coords = {}
    for (k, v) in data.coords.items():
        if k != lev_dim:
            coords.update({k: v})
        else:
            coords.update({"plev": new_levels})

    output = output.transpose(*dims).assign_coords(coords)

    return output


# In[6]:


def compute_mse(tabs,qv,z,plevm,plev,pbot,ptop):
    """
    Compute column-integrated moist static energy
    e.g., h = compute_mse(tabs,qv,gz,plevm,plev,pbot,ptop)
    """
    #Define constants
    cp = 1.00464e3
    g=9.8
    Lv=2.501e6
    
    #interpolate to pressure levels
    tabs_p = interp_hybrid_to_pressure(tabs, plevm, new_levels=plev, lev_dim='hybrid').load()
    qv_p = interp_hybrid_to_pressure(qv, plevm, new_levels=plev, lev_dim='hybrid').load()
    z_p = interp_hybrid_to_pressure(z, plevm, new_levels=plev, lev_dim='hybrid').load()

    #Compute moist static energy
    mse = cp*tabs_p + g*z_p + Lv*qv_p
    
    #Select the range we are integrating over and define dp. This selects betweeh ptop and the bottom pressure defined as pbot
    #indexing is from bottom to top so do this way for positive dp. We know that ptop and pbot are pressure levels so can do this way.
    sz = mse.shape
    plevDA = xr.DataArray(plev,dims=['level'],coords={'level':plev}) #make xr.DataArray so can use .sel
    dp = -1*np.diff(plevDA.sel(level=slice(pbot,ptop)))
    dptile = np.transpose(np.tile(dp,(mse.shape[1],mse.shape[2],1)),(2, 0, 1)) #make plev x lat x lon
    dptile = dptile*100 #convert to Pa

    #Do vertical integral
    ibot = np.int(np.where(plev==pbot)[0]) #find index corresponding to pbot
    itop = np.int(np.where(plev==ptop)[0]) #find index corresponding to ptop
    h = np.sum(mse[ibot+1:itop+1,:,:]*dptile,axis=0)/g #sum over zeroth (plev) dimension. do from 925-1 for pbot=950,ptop=1
    
    return h


# In[7]:


def write_to_file(h,filedir,year,mm,dd,hour):
    """
    Write column-integrated moist static energy (already defined as a data array) to a netcdf file
    """
    h.attrs['units']='J/m^2'
    h.attrs['long_name']='column-integrated moist static energy'
    h.attrs['_FillValue']=-9999
    h.attrs['GridType']='Gaussian Latitude/Longitude Grid'

    hds = xr.Dataset({'h':h}, attrs={'note':'column integral from 950 hPa to 1 hPa'})

    hds.to_netcdf(filedir+'jra55.h.'+str(year)+"{0:0=2d}".format(mm)+"{0:0=2d}".format(dd)+hour+'.nc', format='NETCDF4')


# In[8]:


##### Declare subdomain
latbox = slice(25,15)
lonbox = slice(260,270)

#pressure levels to interpolate to
plev = np.array([1000., 975., 950., 925., 900., 875., 850., 825., 800., 775., 750., 700., 
                 650., 600., 550., 500., 450., 400., 350., 300., 250., 300., 250., 225., 
                 200., 175., 150., 125., 100., 70., 50., 30., 20., 10., 7., 5., 3., 2., 1.])

#What level to integrate from (hPa) and to
pbot = 950 
ptop = 1


# ## Read 3D data

# In[9]:


hour = ['00','06','12','18']
filebase = '/gpfs/fs1/collections/rda/data/ds628.0/anl_mdl/' #where to read files from
filebasep = '/gpfs/fs1/collections/rda/data/ds628.0/anl_surf/' #where to read surface pressure files from
filedir = '/glade/p/univ/ufsu0014/jra55/2D/h/' #where to output files to
years = np.arange(2010,2011,1) #define range of years
for yy,year in enumerate(years): #loop over years
    for mm in range(1,13): #month 1 through 12
        if mm==9 or mm==4 or mm==6 or mm==11: #if September, April, June, or November
            for dd in range(1,31): #30 days in month
                for hh in range(len(hour)):
                    print(mm,dd,hour[hh])
                    #Open datasets
                    if dd <= 10:
                        dst = xr.open_dataset(filebase+str(year)+'/anl_mdl.011_tmp.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'1018',engine='cfgrib')
                        dsz = xr.open_dataset(filebase+str(year)+'/anl_mdl.007_hgt.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'1018',engine='cfgrib')
                        dsq = xr.open_dataset(filebase+str(year)+'/anl_mdl.051_spfh.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'1018',engine='cfgrib')                        
                    elif dd > 10 and dd <=20:
                        dst = xr.open_dataset(filebase+str(year)+'/anl_mdl.011_tmp.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'1100_'+str(year)+"{0:0=2d}".format(mm)+'2018',engine='cfgrib')
                        dsz = xr.open_dataset(filebase+str(year)+'/anl_mdl.007_hgt.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'1100_'+str(year)+"{0:0=2d}".format(mm)+'2018',engine='cfgrib')
                        dsq = xr.open_dataset(filebase+str(year)+'/anl_mdl.051_spfh.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'1100_'+str(year)+"{0:0=2d}".format(mm)+'2018',engine='cfgrib')                  
                    else: #>20
                        dst = xr.open_dataset(filebase+str(year)+'/anl_mdl.011_tmp.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'2100_'+str(year)+"{0:0=2d}".format(mm)+'3018',engine='cfgrib')
                        dsz = xr.open_dataset(filebase+str(year)+'/anl_mdl.007_hgt.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'2100_'+str(year)+"{0:0=2d}".format(mm)+'3018',engine='cfgrib')
                        dsq = xr.open_dataset(filebase+str(year)+'/anl_mdl.051_spfh.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'2100_'+str(year)+"{0:0=2d}".format(mm)+'3018',engine='cfgrib')
                    if year < 2014:
                        dsp = xr.open_dataset(filebasep+str(year)+'/anl_surf.001_pres.reg_tl319.'+str(year)+'010100_'+str(year)+'123118',engine='cfgrib')
                    else:
                        dsp = xr.open_dataset(filebasep+str(year)+'/anl_surf.001_pres.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'3018',engine='cfgrib')
                    
		    #Format the time we want
                    targettime = str(year)+'-'+"{0:0=2d}".format(mm)+'-'+"{0:0=2d}".format(dd)+'T'+hour[hh]+':00:00.000000000'
                    #Read in data
                    #tabs,qv,z,ps,lat,lon = read_data_box(latbox,lonbox)
                    tabs,qv,z,ps,lat,lon = read_data()
                    #Calculate pressure
                    plevm,plevi = calc_pressure(ps)
                    #Compute column-integrated moist static energy
                    h = compute_mse(tabs,qv,z,plevm,plev,pbot,ptop)
                    #Write h out to netcdf file
                    write_to_file(h,filedir,year,mm,dd,hour[hh])               
        elif mm==2 and year % 4 ==0: #leap year February
            for dd in range(1,30): #29 days
                for hh in range(len(hour)):
                    print(mm,dd,hour[hh])
                   #Open datasets
                    if dd <= 10:
                        dst = xr.open_dataset(filebase+str(year)+'/anl_mdl.011_tmp.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'1018',engine='cfgrib')
                        dsz = xr.open_dataset(filebase+str(year)+'/anl_mdl.007_hgt.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'1018',engine='cfgrib')
                        dsq = xr.open_dataset(filebase+str(year)+'/anl_mdl.051_spfh.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'1018',engine='cfgrib')                        
                    elif dd > 10 and dd <=20:
                        dst = xr.open_dataset(filebase+str(year)+'/anl_mdl.011_tmp.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'1100_'+str(year)+"{0:0=2d}".format(mm)+'2018',engine='cfgrib')
                        dsz = xr.open_dataset(filebase+str(year)+'/anl_mdl.007_hgt.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'1100_'+str(year)+"{0:0=2d}".format(mm)+'2018',engine='cfgrib')
                        dsq = xr.open_dataset(filebase+str(year)+'/anl_mdl.051_spfh.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'1100_'+str(year)+"{0:0=2d}".format(mm)+'2018',engine='cfgrib')                  
                    else: #>20
                        dst = xr.open_dataset(filebase+str(year)+'/anl_mdl.011_tmp.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'2100_'+str(year)+"{0:0=2d}".format(mm)+'2918',engine='cfgrib')
                        dsz = xr.open_dataset(filebase+str(year)+'/anl_mdl.007_hgt.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'2100_'+str(year)+"{0:0=2d}".format(mm)+'2918',engine='cfgrib')
                        dsq = xr.open_dataset(filebase+str(year)+'/anl_mdl.051_spfh.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'2100_'+str(year)+"{0:0=2d}".format(mm)+'2918',engine='cfgrib')                    
                    if year < 2014:
                        dsp = xr.open_dataset(filebasep+str(year)+'/anl_surf.001_pres.reg_tl319.'+str(year)+'010100_'+str(year)+'123118',engine='cfgrib')
                    else:
                        dsp = xr.open_dataset(filebasep+str(year)+'/anl_surf.001_pres.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'2918',engine='cfgrib')
                    
		    #Format the time we want
                    targettime = str(year)+'-'+"{0:0=2d}".format(mm)+'-'+"{0:0=2d}".format(dd)+'T'+hour[hh]+':00:00.000000000'
                    #Read in data
                    #tabs,qv,z,ps,lat,lon = read_data_box(latbox,lonbox)
                    tabs,qv,z,ps,lat,lon = read_data()
                    #Calculate pressure
                    plevm,plevi = calc_pressure(ps)
                    #Compute column-integrated moist static energy
                    h = compute_mse(tabs,qv,z,plevm,plev,pbot,ptop)
                    #Write h out to netcdf file
                    write_to_file(h,filedir,year,mm,dd,hour[hh])
        elif mm==2 and year % 4 !=0: #non-leap year February
            for dd in range(1,29): #28 days
                for hh in range(len(hour)):
                    print(mm,dd,hour[hh])
                   #Open datasets
                    if dd <= 10:
                        dst = xr.open_dataset(filebase+str(year)+'/anl_mdl.011_tmp.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'1018',engine='cfgrib')
                        dsz = xr.open_dataset(filebase+str(year)+'/anl_mdl.007_hgt.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'1018',engine='cfgrib')
                        dsq = xr.open_dataset(filebase+str(year)+'/anl_mdl.051_spfh.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'1018',engine='cfgrib')                        
                    elif dd > 10 and dd <=20:
                        dst = xr.open_dataset(filebase+str(year)+'/anl_mdl.011_tmp.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'1100_'+str(year)+"{0:0=2d}".format(mm)+'2018',engine='cfgrib')
                        dsz = xr.open_dataset(filebase+str(year)+'/anl_mdl.007_hgt.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'1100_'+str(year)+"{0:0=2d}".format(mm)+'2018',engine='cfgrib')
                        dsq = xr.open_dataset(filebase+str(year)+'/anl_mdl.051_spfh.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'1100_'+str(year)+"{0:0=2d}".format(mm)+'2018',engine='cfgrib')                  
                    else: #>20
                        dst = xr.open_dataset(filebase+str(year)+'/anl_mdl.011_tmp.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'2100_'+str(year)+"{0:0=2d}".format(mm)+'2818',engine='cfgrib')
                        dsz = xr.open_dataset(filebase+str(year)+'/anl_mdl.007_hgt.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'2100_'+str(year)+"{0:0=2d}".format(mm)+'2818',engine='cfgrib')
                        dsq = xr.open_dataset(filebase+str(year)+'/anl_mdl.051_spfh.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'2100_'+str(year)+"{0:0=2d}".format(mm)+'2818',engine='cfgrib')
                    if year < 2014:
                        dsp = xr.open_dataset(filebasep+str(year)+'/anl_surf.001_pres.reg_tl319.'+str(year)+'010100_'+str(year)+'123118',engine='cfgrib')
                    else:
                        dsp = xr.open_dataset(filebasep+str(year)+'/anl_surf.001_pres.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'2818',engine='cfgrib')

		    #Format the time we want
                    targettime = str(year)+'-'+"{0:0=2d}".format(mm)+'-'+"{0:0=2d}".format(dd)+'T'+hour[hh]+':00:00.000000000'
                    #Read in data
                    #tabs,qv,z,ps,lat,lon = read_data_box(latbox,lonbox)
                    tabs,qv,z,ps,lat,lon = read_data()
                    #Calculate pressure
                    plevm,plevi = calc_pressure(ps)
                    #Compute column-integrated moist static energy
                    h = compute_mse(tabs,qv,z,plevm,plev,pbot,ptop)
                    #Write h out to netcdf file
                    write_to_file(h,filedir,year,mm,dd,hour[hh])
        else:
            for dd in range(1,32): #31 days
                for hh in range(len(hour)):
                    print(mm,dd,hour[hh])
                    #Open datasets
                    if dd <= 10:
                        dst = xr.open_dataset(filebase+str(year)+'/anl_mdl.011_tmp.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'1018',engine='cfgrib')
                        dsz = xr.open_dataset(filebase+str(year)+'/anl_mdl.007_hgt.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'1018',engine='cfgrib')
                        dsq = xr.open_dataset(filebase+str(year)+'/anl_mdl.051_spfh.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'1018',engine='cfgrib')                        
                    elif dd > 10 and dd <=20:
                        dst = xr.open_dataset(filebase+str(year)+'/anl_mdl.011_tmp.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'1100_'+str(year)+"{0:0=2d}".format(mm)+'2018',engine='cfgrib')
                        dsz = xr.open_dataset(filebase+str(year)+'/anl_mdl.007_hgt.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'1100_'+str(year)+"{0:0=2d}".format(mm)+'2018',engine='cfgrib')
                        dsq = xr.open_dataset(filebase+str(year)+'/anl_mdl.051_spfh.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'1100_'+str(year)+"{0:0=2d}".format(mm)+'2018',engine='cfgrib')                  
                    else: #>20
                        dst = xr.open_dataset(filebase+str(year)+'/anl_mdl.011_tmp.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'2100_'+str(year)+"{0:0=2d}".format(mm)+'3118',engine='cfgrib')
                        dsz = xr.open_dataset(filebase+str(year)+'/anl_mdl.007_hgt.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'2100_'+str(year)+"{0:0=2d}".format(mm)+'3118',engine='cfgrib')
                        dsq = xr.open_dataset(filebase+str(year)+'/anl_mdl.051_spfh.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'2100_'+str(year)+"{0:0=2d}".format(mm)+'3118',engine='cfgrib')
                    if year < 2014:
                        dsp = xr.open_dataset(filebasep+str(year)+'/anl_surf.001_pres.reg_tl319.'+str(year)+'010100_'+str(year)+'123118',engine='cfgrib')
                    else:
                        dsp = xr.open_dataset(filebasep+str(year)+'/anl_surf.001_pres.reg_tl319.'+str(year)+"{0:0=2d}".format(mm)+'0100_'+str(year)+"{0:0=2d}".format(mm)+'3118',engine='cfgrib')

		    #Format the time we want
                    targettime = str(year)+'-'+"{0:0=2d}".format(mm)+'-'+"{0:0=2d}".format(dd)+'T'+hour[hh]+':00:00.000000000'
                    #Read in data
                    #tabs,qv,z,ps,lat,lon = read_data_box(latbox,lonbox)
                    tabs,qv,z,ps,lat,lon = read_data()
                    #Calculate pressure
                    plevm,plevi = calc_pressure(ps)
                    #Compute column-integrated moist static energy
                    h = compute_mse(tabs,qv,z,plevm,plev,pbot,ptop)
                    #Write h out to netcdf file
                    write_to_file(h,filedir,year,mm,dd,hour[hh])


