#!python
# -*- coding: utf-8 -*-
# Copyright(c) Ryuichiro Nakato <rnakato@iqb.u-tokyo.ac.jp>
# All rights reserved.


import argparse
import matplotlib.pyplot as plt
import seaborn as sns
from custardpy.HiCmodule import JuicerMatrix
from custardpy.InsulationScore import getInsulationScoreOfMultiSample
from custardpy.generateCmap import *
from custardpy.loadData import *
from custardpy.PlotModule import *

#import pdb

def main():
    parser = argparse.ArgumentParser()
    tp = lambda x:list(map(str, x.split(':')))
    parser.add_argument("input", help="<Input matrix>:<label>", type=tp, nargs='*')
    parser.add_argument("-o", "--output", help="Output prefix", type=str, default="output")
    parser.add_argument("--log", help="logged count", action='store_true')
    parser.add_argument("-r", "--resolution", help="resolution", type=int, default=25000)
    parser.add_argument("-s", "--start", help="start bp", type=int, default=0)
    parser.add_argument("-e", "--end", help="end bp", type=int, default=1000000)
    parser.add_argument("--vmax", help="max value of color bar", type=int, default=50)
    parser.add_argument("--vmin", help="min value of color bar", type=int, default=0)
    parser.add_argument("-d", "--vizdistancemax", help="max distance in heatmap", type=int, default=0)
    parser.add_argument("--xsize", help="xsize for figure", type=int, default=10)
    parser.add_argument("--ysize", help="ysize (* times of samples)", type=int, default=3)

    args = parser.parse_args()
#    print(args)

    inputfiles = []
    labels = []
    for input in args.input:
        inputfiles.append(input[0])
        if (len(input) >1):
            labels.append(input[1])
        else:
            labels.append("")

    if len(inputfiles) != 2:
        print ("Error: specify two samples (-i).")
        exit()

    resolution = args.resolution
    figstart = args.start
    figend = args.end
    s = int(figstart / resolution)
    e = int(figend   / resolution)
    vmax = args.vmax
    vmin = args.vmin
    if (args.log):
        vmax = np.log1p(vmax)
        vmin = np.log1p(vmin)

    print (resolution)
    samples = []
    for inputfile in inputfiles:
        samples.append(JuicerMatrix("RPM", inputfile, resolution))

    plt.subplot(1, 1, 1)
    if (args.log):
        m1 = samples[0].getlog()
        m2 = samples[1].getlog()
    else:
        m1 = samples[0].getmatrix()
        m2 = samples[1].getmatrix()

    matrix = np.triu(m1) + np.tril(m2, k=-1)

    if (labels[0] != "" and labels[1] != ""):
        label = labels[0] + "(upper) - " + labels[1] + "(lower)"
    else:
        label = ""

    drawHeatmapSquare(matrix, resolution,
                      figstart=figstart, figend=figend,
                      vmax=vmax, vmin=vmin,
                      label=label, xticks=True)
    plt.savefig(args.output + ".pdf")


if(__name__ == '__main__'):
    main()
