#!/usr/bin/env python
'''
    BESST - Scaffolder for genomic assemblies 
    Copyright (C) 2013  Kristoffer Sahlin

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.

'''
from __future__ import print_function
import sys
import os
import errno
from time import time
import argparse
import shutil
try:
    # Python 2
    from itertools import izip
except ImportError:
    # Python 3
    izip = zip

from BESST import errorhandle
from BESST import decide_approach

def mkdir_p(path):
    try:
        os.makedirs(path)
    except OSError as exc: # Python >2.5
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else: raise

def ReadInContigseqs(contigfile,contig_filter_length, Information):
    cont_dict = {}
    k = 0
    temp = ''
    accession = ''
    for line in contigfile:
        if line[0] == '>' and k == 0:
            accession = line[1:].strip().split()[0]
            cont_dict[accession] = ''
            k += 1
        elif line[0] == '>':
            cont_dict[accession] = temp
            temp = ''
            accession = line[1:].strip().split()[0]
        else:
            temp += line.strip()
    cont_dict[accession] = temp

    # for contig in cont_dict.keys():
    #     print 'Initial length:', len(cont_dict[contig])
    print('Initial number of contigs: {}. '.format(len(cont_dict.keys())), file=Information)

    if contig_filter_length:
        singled_out = 0
        for contig in cont_dict.keys():
            if len(cont_dict[contig]) < contig_filter_length:
                del cont_dict[contig]
                singled_out += 1
        print('Number of contigs discarded from further analysis (with -filter_contigs set to {0}): {1}'.format(contig_filter_length,singled_out), file=Information)
    return(cont_dict)

def main(options):
    import pysam

    from BESST import Parameter
    from BESST import CreateGraph as CG
    from BESST import MakeScaffolds as MS
    from BESST import GenerateOutput as GO
    from BESST import libmetrics
    #from BESST import bam_parser
    from BESST.diploid import get_haplotype_regions

    tot_start = time()
    Contigs = {} # contig dict that stores contig objects 
    Scaffolds = {}     # scaffold dict with contig objects for easy fetching of all contigs in a scaffold
    small_contigs = {}
    small_scaffolds = {}

    param = Parameter.parameter() # object containing all parameters (user specified, defaulted and comuted along the way.)
    param.scaffold_indexer = 1 # global indicator for scaffolds, used to index scaffolds when they are created
    param.multiprocess = options.multiprocess
    param.no_score = options.no_score
    param.score_cutoff = options.score_cutoff
    param.max_extensions = options.max_extensions
    param.NO_ILP = options.NO_ILP
    param.FASTER_ILP = options.FASTER_ILP
    param.dfs_traversal = True if not options.bfs_traversal else False
    param.print_scores = options.print_scores
    param.min_mapq = options.min_mapq
    param.max_contig_overlap = options.max_contig_overlap
    param.cov_cutoff = options.covcutoff
    param.lower_cov_cutoff = options.lower_covcutoff
    if args.devel:
        h = guppy.hpy()
        initial = h.heap()
        print('Initial:\n{0}'.format(initial))


    param.development = options.devel
    options.output = options.output + '/BESST_output'

    mkdir_p(options.output)
    if options.plots:
        mkdir_p(os.path.join(options.output, 'plots'))

    param.information_file = open(options.output + '/Statistics.txt', 'w')
    param.plots = options.plots
    Information = param.information_file
    open(options.output + '/haplotypes.fa', 'w')
    #Read in the sequences of the contigs in memory
    start = time()
    C_dict = ReadInContigseqs(open(options.contigfile, 'r'), options.contig_filter_length, Information)

    elapsed = time() - start
    print('Time elapsed for reading in contig sequences:' + str(elapsed) + '\n', file=Information)

    print('Number of initial contigs:', len(C_dict))
    if options.devel:
        h = guppy.hpy()
        after_ctgs = h.heap()
        print('After ctg seqs:\n{0}'.format(after_ctgs))

    if options.haplotype:
        get_haplotype_regions.main(options.output, C_dict, options.kmer, options.mmer, 0.8)
        sys.exit()
    #iterate over libraries
    param.first_lib = True
    param.path_threshold = options.path_threshold
    for i in range(len(options.bamfiles)):
        param.pass_number = i+1
        param.bamfile = options.bamfiles[i]
        param.orientation = options.orientation[i] if options.orientation != None else 'fr'
        param.mean_ins_size = options.mean[i] if options.mean != None else None
        param.ins_size_threshold = options.threshold[i] if options.threshold != None else None
        param.edgesupport = options.edgesupport[i] if options.edgesupport != None else None
        param.read_len = options.readlen[i] if options.readlen != None else None
        param.output_directory = options.output
        param.std_dev_ins_size = options.stddev[i] if options.stddev != None else None
        param.contig_threshold = options.minsize[i] if options.minsize != None else None
        param.hapl_ratio = options.haplratio
        param.hapl_threshold = options.haplthreshold
        param.detect_haplotype = options.haplotype
        param.detect_duplicate = options.duplicate
        param.extend_paths = options.extendpaths
        print('\nPASS ' + str(i + 1) + '\n\n', file=Information)

        #get library statistics
        with pysam.Samfile(param.bamfile, 'rb') as bam_file:
            #if param.first_lib:
            #print bam_file.references
            #print map(bam_file.gettid, bam_file.references)
            start = time()
            param.contig_index = dict(izip(range(len(bam_file.references)), bam_file.references))  
            libmetrics.get_metrics(bam_file, param, Information)
            
            elapsed = time() - start
            print('Time elapsed for getting libmetrics, iteration ' + str(i) + ': ' + str(elapsed) + '\n', file=Information)


            if args.devel:
                h = guppy.hpy()
                after_libmetrics = h.heap()
                print('After libmetrics\n{0}\n'.format(after_libmetrics))

            #create graph
            print('Creating contig graph with library: ', param.bamfile)
            start = time()
            (G, G_prime) = CG.PE(Contigs, Scaffolds, Information, C_dict, param, small_contigs, small_scaffolds, bam_file)      #Create graph, single out too short contigs/scaffolds and store them in F
            elapsed = time() - start
            print('Total time for CreateGraph-module, iteration ' + str(i) + ': ' + str(elapsed) + '\n', file=Information)
            decide_approach.decide_scaffolding_procedure(Scaffolds,small_scaffolds, param)

            if args.devel:
                h = guppy.hpy()
                after_ctg_graph = h.heap()
                print('After contig graph\n{0}\n'.format(after_ctg_graph))

        #print G, not G == None, not None, None == False
        if not G and not param.no_score:
            print('0 contigs/super-contigs passed the length criteria of this step. Exiting and printing results.. ')
            

        print('Constructed contig graph. Start BESST algorithm for creating scaffolds. ')
        start = time()
        MS.Algorithm(G, G_prime, Contigs, small_contigs, Scaffolds, small_scaffolds, Information, param)   # Make scaffolds, store the complex areas (consisting of contig/scaffold) in F, store the created scaffolds in Scaffolds, update Contigs
        elapsed = time() - start
        print('Time elapsed for making scaffolds, iteration ' + str(i) + ': ' + str(elapsed) + '\n', file=Information)

        print('Writing out scaffolding results for step', i + 1, ' ...')

        F = [] #list of (ordered) lists of tuples containing (contig_name, direction, position, length, sequence). The tuple is a contig within a scaffold and the list of tuples is the scaffold.
        for scaffold_ in small_scaffolds:
            S_obj = small_scaffolds[scaffold_]
            list_of_contigs = S_obj.contigs   #list of contig objects contained in scaffold object
            F = GO.WriteToF(F, small_contigs, list_of_contigs)
        for scaffold_ in Scaffolds.keys(): #iterate over keys in hash, so that we can remove keys while iterating over it
            ###  Go to function and print to F
            ### Remove Scaf_obj from Scaffolds and Contig_obj from contigs
            S_obj = Scaffolds[scaffold_]
            list_of_contigs = S_obj.contigs   # List of contig objects contained in scaffold object
            F = GO.WriteToF(F, Contigs, list_of_contigs)


        GO.PrintOutput(F, Information, param.output_directory, param, i + 1)

        if not options.separate_repeats:
            destination = open("{0}/pass{1}/Scaffolds_pass{1}.fa".format(param.output_directory, str(i+1) ), 'wb')
            #destination = open(outfile,'wb')
            shutil.copyfileobj(open("{0}/pass{1}/Scaffolds-pass{1}.fa".format(param.output_directory, str(i+1) ), 'rb'), destination)
            shutil.copyfileobj(open("{0}/repeats.fa".format(param.output_directory, str(i+1) ), 'rb'), destination)
            destination.close()
            os.remove( "{0}/pass{1}/Scaffolds-pass{1}.fa".format(param.output_directory, str(i+1)) )

        all_parameter_settings = param.get_params()  
        print(Information, all_parameter_settings, file=Information)

        param.first_lib = False   #not the first lib any more    

    if not options.separate_repeats:
        os.remove( "{0}/repeats.fa".format(param.output_directory ) )

    ### Calculate stats for last scaffolding step    
    scaf_lengths = [Scaffolds[scaffold_].s_length for scaffold_ in Scaffolds.keys()]
    sorted_lengths = sorted(scaf_lengths, reverse=True)
    scaf_lengths_small = [small_scaffolds[scaffold_].s_length for scaffold_ in small_scaffolds.keys()]
    sorted_lengths_small = sorted(scaf_lengths_small, reverse=True)
    N50, L50 = CG.CalculateStats(sorted_lengths, sorted_lengths_small, param, Information)
    param.current_L50 = L50
    param.current_N50 = N50

    elapsed = time() - tot_start
    print('Total time for scaffolding: ' + str(elapsed) + '\n', file=Information)

    print('Finished\n\n ')

    return()



parser = argparse.ArgumentParser(prog='BESST', description="BESST - scaffolder for genomic assemblies.", epilog='''Source code and manual: http://github.com/ksahlin/BESST\n\nPlease cite:\nSahlin K, Vezzi F, Nystedt B, Lundeberg J, Arvestad L (2014) BESST--efficient scaffolding of large fragmented assemblies. BMC Bioinformatics 15, 281\n''', formatter_class=argparse.RawDescriptionHelpFormatter)

required = parser.add_argument_group('Required arguments')
libstats = parser.add_argument_group('Library information', 'Library parameters that can be set, e.g., read length, insert size mean/std_dev, coverage etc.')
haplotypes = parser.add_argument_group('Ploidy', 'Options involving detection of split allele contigs in diploid assemblies.')
parameters = parser.add_argument_group('Parameters/variables/threshold involved in BESST.')
other = parser.add_argument_group('Various other parameters')
group = parser.add_mutually_exclusive_group(required=False)

## required
required.add_argument("-c", dest="contigfile", type=str, required=True,
                      help="Fasta file containing contigs.")

required.add_argument("-f", dest="bamfiles", type=str, nargs='+', required=True,
                      help="Path(s) to bamfile(s).")

required.add_argument("-orientation", dest="orientation", type=str, nargs='+', required=True,
                      help="Mapped orientation for each library given with -f option. Valid input is either fr (forward reverse orientation) or rf (reverse forward orientation).")


## libstats

libstats.add_argument("-r", dest="readlen", type=int, nargs='+',
                  help="Mean read length of libraries. ")

libstats.add_argument("-m", dest="mean", type=float, nargs='+',
                  help="Mean insert size of libraries.")

libstats.add_argument("-s", dest="stddev", type=float, nargs='+',
                  help="Estimated standard deviation of libraries.")

libstats.add_argument("-z", dest="covcutoff", type=int, default=None,
                  help="User specified coverage cutoff. (Manually filter \
                  out contigs with coverage over this value)")

libstats.add_argument("-z_min", dest="lower_covcutoff", type=int, default=0,
                  help="User specified coverage cutoff. (Manually filter \
                  out contigs with coverage under this value)")

## parameters

parameters.add_argument("-T", dest="threshold", type=int, nargs='+',
                  help="Threshold value filter out reads that are mapped too far apart given insert size. ")

parameters.add_argument("-e", dest="edgesupport", type=int, nargs='+',
                  help="Threshold value for the least nr of links that is needed to create an edge. Default for all libs: Inferred by BESST (see value in output in Statistics.txt). ")

parameters.add_argument("-k", dest="minsize", type=int, nargs='+',
                    help="Contig size threshold for the library (all contigs below this size is discarded from the \
                    'large contigs' scaffolding, but included in pathfinding). Default: Set to same as -T parameter")

parameters.add_argument("-filter_contigs", dest="contig_filter_length", type=int, default=None,
                    help="Contigs under this size is discarded from all scaffolding (including pathfinding).\
                    they are also not included in the resulting scaffold output fasta file and any\
                    of the other files as well as all statistics. Default: All contigs are included")

parameters.add_argument("--min_mapq", dest="min_mapq", type=int, default=11,
                      help="Lowest mapping quality allowed in order to use the read pair (both reads needs to have equal or bigger mapq than this value). This value is compared to the mapping quality column in the BAM file. ")


parameters.add_argument("--iter", dest="path_threshold", type=int, default=100000,
                  help="The number of iterations performed in breadth first search for placing smaller contigs. ")

parameters.add_argument("--score_cutoff", dest="score_cutoff", type=float, default=1.5,
                  help="Only store paths with score larger than score_cutoff. ")

parameters.add_argument("--max_extensions", dest="max_extensions", type=int,
                  help="Maximum number of (large) scaffolds to search for paths extensions with. ")

## related to scaffolding high heterozygozity assemblies with lots of split haplotype contigs

haplotypes.add_argument("-a", dest="haplratio", type=float, default=1.3,
                  help="Maximum length ratio for merging of haplotypic regions.")

haplotypes.add_argument("-b", dest="haplthreshold", type=int, default=3,
                  help="Number of standard deviations over mean/2 of coverage to allow for clasification of haplotype.\
                   Example: contigs with under mean/2 + 3sigma are indicated as potential haplotypes (tested later) \
                   if -b 3")

other.add_argument("-K", dest="kmer", type=int, default=50,
                  help="k-mer size used in de brujin graph for obtaining contigs in assembly, default 50")

other.add_argument("-M", dest="mmer", type=int, default=40,
                  help="m-mer usted for creating connection graph. Should be set lower than k-mer size ")

haplotypes.add_argument("-g", dest="haplotype", action="store_true",
                  help="Haplotype detection function, default = off")


## other


other.add_argument("-o", dest="output", type=str, default='.',
                  help="Path to output directory. BESST will create a folder named 'BESST_output' in the directory given by the path.")

other.add_argument("-d", dest='duplicate' , action="store_false",
                  help="Deactivate sequencing duplicates detection")

group.add_argument("-y", dest='extendpaths', action="store_false",
                  help="Deactivate pathfinder module for including smaller contigs.")

other.add_argument("-q", dest="multiprocess", action="store_true",
                help="Parallellize work load of path finder module in case of multiple processors available.\
                  ",)

group.add_argument("--no_score", dest="no_score", action="store_true",
                help="Statistical scoring is not performed. BESST instead searches for paths between contigs.\
                  ",)

other.add_argument("-devel", action="store_true",
                  help="Run in development mode (bug checking and memory usage etc.)")

other.add_argument("-plots", action="store_true",
                  help="Plot coverage, histograms of scores e.t.c.")

other.add_argument("--separate_repeats", dest="separate_repeats", action="store_true",
                  help="Separates contigs classified as repeats by BESST into a file 'repeats.fa'. They are \
                  not included in the main scaffolding output file with this flag specified.")

parameters.add_argument("--NO_ILP", dest="NO_ILP", action="store_true",
                  help="Warning: Do not use this option! This option is included only for benchmarking purposes of old BESST algorithm. This option will give poor results as it mimics earlier versions of BESST.")

parameters.add_argument("--FASTER_ILP", dest="FASTER_ILP", action="store_true",
                  help="Faster but worse performing heuristic solution to solving ILPs. May be used if BESST is too slow. However, lowering --iter is usually more effective to reduce scaffolding time.")

parameters.add_argument("--print_scores", dest="print_scores", action="store_true",
                  help="Print BESST scores on edges in the Scaffolding graph.")

other.add_argument("--dfs_traversal", dest="dfs_traversal", action="store_true",
                help="Depth first search with DP in the contig graph (default).")

other.add_argument("--bfs_traversal", dest="bfs_traversal", action="store_true",
                help="Choose a breadth first search in the contig graph. Default is the new depth first search with a DP approch that seems to outperform\
                previous traversals. This option is kept for evaluation but is not reccomended to specify.")

other.add_argument("-max_contig_overlap", dest="max_contig_overlap", type=int, default=200,
                  help="BESST checks for overlapping ends in contigs that are adjacent in a scaffold. This parameter sets the maximum identical overlap to search for, default is 200. ")


other.add_argument('--version', action='version', version='%(prog)s 2.2.8')




if len(sys.argv) == 1:
    sys.argv.append('-h')

args = parser.parse_args()

errorhandle.check_module('mathstats')
errorhandle.parse_check(args, parser)

if args.devel:
    import guppy

if args.plots:
    try:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
    except ImportError:
        args.plots = False


##
# Call BESST algorithm

main(args)

if args.devel:
    h = guppy.hpy()
    at_end= h.heap()

    print('At end:\n{0}\n'.format(at_end))



