#!python
# -*- coding: utf-8 -*-
# Copyright(c) Ryuichiro Nakato <rnakato@iqb.u-tokyo.ac.jp>
# All rights reserved.

import os
import argparse
import numpy as np
import pandas as pd
from scipy import stats
import matplotlib.pyplot as plt
from custardpy.loadData import *
from custardpy.HiCmodule import JuicerMatrix
from custardpy.PlotModule import *
from custardpy.generateCmap import *
from custardpy.DirectionalRelativeFreq import *

cm = generate_cmap(['#1310cc', '#FFFFFF', '#d10a3f'])
Mega = 1000000

def getCI(mat, alpha):
    mat = mat.fillna(0)
    if mat.size == 0:
        print("Warning: matrix is empty after replacing NaNs with zeros.")
        return None
        
    valid_values = mat[~np.isinf(mat)]
    
    mean = valid_values.mean()
    sem = valid_values.std() / np.sqrt(valid_values.count())
    
    # t分布を仮定した場合の信頼区間
    ci = stats.t.interval(confidence = alpha, # 信頼区間
                          df = mat.count()-1, # 自由度
                          loc = mean,         # 標本平均
                          scale = sem)        # 標準誤差
    return ci

def getBedfromList(_list):
    s = _list[0]
    e = s + resolution
    d = []
    for i in range(len(_list)-1):
        if e == _list[i+1]:
            e = _list[i+1] + resolution
        else:
            d.append([s,e])
            s = _list[i+1]
            e = s + resolution
    d.append([s,e])
        
    d = pd.DataFrame(d, columns=["start", "end"])
    return d

def addmaxval(bed, mat):
    bed["maxval"] = 0.0
    bed["maxval_abs"] = 0.0
    for index, data in bed.iterrows():
        s = data["start"]
        e = data["end"]
        mean = mat.loc[:,s:e].mean()
        temp = max(enumerate(abs(mean)), key = lambda x:x[1])
        maxidx = temp[0]
        bed.at[index, "maxval_abs"] = temp[1]
        bed.at[index, "maxval"] = mean.iloc[maxidx]
    return bed


def getDRFMatrix(Combined, labels, labels_treated, labels_control, startdistance=0, distance=2*Mega):
    Matrix = np.vstack([
        DirectionalRelativeFreq(sample, resolution, startdistance=startdistance, distance=distance).getarraydiff() 
        for sample in Combined
        ])

    matrix = pd.DataFrame(data=Matrix, index=labels[1:], columns=range(0, Matrix.shape[1]*resolution, resolution))

    mat_treated = matrix.loc[labels_treated]
    if len(labels_control) > 0:
        mat_control = matrix.loc[labels_control] 
    else:
        mat_control = pd.DataFrame()
    
    return mat_treated, mat_control


def getDifferentialDRFregions(mat_treated, ci_treated, ci_control, thre):
    if ci_control is None:
        diff_both = np.full(mat_treated.shape[1], True)
    else:
        diff_minus = ci_control[0] - ci_treated[1] > 0  # treatedが負にふれている場合
        diff_plus  = ci_treated[0] - ci_control[1]  > 0  # treatedが正にふれている場合
        diff_both  = np.logical_or(diff_minus, diff_plus) 

    ave_over_thre = np.logical_or(mat_treated.mean() > thre, mat_treated.mean() < -thre)
    diff_ave_over = np.logical_and(diff_both, ave_over_thre)
    
    differentialDRFregions = list(mat_treated.columns[diff_ave_over])
    
    if len(differentialDRFregions) > 0:
        differentialDRFregions = getBedfromList(differentialDRFregions)
    else:
        differentialDRFregions = pd.DataFrame(columns=["start", "end"])
    
    return differentialDRFregions


def plot_DRFregion(mat_treated, mat_control, thre, xstart, xend, odir):
    if mat_control.empty:
        ci_control = None
    else:
        ci_control = getCI(mat_control, 0.99)
    ci_treated = getCI(mat_treated, 0.99)

    differentialDRFregions = getDifferentialDRFregions(mat_treated, ci_treated, ci_control, thre)

    figsize_x = max(int((xend-xstart)/2000000), 10)
    plt.figure(figsize=(figsize_x, 3))

    plt.plot(mat_treated.mean(), label="Treated", color="blue")
    plt.fill_between(mat_treated.columns, ci_treated[0], ci_treated[1], facecolor="blue", alpha=0.2)

    if not mat_control.empty:
       plt.plot(mat_control.mean(), label="Control", color="black")
       plt.fill_between(mat_control.columns, ci_control[0], ci_control[1], facecolor="black", alpha=0.2)

    p = plt.hlines([thre,-thre], xstart, xend, "red", linestyles='dashed')

    for row in differentialDRFregions.itertuples():
        plt.hlines(-1.5, row[1], row[2], "purple") 
    
    if (xend > Mega*10):
        pltxticks_mega_subplot2grid(xstart, xend, xstart, xend)
    else:
        nxticks = max(int((xend - xstart)/Mega), 10)
        pltxticks_subplot2grid(xstart, xend, xstart, xend, nxticks)

    plt.xlim(xstart, xend)
    plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')

    plt.tight_layout()
    
    prefix = odir + '/DRFdiff.' + chr + '.thre' + str(thre)
    plt.savefig(prefix + '.pdf')

    differentialDRFregions = addmaxval(differentialDRFregions, mat_treated)
    
    differentialDRFregions["chromosome"] = chr
    differentialDRFregions = differentialDRFregions.reindex(columns=['chromosome', 'start', 'end', "maxval", "maxval_abs"])

    return differentialDRFregions

def loadJuicerData(dirname, type, resolution, *, normalizetype="RPM"):
    return JuicerMatrix(normalizetype,
                        f"{dirname}/Matrix/intrachromosomal/{resolution}/observed.{type}.{chr}.matrix.gz", resolution)

def getCombined(samplelist, labels, chr, norm, resolution):
    Jdata = {}
    for i, dirname in enumerate(samplelist):
        Jdata[labels[i]] = loadJuicerData(dirname, norm, resolution)

    alist = []
    for sample in labels:
        alist.append(Jdata[sample])

    Combined = make3dmatrixRatio(alist)
    
    return Combined

if(__name__ == '__main__'):
    parser = argparse.ArgumentParser()
    tp = lambda x:list(map(str, x.split(':')))
    tp_c = lambda x:list(map(str, x.split(',')))
    parser.add_argument("input",  help="Input sample '<Input directory>:<label>'", type=tp, nargs='*')
    parser.add_argument("-c","--control",  help="Labels of negative control samples (separated by ',')", type=tp_c)
    parser.add_argument("--type", help="normalize type (default: SCALE)", type=str, default="SCALE")
    parser.add_argument("--gt", help="Genome table", type=str)
    parser.add_argument("--thre", help="threshold of differential DRF (default: 0.7)", type=float, default=0.7)
    parser.add_argument("-r", "--resolution", help="resolution (default: 25000)", type=int, default=25000)
    parser.add_argument("--distance_min", help="minimum distance of DRF (default: 500000)", type=int, default=500000)
    parser.add_argument("--distance_max", help="maximum distance of DRF (default: 2000000)", type=int, default=2000000)
    parser.add_argument("--odir", help="Output directory (default: diffDRFregions)", type=str, default="diffDRFregions/")
    args = parser.parse_args()

    resolution = args.resolution
    type = args.type
    thre = args.thre

    if args.gt is None:
        print ("Error: specify genome_table file (--gt).")
        exit()
    else:
        genometable = pd.read_csv(args.gt, delimiter='\t', index_col=[0], header=None)

    samplelist = []
    labels = []
    for input in args.input:
        samplelist.append(input[0])
        if (len(input) >1):
            labels.append(input[1])
        else:
            labels.append("")

    labels_first = labels[0]
    labels_control = args.control
    if labels_control is None:
        labels_control = []
    labels_treated = list(set(labels[1:]) - set(labels_control))
    
    print ("1st sample: " + labels_first)
    print("Treated samples: " + ', '.join(map(str, labels_treated)))
    print("Negative control samples: " + ', '.join(map(str, labels_control)))

    if len(samplelist) == 0:
        print ("Error: specify input data.")
        exit()

    odir = args.odir + "/"
    os.makedirs(odir, exist_ok=True)
    
    differentialDRFregions_all = pd.DataFrame()
    for row in genometable.itertuples():
        chr = row[0]
        chrlen = row[1]

        if chr == "chrY" or chr == "chrM":
            continue
        print (chr)
        
        Combined = getCombined(samplelist, labels, chr, type, resolution)
        mat_treated, mat_control = getDRFMatrix(Combined, labels, labels_treated, labels_control, startdistance=args.distance_min, distance=args.distance_max)
        df = plot_DRFregion(mat_treated, mat_control, thre, 0, chrlen, args.odir)
        differentialDRFregions_all = pd.concat([differentialDRFregions_all, df], axis=0)

    prefix = odir + '/DifferentialDRFregions.thre' + str(thre)
    differentialDRFregions_all.columns = ["chromosome", "start", "end", "DRF maxval", "DRF maxval_abs"]
    differentialDRFregions_all.to_csv(prefix + '.tsv', sep="\t", header=True, index=False)
