#! /usr/bin/env python

"""
Convert a grid file from v1 to v2.
"""

import argparse, sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

from standard_script_setup import *
from CIME.utils import expect
from CIME.XML.generic_xml import GenericXML
import xml.etree.ElementTree as ET

from collections import OrderedDict

###############################################################################
def parse_command_line(args, description):
###############################################################################
    parser = argparse.ArgumentParser(
        usage="""\n{0} <V1-GRID-FILE>
OR
{0} --help
""".format(os.path.basename(args[0])),
        description=description,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    CIME.utils.setup_standard_logging_options(parser)

    parser.add_argument("v1file",
                        help="v1 file path")

    args = CIME.utils.parse_args_and_handle_standard_logging_options(args, parser)

    return args.v1file

###############################################################################
def convert_gridmaps(v1file_obj, v2file_obj):
###############################################################################
    gridmap_data = [] # (attribs, {name->file})

    v1gridmaps = v1file_obj.get_child(name="gridmaps")
    v1gridmap = v1file_obj.get_children(name="gridmap", root=v1gridmaps)
    for gridmap_block in v1gridmap:
        attribs = v1file_obj.attrib(gridmap_block)
        children = []
        for child in v1file_obj.get_children(root=gridmap_block):
            children.append( (v1file_obj.name(child), v1file_obj.text(child)) )

        gridmap_data.append( (attribs, children) )

    v2gridmaps = v2file_obj.make_child("gridmaps")

    for attribs, children in gridmap_data:
        gridmap = v2file_obj.make_child("gridmap", attributes=attribs, root=v2gridmaps)
        for name, text in children:
            v2file_obj.make_child("map", attributes={"name": name}, root=gridmap, text=text)

###############################################################################
def convert_domains(v1file_obj, v2file_obj):
###############################################################################
    domain_data = [] # (name, nx, ny, {filemask->mask->file}, {pathmask->mask->path}, desc)

    v1domains = v1file_obj.get_child(name="domains")
    v1domain = v1file_obj.get_children(name="domain", root=v1domains)
    for domain_block in v1domain:
        attrib = v1file_obj.attrib(domain_block)
        expect(attrib.keys() == ["name"],
               "Unexpected attribs: {}".format(attrib))

        name = attrib["name"]

        desc = v1file_obj.get_element_text("desc", root=domain_block)
        sup  = v1file_obj.get_element_text("support", root=domain_block)
        nx   = v1file_obj.get_element_text("nx", root=domain_block)
        ny   = v1file_obj.get_element_text("ny", root=domain_block)

        if sup and not desc:
            desc = sup

        file_masks, path_masks = OrderedDict(), OrderedDict()

        for child_name, masks in [("file", file_masks), ("path", path_masks)]:
            children = v1file_obj.get_children(name=child_name, root=domain_block)
            for child in children:
                attrib = v1file_obj.attrib(child)
                expect(len(attrib) == 1, "Bad {} attrib: {}".format(child_name, attrib))
                mask_key, mask_value = attrib.items()[0]

                component, _ = mask_key.split("_")
                masks.setdefault(component, OrderedDict())[mask_value] = v1file_obj.text(child)

        for child in v1file_obj.get_children(root=domain_block):
            expect(v1file_obj.name(child) in ["nx", "ny", "file", "path", "desc", "support"],
                   "Unhandled child of grid '{}'".format(v1file_obj.name(child)))

        domain_data.append( (name, nx, ny, file_masks, path_masks, desc) )

    v2domains = v2file_obj.make_child("domains")

    for name, nx, ny, file_masks, path_masks, desc in domain_data:
        attribs = {"name":name} if name else {}
        domain_block = v2file_obj.make_child("domain", attributes=attribs, root=v2domains)

        v2file_obj.make_child("nx", root=domain_block, text=nx)
        v2file_obj.make_child("ny", root=domain_block, text=ny)

        file_to_attrib = OrderedDict()
        for component, mask_values in file_masks.iteritems():
            for mask_value, filename in mask_values.iteritems():
                if filename is None:
                    continue

                try:
                    path = path_masks[component][mask_value]
                except KeyError:
                    path = "$DIN_LOC_ROOT/share/domains"

                fullfile = os.path.join(path, filename)
                mask_value = mask_value if mask_value not in ["reg", name] else ""
                file_to_attrib.setdefault(fullfile, OrderedDict()).setdefault(mask_value, []).append(component)

        for filename, masks in file_to_attrib.iteritems():
            attrib = {}
            expect(len(masks) == 1, "Bad mask")
            for mask, components in masks.iteritems():
                attrib["grid"] = "|".join(components)

            if mask:
                attrib["mask"] = mask

            v2file_obj.make_child("file", attributes=attrib, root=domain_block, text=filename)

        if desc:
            v2file_obj.make_child("desc", root=domain_block, text=desc)

###############################################################################
def convert_grids(v1file_obj, v2file_obj):
###############################################################################
    grid_data = [] # (compset, lname, sname, alias, support)

    v1grids = v1file_obj.get_child(name="grids")
    v1grid = v1file_obj.get_children(name="grid", root=v1grids)
    for grid_block in v1grid:
        attrib = v1file_obj.attrib(grid_block)

        compset = attrib["compset"] if "compset" in attrib else None
        expect(attrib.keys() in [ ["compset"], [] ],
               "Unexpected attribs: {}".format(attrib))

        lname   = v1file_obj.get_element_text("lname", root=grid_block)
        sname   = v1file_obj.get_element_text("sname", root=grid_block)
        alias   = v1file_obj.get_element_text("alias", root=grid_block)
        support = v1file_obj.get_element_text("support", root=grid_block)

        for child in v1file_obj.get_children(root=grid_block):
            expect(v1file_obj.name(child) in ["lname", "sname", "alias", "support"],
                   "Unhandled child of grid '{}'".format(v1file_obj.name(child)))

        grid_data.append((compset, lname, sname, alias, support))

    v2grids = v2file_obj.make_child("grids")

    # TODO: How to leverage model_grid_defaults

    for compset, lname, sname, alias, support in grid_data:
        v2_alias = alias if alias else sname
        attribs = {"alias":v2_alias} if v2_alias else {}
        attribs.update({"compset":compset} if compset else {})
        v2grid = v2file_obj.make_child("model_grid", attributes=attribs, root=v2grids)

        pieces_raw = lname.split("_")
        pieces = []
        for raw_piece in pieces_raw:
            if "%" in raw_piece:
                pieces.append(raw_piece)
            else:
                pieces[-1] += ("_" + raw_piece)

        ctype_map = {"a":"atm", "l":"lnd", "oi":"ocnice", "r":"rof", "m":"mask", "g":"glc", "w":"wav"}
        mask = None
        for piece in pieces:
            ctype, data = piece.split("%")
            cname = ctype_map[ctype.strip()]
            if cname == "mask":
                expect(mask is None, "Multiple masks")
                mask = data
            else:
                v2file_obj.make_child("grid", attributes={"name":cname}, text=data, root=v2grid)

        if mask is not None:
            v2file_obj.make_child("mask", text=mask, root=v2grid)

###############################################################################
def convert_to_v2(v1file):
###############################################################################
    v1file_obj = GenericXML(infile=v1file, read_only=True)
    v2file_obj = GenericXML(infile="out.xml", read_only=False, root_name_override="grid_data", root_attrib_override={"version":"2.0"})

    convert_grids(v1file_obj, v2file_obj)

    convert_domains(v1file_obj, v2file_obj)

    convert_gridmaps(v1file_obj, v2file_obj)

    v2file_obj.write(outfile=sys.stdout)

###############################################################################
def _main_func(description):
###############################################################################
    v1file = parse_command_line(sys.argv, description)

    convert_to_v2(v1file)

if (__name__ == "__main__"):
    _main_func(__doc__)
