#! /usr/bin/env Rscript 
# input is table from deviaTE_analyse

# set the location to look for packages
# conda needs to be first if this script is run from a conda environment!
prioritize_conda <- function(lib_tree){
  cpath <-  grep('conda', lib_tree, value=TRUE, ignore.case=TRUE)
  ifelse(length(cpath) == 0, return(lib_tree), return(rev(c(lib_tree, cpath))))
}

new_tree <- prioritize_conda(lib_tree=.libPaths())
.libPaths(new_tree)

rep<-"http://cran.rstudio.com/"
if (!suppressMessages(suppressWarnings(require(ggplot2)))) install.packages('ggplot2', repos=rep)
if (!suppressMessages(suppressWarnings(require(grid)))) install.packages('grid', repos=rep)
if (!suppressMessages(suppressWarnings(require(gridExtra)))) install.packages('gridExtra', repos=rep)
if (!suppressMessages(suppressWarnings(require(argparse)))) install.packages('argparse', repos=rep)
if (!suppressMessages(suppressWarnings(require(data.table)))) install.packages('data.table', repos=rep)
if (!suppressMessages(suppressWarnings(require(cowplot)))) install.packages('cowplot', repos=rep)
library(ggplot2)
library(grid)
library(gridExtra)
library(argparse)
library(data.table)
suppressPackageStartupMessages(suppressWarnings(library(cowplot)))

parser <- ArgumentParser()
parser$add_argument('--input', required=TRUE, help='input as created by TEplotter')
parser$add_argument('--output', default='', help='name for the output file')
parser$add_argument('--free_yaxis', action='store_true', help='free y-axis e.g. if one sample has much higher cov')
parser$add_argument('--fontsize', default=as.numeric(14))
parser$add_argument('--out_format', default='pdf', help='format of output, either pdf or eps')
args <- parser$parse_args()

# color palette
base_colors <- c('A'='#1b9e77','C'='#d95f02','G'='#7570b3','T'='#e6ab02', 'in'='#66a61e','del'='#e7298a')


fam_name <- function(df){
    # returns the TEfamily for a sample
    return(as.character(df$TEfam[1]))
}


fam_len <- function(df){
    # return the length of the TEfam from one sample
    return(nrow(df))
}


snp_positions <- function(df){
  # get snp positions
  snps <- subset(df, snp==TRUE | refsnp==TRUE, sel='pos')
  snp_pos <- as.numeric(snps$pos) + 1
  return(snp_pos)
}


ref_allele <- function(snp_index, refbases, len){
  if (refbases[snp_index] == 'A'){return(snp_index)
    } else if (refbases[snp_index] == 'C'){return(snp_index + len)
    } else if (refbases[snp_index] == 'G'){return(snp_index + 2*len)
    } else {return(snp_index + 3*len)}
}


snp_frame <- function(df){
    # results in a df with 4 * length and 5 cols
    # concatenate the base coverage and create factor columns
    l <- fam_len(df)
    counts_col <- c(df$A, df$C, df$G, df$T)
    pos_col <- rep(1:l, 4) # change 1 to 0 and l to l-1
    fam_col <- rep(df$TEfam, 4)
    sample_col <- rep(df$sample_id, 4)
    base_col <- c(rep('A', l), rep('C', l), rep('G', l), rep('T', l))
    polys <- snp_positions(df)

    # remove reference allele from SNPs
    polys_ref_indices <- sapply(polys, ref_allele, df$refbase, l)
    polys_ref_indices_sorted <- sort(as.numeric(polys_ref_indices))
    base_col[polys_ref_indices_sorted] <- 'X' #multi-indexing
        
    lf <- data.frame(fam_col, sample_col, base_col, pos_col, counts_col)
    snps <- subset(lf, pos_col %in% polys)
    snps_noref <- subset(snps, base_col != 'X')
    return(snps_noref)
}


base_plot <- function(snps, sample){
    # creates base plot object
    famlen <- max(sample$pos) + 1
    lim <- ceiling(famlen/50)
    plot_title <- paste(sample$TEfam[1], sample$sample_id[1])
    # order of levels defines the order of stacked bars; snp at bottom
    snps$base_col <- factor(snps$base_col, levels=c('A','C','G','T'))
    
    # polygon total coverage
    pg <- data.table(px=c(sample$pos, rev(sample$pos)), 
                     py=c(sample$coverage, rep(0,length(sample$coverage)))) 
    
    # polygon hq coverage
    pg_hq <- data.table(px=c(sample$pos, rev(sample$pos)), 
                        py=c(sample$hq_coverage, rep(0,length(sample$hq_coverage)))) 
    
    baseplot <- ggplot() +
        geom_polygon(data=pg, aes(x=px, y=py), fill='lightgrey', color='lightgrey') +
        geom_polygon(data=pg_hq, aes(x=px, y=py), fill='grey', color='grey') +
        geom_bar(data=snps, aes(x=pos_col, y=counts_col, fill=base_col), stat='identity', width=3) +
        labs(fill='') +
        ylab('') +
        xlab('') +
        ggtitle(plot_title) +
        scale_x_continuous(limits=c(-lim, famlen+lim)) +
        scale_fill_manual(values=base_colors, breaks=c('A','C','G','T', 'in', 'del')) +
        theme_minimal(base_size=fontsize) +
        theme(
            legend.position='none',
            legend.background=element_blank(),
            legend.key=element_blank(),
            plot.title = element_text(hjust = 0.5)
            )

    # overwrite clipping of annotations
    #baseplot <- ggplot_gtable(ggplot_build(baseplot))
    #baseplot$layout$clip[baseplot$layout$name == "panel"] <- "off"

    return(list(baseplot))
}


make_legend <- function(){
    # create a fake dataframe to produce a consistent legend for all plots
    fake_df <- data.table(base_col=c('A', 'C', 'G', 'T'), counts_col=rep(1, 4), pos_col=1:4)

    legplot <- ggplot(fake_df) +
        geom_bar(aes(x=pos_col, y=counts_col, fill=base_col), stat='identity', width=1) +
        labs(fill='') +
        scale_fill_manual(values=base_colors, breaks=c('A','C','G','T', 'in', 'del')) +
        # theme_minimal(base_size=fontsize) +
        theme(legend.position='bottom') +
        guides(fill = guide_legend(nrow = 1))

    legplot_list <- list(legplot)
    del <- list(data.frame(x=c(1), xend=c(2), y=c(0), yend=c(1), fam_col='t', sample_col='s', fac='del'))
    ins <- list(data.frame(x=c(1), xend=c(2), y=c(0), yend=c(1), fam_col='t', sample_col='s', fac='in'))
    legplot_list <- mapply(add_rect, p=legplot_list, coords=ins)
    legplot_list <- mapply(add_rect, p=legplot_list, coords=del)

    legend <- get_legend(legplot_list[[1]])
    return(legend)
}


int_del_data <- function(df){
    # generate coordinates to plot internal deletions, size ~ counts
    feat_df <- subset(df, int_del != 'NA', sel=c('TEfam', 'sample_id', 'int_del'))
    
    if (nrow(feat_df) == 0){
        return(data.frame(x=0, xend=1, y=0, yend=0, size=0))
    }

    # split at those positions that contain multiple features
    # then split again to get start, end and count
    feats <- unlist(strsplit(feat_df$int_del, ','))
    feat_list <- strsplit(feats, ':')

    xsrt <- as.numeric(lapply(feat_list, '[[', 1))
    xend <- as.numeric(lapply(feat_list, '[[', 2))
    size <- as.numeric(lapply(feat_list, '[[', 3))

    ysrt <- lapply(xsrt, function(i){if (i>5){
        mean(df$coverage[(i-5):i])}
        else {df$coverage[i]} })
    yend <- lapply(xend, function(i){mean(df$coverage[i:(i+5)])})

    # add factors to plot in correct facet
    fam <- rep(feat_df$TEfam[1], length(xsrt))
    sample <- rep(feat_df$sample_id[1], length(xsrt))

    return(data.frame(x=xsrt, y=as.numeric(ysrt), xend=xend, yend=as.numeric(yend),
     size=size, fam_col=fam, sample_col=sample))
}


insertion_data <- function(df){
    # coordinates for plotting insertions
    ins_df <- subset(df, insertion != 'NA', sel=c('TEfam', 'sample_id', 'insertion'))

    if (nrow(ins_df) == 0){
        return(data.frame(x=0, xend=1, y=0, yend=1, fac=NA))
    }

    # split multiple insertions at one site
    ins <- unlist(strsplit(ins_df$insertion, ','))
    ins_list <- strsplit(ins, ':')

    xsrt <- as.numeric(lapply(ins_list, '[[', 1))
    xend <- as.numeric(lapply(ins_list, '[[', 2))
    yend <- as.numeric(lapply(ins_list, '[[', 3))
    ysrt <- 0

    # add factors to plot in correct facet
    fam <- rep(ins_df$TEfam[1], length(xsrt))
    sample <- rep(ins_df$sample_id[1], length(xsrt))
    fac <- rep('in', length(xsrt))

    return(data.frame(x=xsrt, xend=xend, y=ysrt, yend=yend, fam_col=fam, sample_col=sample, fac=fac))
}


deletion_data <- function(df){
    # coordinates for plotting deletions
    del_df <- subset(df, deletion != 'NA', sel=c('TEfam', 'sample_id', 'deletion'))

    if (nrow(del_df) == 0){
        return(data.frame(x=0, xend=1, y=0, yend=1, fac=NA))
    }

    # split multiple insertions at one site
    dels <- unlist(strsplit(del_df$deletion, ','))
    dels_list <- strsplit(dels, ':')

    xsrt <- as.numeric(lapply(dels_list, '[[', 1))
    xend <- as.numeric(lapply(dels_list, '[[', 2))
    yend <- as.numeric(lapply(dels_list, '[[', 3))
    ysrt <- 0

    # add factors to plot in correct facet
    fam <- rep(del_df$TEfam[1], length(xsrt))
    sample <- rep(del_df$sample_id[1], length(xsrt))
    fac <- rep('del', length(xsrt))

    return(data.frame(x=xsrt, xend=xend, y=ysrt, yend=yend, fam_col=fam, sample_col=sample, fac=fac))
}


truncation_data <- function(df){
    # generate coordinates to plot truncations
    tl_df <- subset(df, truncation_left != 0, sel=c('pos', 'TEfam', 'sample_id', 'truncation_left'))
    tr_df <- subset(df, truncation_right != 0, sel=c('pos', 'TEfam', 'sample_id', 'truncation_right'))
    
    # remove truncs at pos 0 and end, only if trunc frame is not empty
    if (nrow(tl_df) != 0){
      if (tl_df[1,]$pos == 0){tl_df <- tl_df[-1,]}
    }
    if (nrow(tr_df) != 0){
      if (tr_df[nrow(tr_df),]$pos == nrow(df) - 1){tr_df <- tr_df[-nrow(tr_df),]}
    }
    
    l_xend <- as.numeric(tl_df$pos) + 1
    l_xsrt <- l_xend - (fam_len(df)/50)
    l_yend <- as.numeric(lapply(l_xend, function(i){df$coverage[i+1]}))
    l_ysrt <- as.numeric(l_yend) + (max(df$coverage)/5)
    l_alpha <- as.numeric(tl_df$truncation_left)

    r_xsrt <- as.numeric(tr_df$pos)
    r_xend <- r_xsrt + (fam_len(df)/50)
    r_ysrt <- as.numeric(lapply(r_xsrt, function(i){df$coverage[i]}))
    r_yend <- as.numeric(r_ysrt) + (max(df$coverage)/5)
    r_alpha <- as.numeric(tr_df$truncation_right)

    left_truncs <- data.frame(x=l_xsrt, xend=l_xend, y=l_ysrt, yend=l_yend, alpha=l_alpha)
    right_truncs <- data.frame(x=r_xsrt, xend=r_xend, y=r_ysrt, yend=r_yend, alpha=r_alpha)
    truncs <- rbind(left_truncs, right_truncs)
    
    # add factors to plot in correct facet
    fam_col <- rep(tl_df$TEfam[1], length(truncs$x))
    sample_col <- rep(tl_df$sample_id[1], length(truncs$x))
    truncs <- cbind(truncs, fam_col, sample_col)

    return(truncs)
}


add_curve <- function(p, coords){
    # takes a plot and a frame of int_dels
    # and adds the curves, need to return the new plot as list
    q <- p + geom_curve(data=coords, mapping=aes(x=x, y=y, xend=xend, yend=yend, 
        size=size), curvature=-0.2, ncp=10, show.legend=FALSE) +
    scale_size(range = sizerange)

    return(list(q))
}


add_rect <- function(p, coords){
    # takes a plot and frame of coords
    # plots insertions or deletions
    q <- p + geom_rect(data=coords, mapping=aes(xmin=x, xmax=xend, 
        ymin=y, ymax=yend, fill=fac))

    return(list(q))
}


add_segment <- function(p, coords){
    # takes a plot and frame of truncations
    # adds the angled segments and alpha as aes
    q <- p + geom_segment(data=coords, mapping=aes(x=x, y=y, xend=xend, yend=yend, 
        alpha=alpha), lty='dashed', show.legend=FALSE) +
    scale_alpha(range = alpharange)

    return(list(q))
}


polygons <- function(a, fam_len){
  f <- 0:fam_len
  px <- c(f, rev(f))
  py <- rep(0, length(f))
  py[a$pos] <- 1
  py <- c(py, rep(0, length(f)))
  rects <- data.table(px=px, py=py, group=rep(a$annotation[1], length(px)))
  return(rects)
}


plot_anno <- function(df){
    famlen <- max(df$pos) + 1
    lim <- ceiling(famlen/50)
    anno_df <- subset(df, !(annotation %in% c(NA, 'intergenic')), sel=c('pos', 'annotation'))
    
    if (nrow(anno_df) != 0 ){
      
      anno_df$annotation <- as.character(anno_df$annotation) # cast vector to string
      annos <- split(anno_df, f=anno_df$annotation) # split by annotation type
      annotation_polygons <- do.call(rbind, lapply(annos, polygons, fam_len=famlen)) # call polygons and fuse to one frame
    
      ann <- ggplot() + 
          geom_polygon(data=annotation_polygons, aes(x=px, y=py, fill=group)) +
          scale_fill_grey() + 
          labs(fill='') +
          xlab('') +
          ylab('') +
          scale_x_continuous(limits=c(-lim, famlen+lim)) +
          theme_void(base_size=fontsize) +
          theme(
              legend.position='bottom',
              legend.key=element_blank(), 
              legend.background=element_blank(),
              axis.text = element_text(size = rel(0.8), colour = 'grey50'),
              axis.text.x = element_text(color = 'grey50', angle=0),
              axis.text.y = element_blank(),
              axis.ticks.x = element_line(color= 'grey50'),
              axis.line.x = element_line(color='grey50')
              )
    } else {ann <- ggplot(data.frame()) + geom_blank()}
    return(ann)
}


find_max <- function(df_list, var, ff){
    # returns the maximum coverage per family
    maxima <- sapply(df_list, function(df){suppressWarnings(max(df[,var]))})
    # remove -INF if one sample is empty
    maxima[maxima==-Inf] <- 0
    # make a frame with maxima and family factor, split by factor
    maxima_list <- split(data.table(maxima, ff), ff)
    max_per_fam <- as.numeric(lapply(maxima_list, function(df){max(df$maxima)}))
    max_per_fam[max_per_fam==0] <- 1
    return(max_per_fam)
}


### --- ### --- ###

inp_arg <- args$input
out_arg <- args$output
axis_arg <- args$free_yaxis
fontsize <- as.numeric(args$fontsize)
out_format <- args$out_format

# replace commandline arguments for debugging
# inp_arg <- 'jockey_dmel.fastq.DMLINEJA'
# out_arg <- 'test.pdf'
# axis_arg <- FALSE
# fontsize <- 14
# out_format <- 'pdf'


# load full dataframe with fread and name the columns - 20 cols
print(paste('Loading data:', inp_arg, '; creating plot'))
frame <- fread(cmd=paste('grep -v ^#', inp_arg), data.table=FALSE, header=FALSE,
 colClasses=list(character=c(1,2,4,14,15,18,19,20)))

norm <- names(fread(cmd=paste('grep ", norm:"', inp_arg), sep=','))
norm_meth <- norm[length(norm)]

names(frame) <- c('TEfam', 'sample_id', 'pos', 'refbase', 'A', 'C', 'G', 'T',
                  'coverage', 'phys_coverage', 'hq_coverage', 'snp', 'refsnp', 'int_del', 'int_del_freq',
                  'truncation_left', 'truncation_right', 'insertion', 'deletion', 'annotation')

frame$TEfam <- as.factor(frame$TEfam)
frame$sample_id <- as.factor(frame$sample_id)

nfam <- length(levels(frame$TEfam))
nsample <- length(levels(frame$sample_id))


# parameters to set for plotting ----------------
sizerange <- c(0.01, 2)
alpharange <- c(0.01, 1)

base_w <- 60 * nfam
base_h <- 20 * nsample

anno_height <- ((1/10)/nsample)
base_height <- (1 - anno_height) / nsample

plot_heights <- c( rep(base_height, nsample), anno_height )

leg_height <- 0.05 / nsample
yaxs_width <- 0.01 / nfam
# parameters to set for plotting ----------------


# split the whole frame into a list of samples; by TE family and by sample_id
samples <- split(frame, list(frame$TEfam, frame$sample_id))

# generate a frame of snps for each sample
snp_frames <- lapply(samples, snp_frame)

# generate a list of plots for each sample
TEplots <- mapply(base_plot, snps=snp_frames, sample=samples)

# fix y-scale for comparison, free axis FALSE by default
if (axis_arg == FALSE){
  fam_fac <- factor(unlist(lapply(samples, function(s){s[1,'TEfam']})))
  max_cov <- find_max(samples, 'coverage', fam_fac)

  # apply y-axis limit for each family
  for (i in 1:nfam){
    ind <- as.numeric(fam_fac) == i
    TEplots[ind] <- lapply(TEplots[ind], function(p){
      p + scale_y_continuous(limits=c(0, max_cov[i] + 3)) })   
  }
}

# get the plotting coordinates for all features
int_del_coords <- lapply(samples, int_del_data)
ins_coords <- lapply(samples, insertion_data)
del_coords <- lapply(samples, deletion_data)
trunc_coords <- lapply(samples, truncation_data)

# add the features to the plots as geoms
TEplots <- mapply(add_curve, p=TEplots, coords=int_del_coords)
TEplots <- mapply(add_rect, p=TEplots, coords=ins_coords)
TEplots <- mapply(add_rect, p=TEplots, coords=del_coords)
TEplots <- mapply(add_segment, p=TEplots, coords=trunc_coords)

# get coordinates for annotations from the full frame
# then reduce to one species, since they are the same
# create a plot if annotation was given
anno_plots <- list(ggplot(data.frame()) + geom_blank())
anno_plots <- lapply(samples[1:nfam], plot_anno)

# create a list of all plots
# and arrange them (without legends)
plot_list <- append(TEplots, anno_plots)

pdf(NULL)
combo_grob <- suppressWarnings(plot_grid(plotlist=plot_list, ncol=nfam, nrow=nsample + 1, 
    align='hv', axis='l', rel_heights=plot_heights, labels='AUTO', label_size=14))


# create a legend for the plots
# and add it to the grid
legend <- make_legend()
combo_leg <- plot_grid(legend, combo_grob, ncol = 1, nrow = 2, rel_heights = c(leg_height, 1 - leg_height))


# add a shared y-axis, tilted label on an empty background
if (norm_meth == 'norm: raw'){
  yaxis <- ggdraw() + draw_label('coverage [raw]', angle=90, size=fontsize)
} else if (norm_meth == 'norm: rpm') {
  yaxis <- ggdraw() + draw_label('coverage [reads per million]', angle=90, size=fontsize)
} else if (norm_meth == 'norm: scg') {
  yaxis <- ggdraw() + draw_label('coverage [single copy genes]', angle=90, size=fontsize)
} else {
  print('could not identify normalization method')
}
if (nsample > 1) {
  yaxis <- ggdraw() + draw_label('coverage', angle=90, size=fontsize)
} 
combo_y <- plot_grid(yaxis, combo_leg, ncol=2, nrow=1, rel_widths = c(yaxs_width, 1 - yaxs_width))


if (out_format == 'pdf'){
  if (out_arg == ''){out_arg <- paste(inp_arg, '.pdf', sep='')}
  
  ggsave(file=out_arg, plot=combo_y, device='pdf', width=base_w, height=base_h, 
         units='cm', limitsize=FALSE)
} else if (out_format == 'eps'){
  if (out_arg == ''){out_arg <- paste(inp_arg, '.eps', sep='')}
  
  cairo_ps(filename=out_arg, width=base_w/2, height=base_h/2, fallback_resolution=300)
  print(combo_y)
  dev.off()
}

print(paste('plotting completed; written to:', out_arg))

