#!python

import argparse
import pysam
import sys
import datetime
import deviaTE.deviaTE_pileup as pileup
from deviaTE.deviaTE_IO import get_data, execute


# set up parser and arguments
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True, help='alignment file to be analysed')
parser.add_argument('--family', required=True, help='TE family to analyse')
parser.add_argument('--library', help='path to alternative reference sequence library')
parser.add_argument('--output', help='name of output table')
parser.add_argument('--sample_id', help='sample identifier')
parser.add_argument('--annotation', help='alternative annotation in gff-format')
parser.add_argument('--no_freq_corr', action='store_true', help='deactivate frequency correction for internal deletions')
parser.add_argument('--hq_threshold', help='mapping quality threshold to consider as high quality', type=int, default=20)
norm_group = parser.add_mutually_exclusive_group()
norm_group.add_argument('--rpm', action='store_true', help='normalize all abundances by reads per million')
norm_group.add_argument('--single_copy_genes', help='comma-separated names of single-copy genes in reference')
args = parser.parse_args()

if args.sample_id is None:
    args.sample_id = args.input

if args.output is None:
    args.output = args.sample_id + '.' + args.family

if args.library is None:
    args.library = get_data('lib/te_library')

if args.annotation is None:
    args.annotation = get_data('lib/te_features.gff')

print('Starting analysis of: ' + args.family + ' in: ' + args.input)
sample = pileup.Sample(name=args.sample_id, fam=args.family, lib=args.library,
                       anno=args.annotation, bam=args.input)

# get reference of family
sample.get_ref_multi()
if sample.refseq is None:
    raise ValueError('Selected family is not in references. Typo? ' + args.family)

# get annotation for family
sample.get_anno()
if len(sample.fam_anno) is 0:
    print('no annotaions found for: ' + args.family)

# fill sites of sample
c = 0
for base in sample.refseq:
    sample.sites.append(pileup.Site(pos=c, refbase=base, sid=args.sample_id, fam=args.family))
    c += 1

# perform pileup
sample.perform_pileup(hq_threshold=args.hq_threshold)

# sum coverage
for s in sample.sites:
    s.sum_coverage()
cov_list = [s.cov for s in sample.sites]
mean_cov = sum(cov_list) / len(cov_list)

for s in sample.sites:
    s.is_snp(min_count=mean_cov * 0.1, min_freq=0.1, A=s.A, C=s.C, G=s.G, T=s.T, cov=s.cov)
    s.filter_IND(att='int_del', min_count=mean_cov * 0.02)
    s.filter_IND(att='ins', min_count=mean_cov * 0.1)
    s.filter_IND(att='delet', min_count=mean_cov * 0.1)
    s.filter_trunc(min_trunc_count=mean_cov * 0.05)
    s.check_annotation(anno=sample.fam_anno)


# collect int dels and add phys cov to base cov
sample.collect_int_dels()
sample.calc_phys_cov()


# estimate freqs., correct and parse new col
mrl = sample.mean_read_length()
cf = pileup.correction_factor(mrl)
for int_del in sample.int_dels:
    if args.no_freq_corr == False:
        int_del.est_freq(sites=sample.sites, corr_factor=cf)
    else:
        int_del.est_freq(sites=sample.sites, corr_factor=1)
    int_del.write_freq(sites=sample.sites)


# write raw results
ihat = 'NA'
comm = ' '.join(sys.argv)
timestamp = datetime.datetime.now().strftime("%y-%m-%d_%H:%M:%S")
sample.write_frame(out=args.output + '.raw', insertions=ihat, command=comm, t=timestamp, norm='raw')

# normalize by rpm
if args.rpm == True:
    print('Normalization: reads per million')
    rpm_fac = sample.get_norm_fac_rpm()
    for s in sample.sites:
        s.normalize(norm_factor=rpm_fac)
    sample.write_frame(out=args.output, insertions=ihat, command=comm, t=timestamp, norm='rpm')

# normalize with set of single copy genes
elif args.single_copy_genes is not None:
    print('Normalization: single copy genes')
    scg_fac = sample.get_norm_fac_scg(genes=args.single_copy_genes)
    ihat = sample.estimate_insertions(norm_factor=scg_fac)
    for s in sample.sites:
        s.normalize(norm_factor=scg_fac)
    sample.write_frame(out=args.output, insertions=ihat, command=comm, t=timestamp, norm='scg')

else:
    print("Normalization: none (values are raw abundances)")
    mv = ['mv', args.output + '.raw', args.output]
    execute(command=' '.join(mv))

print('analysis completed; table written to: ' + args.output)
