#!/usr/bin/env python3
from matplotlib import pyplot as plt
import argparse
import sys
import jtutils
import csv

def readCL():
    desc_string = """Quick graphing utility for csvs of one or two columns.
One column csvs are graphed as histograms, two column csvs are plotted as scatterplots (<= 1000 points) or binned scatterplots (>1000 points).
A binned scatterplot divides the points into 10 bins by x-value and plots the average (x,y) for each bin.
"""
    parser = argparse.ArgumentParser(description=desc_string, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument("-f","--infile",default="/dev/stdin")
    parser.add_argument("-n","--no_header",action="store_true")
    parser.add_argument("-b","--do_bin_plot",action="store_true")
    parser.add_argument("-s","--do_scatter_plot",action="store_true")
    parser.add_argument("-l","--do_line_plot",action="store_true")
    parser.add_argument("-k","--keep_outliers",action="store_true", help="gets rid of points with any dimension more than 5sd away from the mean")
    parser.add_argument("-o","--outfile")
    args = parser.parse_args()
    return args.infile, args.no_header, args.do_bin_plot, args.do_scatter_plot, args.do_line_plot, args.keep_outliers, args.outfile


def center_plot(pts, plt):
    xmin = min([p[0] for p in pts])
    xmax = max([p[0] for p in pts])
    ymin = min([p[1] for p in pts])
    ymax = max([p[1] for p in pts])
    x_range = xmax - xmin; y_range = ymax - ymin
    eps = 0.05
    #don't reset if the ranges are empty, which happens
    #when plotting one point
    if x_range != 0 and y_range != 0:
        plt.xlim(xmin - eps * x_range, xmax + eps * x_range)
        plt.ylim(ymin - eps * y_range, ymax + eps * y_range)

def counter(l):
    cnt_dict = {}
    for i in l:
        cnt_dict[tuple(i)] = cnt_dict.get(tuple(i),0) + 1
    return cnt_dict

#also can try context="fivethirtyeight"
def scatter(pts, context="ggplot", title=None, line=False, outfile=None, freq_sizes=False):
    with plt.style.context(context):
        x,y = zip(*pts)
        if line:
            plt.plot(x, y)
            center_plot(pts, plt)
        else:
            if freq_sizes:
                cnt_dict = counter(pts)
                size = [cnt_dict[tuple(p)] + 15 for p in pts] #make minimum size 16 because 1 is too small to see
                plt.scatter(x, y, s=size)
            else:
                plt.scatter(x, y)
    if title:
        plt.suptitle(title)
    if outfile:
        plt.savefig(outfile)
    else:
        plt.show()

def discrete_cmap(N, base_cmap=None):
    """copied from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
    Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    #    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:
    import numpy as np
    base = plt.cm.get_cmap(base_cmap)
    #jtrigg@20160319 bugfix: below line fails when N=1 so try max(N,2)
    color_list = base(np.linspace(0, 1, max(N,2)))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)



def scatter_3d(pts, context="ggplot", title=None, outfile=None):
    with plt.style.context(context):
        x,y,z = zip(*pts)
        z_vals = list(set(z))
        N = len(z_vals)
        if N <= 5:
            z_val_bins = process_color_col(z_vals) #map z_vals to distinct floats
            for i in range(N):
                val = z_vals[i]
                color_bin = z_val_bins[i]
                x_subset, y_subset = zip(*[(x,y) for x,y,z in pts if z == val])
                plt.scatter(x_subset, y_subset, label=val, color=discrete_cmap(N, 'rainbow')(color_bin)) #http://matplotlib.org/examples/color/colormaps_reference.html
            #set legend (and move it outside the plot box)
            ax = plt.subplot(111)
            box = ax.get_position()
            ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
            plt.legend(bbox_to_anchor=(1.3,0.7))
        else:
            z_color_bins = process_color_col(z)
            plt.scatter(x, y, c=z_color_bins, cmap='jet')
    if title:
        plt.suptitle(title)
    if outfile:
        plt.savefig(outfile)
    else:
        plt.show()


def line_plot(pts, outfile=None):
    pts.sort() #sort by the x-axis (should there be an option to skip this step?)
    assert(len(pts) > 0)
    N = len(pts[0])
    assert(N > 1)
    x = [l[0] for l in pts]
    for i in range(1,N):
        y = [l[i] for l in pts]
        plt.plot(x, y, color=discrete_cmap(N, 'rainbow')(i/float(N)), ls="-")
    if outfile:
        plt.savefig(outfile)
    else:
        plt.show()


def process_color_col(vals):
    """Convert a list of values eg ["black", "white", "white", "black"]
    to a list of floats [0,1,1,0] for use in matplotlib colormaps
    """
    all_float = all(jtutils.str_is_float(v) for v in vals)
    distinct = set(vals)
    if not all_float and len(distinct) <= 5:
        val_to_id = dict((v,i) for i,v in enumerate(distinct))
        return [val_to_id[v] for v in vals]
    elif not all_float:
        raise Exception("ERROR: color (third) column must either be all floats or have <= 5 distinct values")
    else:
        #normalize to 0-1
        vals = [float(v) for v in vals]
        minval = min(vals)
        vals = [v - minval for v in vals]
        maxval = max(vals)
        #jtrigg20160319 bugfix: if all values are the same don't try to normalize
        if maxval > 0:
            vals = [v / maxval for v in vals]
        return vals


def bin_plot(pts, num_bins = 10, outfile=None):
    #sort pts by x value:
    pts.sort(key = lambda x: x[0])
    l = len(pts)
    outpts = []
    for i in range(int(num_bins)):
        break_pt1 = int(i*l/float(num_bins))
        break_pt2 = int((i+1)*l/float(num_bins))
        x_vals, y_vals = zip(*pts[break_pt1:break_pt2])
        x_val = sum(x_vals) / float(len(x_vals))
        y_val = sum(y_vals) / float(len(y_vals))
        outpts.append((x_val,y_val))
    # print outpts
    scatter(outpts, title="Bin plot", outfile=outfile)

def curve_fit(xy_pts, fn, plot_points=True):
    #xy_pts = [(x,y)]
    #function of the form f(x, p_1, p_2,.. p_k)
    #for parameters p_i
    import scipy.optimize
    import matplotlib.pyplot as plt
    import numpy
    x, y = zip(*xy_pts)
    popt, pcov = scipy.optimize.curve_fit(fn, x, y, [1.0,1.0])

    # Plot data
    if plot_points:
        plt.plot(x, y, 'or')

    # Plot fit curve
    fit_x = numpy.linspace(min(x), max(x), 200)
    plt.plot(fit_x, fn(fit_x, *popt), "--r")
    plt.show()


def ls_dist_curve_fit(pts, fn, cnt=25):
    #least square curve fitting
    import matplotlib.pyplot as plt
    from jtutils import pairwise
    n, bins, patches = plt.hist(pts, cnt, density=1, histtype='stepfilled', rwidth=0.8)
    bin_mids = [(b1 + b2) / 2.0 for b1,b2 in pairwise(bins) if b2]
    print(len(bins))
    xy_pts = zip(bin_mids,n)
    curve_fit(xy_pts,fn,plot_points=False)


def mle_dist_curve_fit(pts, fn, cnt=25):
    #maximum likelihood curve fitting
    #http://glowingpython.blogspot.com/2012/07/distribution-fitting-with-scipy.html
    import scipy.stats
    import numpy
    import matplotlib.pyplot as plt

    #choices for fn:
    #scipy.stats.norm
    #scipy.stats.expon
    #scipy.stats.powerlaw

    fn_dict = {"normal":scipy.stats.norm,
               "exponential":scipy.stats.expon,
               "power":scipy.stats.pareto}
    if fn in fn_dict:
        fn = fn_dict[fn]

    param = fn.fit(pts) # distribution fitting
    print("Parameters: ",param)
    xmin = min(pts)
    xmax = max(pts)
    x = numpy.linspace(xmin,xmax,100)
    # fitted distribution
    pdf_fitted = fn.pdf(x,*param)
    nll = fn.nnlf(param,x)
    print("Negative log likelihood: ", nll)
    # original distribution
    # pdf = scipy.stats.norm.pdf(x)

    plt.title('Histogram + fit')
    plt.plot(x,pdf_fitted,'r-') #,x,pdf,'b-')
    plt.hist(pts,cnt,density=1,alpha=.3)
    plt.show()


def mle_fit_normal_curve(pts):
    #maximum likelihood curve fitting for pts = [x]
    import scipy
    import scipy.optimize
    import scipy.stats
    def neg_llh(params, *args):
        pts = args
        loc, scale = params
        llh = sum([scipy.stats.norm.logpdf(x,loc,scale) for x in pts])
        return -1 * llh
    res = scipy.optimize.minimize(neg_llh, (0.0,1.0), args=tuple(pts), method='TNC', bounds=[(None,None),(1e-10,None)])
    return res


def mle_fit_pareto_curve(pts):
    #maximum likelihood curve fitting for pts = [x]
    #pareto.pdf(x, b, loc, scale) = (b/scale) / ((x-loc)/scale)**(b+1)
    #for (x-loc)/scale >= 1, b > 0
    import scipy
    import scipy.optimize
    import scipy.stats
    import math
    import numpy
    def neg_llh(params, *args):
        min_x, pts = args
        b, scale = params
        # scale + loc <= min(x)
        # loc = min_x - scale
        loc = min_x - scale
        # print(b,loc,scale)
        # print([scipy.stats.pareto.logpdf(x,b,loc,scale) for x in pts])
        llh = sum([scipy.stats.pareto.logpdf(x,b,loc,scale) for x in pts])
        return -1 * llh
    def grad(params, *args):
        min_x, pts = args
        b, scale = params
        grad_b = -1 * sum([1/float(b) - math.log((x-min_x)/float(scale) + 1) for x in pts])
        grad_scale = -1 * sum([-1/float(scale) + (b+1)* ((x-min_x)/float(scale)**2) / ((x-min_x)/float(scale) + 1) for x in pts])
        return numpy.array([grad_b, grad_scale])

    min_x = min(pts)
    print("Fitting pareto...")
    res = scipy.optimize.minimize(neg_llh, [1.0,1.0], jac=grad, args=(min_x,pts), method='L-BFGS-B', bounds=[(0.01,None),(0.01,None)], options={"disp":True})
    # res = scipy.optimize.minimize(neg_llh, [1.0,1.0], jac=grad, args=(min_x,pts), method='nelder-mead', bounds=[(0.01,None),(0.01,None)], options={"disp":True})
    print(res)
    return res


def plot_hist(pts,cnt=25,context="ggplot",outfile=None):
    pts = [x[0] for x in pts]
    x_cnt = len(set(pts))
    x_min = min(pts)
    x_max = max(pts)
    cnt = min(x_cnt,cnt) #don't have more buckets than x values

    #default x_range of the hist is [x_min, x_max]
    #increase this size a bit to make space in case some
    #bins are *centered* at x_min or x_max
    #(this happens in dice roll histogram with 6 bins)
    bin_width = (x_max - x_min) / (cnt - 1)
    x_range = [x_min - 0.5 * bin_width, x_max + 0.5 * bin_width]

    with plt.style.context(context):
        n, bins, patches = plt.hist(pts, cnt, range=x_range, density=1) #, histtype='stepfilled', rwidth=0.8)
        plt.setp(patches, 'facecolor', 'g', 'alpha', 0.75)
    if outfile:
        plt.savefig(outfile)
    plt.show()


def plot_images(*image_list):
    import math
    n = len(image_list)
    d = math.ceil(n**0.5)
    for i,ix in enumerate(image_list):
        plt.subplot(d,d,1 + i)
        plt.imshow(ix, cmap=plt.cm.gray)
        plt.axis("off")
    plt.show()


def readcsv(f):
    if isinstance(f, file):
        f_in = f
        for r in _readcsv(f_in):
            yield r
    elif isinstance(f, str):
        with open(f) as f_in:
            for r in _readcsv(f_in):
                yield r
    else:
        raise

def _readcsv(f_in):
    header = None
    for line in csv.reader(f_in):
        if not header:
            header = line
        else:
            yield OrderedDict(zip(header,line))

def quartile_cutoffs(l):
    #return (25,50,75%ile) of list
    l_sort = sorted(l)
    n = len(l)
    return l_sort[int(n/4.0)],l_sort[int(n/2.0)],l_sort[int(3*n/4.0)]

if __name__ == "__main__":
    infile, no_header, do_bin_plot, do_scatter_plot, do_line_plot, keep_outliers, outfile = readCL()

    #jtrigg avoid pandas for quicker loadtimes
    # import pandas as pd
    # import numpy

    # if no_header:
    #     dat = pd.read_csv(infile, dtype=str, header = None)
    # else:
    #     dat = pd.read_csv(infile, dtype=str)

    # print("imported")

    with open(infile) as f_in:
        if not no_header:
            try:
                hdr = f_in.readline().strip()
            except StopIteration:
                sys.stderr.write("ERROR: input has no rows. Exiting..." + "\n")
                sys.exit(-1)
            sys.stdout.write("WARNING: pgraph using first line of input, \"{hdr}\", as header. If file doesn't have a header use -n option.".format(**vars()) + '\n')
        df = [l.strip().split(',') for l in f_in]
        l1 = len(df)
        if do_line_plot:
            df = [[float(x) for x in r] for r in df if all(jtutils.str_is_float(x) for x in r)]
        else:
            #only drop rows with non-floats in first two columns (third column used for coloring)
            df = [r for r in df if all(jtutils.str_is_float(x) for x in r[:2])]
            df = [[float(x) for x in r[:2]] + r[2:] for r in df]
        l2 = len(df)
        dropped = l1 - l2
        #convert first two columns to float

    if len(df) == 0:
        sys.stderr.write("ERROR: input has no rows. Exiting..." + "\n")
        sys.exit(-1)

    if dropped != 0:
        sys.stderr.write("WARNING: dropped {} non float values".format(l1-l2) + '\n')

    if not keep_outliers and not line_plot:
        import numpy as np
        transpose = zip(*df)

        cutoffs = [quartile_cutoffs(t) for t in transpose] #cutoffs = [(q1,median,q3)]

        IQR_CONST = 10 #determines where to set the outlier cutoff

        mins = [median - IQR_CONST * (median - q1) for q1,median,q3 in cutoffs]
        maxs = [median + IQR_CONST * (q3 - median) for q1,median,q3 in cutoffs]

        df = [r for r in df if all((x >= mins[i] and x <= maxs[i]) for i,x in enumerate(r))]
        l3 = len(df)
        if l2 - l3 != 0:
            sys.stderr.write("WARNING: dropped {} outliers".format(l2-l3) + "\n")
            print(mins, maxs)


    if do_line_plot:
        line_plot(df, outfile=outfile)
        sys.exit(0)

    if len(df[0]) == 1:
        plot_hist(df, outfile=outfile)
    elif len(df[0]) == 2:
        if do_bin_plot:
            bin_plot(df, outfile=outfile)
        elif do_scatter_plot:
            scatter(df, outfile=outfile, freq_sizes=True)
        elif len(df) > 1000:
            bin_plot(df, outfile=outfile)
        else:
            scatter(df, outfile=outfile, freq_sizes=True)
    elif len(df[0]) == 3:
        #indicate third dimension by color
        scatter_3d(df, outfile=outfile)
    else:
        raise Exception("ERROR: >2 columns")
