#!/usr/bin/env python

"""
Set POP cppdefs and config_cache.xml file (the latter is used for namelist generation)
"""
# pylint: disable=unused-wildcard-import, wildcard-import
# pylint: disable=wrong-import-position, invalid-name
# pylint: disable=too-many-locals, too-many-statements, too-many-branches

import os
import sys

CIMEROOT = os.environ.get("CIMEROOT")
if CIMEROOT is None:
    raise SystemExit("ERROR: must set CIMEROOT environment variable")
sys.path.append(os.path.join(CIMEROOT, "scripts", "Tools"))

from standard_script_setup import *

from CIME.utils import expect
from CIME.utils import run_cmd
from CIME.case import Case
from CIME.buildnml import parse_input

logger = logging.getLogger(__name__)

_config_cache_template = """<?xml version="1.0"?>
<config_definition>
<commandline></commandline>
<entry id="comp_wav" value="{comp_wav}" list="" valid_values="">wave component</entry>
<entry id="irf_mode" value="{IRF_MODE}" list="" valid_values="NK_precond,offline_transport">IRF tracer module mode</entry>
<entry id="irf_nt" value="{IRF_NT}" list="" valid_values="">Number of IRF tracers</entry>
<entry id="marbl_nt" value="{MARBL_NT}" list="" valid_values="">Number of MARBL tracers</entry>
<entry id="ocn_bgc_config" value="{ocn_bgc_config}" list="" valid_values="">Version of BGC tunings to use</entry>
<entry id="ocn_grid" value="{ocn_grid}" list="" valid_values="">Ocean grid used for POP</entry>
<entry id="use_abio" value="{use_abio}" list="" valid_values="FALSE,TRUE">Use abio tracer module</entry>
<entry id="use_cfc" value="{use_cfc}" list="" valid_values="FALSE,TRUE">Use cfc tracer module</entry>
<entry id="use_ecosys" value="{use_ecosys}" list="" valid_values="FALSE,TRUE">Use ecosys tracer module</entry>
<entry id="use_iage" value="{use_iage}" list="" valid_values="FALSE,TRUE">Use ideal age tracer module</entry>
<entry id="use_irf" value="{use_irf}" list="" valid_values="FALSE,TRUE">Use IRF tracer module</entry>
<entry id="use_sf6" value="{use_sf6}" list="" valid_values="FALSE,TRUE">Use sf6 tracer module</entry>
</config_definition>
"""

###############################################################################
def determine_tracer_count(case, caseroot, srcroot, decomp_cppdefs):
###############################################################################
    """
    Determine the number of tracers
    """
    MARBL_NT = 0
    IRF_NT = 0
    IRF_MODE = "NK_precond"

    #----------------------------------------------------
    # Parse OCN_TRACER_MODULES_OPT
    #----------------------------------------------------

    # expression on RHS strips leading and trailing whitespace,
    # and replaces multiple whitespace with a single space
    ocn_tracer_modules_opt = " ".join(case.get_value("OCN_TRACER_MODULES_OPT").split())

    if ocn_tracer_modules_opt:
        module_opts = ocn_tracer_modules_opt.split(" ")
        for module_opt in module_opts:
            varname = module_opt.split("=")[0]
            value = module_opt.split("=")[1]
            if varname == "IRF_NT":
                IRF_NT = int(value)
            elif varname == "IRF_MODE":
                is_valid_value = (value == "offline_transport") or (value == "NK_precond")
                expect(is_valid_value, "%s is not a valid value for %s in OCN_TRACER_MODULES_OPT"
                       % (value, varname))
                IRF_MODE = value
            else:
                expect(False, "%s is not a valid value in OCN_TRACER_MODULES_OPT" %(varname))

    #----------------------------------------------------
    # Parse OCN_TRACER_MODULES
    #     set use_* vars and compute NT
    #----------------------------------------------------

    use_iage = "FALSE"
    use_cfc = "FALSE"
    use_sf6 = "FALSE"
    use_abio = "FALSE"
    use_ecosys = "FALSE"
    use_irf = "FALSE"

    NT = 2

    # expression on RHS strips leading and trailing whitespace,
    # and replaces multiple whitespace with a single space
    ocn_tracer_modules = " ".join(case.get_value("OCN_TRACER_MODULES").split())
    ocn_grid = case.get_value("OCN_GRID")

    if ocn_tracer_modules:
        for module in ocn_tracer_modules.split(" "):
            if module == "iage":
                use_iage = "TRUE"
                NT = NT + 1
            elif module == "cfc":
                use_cfc = "TRUE"
                NT = NT + 2
            elif module == "sf6":
                use_sf6 = "TRUE"
                NT = NT +  1
            elif module == "abio_dic_dic14":
                use_abio = "TRUE"
                NT = NT + 2
            elif module == "ecosys":
                use_ecosys = "TRUE"
                # call MARBL's tool to determine tracer count:
                # (i) Import MARBL_wrappers and generate MARBL_settings_for_POP object
                # Note that at this time we do not support different tracer counts in
                # different instances of a multi-instance run... so we use first instance
                # to set up config_cache.xml
                if case.get_value("NINST_OCN") > 1:
                    input_file = "user_nl_marbl_0001"
                else:
                    input_file = "user_nl_marbl"
                MARBL_dir = os.path.join(srcroot, "components", "pop", "externals", "MARBL")
                sys.path.append(os.path.join(srcroot, "components", "pop", "MARBL_scripts"))
                from MARBL_wrappers import MARBL_settings_for_POP
                MARBL_settings = MARBL_settings_for_POP(MARBL_dir, input_file, caseroot, ocn_grid,
                                                        case.get_value("RUN_TYPE"),
                                                        case.get_value("CONTINUE_RUN"),
                                                        case.get_value("OCN_BGC_CONFIG"))
                # (ii) get MARBL_NT
                MARBL_NT = MARBL_settings.get_MARBL_NT()
                NT = NT + MARBL_NT
                case.set_value("MARBL_NT", MARBL_NT)
            elif module == "IRF":
                use_irf = "TRUE"
                if IRF_NT == 0:
                    if IRF_MODE == "offline_transport" and ocn_grid == "gx3v7":
                        IRF_NT = 156
                    elif IRF_MODE == "offline_transport" and ocn_grid == "gx1v6":
                        IRF_NT = 178
                    elif IRF_MODE == "NK_precond":
                        IRF_NT = 36
                    else:
                        expect(False, "IRF_MODE %s is not a valid choice for IRF_MODE!" %IRF_MODE)
                NT = NT + IRF_NT
            else:
                expect(False, "module %s is not a valid value in OCN_TRACER_MODULES!" %module)

    #----------------------------------------------------
    # Determine tracer cppdefs and update pop_cppdefs
    #----------------------------------------------------

    tracer_cppdefs = " -DNT=%d -DMARBL_NT=%d -DIRF_NT=%d " % (NT, MARBL_NT, IRF_NT)

    other_cppdefs = ""
    if "tx0.1" in case.get_value("OCN_GRID"):
        other_cppdefs = other_cppdefs + " -D_HIRES "
    if case.get_value("OCN_ICE_FORCING") == 'inactive':
        other_cppdefs = other_cppdefs + " -DZERO_SEA_ICE_REF_SAL "
    if case.get_value("POP_TAVG_R8"):
        other_cppdefs = other_cppdefs + " -DTAVG_R8 "

    pop_cppdefs = "-DCCSMCOUPLED -DMARBL_TIMING_OPT=CIME" + decomp_cppdefs + \
                  other_cppdefs + tracer_cppdefs

    #----------------------------------------------------
    # Create config_cache.xml file
    #----------------------------------------------------

    config_cache_text = _config_cache_template.format(
        comp_wav=case.get_value("COMP_WAV"),
        cppdefs=pop_cppdefs,
        IRF_MODE=IRF_MODE,
        IRF_NT=IRF_NT,
        MARBL_NT=MARBL_NT,
        ocn_bgc_config=case.get_value("OCN_BGC_CONFIG"),
        ocn_grid=case.get_value("OCN_GRID"),
        use_abio=use_abio,
        use_cfc=use_cfc,
        use_ecosys=use_ecosys,
        use_iage=use_iage,
        use_irf=use_irf,
        use_sf6=use_sf6)

    config_cache_path = os.path.join(caseroot, "Buildconf", "popconf", "config_cache.xml")
    with open(config_cache_path, 'w') as config_cache_file:
        config_cache_file.write(config_cache_text)

    return pop_cppdefs

###############################################################################
def buildcpp(case):
###############################################################################
    """
    Determine the CPP flags values needed to build the pop component
    """
    #----------------------------------------------------
    # create $CASEROOT/Buildconf/popconf if it does not exist
    #----------------------------------------------------

    caseroot = case.get_value("CASEROOT")
    srcroot = case.get_value("SRCROOT")

    popconf_dir = os.path.join(caseroot, "Buildconf", "popconf")
    if not os.path.exists(popconf_dir):
        os.makedirs(popconf_dir)

    #----------------------------------------------------
    # determine decomposition xml variables if POP_AUTO_DECOMP is true
    #----------------------------------------------------

    # - invoke generate_pop_decomp.pl and update env_build.xml settings to
    # reflect changes in the configuration this will trigger
    pop_auto_decomp = case.get_value("POP_AUTO_DECOMP")
    if pop_auto_decomp:
        nthrds_ocn = case.get_value("NTHRDS_OCN")
        ntasks = case.get_value("NTASKS_PER_INST_OCN")
        ocn_grid = case.get_value("OCN_GRID")

        cmd = os.path.join(srcroot, "components", "pop", "bld", "generate_pop_decomp.pl")
        command = "%s -ccsmroot %s -res %s -nproc %s -thrds %s -output %s "  \
                  % (cmd, srcroot, ocn_grid, ntasks, nthrds_ocn, "all")
        rc, out, err = run_cmd(command)

        expect(rc == 0, "Command %s failed rc=%d\nout=%s\nerr=%s"%(cmd, rc, out, err))
        if out:
            logger.debug("     %s", out)
        if err:
            logger.info("     %s", err)

        config = out.split()
        if int(config[0]) > 0:
            case.set_value("POP_BLCKX", config[2])
            case.set_value("POP_BLCKY", config[3])
            case.set_value("POP_MXBLCKS", config[4])
            case.set_value("POP_DECOMPTYPE", config[5])
            case.set_value("POP_NX_BLOCKS", config[6])
            case.set_value("POP_NY_BLOCKS", config[7])
        else:
            expect(False, "pop decomp not set for "+str(ocn_grid)+\
                          " on "+str(ntasks)+" x "+str(nthrds_ocn)+" procs.")

    # set decomposition block sizes
    pop_blckx = case.get_value("POP_BLCKX")
    pop_blcky = case.get_value("POP_BLCKY")
    pop_mxblcks = case.get_value("POP_MXBLCKS")
    decomp_cppdefs = " -DBLCKX=%d -DBLCKY=%d -DMXBLCKS=%d" %(pop_blckx, pop_blcky, pop_mxblcks)

    #----------------------------------------------------
    # determine the tracer count cpp variables, create_config_cache.xml
    # and update cppdefs accordingly
    #----------------------------------------------------

    pop_cppdefs = determine_tracer_count(case, caseroot, srcroot, decomp_cppdefs)

    # Verify that config_cache.xml exists
    if not os.path.isfile(os.path.join(popconf_dir, "config_cache.xml")):
        expect(False, "config_cache.xml is missing after configure call")

    pop_cppdefs = case.set_value("POP_CPPDEFS", pop_cppdefs)
    case.flush()

    return pop_cppdefs

###############################################################################
def _main_func():

    caseroot = parse_input(sys.argv)
    with Case(caseroot) as case:
        pop_cppdefs = buildcpp(case)
    logger.info("POP_CPPDEFS: %s", pop_cppdefs)

if __name__ == "__main__":
    _main_func()
