#!python
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 26 12:57:25 2022

@author: evanqu
"""
import argparse
import os
import sys

# sys.set_int_max_str_digits(0)
# So I don't get weird ValueError: Exceeds the limit (4300) for integer string conversion

import phlame.classify as classify
import phlame.tree as tree
import phlame.makedb as makedb
import phlame.plot as plot
import phlame.countsCMT as CMT

#%%

# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('phlame.py'), '../phlame')))

def print_help():
    print('')
    print('            ...::: PHLAME v1.0 :::...''')
    print('       Evan Qu, Lieberman Lab, MIT. 2025\n''')
    print('Usage: phlame [-h] {classify,makedb,tree} ...\n')
    print('''
Choose one of the operations below for more detailed help.
Example: phlame classify -h

Main operations:
    classify ->  Estimate metagenomic clade-level frequencies using a PHLAME database.
    makedb   ->  Create a PHLAME database.
    tree     ->  Create a phylogenetic tree for a PHLAME database.

Auxiliary operations:
    plot -> Generate informative plots from classify output.
    cmt -> Convert counts files into a candidate mutation table.
    counts -> Convert aligned pileup files into compressed counts matrix format.
            ''')

if __name__ == '__main__':

    # Main help message
    if len(sys.argv) == 1 or sys.argv[1] == '-h' or sys.argv[1] == '--help':
        print_help()
        sys.exit(0)

    
    parser = argparse.ArgumentParser(
                    prog = 'phlame',
                    description = 'What the program does',
                    epilog = 'Text at the bottom of help',
                    add_help = False)
    
    subparsers = parser.add_subparsers(help='Desired operation',dest='operation')
    
    classify_op = subparsers.add_parser('classify', help='Estimate metagenomic clade-level frequencies using a PHLAME database.')
    makedb_op = subparsers.add_parser('makedb', help='Create a PHLAME database.')
    tree_op = subparsers.add_parser('tree', help='Create a phylogenetic tree or .phylip file.')
    plot_op = subparsers.add_parser('plot', help='Generate informative plots from classify output.')
    cmt_op = subparsers.add_parser('cmt', help='Convert counts files into a candidate mutation table.')
    counts_op = subparsers.add_parser('counts', help='Convert aligned pileup files into a compressed counts file.')


    # Classify arguments
    classify_op.add_argument('-i', dest='input', type=str, required=True, 
                          help='Path to input bam file.')
    classify_op.add_argument('-c', dest='classifier', type=str, required=True,
                          help='Path to classifer file.')
    classify_op.add_argument('-r', dest='ref', type=str, required=True,
                            help='Path to reference genome (.fasta).')
    classify_op.add_argument('-l', dest='level', type=str, default=False, required=False,
                          help='Level specification.')
    classify_op.add_argument('-o', dest='output', type=str, required=True,
                          help='Path to output frequencies file (.csv)')
    classify_op.add_argument('-p', dest='outputdata', type=str, default=False, required=False,
                          help='Path to output data file (.pickle.gz)')
    classify_op.add_argument('-m', choices=['bayesian', 'mle'], required=True,
                             help="Inference algorithm to use (default mle).", default="mle")
    classify_op.add_argument('--max_pi', type=float, default=0.3, required=False,
                          help='Maximum pi value to count a lineage as present.')
    classify_op.add_argument('--min_snps', type=int, default=10, required=False,
                          help='Minimum number of present mutations to count a lineage as present.')
    classify_op.add_argument('--min_prob', type=float, default=0.5, required=False,
                          help='Bayesian only: Minimum probability score to count a lineage as present.')
    classify_op.add_argument('--min_hpd', type=int, default=0.15, required=False,
                          help='Bayesian only: Minimum value the highest posterior density interval over divergence must cover to count a lineage as present.')
    classify_op.add_argument('--seed', type=int, required=False, default=False,
                          help='Set random seed for reproducibility.') 
    classify_op.add_argument('--verbose', required=False,  action='store_true', default=True,
                            help='Print progress messages.')

    # Make_classifier arguments
    makedb_op.add_argument('-i', dest='input', type=str, required=True, 
                          help='Path to input candidate mutation table.')
    makedb_op.add_argument('-t', dest='intree', type=str, required=True,
                          help='Path to input phylogeny (Newick format).')
    
    makedb_op.add_argument('-o', dest='outdb', type=str, required=True,
                          help='Path to output classifier.')
    makedb_op.add_argument('-p', dest='outclades', type=str, required=True,
                          help='Path to output clades file.')
    makedb_op.add_argument('-y', dest='outtree', type=str, required=False, default=None,
                            help='Path to output tree labelled with clade names.')
    
    makedb_op.add_argument('-c', dest='inclades', type=str, required=False, default=False,
                        help='Manually define clades to be included in the classifier.')

    makedb_op.add_argument('--outgroup', type=float, default=False, required=False,
                          help='Specify a comma separated list of samples to be considered outgroups.')
    makedb_op.add_argument('--max_outgroup', type=float, default=False, required=False,
                          help='Maximum number of outgroup samples a position can be found in (non-N) to be considered.')

    makedb_op.add_argument('--min_snps', type=int, default=10, required=False,
                          help='Minimum number of SNPs to include a candidate clade')
    makedb_op.add_argument('--maxn', type=float, default=0.1, required=False,
                          help='Maximum percentage of Ns for a position to be considered')
    makedb_op.add_argument('--core', type=float, default=0.9, required=False,
                          help='Minimum percent of samples a position must be found in (non-N) to be considered.')


    makedb_op.add_argument('--min_branchlen', type=float, default=100, required=False,
                          help='Minimum branch length leading up to a clade.')
    makedb_op.add_argument('--min_leaves', type=int, default=3, required=False,
                          help='Minimum number of samples in a clade.')
    makedb_op.add_argument('--min_support', type=float, default=0.75, required=False,
                          help='Minimum bootstrap support for a clade.')
    
    makedb_op.add_argument('--minAF', type=float, default=0.75, required=False,
                          help='Minimum allele frequency to make a base call.')
    makedb_op.add_argument('--min_strand_cov', type=int, default=2, required=False,
                            help='Minimum per-strand coverage across a position to make a base call.')
    makedb_op.add_argument('--qual', type=int, default=-30, required=False,
                            help='Minimum MAPQ score to make a base call.')
    makedb_op.add_argument('--max_frac_ambiguous', type=float, default=0.5, required=False,
                            help='Maximum fraction of ambiguous (N) calls to include a sample.')
    makedb_op.add_argument('--midpoint', required=False,  action='store_true', default=False,
                            help='Root the input tree at the midpoint.')

    # Tree arguments
    tree_op.add_argument('-i', dest='cmt', type=str, required=False, default=None, 
                          help='Path to input candidate mutation table.')
    tree_op.add_argument('-o', dest='outtree', type=str, required=False, default=None,
                          help='Path to output tree.')
    tree_op.add_argument('-p', dest='outphylip', type=str, required=False, default=None,
                          help='Path to output phylip file.')
    tree_op.add_argument('-r', dest='renaming', type=str, required=False, default=None,
                          help='Path to phylip renaming file.')
    
    tree_op.add_argument('-q', dest='inphylip', type=str, required=False, default=None,
                          help='Path to existing phylip file (this will just run RaXML).')
    
    tree_op.add_argument('--rescale', required=False,  action='store_true', default=True,
                         help='Rescale tree branch lengths into numbers of SNVs.')
    
    tree_op.add_argument('--min_cov', type=int, default=10, required=False,
                        help='Minimum coverage across a position to make a base call.')
    tree_op.add_argument('--minAF', type=float, default=0.75, required=False,
                        help='Minimum allele frequency to make a base call.')
    tree_op.add_argument('--min_strand_cov', type=int, default=2, required=False,
                        help='Minimum per-strand coverage across a position to make a base call.')
    tree_op.add_argument('--qual', type=int, default=-30, required=False,
                        help='Minimum MAPQ score to make a base call.')
    tree_op.add_argument('--core', type=float, default=0.9, required=False,
                        help='Minimum percent of samples a position must be found in (non-N) to be considered.')
    tree_op.add_argument('--min_cov_position', type=int, default=3, required=False,
                        help='Minimum median coverage of position across samples to include.')
    tree_op.add_argument('--max_frac_ambiguous', type=float, default=0.1, required=False,
                        help='Maximum fraction of ambiguous (N) calls to include a sample.')
    tree_op.add_argument('--copynum', type=float, default=2.5, required=False,
                        help='Maximum average copy number to include a position.')
    
    tree_op.add_argument('--remov_recomb', required=False,  action='store_true', default=False,
                        help='Remove recombination events from the tree.')
    tree_op.add_argument('--outgroup', type=float, default=None, required=False,
                          help='Specify a comma separated list of samples to be considered outgroups; i.e. not included in tree building.')

    # Plot arguments
    plot_op.add_argument('-f', dest='infreq', type=str, required=True,
                        help='Path to input frequencies file.')
    plot_op.add_argument('-d', dest='indata', type=str, required=True,
                        help='Path to input data file.')
    plot_op.add_argument('-o', dest='output', type=str, required=True,
                            help='Path to plot of classify output data (.pdf).')
    plot_op.add_argument('--max_pi', type=float, default=0.35, required=False,
                        help='Maximum pi value to count a lineage as present.')
    plot_op.add_argument('--min_prob', type=float, default=0.5, required=False,
                        help='Minimum probability score to count a lineage as present.')
    

    # CMT arguments
    cmt_op.add_argument('-i', dest='counts_files', type=str, required=True,
                        help='Path to file (newline-separated) listing counts files, in same order as sample names.')
    cmt_op.add_argument('-s', dest='sample_names', type=str, required=True,
                        help='Path to file (newline-separated) listing sample names.')
    cmt_op.add_argument('-r', dest='ref', type=str, required=True,
                        help='Path to reference genome.')
    cmt_op.add_argument('-o', dest='out_cmt', type=str, required=True,
                        help='Path to output CMT file.')
       
    
    # Counts arguments
    counts_op.add_argument('-p', dest='pileup', type=str, required=True,
                            help='Path to input pileup file.')
    counts_op.add_argument('-v', dest='vcf', type=str, required=True,
                            help='Path to input VCF file.')
    counts_op.add_argument('-w', dest='variant_vcf', type=str, required=True,
                            help='Path to input variant VCF file.')
    counts_op.add_argument('-r', dest='ref', type=str, required=True,
                            help='Path to reference genome.')
    counts_op.add_argument('-o', dest='output_counts', type=str, required=True,
                            help='Path to output counts file.')
    


    args = parser.parse_args()
    
    if args.operation=='classify':
        
        results = classify.Classify(path_to_bam=args.input,
                                    path_to_classifier=args.classifier,
                                    ref_file=args.ref,
                                    path_to_frequencies=args.output,
                                    path_to_data=args.outputdata,
                                    level_input=args.level,
                                    mode=args.m,
                                    max_pi = args.max_pi,
                                    min_snps = args.min_snps,
                                    min_prob = args.min_prob,
                                    min_hpd=args.min_hpd,
                                    seed = args.seed,
                                    verbose = args.verbose)
        
        results.main()

    if args.operation=='makedb':

        db = makedb.MakeDB(path_to_cmt=args.input,
                           path_to_tree=args.intree,
                           path_to_output_clades=args.outclades,
                           path_to_output_db=args.outdb,
                           outgroup_str=args.outgroup,
                           path_to_input_clades=args.inclades,
                           min_branch_len=args.min_branchlen,
                           min_nsamples=args.min_leaves,
                           min_support=args.min_support,
                           min_snps = args.min_snps,
                           maxn = args.maxn,
                           core = args.core,
                           min_maf_for_call=args.minAF,
                           min_strand_cov_for_call=args.min_strand_cov,
                           max_qual_for_call=args.qual,
                           max_frac_ambiguous=args.max_frac_ambiguous,
                           max_outgroup=args.max_outgroup,
                           midpoint_root=args.midpoint)

        db.main()
        
    if args.operation=='tree':
        
        results = tree.CMT2tree(input_cmt_file=args.cmt,
                                output_phylip=args.outphylip,
                                output_renaming_file=args.renaming,
                                input_phylip=args.inphylip,
                                output_tree=args.outtree,
                                outgroup_str=args.outgroup,
                                rescale_bool=args.rescale,
                                min_cov_to_include=args.min_cov,
                                min_maf_for_call=args.minAF,
                                min_strand_cov_for_call=args.min_strand_cov, 
                                max_qual_for_call=args.qual,
                                min_presence_core=args.core,
                                min_median_cov_samples=args.min_cov_position,
                                max_frac_ambiguous_pos=args.max_frac_ambiguous,
                                max_mean_copynum=args.copynum,
                                remov_recomb=args.remov_recomb)
        
        results.main()

    if args.operation=='plot':

        results = plot.PlotSample(path_to_frequencies_file=args.infreq,
                                  path_to_data_file=args.indata,
                                  path_to_output_plot=args.output,
                                  max_pi=args.max_pi,
                                  min_prob=args.min_prob)
        
        results.main()

    if args.operation=='cmt':

        results = CMT.Case(path_to_sample_names=args.sample_names,
                           path_to_diversity_files=args.counts_files,
                           path_to_ref=args.ref,
                           path_to_out_cmt=args.out_cmt)
        
        results.main()

    if args.operation=='counts':

        results = CMT.Pileup2Diversity(path_to_pileup=args.pileup,
                                       path_to_vcf=args.vcf,
                                       path_to_variant_vcf=args.variant_vcf,
                                       path_to_ref=args.ref,
                                       path_to_output_diversity=args.output_counts)
        
        results.main()
