#!/usr/bin/env python

"""pop physics cycle postrun script
"""

import os, shutil, sys

_CIMEROOT = os.environ.get("CIMEROOT")
if _CIMEROOT is None:
    raise SystemExit("ERROR: must set CIMEROOT environment variable")
sys.path.append(os.path.join(_CIMEROOT, "scripts", "Tools"))

from standard_script_setup import *
from CIME.case import Case
from CIME.utils import expect

logger = logging.getLogger(__name__)

###############################################################################
def phys_cycle_postrun(caseroot):
###############################################################################

    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setLevel(logging.INFO)
    logger.addHandler(ch)

    with Case(caseroot) as case:
        pop_phys_cycle_years_in_cycle = case.get_value("POP_PHYS_CYCLE_YEARS_IN_CYCLE")
        expect(pop_phys_cycle_years_in_cycle is not None,
               "POP_PHYS_CYCLE_YEARS_IN_CYCLE should be defined when %s is called" % sys.argv[0])
        if (pop_phys_cycle_years_in_cycle <= 0):
            return

        # -------------------------------------------------------------------------
        # Determine if model has run the length of POP physics run cycle length
        # -------------------------------------------------------------------------

        run_type    = case.get_value("RUN_TYPE")
        stop_option = case.get_value("STOP_OPTION")
        stop_n      = case.get_value("STOP_N")
        caseroot    = case.get_value("CASEROOT")
        rundir      = case.get_value("RUNDIR")

        pop_phys_cycle_months_run_since_cycle = case.get_value("POP_PHYS_CYCLE_MONTHS_RUN_SINCE_CYCLE")
        pop_phys_cycle_years_in_cycle         = case.get_value("POP_PHYS_CYCLE_YEARS_IN_CYCLE")

        # we know from phys_cycle_prerun that stop_option is either "nmonth*" or "nyear*"

        if stop_option.find("nmonth") == 0:
            pop_phys_cycle_months_run_since_cycle_new = pop_phys_cycle_months_run_since_cycle + stop_n
        if stop_option.find("nyear") == 0:
            pop_phys_cycle_months_run_since_cycle_new = pop_phys_cycle_months_run_since_cycle + stop_n * 12

        case.set_value("POP_PHYS_CYCLE_MONTHS_RUN_SINCE_CYCLE", pop_phys_cycle_months_run_since_cycle_new)
        logger.info("pop_phys_cycle_months_run_since_cycle is now %d" % pop_phys_cycle_months_run_since_cycle_new)

        pop_phys_cycle_months_in_cycle = pop_phys_cycle_years_in_cycle * 12
        if pop_phys_cycle_months_run_since_cycle_new < pop_phys_cycle_months_in_cycle:
            case.set_value("CONTINUE_RUN", "TRUE")
            case.set_value("POP_PASSIVE_TRACER_RESTART_OVERRIDE", "none")
            case.flush()
            return

        expect(pop_phys_cycle_months_run_since_cycle_new <= pop_phys_cycle_months_in_cycle,
               "POP_PHYS_CYCLE_MONTHS_RUN_SINCE_CYCLE has overshot POP_PHYS_CYCLE_MONTHS_IN_CYCLE=%d"
               % pop_phys_cycle_months_in_cycle)

        # -------------------------------------------------------------------------
        # Model has run the length of POP physics run cycle length
        # Reset model physics, keeping restart file for tracers to continue to evolve
        # -------------------------------------------------------------------------

        logger.info("resetting model physics")
        case.set_value("CONTINUE_RUN", "FALSE")
        case.set_value("POP_PHYS_CYCLE_MONTHS_RUN_SINCE_CYCLE", 0)

        with open(os.path.join(rundir, "rpointer.ocn.restart"),'r') as f:
            case.set_value("POP_PASSIVE_TRACER_RESTART_OVERRIDE", os.path.join(rundir, f.readline()))

        # -------------------------------------------------------------------------
        # Compute new RUN_STARTDATE
        # -------------------------------------------------------------------------

        run_startdate = case.get_value("RUN_STARTDATE")
        yyyy = int(run_startdate[:4])
        yyyy = yyyy + pop_phys_cycle_years_in_cycle
        run_startdate = "%04d" % yyyy + run_startdate[4:]
        case.set_value("RUN_STARTDATE", run_startdate)

        # -------------------------------------------------------------------------
        # Compute new _CPLHIST_YR_ALIGN variables
        # -------------------------------------------------------------------------

        for align_varname in [ "DATM_CPLHIST_YR_ALIGN", "DROF_CPLHIST_YR_ALIGN" ]:
            align_val = case.get_value(align_varname)
            if align_val is not None:
                align_val = align_val + pop_phys_cycle_years_in_cycle
                case.set_value(align_varname, align_val)

        case.flush()

        # -------------------------------------------------------------------------
        # re-prestage POP rpointer files if necessary
        # -------------------------------------------------------------------------

        if run_type != "startup":
            for rpointer_filename in [ "rpointer.ocn.restart", "rpointer.ocn.ovf" ]:
                sfile = os.path.join(caseroot, "Buildconf", "popconf", rpointer_filename+".orig")
                dfile = os.path.join(rundir, rpointer_filename)
                expect(os.path.isfile(sfile), "required pop file %s is missing" % sfile)
                shutil.copy(sfile, dfile)

###############################################################################

if __name__ == "__main__":
    expect(len(sys.argv) == 2, "caseroot is a required input argument")
    caseroot = sys.argv[1]
    phys_cycle_postrun(caseroot)
