#!/usr/bin/env python

"""Namelist creator for CIME's data atmosphere model.
"""

# Typically ignore this.
# pylint: disable=invalid-name

# Disable these because this is our standard setup
# pylint: disable=wildcard-import,unused-wildcard-import,wrong-import-position

import os, sys, glob

_CIMEROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..","..","..","..","..")
sys.path.append(os.path.join(_CIMEROOT, "scripts", "Tools"))

from standard_script_setup import *
from CIME.case import Case
from CIME.nmlgen import NamelistGenerator
from CIME.utils import expect, get_model, safe_copy
from CIME.buildnml import create_namelist_infile, parse_input
from CIME.XML.files import Files

logger = logging.getLogger(__name__)

# pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
####################################################################################
def _create_namelists(case, confdir, inst_string, infile, nmlgen, data_list_path):
####################################################################################
    """Write out the namelist for this component.

    Most arguments are the same as those for `NamelistGenerator`. The
    `inst_string` argument is used as a suffix to distinguish files for
    different instances. The `confdir` argument is used to specify the directory
    in which output files will be placed.
    """

    #----------------------------------------------------
    # Get a bunch of information from the case.
    #----------------------------------------------------
    datm_mode = case.get_value("DATM_MODE")
    datm_topo = case.get_value("DATM_TOPO")
    datm_presaero = case.get_value("DATM_PRESAERO")
    datm_co2_tseries = case.get_value("DATM_CO2_TSERIES")
    atm_grid = case.get_value("ATM_GRID")
    grid = case.get_value("GRID")
    clm_usrdat_name = case.get_value("CLM_USRDAT_NAME")

    #----------------------------------------------------
    # Check for incompatible options.
    #----------------------------------------------------
    if "CLM" in datm_mode and case.get_value("COMP_LND") == "clm":
        expect(datm_presaero != "none",
               "A DATM_MODE for CLM is incompatible with DATM_PRESAERO=none.")
        expect(datm_topo != "none",
               "A DATM_MODE for CLM is incompatible with DATM_TOPO=none.")
    expect(grid != "CLM_USRDAT" or clm_usrdat_name in ("", "UNSET"),
           "GRID=CLM_USRDAT and CLM_USRDAT_NAME is NOT set.")

    #----------------------------------------------------
    # Log some settings.
    #----------------------------------------------------
    logger.debug("DATM mode is {}".format(datm_mode))
    logger.debug("DATM grid is {}".format(atm_grid))
    logger.debug("DATM presaero mode is {}".format(datm_presaero))
    logger.debug("DATM topo mode is {}".format(datm_topo))

    #----------------------------------------------------
    # Create configuration information.
    #----------------------------------------------------
    config = {}
    config['grid'] = grid
    config['atm_grid'] = atm_grid
    config['datm_mode'] = datm_mode
    config['datm_co2_tseries'] = datm_co2_tseries
    config['datm_presaero'] = datm_presaero
    config['cime_model'] = get_model()

    #----------------------------------------------------
    # Initialize namelist defaults
    #----------------------------------------------------
    nmlgen.init_defaults(infile, config)

    #----------------------------------------------------
    # Construct the list of streams.
    #----------------------------------------------------
    streams = nmlgen.get_streams()
    #
    # This disable is required because nmlgen.get_streams
    # may return a string or a list.  See issue #877 in ESMCI/cime
    #
    #pylint: disable=no-member
    if datm_presaero != "none":
        streams.append("presaero.{}".format(datm_presaero))

    if datm_topo != "none":
        streams.append("topo.{}".format(datm_topo))

    if datm_co2_tseries != "none":
        streams.append("co2tseries.{}".format(datm_co2_tseries))

    # Add bias correction stream if given in namelist.
    bias_correct = nmlgen.get_value("bias_correct")
    streams.append(bias_correct)

    # Add all anomaly forcing streams given in namelist.
    anomaly_forcing = nmlgen.get_value("anomaly_forcing")
    streams += anomaly_forcing

    #----------------------------------------------------
    # For each stream, create stream text file and update
    # shr_strdata_nml group and input data list.
    #----------------------------------------------------
    for stream in streams:

        # Ignore null values.
        if stream is None or stream in ("NULL", ""):
            continue

        inst_stream = stream + inst_string
        logger.debug("DATM stream is {}".format(inst_stream))
        stream_path = os.path.join(confdir, "datm.streams.txt." + inst_stream)
        nmlgen.create_stream_file_and_update_shr_strdata_nml(config, case.get_value("CASEROOT"),stream, stream_path, data_list_path)

    #----------------------------------------------------
    # Create `shr_strdata_nml` namelist group.
    #----------------------------------------------------
    # set per-stream variables
    nmlgen.create_shr_strdata_nml()

    # Determine model domain filename (in datm_in)
    if "CPLHIST" in datm_mode:
        datm_cplhist_domain_file = case.get_value("DATM_CPLHIST_DOMAIN_FILE")
        if datm_cplhist_domain_file == 'null':
            logger.info("   ....  Obtaining DATM model domain info from first stream file: {}".format(streams[0]))
        else:
            logger.info("   ....  Obtaining DATM model domain info from stream {}".format(streams[0]))
        nmlgen.add_default("domainfile", value=datm_cplhist_domain_file)
    else:
        atm_domain_file = case.get_value("ATM_DOMAIN_FILE")
        atm_domain_path = case.get_value("ATM_DOMAIN_PATH")
        if atm_domain_file != "UNSET":
            full_domain_path = os.path.join(atm_domain_path, atm_domain_file)
            nmlgen.add_default("domainfile", value=full_domain_path)
        else:
            nmlgen.add_default("domainfile", value='null')

    #----------------------------------------------------
    # Finally, write out all the namelists.
    #----------------------------------------------------
    namelist_file = os.path.join(confdir, "datm_in")
    nmlgen.write_output_file(namelist_file, data_list_path, groups=['datm_nml','shr_strdata_nml'])

###############################################################################
def buildnml(case, caseroot, compname):
###############################################################################

    # Build the component namelist and required stream txt files
    if compname != "datm":
        raise AttributeError

    rundir = case.get_value("RUNDIR")
    ninst = case.get_value("NINST_ATM")
    if ninst is None:
        ninst = case.get_value("NINST")

    # Determine configuration directory
    confdir = os.path.join(caseroot,"Buildconf",compname + "conf")
    if not os.path.isdir(confdir):
        os.makedirs(confdir)

    #----------------------------------------------------
    # Construct the namelist generator
    #----------------------------------------------------
    # Determine directory for user modified namelist_definitions.xml and namelist_defaults.xml
    user_xml_dir = os.path.join(caseroot, "SourceMods", "src." + compname)
    expect (os.path.isdir(user_xml_dir),
            "user_xml_dir {} does not exist ".format(user_xml_dir))

    # NOTE: User definition *replaces* existing definition.
    files = Files()
    definition_file = [files.get_value("NAMELIST_DEFINITION_FILE", {"component":"datm"})]

    user_definition = os.path.join(user_xml_dir, "namelist_definition_datm.xml")
    if os.path.isfile(user_definition):
        definition_file = [user_definition]
    for file_ in definition_file:
        expect(os.path.isfile(file_), "Namelist XML file {} not found!".format(file_))

    # Create the namelist generator object - independent of instance
    nmlgen = NamelistGenerator(case, definition_file, files=files)

    #----------------------------------------------------
    # Clear out old data.
    #----------------------------------------------------
    data_list_path = os.path.join(case.get_case_root(), "Buildconf",
                                  "datm.input_data_list")
    if os.path.exists(data_list_path):
        os.remove(data_list_path)

    #----------------------------------------------------
    # Loop over instances
    #----------------------------------------------------
    for inst_counter in range(1, ninst+1):
        # determine instance string
        inst_string = ""
        if ninst > 1:
            inst_string = '_' + '{:04d}'.format(inst_counter)

        # If multi-instance case does not have restart file, use
        # single-case restart for each instance
        rpointer = "rpointer." + compname
        if (os.path.isfile(os.path.join(rundir,rpointer)) and
            (not os.path.isfile(os.path.join(rundir,rpointer + inst_string)))):
            safe_copy(os.path.join(rundir, rpointer),
                      os.path.join(rundir, rpointer + inst_string))

        inst_string_label = inst_string
        if not inst_string_label:
            inst_string_label = "\"\""

        # create namelist output infile using user_nl_file as input
        user_nl_file = os.path.join(caseroot, "user_nl_" + compname + inst_string)
        expect(os.path.isfile(user_nl_file),
               "Missing required user_nl_file {} ".format(user_nl_file))
        infile = os.path.join(confdir, "namelist_infile")
        create_namelist_infile(case, user_nl_file, infile)
        namelist_infile = [infile]

        # create namelist and stream file(s) data component
        _create_namelists(case, confdir, inst_string, namelist_infile, nmlgen, data_list_path)

        # copy namelist files and stream text files, to rundir
        if os.path.isdir(rundir):
            filename = compname + "_in"
            file_src  = os.path.join(confdir, filename)
            file_dest = os.path.join(rundir, filename)
            if inst_string:
                file_dest += inst_string
            safe_copy(file_src,file_dest)

            for txtfile in glob.glob(os.path.join(confdir, "*txt*")):
                safe_copy(txtfile, rundir)

###############################################################################
def _main_func():
    caseroot = parse_input(sys.argv)
    with Case(caseroot) as case:
        buildnml(case, caseroot, "datm")


if __name__ == "__main__":
    _main_func()
