#!python
# -*- coding: utf-8 -*-
# Copyright(c) Ryuichiro Nakato <rnakato@iqb.u-tokyo.ac.jp>
# All rights reserved.

import os
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from scipy.cluster.hierarchy import linkage, fcluster, leaves_list

def getbedGraph(file, label):
    d = pd.read_csv(file, sep="\t", header=None, index_col = [1])
    d = d.drop([0,2], axis=1)
    d.columns = [label]
    return d

def get_boundary(boundaryfile):
    boundary = pd.read_csv(boundaryfile, sep="\t", header=None)
    boundary.columns = ["chromosome", "start", "end"]
    return boundary

def get_samples(dirs, labels, chr, type, resolution):
    samples = pd.DataFrame()

    for i, dir in enumerate(dirs):
        bedgraph = f"{dir}/InsulationScore/{type}/{resolution}/Insulationscore.{chr}.500k.{resolution}.bedGraph"
        print(bedgraph)
        d = getbedGraph(bedgraph, labels[i])
        samples = pd.concat([samples, d], axis=1)

    return samples
    

def plot_insulation_score(samples, boundary, chr, chrlen, labels, odir):
    figsize_y = int(chrlen/2000000)
    plt.figure(figsize=(figsize_y, 3))
    for sample in labels:
        plt.plot(samples.index, samples[sample], label=sample)

    for row in boundary.itertuples():
        plt.hlines([0], row[2], row[3], "blue")

    plt.xlim(0, chrlen)
    step = 1000000
    plt.xticks(np.arange(0, int(chrlen/step)*step, step))
    plt.legend(bbox_to_anchor=(1.01, 1.0), loc='upper left')
    plt.tight_layout()
    plt.savefig(odir + "Insulation_score." + chr + ".pdf")
    plt.close()

    plt.figure(figsize=(figsize_y, 3))
    for sample in labels[1:]:
        plt.plot(samples.index, samples[sample] - samples[labels[0]], label=sample)

    for row in boundary.itertuples():
        plt.hlines([-.3], row[2], row[3], "blue")

    plt.xlim(0, chrlen)
    plt.xticks(np.arange(0, int(chrlen/step)*step, step))
    plt.legend(bbox_to_anchor=(1.01, 1.0), loc='upper left')
    plt.tight_layout()
    plt.savefig(odir + "Insulation_score.diff." + chr + ".pdf")
    plt.close()

def annotate_boundary(d, diffsamples, labels):
    d['status'] = 'unknown'
    thre = 0.13
    thre_insu = -0.13
    num_half = int(diffsamples.shape[1] * 0.5)

    for posi in d.index:
        row = diffsamples.loc[posi,:]
        if sum(row <= thre_insu) >= num_half:
            d.loc[posi, 'status'] = 'Gain'
        elif sum(row > thre) >= num_half:
            d.loc[posi, 'status'] = 'Loss'
        elif d.loc[posi, labels[0]] >= 0.8:  # Insulation score of the control sample
            d.loc[posi, 'status'] = 'Non-boundary'
        else:
            d.loc[posi, 'status'] = 'Robust'

    return d


def get_averagedIS_for_boundaries(samples, df, boundary):
    all = pd.DataFrame(index=samples)
    for row in boundary.itertuples():
        s=row[2]
        e=row[3]
        a = df[(df.index>=s) & (df.index<e)].mean(axis=0)
        all = pd.concat([all, a], axis=1)

    all = all.T
    all.index = boundary.index
    all = pd.concat([boundary, all], axis=1)
    all.index = boundary["start"]
    all = all[all.loc[:,samples].mean(axis=1)>0.3]

    return all

def calculate_mean_within_boundaries(df, start, end):
    return df[(df.index >= start) & (df.index < end)].mean()

def get_averaged_insulation_score_for_boundaries(samples, df, boundary):
    results = []
    
    for row in boundary.itertuples():
        start = row[2]
        end = row[3]
        averaged = calculate_mean_within_boundaries(df, start, end)
        results.append(averaged)

    all_df = pd.concat(results, axis=1).T
    all_df.index = boundary.index
    all_df = pd.concat([boundary, all_df], axis=1)
    all_df.index = boundary["start"]
    all_df = all_df[all_df.loc[:,samples].mean(axis=1) > 0.3]

    return all_df


def annotate_boundary_chromosome(chr, chrlen, samples, boundary, labels):

    plot_insulation_score(samples, boundary, chr, chrlen, labels, odir)

    df_boundary = get_averaged_insulation_score_for_boundaries(labels, samples, boundary)

    diffsamples = pd.DataFrame()
    for sample in labels[1:]:
        diff = df_boundary[sample] - df_boundary[labels[0]]
        diffsamples = pd.concat([diffsamples, diff], axis=1)

    diffsamples.columns = labels[1:]

    df_boundary = annotate_boundary(df_boundary, diffsamples, labels)

    return df_boundary, diffsamples


def plot_heatmap_diffsamples(df_boundary_all, diff_boundary_all, num_clusters):
    plt.figure()
    sns.clustermap(diff_boundary_all.corr(), cmap="bwr")
    plt.savefig(odir + "heatmap.diff.correlation.pdf", bbox_inches="tight")
    plt.close()

    linkage_matrix = linkage(diff_boundary_all, 'ward')
    cluster_labels = fcluster(linkage_matrix, num_clusters, criterion='maxclust')
    diff_boundary_all['cluster_labels'] = cluster_labels
    diff_boundary_all = diff_boundary_all.sort_values('cluster_labels')

    linkage_matrix = linkage(diff_boundary_all.T, 'ward')
    reordered_cols = leaves_list(linkage_matrix)
    diff_boundary_all = diff_boundary_all.iloc[:, reordered_cols]

    plt.figure()
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 12])

    ax0 = plt.subplot(gs[0])
    ax0.matshow(diff_boundary_all['cluster_labels'].values.reshape(-1, 1), cmap='Accent', aspect='auto')
    ax0.set_xticks([]) 
    ax0.set_yticks([]) 
    ax0.set_title("Cluster") 
    unique_clusters = np.unique(cluster_labels)
    for cluster_id in unique_clusters:
        indices = np.where(diff_boundary_all['cluster_labels'] == cluster_id)[0]
        center_index = indices.min() + len(indices) / 2
        ax0.text(-1, center_index, f"{cluster_id}", va='center', ha='right', color='black')

    ax1 = plt.subplot(gs[1])
    sns.heatmap(diff_boundary_all.drop('cluster_labels', axis=1), cmap='bwr_r', yticklabels=False, cbar_kws={'label': 'Value'}, ax=ax1)
    ax1.set_title("Heatmap (differential insulation score)") 
    plt.savefig(odir + "heatmap.diff.clustered.pdf", bbox_inches="tight")
    plt.close()
    
    df_boundary_all['cluster_labels'] = cluster_labels
    df_boundary_all = df_boundary_all.sort_values('cluster_labels')

    return df_boundary_all

def annotate_boundary_genome(boundary, dirs, labels, resolution, genometable):
    df_boundary_all = pd.DataFrame()
    diff_boundary_all = pd.DataFrame()

    for row in genometable.itertuples():
        chr = row[0]
        chrlen = row[1]

        if chr == "chrY" or chr == "chrM":
            continue
        print (chr)

        boundary_chr = boundary[boundary["chromosome"]==chr]
        samples = get_samples(dirs, labels, chr, type, resolution)
        df, df_diff = annotate_boundary_chromosome(chr, chrlen, samples, boundary_chr, labels)
        df_boundary_all = pd.concat([df_boundary_all, df], axis=0)
        diff_boundary_all = pd.concat([diff_boundary_all, df_diff], axis=0)

    return df_boundary_all, diff_boundary_all

if(__name__ == '__main__'):
    parser = argparse.ArgumentParser()
    tp = lambda x:list(map(str, x.split(':')))
    parser.add_argument("input",  help="<Input directory>:<label>", type=tp, nargs='*')
    parser.add_argument("--type", help="normalize type (default: SCALE)", type=str, default="SCALE")
    parser.add_argument("--boundary", help="Boundary file (BED format)", type=str)
    parser.add_argument("--gt", help="Genome table", type=str)
    parser.add_argument("-r", "--resolution", help="resolution (default: 25000)", type=int, default=25000)
    parser.add_argument("--ncluster", help="number of cluster (default: 4)", type=int, default=4)
    parser.add_argument("--odir", help="Output directory (default: output_boundary_clustering)", type=str, default="output_boundary_clustering/")
    args = parser.parse_args()

    dirs = []
    labels = []
    for input in args.input:
        dirs.append(input[0])
        if (len(input) >1):
            labels.append(input[1])
        else:
            labels.append("")

#    print(dirs)
#    print(labels)

    if len(dirs) == 0:
        print ("Error: specify input data (-i).")
        exit()

    if args.boundary is None:
        print ("Error: specify boundary file (--boundary).")
        exit()
    else:
        boundary = get_boundary(args.boundary)

    if args.gt is None:
        print ("Error: specify genome_table file (--gt).")
        exit()
    else:
        genometable = pd.read_csv(args.gt, delimiter='\t', index_col=[0], header=None)

    resolution = args.resolution
    type = args.type
    num_cluster = args.ncluster
    odir = args.odir + "/"
    os.makedirs(odir, exist_ok=True)

    df_boundary_all, diff_boundary_all = annotate_boundary_genome(boundary, dirs, labels, resolution, genometable)

    df_boundary_all = plot_heatmap_diffsamples(df_boundary_all, diff_boundary_all, num_cluster)

    df_boundary_all.to_csv(odir + "Annotated_boundaries.tsv", sep="\t", index=False)
    