#!python
# -*- coding: utf-8 -*-

import argparse
import matplotlib.pyplot as plt
import seaborn as sns

from custardpy.HiCmodule import JuicerMatrix
from custardpy.DirectionalityIndex import getDirectionalityIndexOfMultiSample
from custardpy.InsulationScore import getInsulationScoreOfMultiSample
from custardpy.generateCmap import *
from custardpy.loadData import *
from custardpy.PlotModule import *
from custardpy.DirectionalRelativeFreq import *

def get_samples(dirs, chr, type, resolution):
    samples = []
    for dir in dirs:
        observed = dir + "/Matrix/intrachromosomal/" + str(resolution) + "/observed."  + type + "." + chr + ".matrix.gz"
        eigen = dir + "/Eigen/" + str(resolution) + "/eigen."  + type + "." + chr + ".txt.gz"
        samples.append(JuicerMatrix("RPM", observed, resolution, eigenfile=eigen))
    return samples

def set_figsize_x(xsize, figstart, figend):
    if xsize == 0:
        figsize_x = max(int((figend-figstart)/2000000), 10)
    else:
        figsize_x = xsize
    return figsize_x

def plot_HiC_Map(nrow, nrow_now, nrow_heatmap, sample, label, dirname, 
                 type, resolution, vmax, vmin, figstart, figend, distance_max,
                 colspan_plot, colspan_colorbar, colspan_full):
                 
    # load TADs
    tadfile = dirname + "/TAD/" + type + "/" + str(resolution) + "_blocks.bedpe"
    print(tadfile)
    tads = loadTADs(tadfile, chr, start=figstart, end=figend)

    # load loops
    loopfile = dirname + "/loops/" + type + "/merged_loops.bedpe"
    print(loopfile)
    loops = loadloops(loopfile, chr, start=figstart, end=figend)

    # plot Hi-C
    heatmap_ax  = plt.subplot2grid((nrow, colspan_full), (nrow_now, 0), 
                                    rowspan=nrow_heatmap, colspan=colspan_plot)
    colorbar_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, colspan_plot +1), 
                                    rowspan=nrow_heatmap, colspan=colspan_colorbar)
    drawHeatmapTriangle_subplot2grid(sample.getmatrix(), resolution, figstart=figstart, figend=figend,
                                     tads=tads, loops=loops, vmax=vmax, vmin=vmin, distance_max=distance_max, 
                                     label="Contact map (" + label + ")", 
                                     xticks=True, heatmap_ax=heatmap_ax, colorbar_ax=colorbar_ax)

def plot_PC1(nrow, nrow_now, nrow_eigen, sample, label, s, e, colspan_plot, colspan_full):
    plt.subplot2grid((nrow, colspan_full), (nrow_now, 0), rowspan=nrow_eigen, colspan=colspan_plot)
    plt.plot(sample.getEigen(), color="black")
    plt.xlim([s,e])
    xtickoff_ax()

    plt.title("Compartment PC1 (" + label + ")")

def plot_legend_in_subplot2grid(ax_origin, ax_legend):
    ax_legend.axis('off')
    _lines, _labels = ax_origin.get_legend_handles_labels()
    ax_legend.legend(_lines, _labels, loc='center')

def plot_xy_axis_and_title_of_feature_heatmap(ax, labels, title):
    xtickoff_ax(ax=ax)
    ax.set_yticks(np.arange(len(labels)), labels)
    ax.set_title(title)

def get_drf_array(sample, resolution, drf_right, drf_left):
    drf = DirectionalRelativeFreq(sample, resolution)
    if drf_right:
        return drf.getarrayplus()
    elif drf_left:
        return drf.getarrayminus()
    else:
        return drf.getarraydiff()

def get_drf_matrix(samples, resolution, drf_right, drf_left, *, smooth_median_filter=3):
    EnrichMatrices = make3dmatrixRatio(samples, smooth_median_filter)
    arrays = [get_drf_array(sample, resolution, drf_right, drf_left) for sample in EnrichMatrices]
    Matrix = np.vstack(arrays)

    return Matrix

def plot_directional_relative_frequency(samples, labels,  nrow, nrow_now, nrow_feature, 
                                        s, e, figstart, figend, resolution, 
                                        drf_right, drf_left,
                                        colspan_plot, colspan_colorbar, colspan_legend, colspan_full):
    DRFMatrix = get_drf_matrix(samples, resolution, drf_right, drf_left)

    # DRF heatmap
    heatmap_ax  = plt.subplot2grid((nrow, colspan_full), (nrow_now, 0), rowspan=nrow_feature, colspan=colspan_plot)
    colorbar_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, colspan_plot +1), rowspan=nrow_feature, colspan=colspan_colorbar)
    
    img = heatmap_ax.imshow(DRFMatrix[:,s:e], cmap=generate_cmap(['#1310cc', '#FFFFFF', '#d10a3f']), aspect="auto")
    plot_xy_axis_and_title_of_feature_heatmap(heatmap_ax, labels[1:], "Directional Relative Frequency")

    plt.colorbar(img, cax=colorbar_ax)

    nrow_now += nrow_feature
    
    # DRF barplot 
    heatmap_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, 0), rowspan=nrow_feature, colspan=colspan_plot)
    legend_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, colspan_plot), rowspan=nrow_feature, colspan=colspan_legend)
    for i, sample in enumerate(DRFMatrix):
        heatmap_ax.plot(sample, label=labels[i+1])

    heatmap_ax.set_xlim([s, e])
    pltxticks_subplot2grid(s, e, figstart, figend, 10, ax=heatmap_ax)
    plot_legend_in_subplot2grid(heatmap_ax, legend_ax)


def plot_triangle_ratio_multi(samples, labels, nrow, nrow_now, nrow_heatmap, nrow_feature,
                              s, e, figstart, figend, distance_max, resolution, 
                              vmin, vmax, vmin_ratio, vmax_ratio,
                              colspan_plot, colspan_colorbar, colspan_legend, colspan_full,
                              *, smooth_median_filter=3):

    EnrichMatrices = make3dmatrixRatio(samples, smooth_median_filter)

    for i, sample in enumerate(EnrichMatrices):
        # Hi-C logratio Map
        heatmap_ax  = plt.subplot2grid((nrow, colspan_full), (nrow_now, 0), rowspan=nrow_heatmap, colspan=colspan_plot)
        colorbar_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, colspan_plot +1), rowspan=nrow_heatmap, colspan=colspan_colorbar)

        drawHeatmapTriangle_subplot2grid(sample, resolution,
                                         figstart=figstart, figend=figend, 
                                         vmax=vmax_ratio, vmin=vmin_ratio,
                                         cmap=generate_cmap(['#1310cc', '#FFFFFF', '#d10a3f']),
                                         distance_max=distance_max,
                                         label=labels[i+1], xticks=True,
                                         logratio=True, control_label=labels[0],
                                         heatmap_ax=heatmap_ax, colorbar_ax=colorbar_ax)
        nrow_now += nrow_heatmap

        # DRF barplot (left & right)
        heatmap_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, 0), rowspan=nrow_feature, colspan=colspan_plot)
        legend_ax  = plt.subplot2grid((nrow, colspan_full), (nrow_now, colspan_plot +1), rowspan=nrow_feature, colspan=colspan_legend)

        drf = DirectionalRelativeFreq(sample, resolution)

        heatmap_ax.plot(drf.getarrayplus(), label="Right")
        heatmap_ax.plot(drf.getarrayminus(), label="Left")
        heatmap_ax.set_xlim([s,e])
        xtickoff_ax(ax=heatmap_ax)
        plot_legend_in_subplot2grid(heatmap_ax, legend_ax)

        nrow_now += nrow_feature

        # DRF histogram
        heatmap_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, 0), rowspan=nrow_feature, colspan=colspan_plot)
        
        diff = drf.getarraydiff()
        heatmap_ax.bar(range(len(diff)), diff)
        heatmap_ax.set_xlim([s,e])
        xtickoff_ax(ax=heatmap_ax)
        heatmap_ax.set_title("Directional Relative Frequency (Right - Left)")

        nrow_now += nrow_feature

#    plt.tight_layout()

def plot_directionality_index(samples, labels, nrow, nrow_now, nrow_feature,
                              s, e, figstart, figend, distance,
                              colspan_plot, colspan_colorbar, colspan_legend, colspan_full):

    vDI = getDirectionalityIndexOfMultiSample(samples, labels, distance=distance)
    if len(samples) == 1:
        vDI = vDI.reshape((1, -1))

    # Directionality Index heatmap
    heatmap_ax  = plt.subplot2grid((nrow, colspan_full), (nrow_now, 0), rowspan=nrow_feature, colspan=colspan_plot)
    colorbar_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, colspan_plot +1), rowspan=nrow_feature, colspan=colspan_colorbar)
    
    img = heatmap_ax.imshow(vDI[:,s:e], #clim=(-1000, 1000),
                            cmap=generate_cmap(['#1310cc', '#FFFFFF', '#d10a3f']),
                            aspect="auto")
    plot_xy_axis_and_title_of_feature_heatmap(heatmap_ax, labels, "Directionality Index")

    plt.colorbar(img, cax=colorbar_ax)

    nrow_now += nrow_feature

    # Directionality Index barplot 
    heatmap_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, 0), rowspan=nrow_feature, colspan=colspan_plot)
    legend_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, colspan_plot), rowspan=nrow_feature, colspan=colspan_legend)

    for i, sample in enumerate(samples):
        heatmap_ax.plot(vDI[i], label=labels[i])
    heatmap_ax.set_xlim([s, e])
    heatmap_ax.set_ylim([-1000, 1000])
    pltxticks_subplot2grid(s, e, figstart, figend, 10, ax=heatmap_ax)

    plot_legend_in_subplot2grid(heatmap_ax, legend_ax)

def plot_virtual_4c(samples, labels, nrow, nrow_now, nrow_feature,
                    s, e, figstart, figend, anchor, resolution, vmin, vmax,
                    colspan_plot, colspan_full):
    a = int(anchor/resolution)
    
    for i, sample in enumerate(samples):
        heatmap_ax = plt.subplot2grid((nrow, colspan_full), (i + nrow_now, 0), rowspan=nrow_feature, colspan=colspan_plot)

        # Dataframeが値を持っているので、s-eではなくfigstart-figendになってしまう
        # それを避けるためvaluesをつける
        heatmap_ax.plot(sample.getmatrix().iloc[a].values, color="r")
        heatmap_ax.set_xlim([s, e])
        heatmap_ax.set_ylim(vmin, vmax)
        heatmap_ax.set_title(labels[i])

        if s < a and a < e:
            arrow_x = a
            arrow_y = np.sin(arrow_x)
            heatmap_ax.axvline(x=arrow_x, color="black", linestyle="--", linewidth=0.5)
            plt.annotate("", xy=(arrow_x, arrow_y), xytext=(arrow_x, arrow_y - 0.5),
                         arrowprops=dict(facecolor="black", edgecolor="black"))

        pltxticks_subplot2grid(s, e, figstart, figend, 10, ax=heatmap_ax)
            
def plot_compartment_heatmap(samples, labels, nrow, nrow_now, nrow_feature,
                             s, e, figstart, figend,
                             colspan_plot, colspan_colorbar, colspan_legend, colspan_full):
    for i, sample in enumerate(samples):
        if i==0: Matrix = sample.getEigen()
        else:    Matrix = np.vstack((Matrix, sample.getEigen()))

    if len(samples) == 1:
        Matrix = Matrix.reshape((1, -1))

    # PC1 heatmap
    heatmap_ax  = plt.subplot2grid((nrow, colspan_full), (nrow_now, 0), rowspan=nrow_feature, colspan=colspan_plot)
    colorbar_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, colspan_plot +1), rowspan=nrow_feature, colspan=colspan_colorbar)
    img = heatmap_ax.imshow(Matrix[:,s:e], clim=(-0.05, 0.05),
                            cmap=generate_cmap(['#1310cc', '#FFFFFF', '#d10a3f']),
                            aspect="auto")
    plot_xy_axis_and_title_of_feature_heatmap(heatmap_ax, labels, "Compartment PC1")

    plt.colorbar(img, cax=colorbar_ax)

    nrow_now += nrow_feature
    
    # PC1 barplot 
    heatmap_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, 0), rowspan=nrow_feature, colspan=colspan_plot)
    legend_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, colspan_plot), rowspan=nrow_feature, colspan=colspan_legend)

    for i, sample in enumerate(samples):
        heatmap_ax.plot(sample.getEigen(), label=labels[i])
    heatmap_ax.set_xlim([s, e])
    heatmap_ax.set_ylim([-0.05, 0.05])
    pltxticks_subplot2grid(s, e, figstart, figend, 10, ax=heatmap_ax)
    
    plot_legend_in_subplot2grid(heatmap_ax, legend_ax)

def plot_differential_multi_insulation_score(samples, labels, nrow, nrow_now, nrow_feature,
                                                 figstart, figend, s, e,
                                                 colspan_plot, colspan_colorbar, colspan_full):
    for i, sample in enumerate(samples):
        if i == 0:
            MIref = sample.getMultiInsulationScore()
            continue

        MI = sample.getMultiInsulationScore()

        heatmap_ax  = plt.subplot2grid((nrow, colspan_full), (i-1 + nrow_now, 0), rowspan=nrow_feature, colspan=colspan_plot)
        colorbar_ax = plt.subplot2grid((nrow, colspan_full), (i-1 + nrow_now, colspan_plot +1), rowspan=nrow_feature, colspan=colspan_colorbar)

        img = heatmap_ax.imshow(MI.T.iloc[:,s:e] - MIref.T.iloc[:,s:e], clim=(-0.4, 0.4),
                                cmap=generate_cmap(['#d10a3f', '#FFFFFF', '#1310cc']),
                                aspect="auto")
        heatmap_ax.set_title(labels[i])

        if i < len(samples) -1:
            xtickoff_ax(ax=heatmap_ax)
        else:
            pltxticks_subplot2grid(0, e-s, figstart, figend, 10, ax=heatmap_ax) 
            
        plt.colorbar(img, cax=colorbar_ax)


def plot_multi_insulation_score(samples, labels, nrow, nrow_now, nrow_feature,
                                     figstart, figend, s, e,
                                     colspan_plot, colspan_colorbar, colspan_full):
    for i, sample in enumerate(samples):
        MI = sample.getMultiInsulationScore()
        
        heatmap_ax  = plt.subplot2grid((nrow, colspan_full), (i + nrow_now, 0), rowspan=nrow_feature, colspan=colspan_plot)
        colorbar_ax = plt.subplot2grid((nrow, colspan_full), (i + nrow_now, colspan_plot +1), rowspan=nrow_feature, colspan=colspan_colorbar)

        img = heatmap_ax.imshow(MI.T.iloc[:,s:e], clim=(0.4, 1.0),
                         cmap=generate_cmap(['#d10a3f', '#FFFFFF', '#1310cc']),
                         aspect="auto")
        heatmap_ax.set_title(labels[i])

        if i < len(samples) -1:
            xtickoff_ax(ax=heatmap_ax)
        else:
            pltxticks_subplot2grid(0, e-s, figstart, figend, 10, ax=heatmap_ax) 

        plt.colorbar(img, cax=colorbar_ax)

def plot_single_insulation_score(samples, labels, nrow, nrow_now, nrow_feature,
                                 figstart, figend, s, e,
                                 colspan_plot, colspan_colorbar, colspan_legend, colspan_full):
    Matrix = getInsulationScoreOfMultiSample(samples, labels)

    # Insulation score heatmap
    heatmap_ax  = plt.subplot2grid((nrow, colspan_full), (nrow_now, 0), rowspan=nrow_feature, colspan=colspan_plot)
    colorbar_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, colspan_plot +1), rowspan=nrow_feature, colspan=colspan_colorbar)

    img = heatmap_ax.imshow(Matrix.T.iloc[:,s:e], clim=(0.4, 1.0),
                           cmap=generate_cmap(['#d10a3f', '#FFFFFF', '#1310cc']),
                           aspect="auto")
    plot_xy_axis_and_title_of_feature_heatmap(heatmap_ax, labels, "Insulation score")

    plt.colorbar(img, cax=colorbar_ax)

    nrow_now += nrow_feature

    # Insulation score barplot
    heatmap_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, 0), rowspan=nrow_feature, colspan=colspan_plot)
    legend_ax = plt.subplot2grid((nrow, colspan_full), (nrow_now, colspan_plot), rowspan=nrow_feature, colspan=colspan_legend)

    for i in range(len(samples)):
        heatmap_ax.plot(Matrix.iloc[s:e,i], label=labels[i])
    heatmap_ax.set_xlim([figstart, figend])
    pltxticks_subplot2grid(figstart, figend, figstart, figend, 10, ax=heatmap_ax)

    plot_legend_in_subplot2grid(heatmap_ax, legend_ax)


def main():
    parser = argparse.ArgumentParser()
    tp = lambda x:list(map(str, x.split(':')))
    parser.add_argument("input",  help="<Input directory>:<label>", type=tp, nargs='*')
    parser.add_argument("-o", "--output", help="Output prefix", type=str, default="output")
    parser.add_argument("-c", "--chr", help="chromosome", type=str)
    parser.add_argument("--type", help="normalize type (default: SCALE)", type=str, default="SCALE")
    parser.add_argument("--distance", help="distance for DI (default: 500000)", type=int, default=500000)
    parser.add_argument("-r", "--resolution", help="resolution (default: 25000)", type=int, default=25000)
    parser.add_argument("-s", "--start", help="start bp (default: 0)", type=int, default=0)
    parser.add_argument("-e", "--end",   help="end bp (default: 1000000)", type=int, default=1000000)
    parser.add_argument("--multi",       help="plot MultiInsulation Score", action='store_true')
    parser.add_argument("--multidiff",   help="plot differential MultiInsulation Score", action='store_true')
    parser.add_argument("--compartment", help="plot Compartment (eigen)", action='store_true')
    parser.add_argument("--di",    help="plot Directionaly Index", action='store_true')
    parser.add_argument("--drf",   help="plot Directional Relative Frequency", action='store_true')
    parser.add_argument("--drf_right",  help="(with --drf) plot DirectionalRelativeFreq (Right)", action='store_true')
    parser.add_argument("--drf_left",   help="(with --drf) plot DirectionalRelativeFreq (Left)", action='store_true')
    parser.add_argument("--triangle_ratio_multi",   help="plot Triangle ratio multi", action='store_true')
    parser.add_argument("-d", "--vizdistancemax", help="max distance in heatmap", type=int, default=0)
    parser.add_argument("--v4c",   help="plot virtual 4C from Hi-C data", action='store_true')
    parser.add_argument("--vmax", help="max value of color bar (default: 50)", type=int, default=50)
    parser.add_argument("--vmin", help="min value of color bar (default: 0)", type=int, default=0)
    parser.add_argument("--vmax_ratio", help="max value of color bar for logratio (default: 1)", type=int, default=1)
    parser.add_argument("--vmin_ratio", help="min value of color bar for logratio (default: -1)", type=int, default=-1)
    parser.add_argument("--anchor", help="(for --v4c) anchor point", type=int, default=500000)
    parser.add_argument("--xsize", help="xsize for figure (default: max(length/2M, 10))", type=int, default=0)

    args = parser.parse_args()
#    print(args)

    dirs = []
    labels = []
    for input in args.input:
        dirs.append(input[0])
        if (len(input) >1):
            labels.append(input[1])
        else:
            labels.append("")

    if len(dirs) == 0:
        print ("Error: specify input data (-i).")
        exit()
    if (args.drf or args.multidiff or args.triangle_ratio_multi):
        if len(dirs) < 2:
            print ("Error: --drf|--multidiff|--triangle_ratio_multi requires >= 2 samples.")
            exit(1)
    if args.chr is None:
        print ("Error: specify chromosome (-c).")
        exit()

    chr = args.chr
    resolution = args.resolution
    type = args.type

    figstart = args.start
    figend = args.end
    length = figend - figstart
    if (length <= 0):
        print ("Error: end < start.")
        exit(1)

    s = int(figstart / resolution)
    e = int(figend   / resolution)
    binnum = e - s
    vmax = args.vmax
    vmin = args.vmin

    print ("chr: " + chr + ", resolution: " + str(resolution) +  ", width: " + str(length) + ", " + str(binnum) + " bins.")

    samples = get_samples(dirs, chr, type, resolution)
    nsample = len(samples)

    nrow_heatmap = 2
    nrow_eigen = 1
    nrow_feature = 1
    colspan_plot = 24
    colspan_colorbar = 1
    colspan_legend = 6
    colspan_full = colspan_plot + colspan_legend

    ### Plot
    figsize_x = set_figsize_x(args.xsize, figstart, figend)
    if (args.multi or args.multidiff or args.v4c):
        nrow = nrow_heatmap + nrow_eigen + nsample * nrow_feature
#        figsize_y = 6 + nsample * nrow_feature
        figsize_y = nrow * 2
        plt.figure(figsize=(figsize_x, figsize_y))
    elif args.triangle_ratio_multi:
        nrow = nrow_heatmap + nrow_eigen + (nsample-1) * (nrow_heatmap + nrow_feature*2)
        figsize_y = nrow * 2
        plt.figure(figsize=(figsize_x, figsize_y))
    else:
        nrow = nrow_heatmap + nrow_eigen + nrow_feature * 2
        figsize_y = nrow * 2
#        figsize_y = 10
        plt.figure(figsize=(figsize_x, figsize_y))

    nrow_now = 0

    # Hi-C Map
    plot_HiC_Map(nrow, nrow_now, nrow_heatmap, samples[0], labels[0], dirs[0],
                type, resolution, vmax, vmin, figstart, figend, args.vizdistancemax,
                colspan_plot, colspan_colorbar, colspan_full)
    nrow_now += nrow_heatmap

    # Compartment
    plot_PC1(nrow, nrow_now, nrow_eigen, samples[0], labels[0], 
             s, e, colspan_plot, colspan_full)
    nrow_now += nrow_eigen

    # Directional Frequency Ratio
    if (args.drf):  
        plot_directional_relative_frequency(samples, labels,  nrow, nrow_now, nrow_feature, 
                                            s, e, figstart, figend, resolution, 
                                            args.drf_right, args.drf_left,
                                            colspan_plot, colspan_colorbar, colspan_legend, colspan_full)
    # TriangleRatioMulti
    elif (args.triangle_ratio_multi):  
        plot_triangle_ratio_multi(samples, labels, nrow, nrow_now, nrow_heatmap, nrow_feature,
                                  s, e, figstart, figend, args.vizdistancemax, resolution,
                                  args.vmin, args.vmax, args.vmin_ratio, args.vmax_ratio,
                                  colspan_plot, colspan_colorbar, colspan_legend, colspan_full)
    # Directionality Index
    elif (args.di):  
        plot_directionality_index(samples, labels, nrow, nrow_now, nrow_feature,
                                 s, e, figstart, figend, args.distance,
                                 colspan_plot, colspan_colorbar, colspan_legend, colspan_full)
    # Compartment
    elif (args.compartment): 
        plot_compartment_heatmap(samples, labels, nrow, nrow_now, nrow_feature,
                                 s, e, figstart, figend,
                                 colspan_plot, colspan_colorbar, colspan_legend, colspan_full)
    # virtual 4c
    elif (args.v4c): 
        plot_virtual_4c(samples, labels, nrow, nrow_now, nrow_feature,
                        s, e, figstart, figend, args.anchor, resolution, vmin, vmax,
                        colspan_plot, colspan_full)
    # Multi Insulation score
    elif (args.multi):    
        plot_multi_insulation_score(samples, labels, nrow, nrow_now, nrow_feature,
                                    figstart, figend, s, e,
                                    colspan_plot, colspan_colorbar, colspan_full)
    # differential Multi Insulation score
    elif (args.multidiff):
        plot_differential_multi_insulation_score(samples, labels, nrow, nrow_now, nrow_feature,
                                                 figstart, figend, s, e,
                                                 colspan_plot, colspan_colorbar, colspan_full)
    # Single Insulation score
    else:                  
        plot_single_insulation_score(samples, labels, nrow, nrow_now, nrow_feature,
                                     figstart, figend, s, e,
                                     colspan_plot, colspan_colorbar, colspan_legend, colspan_full)

    plt.subplots_adjust(hspace=0.5)
    plt.savefig(args.output + ".pdf")

if(__name__ == '__main__'):
    main()
