#!/usr/bin/env python
import argparse
import pandas as pd
import sys
import pcsv.pindent

def readCL():
    usagestr = "%prog"
    parser = argparse.ArgumentParser()
    parser.add_argument("-f","--infile",default=sys.stdin)
    parser.add_argument("-g","--group_by",help="csv list of columns to aggregate by", default=[])
    parser.add_argument("-a","--agg_fns",help="function to aggregate by. Options: 'sum','mean','max','strmax','min','strmin','std','cnt','pctile_73','val0','val1'", nargs="*", default=[])
    parser.add_argument("-c","--agg_cols",help="columns to aggregate")
    parser.add_argument("-d","--delimiter",default=",")
    parser.add_argument("--lam", nargs="*", default=[], help="Instead of using --agg_cols and --agg_fns specify a custom aggregation function. The function takes a dataframe of all the rows in the group and returns a value. For example, to find the rows with TMAX closest to 15: --lam 'return x.apply(lambda r: (r[\"TMAX\"] - 15)**2,axis=1).min()")
    parser.add_argument("--append",help="Append aggregate value to *all* existing rows, instead of creating a new table with one row per group. Useful to create a variable for each row indicating how it differs from the aggregate", action="store_true")
    parser.add_argument("--filter",help="Same as append, but drop all rows except those whose value matches the aggregate value.",action="store_true")
    args = parser.parse_args()

    #the duplicate pivot uses two primitive functions, val0 and val1
    if "dup" in args.agg_fns:
        args.agg_fns.remove("dup")
        args.agg_fns.append("val0")
        args.agg_fns.append("val1")

    if args.group_by:
        args.group_by = args.group_by.split(",")

    if args.agg_cols:
        args.agg_cols = args.agg_cols.split(",")

    for i,fn_string in enumerate(args.lam):
        fn_string = "def lam{i}(x): {fn_string}".format(**vars())
        exec(pcsv.pindent(fn_string))

    lam = []
    for i,_ in enumerate(args.lam):
        lam.append(vars()["lam"+str(i)])

    return args.infile, args.group_by, args.agg_fns, args.agg_cols, lam, args.append, args.filter, args.delimiter

def sum_fn(array):
    return sum([float(x) for x in array])
def mean_fn(array):
    return sum([float(x) for x in array]) / len([float(x) for x in array])
def max_fn(array):
    return max([float(x) for x in array])
def min_fn(array):
    return min([float(x) for x in array])
def strmax_fn(array):
    return max([str(x) for x in array])
def strmin_fn(array):
    return min([str(x) for x in array])
def std_fn(array):
    import numpy
    return numpy.std([float(x) for x in array])
def median_fn(array):
    import numpy
    return numpy.median([float(x) for x in array])
def pctile_fn(array, pctile):
    #pctile in range [0,100]
    import numpy
    return numpy.percentile([float(x) for x in array], pctile)
def cnt_fn(array):
    return len(array)
def group_concat_fn(array):
    return str(array.tolist())
def val0_fn(array):
    if len(array)>0:
        return list(array)[0]
    else:
        return ""
def val1_fn(array):
    if len(array)>1:
        return list(array)[1]
    else:
        return ""

def aggstr_to_fn(agg_str):
    lookup = {"sum":sum_fn,
              "mean":mean_fn,
              "max":max_fn,
              "strmax":strmax_fn,
              "min":min_fn,
              "strmin":strmin_fn,
              "median":median_fn,
              "std":std_fn,
              "cnt":cnt_fn,
              "group_concat":group_concat_fn,
              "val0":val0_fn,
              "val1":val1_fn}

    if agg_str in lookup:
        return lookup[agg_str]
    elif agg_str.startswith("pctile"):
        #agg_str = pctile_95 or pctile_5
        def fn(array):
            return pctile_fn(array, pctile)
        #rename function because function name is printed later
        fn.__name__ = "pctile_"+str(pctile) + "_fn"
        return fn
    else:
        # sys.stderr.write("WARNING: interpreting -a argument as a python lambda\n")
        # return eval(agg_str)
        raise Exception("ERROR: unknown aggregate string")

if __name__ == "__main__":
    infile, group_by_cols, agg_str_list, agg_cols, lambda_fn_list, append, do_filter, delimiter = readCL()
    df = pd.read_csv(infile, delimiter=delimiter)

    #special treatment for cnt
    if not agg_cols and group_by_cols and agg_str_list == ["cnt"]:
        df_groups = df.groupby(group_by_cols)
        df_out = df_groups.size()
        # print dir(df_out)
        df_out = pd.DataFrame({"_".join(group_by_cols) + '_cnt' : df.groupby( group_by_cols ).size()}).reset_index()
        # df_out = df_out.rename(columns={'$a': 'a', '$b': 'b'})
        # print pd.DataFrame(df_out)
        # df_out.name = "cnt"
        # df_out.to_csv(sys.stdout, header=True)
    else:
        if not agg_cols: #aggregate all columns except group_by columns
            agg_cols = [d for d in df.columns.values if not d in group_by_cols]

        agg_fn_list = [aggstr_to_fn(a) for a in agg_str_list]


        if lambda_fn_list:
            #use custom lambdas instead of agg columns and agg functions
            df_groups = df.groupby(group_by_cols, as_index=True)
            #apply as here:
            #http://stackoverflow.com/questions/15259547/conditional-sums-for-pandas-aggregate
            def apply_fn(r):
                return pd.Series(dict([("lambda_"+str(i),lam(r)) for i,lam in enumerate(lambda_fn_list)]))
            df_out = df_groups.apply(apply_fn).reset_index()
        elif group_by_cols:
            #apply agg functions groupwise
            df_groups = df.groupby(group_by_cols, as_index=False)
            agg_dict = dict((c, agg_fn_list) for c in agg_cols)
            df_out = df_groups.agg(agg_dict)
            #rename multiindex
            df_out.columns = [c[0] + "_" + (c[1].split("_",1))[0] if c[1] else c[0] for c in df_out.columns]
        else:
            #no group_by selected -> apply lambdas to the entire df
            df_out = pd.DataFrame()
            for c in agg_cols:
                for agg_fn in agg_fn_list:
                    agg_str = agg_fn.__name__.rsplit("_",1)[0]
                    df_out.loc["0",c+"_"+agg_str] = df.loc[:,[c]].apply(agg_fn).values[0]

    if append or do_filter:
        output_cols = df_out.columns
        if not group_by_cols:
            for i,c in enumerate(output_cols):
                df[c] = df_out.iloc[0,i]
            df_out = df
        else:
            df_out = df.merge(df_out, on=group_by_cols, how='left') #, suffixes = ('', '_' + agg_str_list[0]))

        if do_filter:
            def filter_fn(r):
                return [r[c] for c in group_by_cols + agg_cols] == [r[c] for c in output_cols]
            df_out = df_out[df_out.apply(filter_fn,axis=1)]

    df_out.to_csv(sys.stdout, index=False)
