#!python
import argparse
import pandas as pd
import numpy as np
import sys
import collections
import Levenshtein
# from leven import levenshtein
import jtutils

def readCL():
    parser = argparse.ArgumentParser()
    parser.add_argument("col",nargs="*")
    parser.add_argument("-f","--infile", default=sys.stdin)
    parser.add_argument("-n","--no_header",action="store_true")
    parser.add_argument("-d","--delimiter",default=",")
    parser.add_argument("-c","--cluster",action="store_true")
    parser.add_argument("--cluster_thresh",help="threshold to cluster strings", default=0.5, type=float)
    parser.add_argument('--cluster_match', nargs="*", help="""Force pairs of strings to cluster together. For example: --cluster_match "Scott Adams","Dilbert Blog" """)
    parser.add_argument("--correlation_threshold",default=0.3, type=float, help="When viewing correlation matrix of multiple columns, only display correlations of a minimum absolute value")
    args = parser.parse_args()

    if len(args.col) == 0:
        col1 = None
        col2 = None
    elif len(args.col) == 1:
        col1 = args.col[0]
        col2 = None
    elif len(args.col) == 2:
        col1 = args.col[0]
        col2 = args.col[1]


    return args.infile, col1, col2, args.no_header, args.delimiter, args.correlation_threshold, args.cluster, args.cluster_thresh, args.cluster_match


def str_is_float(var):
    try:
        f = float(var)
        return True
    except:
        return False


def one_col(col):
    col.value_counts().reset_index().to_csv(sys.stdout,header=["val","cnt"], index=False)


def two_col(col1,col2):
    vals1 = set(col1.unique())
    vals2 = set(col2.unique())
    v1 = str(len(vals1.difference(vals2)))
    v2 = str(len(vals1.intersection(vals2)))
    v3 = str(len(vals2.difference(vals1)))
    print("venn diagram breakdown")
    print("|a-b|: " + v1)
    print("|a^b|: " + v2)
    print("|b-a|: " + v3)
    print("----")
    if all(col1.apply(str_is_float)) and \
       all(col2.apply(str_is_float)):
        col1_float = col1.astype(float)
        col2_float = col2.astype(float)
        sd1 = col1_float.std()
        sd2 = col2_float.std()
        corr = (col1_float).corr(col2_float)
        print("Statistics:")
        print("sd1: {:f}".format(sd1))
        print("sd2: {:f}".format(sd2))
        print("R: {:f}".format(corr))
        print("----")
        reg_coeff1 = corr * sd2 / float(sd1)
        reg_coeff2 = corr * sd1 / float(sd2)
        mean1 = col1_float.mean()
        mean2 = col2_float.mean()
        intercept1 = mean2 - reg_coeff1 * mean1
        intercept2 = mean1 - reg_coeff2 * mean2
        print("Linear regression:")
        print("b = {}a + {}".format(reg_coeff1, intercept1))
        print("a = {}b + {}".format(reg_coeff2, intercept2))
        print("---")
    if len(vals1) < 15 and len(vals2) < 15:
        print(pd.crosstab(col1,col2))


def score_cluster(cluster, val, similarity_fn):
    scores = [similarity_fn(val,s) for s in cluster[:3]] #only consider first 3 elements for speed
    # print("debug")
    # print(cluster)
    # print(val)
    # print(scores)
    if not scores:
        return
    else:
        return sum(scores) / float(len(scores))

def quick_cluster(strings, similarity_fn, threshold=0.5):
    #fast + (very) low-quality clustering algorithm with the following features
    #1) runs in < n**2 time
    #2) uses a similarity function (instead of a column of features)
    #use case is classifying a list of strings to gather misspellings and other similar words together
    clusters = []
    for s in strings:
        cluster_scores = [score_cluster(c, s, similarity_fn) for c in clusters]
        if not cluster_scores or max(cluster_scores) < threshold:
            clusters.append([s])
        else:
            idx = jtutils.argmax(cluster_scores)
            clusters[idx].append(s)
        if len(clusters) * len(strings) > 1000000:
            #threshold too high -- decrease it:
            print("Threshold too high -> will take too long to compute. Reducing threshold: " + str(threshold - 0.01))
            return quick_cluster(strings, similarity_fn, threshold - 0.01)
    return clusters

def post_process_clusters(clusters, similarity_fn, force_matches):
    # force_matches = [["fefe.de", "blog.fefe.de"],
    #                  ["daviddfriedman.blogspot.com", "david friedman"],
    #                  ["Scott Adams", "Dilbert Blog"],
    #                  ["Shtetl Optimized","scottaaronson.com"]
    # ]
    if force_matches:
        for match1, match2 in [f.split(",") for f in force_matches]:
            best_cluster1 = jtutils.argmax(score_cluster(c, match1, similarity_fn) if c else 0 for c in clusters)
            best_cluster2 = jtutils.argmax(score_cluster(c, match2, similarity_fn) if c else 0 for c in clusters)
            clusters[best_cluster1] = clusters[best_cluster1] + clusters[best_cluster2]
            clusters[best_cluster2] = None
    return [c for c in clusters if c]

# def simple_similarity_fn(str1, str2):
#     #1 for identical strings
#     #0 for completely different strings
#     edit_distance = levenshtein(str1.lower(), str2.lower())
#     similarity = 1 - (edit_distance / max(float(len(str1)), float(len(str2))))
#     return max(similarity,0)

#migrate from leven to python-Levenshtein library to support python >=3.12.6
def simple_similarity_fn(str1, str2):
    # 1 for identical strings, 
    # 0 for completely different strings
    edit_distance = Levenshtein.distance(str1.lower(), str2.lower())
    similarity = 1 - (edit_distance / max(float(len(str1)), float(len(str2))))
    return max(similarity, 0)

def cluster_column(col1, threshold, force_matches):
    words = [str(w) for w in col1.values if str(w).strip()]
    clusters = quick_cluster(words, simple_similarity_fn,threshold)
    clusters = post_process_clusters(clusters, simple_similarity_fn, force_matches)
    for c in sorted(clusters, key=lambda x: len(x), reverse=True):
        print(collections.Counter(c))


if __name__ == "__main__":
    infile, col1, col2, no_header, delimiter, correlation_threshold, cluster, cluster_thresh, cluster_match = readCL()
    jtutils.fix_broken_pipe()
    if not no_header:
        dat = jtutils.pd_read_csv(infile, dtype="object", delimiter=delimiter)
        dat = dat.fillna('') #by default empty strings are replaced with NaN
        hdr = ",".join(dat.columns)
        sys.stderr.write("WARNING: ptable using first line of input, \"{hdr}\", as header. If file doesn't have a header use -n option.".format(**vars()) + "\n")
    else:
        dat = jtutils.pd_read_csv(infile, dtype="object", header=None, delimiter=delimiter)
        dat = dat.fillna('') #be default empty strings are replaced with NaN
    # if len(dat.columns) == 2:
    #     for c in dat.columns:
    #         print(dat[c].unique())
    # print((dat.iloc[:,0].unique()).difference(set(dat.iloc[:,1])))
    if len(dat.columns) == 1 and cluster:
        cluster_column(dat.iloc[:,0], cluster_thresh, cluster_match)
    elif len(dat.columns) == 1:
        one_col(dat.iloc[:,0])
    elif col1 and not col2:
        one_col(dat[col1])
    elif len(dat.columns) == 2:
        two_col(dat.iloc[:,0], dat.iloc[:,1])
        # dat.iloc[:,0].value_counts().reset_index().to_csv(sys.stdout,header=["val","cnt"], index=False)
    elif col1 and col2:
        two_col(dat[col1], dat[col2])
    else:
        #print out a correlation table for the numeric variables
        dat = dat.apply(pd.to_numeric, errors='coerce')
        def non_numeric_col(c):
            return (c.isnull().sum() / float(len(c)) > 0.8 )
        non_numerics = [col for col,val in zip(dat.columns, dat.apply(non_numeric_col)) if val]

        # print(list(zip(dat.columns, dat.apply(lambda x: x.isnull().all()))))
        # print(non_numerics)
        dat = dat.drop(non_numerics, axis = 1)

        dat = dat.corr()

        dat = dat.applymap(lambda x: x if abs(x) > correlation_threshold else "")

        dat.to_csv(sys.stdout)
