#!python
""" ggcm_diff: Diff the fields from two xdmf datasets. Exit value is number of
fields with differences. """

from __future__ import print_function
import argparse
import sys
import logging

import numpy as np

from viscid import readers
from viscid import vutil
from viscid.calculator import calc

verb = 0

def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('-t', type=float, default=60.0,
                        help='time to plot')
    parser.add_argument("-p", "--tol", type=float, default=0.0,
                        help="tolerance for relative difference")
    parser.add_argument("--show", "--plot", action="store_true",
                        help="plot differences (requires matplotlib)")
    parser.add_argument('files', metavar="file", nargs=2,
                        help='input files (2 please)')
    parser.add_argument('fields', metavar="field", nargs="*",
                    help="fields to difference (don't give any to check all)")
    args = vutil.common_argparse(parser)

    logging.info("a = {0}".format(args.files[0]))
    logging.info("b = {0}".format(args.files[1]))

    files = readers.load_files(args.files)
    fa = files[0]
    fb = files[1]
    fa.activate_time(args.t)
    fb.activate_time(args.t)

    nr_different = 0

    if len(args.fields) > 0:
        fldnames = []
        for fldname in args.fields:
            if fldname in fa:
                fldnames.append(fldname)
            else:
                logging.warn("file a has no field '{0}'".format(fldname))
                nr_different += 1
    else:
        fldnames = None

    for flda in fa.iter_fields(named=fldnames):
        try:
            with fb[flda.name] as fldb:
                diff = calc.diff(fldb, flda)
                if (np.abs(diff) > args.tol).any():
                    reldiff = np.max(diff.data) / np.max(flda.data)
                    logging.warn("fld '{0}' not equal, max relative diff = {1}"
                                 "".format(flda.name, reldiff))
                    nr_different += 1

                    if args.show:
                        from viscid.plot import vpyplot as vlt
                        from matplotlib import pyplot as plt

                        ax0 = plt.subplot(331)
                        plt.title("File A")
                        vlt.plot(flda, "x=0", show=False)
                        plt.subplot(332, sharex=ax0, sharey=ax0)
                        plt.title("File B")
                        vlt.plot(fldb, "x=0", show=False)
                        plt.subplot(333, sharex=ax0, sharey=ax0)
                        plt.title("Difference")
                        vlt.plot(diff, "x=0", show=False)
                        ax0 = plt.subplot(334)
                        vlt.plot(flda, "y=0", show=False)
                        plt.subplot(335, sharex=ax0, sharey=ax0)
                        vlt.plot(fldb, "y=0", show=False)
                        plt.subplot(336, sharex=ax0, sharey=ax0)
                        vlt.plot(diff, "y=0", show=False)
                        ax0 = plt.subplot(337)
                        vlt.plot(flda, "z=0", show=False)
                        plt.subplot(338, sharex=ax0, sharey=ax0)
                        vlt.plot(fldb, "z=0", show=False)
                        plt.subplot(339, sharex=ax0, sharey=ax0)
                        vlt.plot(diff, "z=0", show=False)
                        plt.show()
                else:
                    logging.info("field '{0}' equal".format(flda.name))
        except KeyError:
            logging.warn("file b has no field '{0}'".format(flda.name))
            nr_different += 1
            continue

    # now see if b has any fields that a didn't
    if fldnames is None:
        for fldb in fb.iter_fields():
            try:
                with fa[fldb.name] as flda:
                    pass
            except KeyError:
                logging.warn("file a has no field '{0}'".format(fldb.name))
                nr_different += 1
                continue

    if nr_different:
        if nr_different == 1:
            slug = "field is"
        else:
            slug = "fields are"
        logging.warning("In total, {0} {1} different"
                        "".format(nr_different, slug))
    else:
        logging.info("super, no differences found")
    return nr_different

if __name__ == "__main__":
    exit_code = main()
    sys.exit(exit_code)

##
## EOF
##
