#!/usr/bin/env python
'''
    BESST - Scaffolder for genomic assemblies 
    Copyright (C) 2013  Kristoffer Sahlin

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.

'''

import sys
import os
import errno
from time import time
import argparse
import shutil
from itertools import izip

from BESST import errorhandle
from BESST import decide_approach

def mkdir_p(path):
    try:
        os.makedirs(path)
    except OSError as exc: # Python >2.5
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else: raise

def ReadInContigseqs(contigfile,contig_filter_length, Information):
    cont_dict = {}
    k = 0
    temp = ''
    accession = ''
    for line in contigfile:
        if line[0] == '>' and k == 0:
            accession = line[1:].strip().split()[0]
            cont_dict[accession] = ''
            k += 1
        elif line[0] == '>':
            cont_dict[accession] = temp
            temp = ''
            accession = line[1:].strip().split()[0]
        else:
            temp += line.strip()
    cont_dict[accession] = temp

    # for contig in cont_dict.keys():
    #     print 'Initial length:', len(cont_dict[contig])
    if contig_filter_length:
        singled_out = 0
        for contig in cont_dict.keys():
            if len(cont_dict[contig]) < contig_filter_length:
                del cont_dict[contig]
                singled_out += 1
        print >> Information, 'Number of contigs discarded from further analysis (with -filter_contigs set to {0}): {1}'.format(contig_filter_length,singled_out)
    return(cont_dict)

def main(options):
    import pysam

    from BESST import Parameter
    from BESST import CreateGraph as CG
    from BESST import MakeScaffolds as MS
    from BESST import GenerateOutput as GO
    from BESST import libmetrics
    #from BESST import bam_parser
    from BESST.diploid import get_haplotype_regions

    tot_start = time()
    Contigs = {} # contig dict that stores contig objects 
    Scaffolds = {}     # scaffold dict with contig objects for easy fetching of all contigs in a scaffold
    small_contigs = {}
    small_scaffolds = {}

    param = Parameter.parameter() # object containing all parameters (user specified, defaulted and comuted along the way.)
    param.scaffold_indexer = 1 # global indicator for scaffolds, used to index scaffolds when they are created
    param.rel_weight = options.relweight
    param.multiprocess = options.multiprocess
    param.no_score = options.no_score
    param.score_cutoff = options.score_cutoff
    param.max_extensions = options.max_extensions

    if args.devel:
        h = guppy.hpy()
        initial = h.heap()
        print 'Initial:\n{0}'.format(initial)


    param.development = options.devel
    options.output = options.output + '/BESST_output'

    mkdir_p(options.output)
    if options.plots:
        mkdir_p(os.path.join(options.output, 'plots'))

    param.information_file = open(options.output + '/Statistics.txt', 'w')
    param.plots = options.plots
    Information = param.information_file
    open(options.output + '/haplotypes.fa', 'w')
    #Read in the sequences of the contigs in memory
    start = time()
    C_dict = ReadInContigseqs(open(options.contigfile, 'r'), options.contig_filter_length, Information)
    elapsed = time() - start
    print >> Information, 'Time elapsed for reading in contig sequences:' + str(elapsed) + '\n'

    print 'Number of initial contigs:', len(C_dict)
    if options.devel:
        h = guppy.hpy()
        after_ctgs = h.heap()
        print 'After ctg seqs:\n{0}'.format(after_ctgs)

    if options.haplotype:
        get_haplotype_regions.main(options.output, C_dict, options.kmer, options.mmer, 0.8)
        sys.exit()
    #iterate over libraries
    param.first_lib = True
    param.path_threshold = options.path_threshold
    for i in xrange(len(options.bamfiles)):
        param.bamfile = options.bamfiles[i]
        param.orientation = options.orientation[i] if options.orientation != None else 'fr'
        param.mean_ins_size = options.mean[i] if options.mean != None else None
        param.ins_size_threshold = options.threshold[i] if options.threshold != None else None
        param.edgesupport = options.edgesupport[i] if options.edgesupport != None else 5
        param.read_len = options.readlen[i] if options.readlen != None else None
        param.output_directory = options.output
        param.std_dev_ins_size = options.stddev[i] if options.stddev != None else None
        param.contig_threshold = options.minsize[i] if options.minsize != None else None
        param.cov_cutoff = options.covcutoff[i] if options.covcutoff != None else None
        param.hapl_ratio = options.haplratio
        param.hapl_threshold = options.haplthreshold
        param.detect_haplotype = options.haplotype
        param.detect_duplicate = options.duplicate
        param.fosmidpool = options.fosmidpool
        param.extend_paths = options.extendpaths
        print >> Information, '\nPASS ' + str(i + 1) + '\n\n'

        #get library statistics
        with pysam.Samfile(param.bamfile, 'rb') as bam_file:
            #if param.first_lib:
            #print bam_file.references
            #print map(bam_file.gettid, bam_file.references)
            start = time()
            param.contig_index = dict(izip(xrange(len(bam_file.references)), bam_file.references))  
            libmetrics.get_metrics(bam_file, param, Information)
            
            elapsed = time() - start
            print >> Information, 'Time elapsed for getting libmetrics, iteration ' + str(i) + ': ' + str(elapsed) + '\n'


            if args.devel:
                h = guppy.hpy()
                after_libmetrics = h.heap()
                print 'After libmetrics\n{0}\n'.format(after_libmetrics)

            #create graph
            print 'Creating contig graph with library: ', param.bamfile
            start = time()
            (G, G_prime) = CG.PE(Contigs, Scaffolds, Information, C_dict, param, small_contigs, small_scaffolds, bam_file)      #Create graph, single out too short contigs/scaffolds and store them in F
            elapsed = time() - start
            print >> Information, 'Total time for CreateGraph-module, iteration ' + str(i) + ': ' + str(elapsed) + '\n'
            decide_approach.decide_scaffolding_procedure(Scaffolds,small_scaffolds, param)

            if args.devel:
                h = guppy.hpy()
                after_ctg_graph = h.heap()
                print 'After contig graph\n{0}\n'.format(after_ctg_graph)

        #print G, not G == None, not None, None == False
        if not G and not param.no_score:
            print '0 contigs/super-contigs passed the length criteria of this step. Exiting and printing results.. '
            

        print 'Constructed contig graph. Start BESST algorithm for creating scaffolds. '
        start = time()
        MS.Algorithm(G, G_prime, Contigs, small_contigs, Scaffolds, small_scaffolds, Information, param)   # Make scaffolds, store the complex areas (consisting of contig/scaffold) in F, store the created scaffolds in Scaffolds, update Contigs
        elapsed = time() - start
        print >> Information, 'Time elapsed for making scaffolds, iteration ' + str(i) + ': ' + str(elapsed) + '\n'

        print 'Writing out scaffolding results for step', i + 1, ' ...'

        F = [] #list of (ordered) lists of tuples containing (contig_name, direction, position, length, sequence). The tuple is a contig within a scaffold and the list of tuples is the scaffold.
        for scaffold_ in small_scaffolds:
            S_obj = small_scaffolds[scaffold_]
            list_of_contigs = S_obj.contigs   #list of contig objects contained in scaffold object
            F = GO.WriteToF(F, small_contigs, list_of_contigs)
        for scaffold_ in Scaffolds.keys(): #iterate over keys in hash, so that we can remove keys while iterating over it
            ###  Go to function and print to F
            ### Remove Scaf_obj from Scaffolds and Contig_obj from contigs
            S_obj = Scaffolds[scaffold_]
            list_of_contigs = S_obj.contigs   # List of contig objects contained in scaffold object
            F = GO.WriteToF(F, Contigs, list_of_contigs)


        GO.PrintOutput(F, Information, param.output_directory, param, i + 1)

        if not options.separate_repeats:
            destination = open("{0}/pass{1}/Scaffolds_pass{1}.fa".format(param.output_directory, str(i+1) ), 'wb')
            #destination = open(outfile,'wb')
            shutil.copyfileobj(open("{0}/pass{1}/Scaffolds-pass{1}.fa".format(param.output_directory, str(i+1) ), 'rb'), destination)
            shutil.copyfileobj(open("{0}/repeats.fa".format(param.output_directory, str(i+1) ), 'rb'), destination)
            destination.close()
            os.remove( "{0}/pass{1}/Scaffolds-pass{1}.fa".format(param.output_directory, str(i+1)) )
        
        param.first_lib = False   #not the first lib any more    

    if not options.separate_repeats:
        os.remove( "{0}/repeats.fa".format(param.output_directory ) )

    ### Calculate stats for last scaffolding step    
    scaf_lengths = [Scaffolds[scaffold_].s_length for scaffold_ in Scaffolds.keys()]
    sorted_lengths = sorted(scaf_lengths, reverse=True)
    scaf_lengths_small = [small_scaffolds[scaffold_].s_length for scaffold_ in small_scaffolds.keys()]
    sorted_lengths_small = sorted(scaf_lengths_small, reverse=True)
    N50, L50 = CG.CalculateStats(sorted_lengths, sorted_lengths_small, param, Information)
    param.current_L50 = L50
    param.current_N50 = N50

    elapsed = time() - tot_start
    print >> Information, 'Total time for scaffolding: ' + str(elapsed) + '\n'

    print 'Finished\n\n '

    return()



parser = argparse.ArgumentParser(prog='BESST', description="BESST - scaffolder for genomic assemblies.")

required = parser.add_argument_group('required', 'Required arguments')
libstats = parser.add_argument_group('read_library', 'Library parameters that can be set, e.g., read length, insert size mean/std_dev, coverage etc.')
haplotypes = parser.add_argument_group('diploid', 'Options involving detection of split allele contigs in diploid assemblies.')
parameters = parser.add_argument_group('parameters', 'Parameters/variables/threshold involved in BESST.')
other = parser.add_argument_group('other', 'Various other parameters')
group = parser.add_mutually_exclusive_group(required=False)

## required
required.add_argument("-c", dest="contigfile", type=str, required=True,
                      help="Fasta file containing contigs.")

required.add_argument("-f", dest="bamfiles", type=str, nargs='+', required=True,
                      help="Path(s) to bamfile(s).")


## libstats

libstats.add_argument("-r", dest="readlen", type=int, nargs='+',
                  help="Mean read length of libraries. ")

libstats.add_argument("-m", dest="mean", type=float, nargs='+',
                  help="Mean insert size of libraries.")

libstats.add_argument("-s", dest="stddev", type=float, nargs='+',
                  help="Estimated standard deviation of libraries.")

libstats.add_argument("-z", dest="covcutoff", type=int, nargs='+',
                  help="User specified coverage cutoff. (Manually filter \
                  out contigs with coverage over this value)")


## parameters

parameters.add_argument("-w", dest="relweight", type=int, default=3,
                  help="Threshold value for the relative weight of an edge (deprecated)")

parameters.add_argument("-T", dest="threshold", type=int, nargs='+',
                  help="Threshold value filter out reads that are mapped too far apart given insert size. ")

parameters.add_argument("-e", dest="edgesupport", type=int, nargs='+',
                  help="Threshold value for the least nr of links that is needed to create an edge. Default for all libs: 5 ")

parameters.add_argument("-k", dest="minsize", type=int, nargs='+',
                    help="Contig size threshold for the library (all contigs below this size is discarded from the \
                    'large contigs' scaffolding, but included in pathfinding). Default: Set to same as -T parameter")

parameters.add_argument("-filter_contigs", dest="contig_filter_length", type=int, default=None,
                    help="Contigs under this size is discarded from all scaffolding (including pathfinding).\
                    they are also not included in the resulting scaffold output fasta file and any\
                    of the other files as well as all statistics. Default: All contigs are included")


parameters.add_argument("--iter", dest="path_threshold", type=int, default=100000,
                  help="The number of iterations performed in breadth first search for placing smaller contigs. ")

parameters.add_argument("--score_cutoff", dest="score_cutoff", type=float, default=1.5,
                  help="Only store paths with score larger than score_cutoff. ")

parameters.add_argument("--max_extensions", dest="max_extensions", type=int,
                  help="Maximum number of (large) scaffolds to search for paths extensions with. ")


## haplotypes

haplotypes.add_argument("-a", dest="haplratio", type=float, default=1.3,
                  help="Maximum length ratio for merging of haplotypic regions.")

haplotypes.add_argument("-b", dest="haplthreshold", type=int, default=3,
                  help="Number of standard deviations over mean/2 of coverage to allow for clasification of haplotype.\
                   Example: contigs with under mean/2 + 3sigma are indicated as potential haplotypes (tested later) \
                   if -b 3")

haplotypes.add_argument("-g", dest="haplotype", action="store_true",
                  help="Haplotype detection function, default = off")

## other 

required.add_argument("--orientation", dest="orientation", type=str, nargs='+',
                      help="Mapped orientation for each library given with -f option. Valid input is either fr (forward reverse orientation) or rf (reverse forward orientation).")

other.add_argument("-o", dest="output", type=str, default='.',
                  help="Path to output directory. BESST will create a folder named 'BESST_output' in the directory given by the path.")

other.add_argument("-d", dest='duplicate' , action="store_false",
                  help="Deactivate sequencing duplicates detection")

other.add_argument("-K", dest="kmer", type=int, default=50,
                  help="k-mer size used in de brujin graph for obtaining contigs in assembly, default 50")

other.add_argument("-M", dest="mmer", type=int, default=40,
                  help="m-mer usted for creating connection graph. Should be set lower than k-mer size ")

group.add_argument("-y", dest='extendpaths', action="store_false",
                  help="Deactivate pathfinder module for including smaller contigs.")

other.add_argument("-q", dest="multiprocess", action="store_true",
                help="Parallellize work load of path finder module in case of multiple processors available.\
                  ",)

group.add_argument("--no_score", dest="no_score", action="store_true",
                help="Statistical scoring is not performed. BESST instead searches for paths between contigs.\
                  ",)

other.add_argument("-devel", action="store_true",
                  help="Run in development mode (bug checking and memory usage etc.)")

other.add_argument("-plots", action="store_true",
                  help="Plot coverage, histograms of scores e.t.c.")

other.add_argument("--separate_repeats", dest="separate_repeats", action="store_true",
                  help="Separates contigs classified as repeats by BESST into a file 'repeats.fa'. They are \
                  not included in the main scaffolding output file with this flag specified.")

other.add_argument('--version', action='version', version='%(prog)s 1.3.3')




#TEMPORARY parameters here, remove after spruce assembly
other.add_argument("-t", dest="transcriptalignfile",
                  help="file of contigs", type=str)

other.add_argument("-p", dest="fosmidpool", default=None,
                  help="""Specify that data comes from a fosmid pool. This parameter sets the number of
                  links that the second strongest link must have in order to not do any scaffolding
                  with the region.""", type=int)


if len(sys.argv) == 1:
    sys.argv.append('-h')

args = parser.parse_args()

errorhandle.check_module('mathstats')
errorhandle.parse_check(args, parser)

if args.devel:
    import guppy

if args.plots:
    try:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
    except ImportError:
        args.plots = False


##
# Call BESST algorithm

main(args)

if args.devel:
    h = guppy.hpy()
    at_end= h.heap()

    print 'At end:\n{0}\n'.format(at_end)



