#!python
# -*- coding: utf-8 -*-
# Copyright(c) Ryuichiro Nakato <rnakato@iqb.u-tokyo.ac.jp>
# All rights reserved.

import argparse
import os
import sys
import pandas as pd
import subprocess
import numpy as np
import random
from pybedtools import BedTool
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

def calculate_ratio(border_bed, gene_bed, allgene_bed):
    gene_all = allgene_bed.intersect(border_bed, u=True)
    gene_deg = gene_bed.intersect(border_bed, u=True)

    gene_num = len(gene_all)
    deg_num = len(gene_deg)

    return deg_num / gene_num if gene_num != 0 else 0

def back_function(border, gene_bed, permutation_times, len_border, allgene_bed):
    dist_randomRatio = []
    for _ in range(permutation_times):
        randomBorder = border.sample(len_border)
        randomBorder_bed = BedTool.from_dataframe(randomBorder)
        randomRatio = calculate_ratio(randomBorder_bed, gene_bed, allgene_bed)
        dist_randomRatio.append(randomRatio)

    d = np.array(dist_randomRatio)
    return [np.quantile(d, quantile) for quantile in [0.25, 0.75, 0.05, 0.95, 0.025, 0.975]]

def plot_graph(df, outputname):
    plt.rcParams['font.size'] = '12'

    plt.plot(df["Distance"]/1000,df["Ratio"],"m")
    plt.fill_between(df["Distance"]/1000,df["low50"],df["high50"],color="grey",alpha=0.5)
    plt.fill_between(df["Distance"]/1000,df["low90"],df["high90"],color="grey",alpha=0.3)
    plt.fill_between(df["Distance"]/1000,df["low95"],df["high95"],color="grey",alpha=0.1)
    plt.xlabel("Distance from TAD boundary (kb)",fontsize=15)
    plt.ylabel("Fraction of DEGs",fontsize=15)

    q50 = mpatches.Patch(color='grey',alpha=0.5,label='50% quantile')
    q90 = mpatches.Patch(color='grey',alpha=0.3,label='90% quantile')
    q95 = mpatches.Patch(color='grey',alpha=0.1,label='95% quantile')
    plt.legend(handles=[q50, q90, q95],fontsize=12)

    plt.savefig(outputname)

def set_border(border, i):
    border_temp = border.copy()
    border_temp[1] = np.maximum(border_temp[1] - i, 0)
    border_temp[2] = border_temp[2] + i
    return border_temp

def permutation_test_ratio(border, allborder, gene_bed, allgene_bed, permutation_times, max_distance, distance_step):
    select_ratios = []
    random_ratios = []
    positions = []

    for i in range(0, max_distance +1, distance_step):
        print(f"Distance {i} bp")
        border_temp = set_border(border, i)
        allborder_temp = set_border(allborder, i)

        border_temp_bed = BedTool.from_dataframe(border_temp)
        select_ratios.append(calculate_ratio(border_temp_bed, gene_bed, allgene_bed))

        len_border = len(border)
        random_ratios.append(back_function(allborder_temp, gene_bed, permutation_times, len_border, allgene_bed))
        positions.append(i)

    random_ratios = np.array(random_ratios)    # convert list of tuples to numpy array
    random_ratios = random_ratios.transpose()  # transpose array to get separate arrays for each quantile

    return select_ratios, random_ratios, positions

def main():
    parser = argparse.ArgumentParser()
    tp = lambda x:list(map(str, x.split(':')))
    parser.add_argument("--border_test",  help="<TAD boundary to be tested (BED format)>", type=str, default=None)
    parser.add_argument("--border_control",  help="<TAD boundary as background (BED format)>", type=str, default=None)
    parser.add_argument("--gene_test",  help="<Genes to be tested (BED format)>", type=str, default=None)
    parser.add_argument("--gene_control",  help="<Genes as background (BED format)>", type=str, default=None)
    parser.add_argument("-o", "--output", help="Output name (*.pdf or *.png, default: output.pdf)", type=str, default="output.pdf")
    parser.add_argument("-n", help="Number of permutation (default: 1000)", type=int, default=1000)
    parser.add_argument("--maxdistance", help="Max distance (bp, default: 300000)", type=int, default=300000)
    parser.add_argument("--step", help="Step of distance (bp, default: 10000)", type=int, default=10000)

    args = parser.parse_args()
#    print(args)

    if args.border_test is None:
        print ("Error: specify --border_test.")
        parser.print_help()
        exit()
    if args.border_control is None:
        print ("Error: specify --border_control.")
        parser.print_help()
        exit()
    if args.gene_test is None:
        print ("Error: specify --gene_test.")
        parser.print_help()
        exit()
    if args.gene_control is None:
        print ("Error: specify --gene_control.")
        parser.print_help()
        exit()

    print ("   TAD boundary to be tested: " + args.border_test)
    print ("   TAD boundary as background: " + args.border_control)
    print ("   Genes to be tested: " + args.gene_test)
    print ("   Genes as background: " + args.gene_control)
    print ("   Permutation time: " + str(args.n))
    print ("   Max distance: " + str(args.maxdistance) + " bp")
    print ("   Step of distance: " + str(args.step) + " bp")
    print ("   Output file: " + args.output)

    border = pd.read_csv(args.border_test, sep="\t", header=None)
    allborder = pd.read_csv(args.border_control, sep="\t", header=None)
    gene_bed = BedTool(args.gene_test)
    allgene_bed = BedTool(args.gene_control)
    permutation_times = args.n
    outputname = args.output
    max_distance = args.maxdistance
    distance_step = args.step

    select_ratios, random_ratios, positions = permutation_test_ratio(border,
                                                                     allborder,
                                                                     gene_bed,
                                                                     allgene_bed,
                                                                     permutation_times,
                                                                     max_distance,
                                                                     distance_step)

    df = pd.DataFrame({
        'Distance': positions,
        'Ratio': select_ratios,
        'low50': random_ratios[0],
        'high50': random_ratios[1],
        'low90': random_ratios[2],
        'high90': random_ratios[3],
        'low95': random_ratios[4],
        'high95': random_ratios[5]
    })

    plot_graph(df, outputname)


if(__name__ == '__main__'):
    main()
