#!/usr/bin/env python

"""
Queries key CIME shell commands (mpirun and batch submission).

To force a certain mpirun command, use:
   ./xmlchange MPI_RUN_COMMAND=$your_cmd

Example:
   ./xmlchange MPI_RUN_COMMAND='mpiexec -np 16 --some-flag'

To force a certain qsub command, use:
   ./xmlchange --subgroup=case.run BATCH_COMMAND_FLAGS=$your_flags

Example:
   ./xmlchange --subgroup=case.run BATCH_COMMAND_FLAGS='--some-flag --other-flag'
"""

from standard_script_setup import *

from CIME.case import Case
from CIME.utils import set_logger_indent

logger = logging.getLogger(__name__)

###############################################################################
def parse_command_line(args, description):
###############################################################################
    parser = argparse.ArgumentParser(
        description=description,
        formatter_class=argparse.RawTextHelpFormatter)

    CIME.utils.setup_standard_logging_options(parser)

    parser.add_argument("caseroot", nargs="?", default=os.getcwd(),
                        help="Case directory to query.\n"
                        "Default is current directory.")

    parser.add_argument("-j", "--job", default=None,
                        help="The job you want to print.\n"
                        "Default is case.run (or case.test if this is a test).")

    args = CIME.utils.parse_args_and_handle_standard_logging_options(args, parser)

    return args.caseroot, args.job

###############################################################################
def _main_func(description):
###############################################################################
    caseroot, job = parse_command_line(sys.argv, description)
    logging.disable(logging.INFO)

    with Case(caseroot, read_only=False) as case:
        print("CASE INFO:")
        print("  nodes: {}".format(case.num_nodes))
        print("  total tasks: {}".format(case.total_tasks))
        print("  tasks per node: {}".format(case.tasks_per_node))
        print("  thread count: {}".format(case.thread_count))
        print("")

        print("BATCH INFO:")
        if not job:
            job = case.get_first_job()
        set_logger_indent("      ")
        job_id_to_cmd = case.submit_jobs(dry_run=True, job=job)
        env_batch = case.get_env('batch')
        for job_id, cmd in job_id_to_cmd:
            print("  FOR JOB: {}".format(job_id))
            print("    ENV:")
            case.load_env(job=job_id, reset=True, verbose=True)

            if "OMP_NUM_THREADS" in os.environ:
                print("      Setting Environment OMP_NUM_THREADS={}".format(os.environ["OMP_NUM_THREADS"]))
            print("")
            print("    SUBMIT CMD:")
            print("      {}".format(case.get_resolved_value(cmd)))
            print("")
            if job_id in ("case.run", "case.test"):
                # get_job_overrides must come after the case.load_env since the cmd may use
                # env vars.
                overrides = env_batch.get_job_overrides(job_id, case)
                print("    MPIRUN (job={}):".format(job_id))
                print ("      {}".format(case.get_resolved_value(overrides["mpirun"])))
                print("")


if __name__ == "__main__":
    _main_func(__doc__)
