#!python
# -*- coding: utf-8 -*-
# Copyright(c) Ryuichiro Nakato <rnakato@iqb.u-tokyo.ac.jp>
# All rights reserved.

import argparse
import matplotlib.pyplot as plt
import seaborn as sns
import sys
from scipy import ndimage
from custardpy.HiCmodule import JuicerMatrix
from custardpy.generateCmap import *
from custardpy.PlotModule import *
from custardpy.DirectionalRelativeFreq import make3dmatrixRatio

def main():
    parser = argparse.ArgumentParser()
    tp = lambda x:list(map(str, x.split(':')))
    parser.add_argument("input", help="<Input matrix>", type=tp, nargs='*')
    parser.add_argument("-o", "--output", help="Output prefix", type=str, default="output")
    parser.add_argument("-r", "--resolution", help="resolution", type=int, default=25000)
    parser.add_argument("-s", "--start", help="start bp", type=int, default=0)
    parser.add_argument("-e", "--end", help="end bp", type=int, default=1000000)
    parser.add_argument("--vmax", help="max value of color bar", type=int, default=1)
    parser.add_argument("--vmin", help="min value of color bar", type=int, default=-1)
    parser.add_argument("--xsize", help="xsize for figure", type=int, default=3)
    parser.add_argument("--ysize", help="ysize (* times of samples)", type=int, default=3)

    args = parser.parse_args()
#    print(args)

    dirs = []
    labels = []
    for input in args.input:
        dirs.append(input[0])
        if (len(input) >1):
            labels.append(input[1])
        else:
            labels.append("")

    if len(dirs) == 0:
        print ("Error: specify input data (-i).")
        exit()

    resolution = args.resolution
    figstart = args.start
    figend = args.end
    s = int(figstart / resolution)
    e = int(figend   / resolution)
    vmax = args.vmax
    vmin = args.vmin

    print (resolution)

    samples = []
    for dir in dirs:
        samples.append(JuicerMatrix("RPM", dir, resolution))

    smooth_median_filter = 3
    mt1 = make3dmatrixRatio([samples[0], samples[1]], smooth_median_filter)
    mt2 = make3dmatrixRatio([samples[2], samples[3]], smooth_median_filter)

    matrix = np.triu(mt1[0], k=1) + np.tril(mt2[0], k=-1)

    if (labels[1] != "" and labels[3] != ""):
        label = labels[1] + "(upper) - " + labels[3] + "(lower)"
    else:
        label = ""

    plt.subplot(1, 1, 1)
    drawHeatmapSquare(matrix, resolution,
                      figstart=figstart, figend=figend,
                      vmax=vmax, vmin=vmin,
                      cmap=generate_cmap(['#1310cc', '#FFFFFF', '#d10a3f']),
                      label=label, xticks=False)
    plt.savefig(args.output + ".pdf")

if(__name__ == '__main__'):
    main()
