#!python
# -*- coding: utf-8 -*-

import argparse
import matplotlib.pyplot as plt

from custardpy.HiCmodule import JuicerMatrix
from custardpy.plotHiCfeature_module import *

def get_samples(dirs, chr, type, resolution):
    samples = []
    for dir in dirs:
        observed = dir + "/Matrix/intrachromosomal/" + str(resolution) + "/observed."  + type + "." + chr + ".matrix.gz"
        eigen = dir + "/Eigen/" + str(resolution) + "/eigen."  + type + "." + chr + ".txt.gz"
        samples.append(JuicerMatrix("RPM", observed, resolution, eigenfile=eigen))
    return samples

def main():
    parser = argparse.ArgumentParser()
    tp = lambda x:list(map(str, x.split(':')))
    parser.add_argument("input",  help="<Input directory>:<label>", type=tp, nargs='*')
    parser.add_argument("-o", "--output", help="Output prefix", type=str, default="output")
    parser.add_argument("-c", "--chr", help="chromosome", type=str)
    parser.add_argument("--type", help="normalize type (default: SCALE)", type=str, default="SCALE")
    parser.add_argument("--distance", help="distance for DI (default: 500000)", type=int, default=500000)
    parser.add_argument("-r", "--resolution", help="resolution (default: 25000)", type=int, default=25000)
    parser.add_argument("-s", "--start", help="start bp (default: 0)", type=int, default=0)
    parser.add_argument("-e", "--end",   help="end bp (default: 1000000)", type=int, default=1000000)
    parser.add_argument("--multi",       help="plot MultiInsulation Score", action='store_true')
    parser.add_argument("--multidiff",   help="plot differential MultiInsulation Score", action='store_true')
    parser.add_argument("--compartment", help="plot Compartment (eigen)", action='store_true')
    parser.add_argument("--di",    help="plot Directionaly Index", action='store_true')
    parser.add_argument("--drf",   help="plot Directional Relative Frequency", action='store_true')
    parser.add_argument("--drf_right",  help="(with --drf) plot DirectionalRelativeFreq (Right)", action='store_true')
    parser.add_argument("--drf_left",   help="(with --drf) plot DirectionalRelativeFreq (Left)", action='store_true')
    parser.add_argument("--triangle_ratio_multi",   help="plot Triangle ratio multi", action='store_true')
    parser.add_argument("--output_logfc_matrix",   help="(with --triangle_ratio_multi) output logfoldchange matrix", action='store_true')
    parser.add_argument("-d", "--vizdistancemax", help="max distance in heatmap", type=int, default=0)
    parser.add_argument("--v4c",   help="plot virtual 4C from Hi-C data", action='store_true')
    parser.add_argument("--vmax", help="max value of color bar (default: 50)", type=int, default=50)
    parser.add_argument("--vmin", help="min value of color bar (default: 0)", type=int, default=0)
    parser.add_argument("--vmax_ratio", help="max value of color bar for logratio (default: 1)", type=int, default=1)
    parser.add_argument("--vmin_ratio", help="min value of color bar for logratio (default: -1)", type=int, default=-1)
    parser.add_argument("--anchor", help="(for --v4c) anchor point", type=int, default=500000)
    parser.add_argument("--xsize", help="xsize for figure (default: max(length/2M, 10))", type=int, default=0)

    args = parser.parse_args()
#    print(args)

    dirs = []
    labels = []
    for input in args.input:
        dirs.append(input[0])
        if (len(input) >1):
            labels.append(input[1])
        else:
            labels.append("")

    if len(dirs) == 0:
        print ("Error: specify input data (-i).")
        exit()
    if (args.drf or args.multidiff or args.triangle_ratio_multi):
        if len(dirs) < 2:
            print ("Error: --drf|--multidiff|--triangle_ratio_multi requires >= 2 samples.")
            exit(1)
    if args.chr is None:
        print ("Error: specify chromosome (-c).")
        exit()

    chr = args.chr
    resolution = args.resolution
    type = args.type

    figstart = args.start
    figend = args.end
    length = figend - figstart
    if (length <= 0):
        print ("Error: end < start.")
        exit(1)

    s = int(figstart / resolution)
    e = int(figend   / resolution)
    binnum = e - s
    vmax = args.vmax
    vmin = args.vmin

    print ("chr: " + chr + ", resolution: " + str(resolution) +  ", width: " + str(length) + ", " + str(binnum) + " bins.")

    samples = get_samples(dirs, chr, type, resolution)
    nsample = len(samples)

    nrow_heatmap = 2
    nrow_eigen = 1
    nrow_feature = 1
    colspan_plot = 24
    colspan_colorbar = 1
    colspan_legend = 6
    colspan_full = colspan_plot + colspan_legend

    ### Plot
    figsize_x = set_figsize_x(args.xsize, figstart, figend)
    if (args.multi or args.multidiff or args.v4c):
        nrow = nrow_heatmap + nrow_eigen + nsample * nrow_feature
#        figsize_y = 6 + nsample * nrow_feature
        figsize_y = nrow * 2
        plt.figure(figsize=(figsize_x, figsize_y))
    elif args.triangle_ratio_multi:
        nrow = nrow_heatmap + nrow_eigen + (nsample-1) * (nrow_heatmap + nrow_feature*2)
        figsize_y = nrow * 2
        plt.figure(figsize=(figsize_x, figsize_y))
    else:
        nrow = nrow_heatmap + nrow_eigen + nrow_feature * 2
        figsize_y = nrow * 2
#        figsize_y = 10
        plt.figure(figsize=(figsize_x, figsize_y))

    nrow_now = 0

    # Hi-C Map
    tadfile = dirs[0] + "/TAD/" + type + "/" + str(resolution) + "_blocks.bedpe"
    loopfile = dirs[0] + "/loops/" + type + "/merged_loops.bedpe"

    plot_HiC_Map(nrow, nrow_now, nrow_heatmap, samples[0], labels[0], chr,
                type, resolution, vmax, vmin, figstart, figend, args.vizdistancemax,
                colspan_plot, colspan_colorbar, colspan_full,
                tadfile=tadfile, loopfile=loopfile)
    nrow_now += nrow_heatmap

    # Compartment
    plot_PC1(nrow, nrow_now, nrow_eigen, samples[0], labels[0], 
             s, e, colspan_plot, colspan_full)
    nrow_now += nrow_eigen

    # Directional Frequency Ratio
    if (args.drf):  
        plot_directional_relative_frequency(samples, labels,  nrow, nrow_now, nrow_feature, 
                                            s, e, figstart, figend, resolution, 
                                            args.drf_right, args.drf_left,
                                            colspan_plot, colspan_colorbar, colspan_legend, colspan_full)
    # TriangleRatioMulti
    elif (args.triangle_ratio_multi):  
        plot_triangle_ratio_multi(samples, labels, nrow, nrow_now, nrow_heatmap, nrow_feature,
                                  s, e, figstart, figend, args.vizdistancemax, resolution,
                                  args.vmin_ratio, args.vmax_ratio,
                                  colspan_plot, colspan_colorbar, colspan_legend, colspan_full,
                                  args.output_logfc_matrix, args.output)
    # Directionality Index
    elif (args.di):  
        plot_directionality_index(samples, labels, nrow, nrow_now, nrow_feature,
                                 s, e, figstart, figend, args.distance,
                                 colspan_plot, colspan_colorbar, colspan_legend, colspan_full)
    # Compartment
    elif (args.compartment): 
        plot_compartment_heatmap(samples, labels, nrow, nrow_now, nrow_feature,
                                 s, e, figstart, figend,
                                 colspan_plot, colspan_colorbar, colspan_legend, colspan_full)
    # virtual 4c
    elif (args.v4c): 
        plot_virtual_4c(samples, labels, nrow, nrow_now, nrow_feature,
                        s, e, figstart, figend, args.anchor, resolution, vmin, vmax,
                        colspan_plot, colspan_full)
    # Multi Insulation score
    elif (args.multi):    
        plot_multi_insulation_score(samples, labels, nrow, nrow_now, nrow_feature,
                                    figstart, figend, s, e,
                                    colspan_plot, colspan_colorbar, colspan_full)
    # differential Multi Insulation score
    elif (args.multidiff):
        plot_differential_multi_insulation_score(samples, labels, nrow, nrow_now, nrow_feature,
                                                 figstart, figend, s, e,
                                                 colspan_plot, colspan_colorbar, colspan_full)
    # Single Insulation score
    else:                  
        plot_single_insulation_score(samples, labels, nrow, nrow_now, nrow_feature,
                                     figstart, figend, s, e,
                                     colspan_plot, colspan_colorbar, colspan_legend, colspan_full,
                                     heatmap=True, lineplot=True)

    plt.subplots_adjust(hspace=0.5)
    plt.savefig(args.output + ".pdf")

if(__name__ == '__main__'):
    main()
