#!python
""" Convert a legacy Py-ART grid NetCDF file to a modern NetCDF file. """

import argparse
import warnings

import netCDF4


def _transfer_var(dset, name, var, dimensions=None):
    """Copy the a variable to a different Dataset with a given name."""
    if dimensions is None:
        dimensions = var.dimensions
    new_var = dset.createVariable(name, var.dtype, dimensions)
    for ncattr in var.ncattrs():
        new_var.setncattr(ncattr, var.getncattr(ncattr))
    new_var[:] = var[:]


def main():
    """main function."""
    # parse the arguments
    parser = argparse.ArgumentParser(
        description="Convert a legacy Py-ART grid netCDF file"
    )
    parser.add_argument("infile", type=str, help="legacy netCDF grid file to covert")
    parser.add_argument("outfile", type=str, help="filename of new netCDF grid file")
    parser.add_argument(
        "-v",
        "--verb",
        dest="verbose",
        action="store_true",
        help="Verbose mode, print out debugging messages.",
    )
    args = parser.parse_args()

    if args.verbose:
        print("Converting:", args.infile, "-->", args.outfile)

    dset_legacy = netCDF4.Dataset(args.infile, "r")
    dset_modern = netCDF4.Dataset(args.outfile, "w")

    # transfer dimensions
    dset_modern.createDimension("time", None)
    dset_modern.createDimension("z", len(dset_legacy.dimensions["nz"]))
    dset_modern.createDimension("y", len(dset_legacy.dimensions["ny"]))
    dset_modern.createDimension("x", len(dset_legacy.dimensions["nx"]))

    # transfer axes variables
    variable_mappings = {
        "time": "time",
        "x_disp": "x",
        "y_disp": "y",
        "z_disp": "z",
        "lat": "origin_latitude",
        "lon": "origin_longitude",
        "alt": "origin_altitude",
    }
    dim_mapping = {"nx": "x", "ny": "y", "nz": "z", "time": "time"}
    for legacy_varname, modern_varname in variable_mappings.items():
        if args.verbose:
            print("Variable:", legacy_varname, "->", modern_varname)
        legacy_var = dset_legacy.variables[legacy_varname]
        dims = [dim_mapping[dim] for dim in legacy_var.dimensions]
        _transfer_var(dset_modern, modern_varname, legacy_var, dims)

    # transfer fields
    field_shape = tuple(len(dset_legacy.dimensions[d]) for d in ["nz", "ny", "nx"])
    field_shape_with_time = (1,) + field_shape
    axes_keys = [
        "time",
        "time_start",
        "time_end",
        "base_time",
        "time_offset",
        "z_disp",
        "y_disp",
        "x_disp",
        "alt",
        "lat",
        "lon",
        "z",
        "lev",
        "y",
        "x",
    ]
    for field in [k for k in dset_legacy.variables if k not in axes_keys]:
        if args.verbose:
            print("Field:", field)
        legacy_var = dset_legacy.variables[field]
        if legacy_var.shape != field_shape_with_time:
            warnings.warn("Field %s skipped due to incorrect shape" % (field))
            continue
        _transfer_var(dset_modern, field, legacy_var, ("time", "z", "y", "x"))

    # set a default projections variable
    projection_var = dset_modern.createVariable("projection", "c", ())
    projection_var.setncattr("_include_lon_0_lat_0", "true")
    projection_var.setncattr("proj", "pyart_aeqd")

    # set a default projection coordinate system
    proj = dset_modern.createVariable("ProjectionCoordinateSystem", "c", ())
    proj.setncattr("grid_mapping_name", "azimuthal_equidistant")
    proj.setncattr("semi_major_axis", 6370997.0)
    proj.setncattr("inverse_flattening", 298.25)
    proj.setncattr("longitude_of_prime_meridian", 0.0)
    proj.setncattr("false_easting", 0.0)
    proj.setncattr("false_northing", 0.0)

    lat = dset_legacy.variables["lat"][0]
    proj.setncattr("latitude_of_projection_origin", lat)

    lon = dset_legacy.variables["lon"][0]
    proj.setncattr("longitude_of_projection_origin", lon)

    proj.setncattr("_CoordinateTransformType", "Projection")
    proj.setncattr("_CoordinateAxes", "x y z time")
    proj.setncattr("_CoordinateAxesTypes", "GeoX GeoY Height Time")

    # Add Conventions if not already present
    if "Conventions" not in dset_modern.ncattrs():
        dset_modern.setncattr("Conventions", "PyART_GRID-1.1")

    # close files
    dset_legacy.close()
    dset_modern.close()


if __name__ == "__main__":
    main()
