#!python

from __future__ import print_function
import argparse
import copy as _copy

import numpy as np
import matplotlib

import viscid
from viscid import logger
from viscid import vutil
from viscid.plot import vpyplot as vlt

# plt.rc('axes', color_cycle=list("brcykmg"))

def _ensure_value(namespace, name, value):
    if getattr(namespace, name, None) is None:
        setattr(namespace, name, value)
    return getattr(namespace, name)

class AddPlotOpts(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        dst = getattr(namespace, self.dest, None)
        if dst is None:
            setattr(namespace, "global_popts", values)
            return None
        logger.info("setting {0} plotopts to {1}".format(dst[-1][0], values))
        dst[-1][1]["plot_opts"] = values


class AddPlotVar(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        items = _copy.copy(_ensure_value(namespace, self.dest, []))
        items.append([values, {'plot_opts': ''}])
        setattr(namespace, self.dest, items)

def split_floats(arg_str):
    return [float(s) for s in arg_str.split(',')]


def main():
    parser = argparse.ArgumentParser(description=__doc__,
                                     conflict_handler='resolve')
    vutil.add_animate_arguments(parser)
    vutil.add_mpl_output_arguments(parser)
    parser.add_argument("-t", default=":", help="times to plot in slice "
                        "notation: ex : for all 60.0: for 60 mins on, and "
                        ":120 for the first 120mins")
    parser.add_argument("-x", default=None, help="limits on time axis")
    parser.add_argument("-y", default=None, help="limits on y axis")
    parser.add_argument("--slice", default=None,
                        help="spatial slice applied to all fields")
    parser.add_argument("-p", "--var", dest="plot_vars", action=AddPlotVar,
                        help="plot variable")
    parser.add_argument("-o", "--popts", dest="plot_vars", action=AddPlotOpts,
                        help="add plot options to most recent var on the "
                        "command line. if no preceeding vars, popts are "
                        "applied to all plots")
    parser.add_argument("--one", action="store_true",
                        help="plot all vars on same plot")
    parser.add_argument("--nofname", action="store_true",
                        help="do not put the filename in the title")
    parser.add_argument("--timeformat", default="",
                        help="style for time axis: 'ut', 'hms', 'dhms', "
                        "'.02f', '' (default)")
    parser.add_argument('--rotateticklabels', '--rl', action='store_true',
                        help="tilt xtick labels")
    parser.add_argument('file', nargs='+', help='input file')
    args = vutil.common_argparse(parser)

    global_popts = getattr(args, "global_popts", "")

    if getattr(args, "plot_vars", None) is None:
        args.plot_vars = [["pp,x=8.0f,y=0.0f,z=0.0f", {"plot_opts":""}],
                          ["by,x=8.0f,y=0.0f,z=0.0f", {"plot_opts":""}]]

    if args.prefix is None:
        args.show = True

    file_ = viscid.load_file(args.file)
    t = np.array([grid.time for grid in file_.iter_times(args.t)])
    if len(t) == 0:
        raise ValueError("Time slice didn't yield any times.")
    plot_names = [None for _ in range(len(args.plot_vars))]
    plot_arrs = [np.zeros_like(t) for _ in range(len(args.plot_vars))]

    for i, grid in enumerate(file_.iter_times(args.t)):
        for j, pvar in enumerate(args.plot_vars):
            pvname, slc, eqn = pvar[0], '', ''
            if ',' in pvname:
                split_pvname = pvname.split(',')
                pvname, slc = split_pvname[0], ','.join(split_pvname[1:])
            if '=' in pvname:
                split_pvname = pvname.split('=')
                pvname, eqn = split_pvname[0], '=' + '='.join(split_pvname[1:])

            if i == 0:
                # FIXME: if pvar[0] is an equation, it does
                # the calculation on the whole grid just to
                # get the name of the resulting field, but the
                # call to get_field will return a scalar, so...
                try:
                    if eqn:
                        plot_names[j] = pvname
                    else:
                        plot_names[j] = grid[pvname].blocks[0].pretty_name
                except AttributeError:
                    plot_names[j] = pvname

            if args.slice and slc:
                _slc = args.slice + "," + slc
            elif slc:
                _slc = slc
            else:
                _slc = args.slice
            val = grid.get_field(pvname + eqn, slc=_slc)

            if val.size > 1:
                raise RuntimeError("you didn't slice away enough")
            elif val.size == 0:
                raise RuntimeError("you sliced away too much?")

            plot_arrs[j][i] = val

    n_pvars = len(args.plot_vars)

    # wrap fields?
    all_plot_kwargs = dict(nolabels=True)
    vlt.plot_opts_to_kwargs(global_popts, all_plot_kwargs)

    for j, plot_var in enumerate(args.plot_vars):
        ts_as_fld = viscid.arrays2field(plot_arrs[j], t, name=plot_names[j])

        plot_kwargs = all_plot_kwargs.copy()
        plot_kwargs['label'] = plot_names[j]
        vlt.plot_opts_to_kwargs(plot_var[1]['plot_opts'], plot_kwargs)

        if not args.one:
            vlt.plt.subplot2grid((n_pvars, 1), (j, 0))
            vlt.plt.ylabel(plot_names[j])

        vlt.plot(ts_as_fld, **plot_kwargs)

        # format the time axis on the last plot
        if j == n_pvars - 1:
            def timeTicks(t, pos):  # pylint: disable=unused-argument
                return file_.format_time(style=args.timeformat, time=t)
            formatter = matplotlib.ticker.FuncFormatter(timeTicks)
            vlt.plt.gca().xaxis.set_major_formatter(formatter)
            if args.rotateticklabels:
                vlt.plt.gcf().autofmt_xdate()
        else:
            vlt.plt.gca().set_xticklabels([])

    if args.one:
        vlt.plt.legend(loc=0)

    if not args.nofname:
        if args.one:
            vlt.plt.title(file_.fname)
        else:
            vlt.plt.suptitle(file_.fname)

    if not args.rotateticklabels:
        vlt.plt.xlabel("time")
    if args.x is not None:
        vlt.plt.xlim(*[float(x) for x in args.x.split('_')])
    if args.y is not None:
        vlt.plt.ylim(*[float(y) for y in args.y.split('_')])

    # plt.tight_layout()

    if args.show:
        vlt.plt.show()
    else:
        vlt.plt.savefig("{0}.{1}".format(args.prefix, args.format))

if __name__ == "__main__":
    main()

##
## EOF
##
