from concurrent.futures import thread
from . import parsimony_pb2
import treeswift
from alive_progress import alive_it, alive_bar
from Bio import SeqIO
from typing import ClassVar

from dataclasses import dataclass
from collections import defaultdict


def reverse_complement(input_string):
    return input_string.translate(str.maketrans("ATCG", "TAGC"))[::-1]


def complement(input_string):
    return input_string.translate(str.maketrans("ATCG", "TAGC"))


@dataclass(eq=True, frozen=True)
class AAMutation:
    gene: str
    one_indexed_codon: int
    initial_aa: str
    final_aa: str
    nuc_for_codon: int
    type: str = "aa"


@dataclass(eq=True, frozen=True)
class NucMutation:  #hashable
    one_indexed_position: int
    par_nuc: str
    mut_nuc: str
    chromosome: str = "chrom"
    type: str = "nt"


@dataclass(eq=True, frozen=True)
class Gene:
    name: str
    strand: int
    start: int  # zero-indexed
    end: int  # 0-indexed


from dataclasses import dataclass


@dataclass(eq=False)
class Codon:
    gene: Gene
    codon_number: int  # zero-indexed
    positions: dict  # zero-indexed positions e.g. {0:123,1:124,2:125}
    strand: int

    def __eq__(self, other):
        if isinstance(other, Codon):
            return self.gene == other.gene and self.codon_number == other.codon_number
        return False

    def __hash__(self):
        return hash((self.gene, self.codon_number))


def get_codon_table():
    bases = "TCAG"
    codons = [a + b + c for a in bases for b in bases for c in bases]
    amino_acids = 'FFLLSSSSYY**CC*WLLLLPPPPHHQQRRRRIIIMTTTTNNKKSSRRVVVVAAAADDEEGGGG'
    return dict(zip(codons, amino_acids))


codon_table = get_codon_table()


def get_gene_name(cds):
    """Returns gene if available, otherwise locus tag"""
    if "gene" in cds.qualifiers:
        return cds.qualifiers["gene"][0]
    elif "locus_tag" in cds.qualifiers:
        return cds.qualifiers["locus_tag"][0]
    else:
        raise ValueError(f"No gene name or locus tag for {cds}")


def get_genes_dict(cdses):
    genes = {}
    for cds in cdses:

        genes[get_gene_name(cds)] = Gene(get_gene_name(cds), cds.strand,
                                         cds.location.start, cds.location.end)
    return genes


def get_mutations(past_nuc_muts_dict,
                  new_nuc_mutations_here,
                  seq,
                  nuc_to_codon,
                  disable_check_for_differences=False):

    by_codon = defaultdict(list)

    for mutation in new_nuc_mutations_here:
        zero_indexed_pos = mutation.one_indexed_position - 1
        if zero_indexed_pos in nuc_to_codon:
            for codon in nuc_to_codon[zero_indexed_pos]:
                by_codon[codon].append(mutation)

    mutations_here = []
    for gene_codon, mutations in by_codon.items():

        # For most of this function we ignore strand - so for negative strand we
        # are actually collecting the complement of the codon

        initial_codon = [seq[gene_codon.positions[x]] for x in range(3)]

        relevant_past_muts = [(x, past_nuc_muts_dict[x])
                              for x in gene_codon.positions.values()
                              if x in past_nuc_muts_dict]
        flipped_dict = {
            position: offset
            for offset, position in gene_codon.positions.items()
        }
        for position, value in relevant_past_muts:
            initial_codon[flipped_dict[position]] = value

        final_codon = initial_codon.copy()

        for mutation in mutations:
            pos_in_codon = flipped_dict[mutation.one_indexed_position - 1]
            final_codon[pos_in_codon] = mutation.mut_nuc

        initial_codon = "".join(initial_codon)
        final_codon = "".join(final_codon)

        if gene_codon.strand == -1:

            initial_codon = complement(initial_codon)
            final_codon = complement(final_codon)

        initial_codon_trans = codon_table[initial_codon]
        final_codon_trans = codon_table[final_codon]
        if initial_codon_trans != final_codon_trans or disable_check_for_differences:
            #(gene, codon_number + 1, initial_codon_trans, final_codon_trans)

            mutations_here.append(
                AAMutation(gene=gene_codon.gene,
                           one_indexed_codon=gene_codon.codon_number + 1,
                           initial_aa=initial_codon_trans,
                           final_aa=final_codon_trans,
                           nuc_for_codon=gene_codon.positions[1]))

    # update past_nuc_muts_dict
    for mutation in new_nuc_mutations_here:
        past_nuc_muts_dict[mutation.one_indexed_position -
                           1] = mutation.mut_nuc

    return mutations_here


def recursive_mutation_analysis(node, past_nuc_muts_dict, seq, cdses, pbar,
                                nuc_to_codon):
    pbar()

    new_nuc_mutations_here = node.nuc_mutations
    new_past_nuc_muts_dict = past_nuc_muts_dict.copy()
    node.aa_muts = get_mutations(new_past_nuc_muts_dict,
                                 new_nuc_mutations_here, seq, nuc_to_codon)
    for child in node.children:
        recursive_mutation_analysis(child, new_past_nuc_muts_dict, seq, cdses,
                                    pbar, nuc_to_codon)


NUC_ENUM = "ACGT"


def preorder_traversal(node):
    yield node
    for clade in node.children:
        yield from preorder_traversal(clade)


def preorder_traversal_internal(node):
    yield node
    for clade in node.children:
        for x in preorder_traversal_internal(clade):
            if not x.is_leaf():
                yield x


def preorder_traversal_iter(node):
    return iter(preorder_traversal(node))


def find_cds(position, cdses):
    for cds in cdses:
        if cds.location.start <= position <= cds.location.end:
            return cds
    return None


def find_codon(position, cds):
    if cds.strand == 1:
        # Get the codon number within the CDS
        codon_number = (position - cds.location.start) // 3
        codon_start = cds.location.start + codon_number * 3
        codon_end = codon_start + 3
    else:
        # Get the codon number within the CDS
        codon_number = (cds.location.end - position - 1) // 3
        codon_end = cds.location.end - codon_number * 3
        codon_start = codon_end - 3
    return codon_number, codon_start, codon_end


class UsherMutationAnnotatedTree:

    def __init__(self,
                 tree_file,
                 genbank_file=None,
                 name_internal_nodes=False,
                 clade_types=[],
                 shear=False,
                 shear_threshold=1000):
        self.data = parsimony_pb2.data()
        self.data.ParseFromString(tree_file.read())
        self.condensed_nodes_dict = self.get_condensed_nodes_dict(
            self.data.condensed_nodes)
        print("Loading tree, this may take a while...")
        self.tree = treeswift.read_tree(self.data.newick, schema="newick")
        if name_internal_nodes:
            self.name_internal_nodes()
        self.data.newick = ''

        self.annotate_mutations()
        self.annotate_clades(clade_types)

        self.expand_condensed_nodes()
        self.assign_num_tips()
        print(f"Loaded initial tree with {self.tree.root.num_tips} tips")
        print("Ending early")
        if genbank_file:
            # We need to reconstruct root seq before shearing as shearing can mess it up

            self.load_genbank_file(genbank_file)
            self.get_root_sequence()
        if shear:
            print("Shearing tree...")
            self.shear_tree(shear_threshold)
        self.assign_num_tips()
        print(f"Tree to use now has {self.tree.root.num_tips} tips")
        self.set_branch_lengths()
        if genbank_file:
            self.perform_aa_analysis()

    def prune_node(self, node_to_prune):
        """Remove node from parent, then check if parent has zero descendants. If so remove it.
        If parent has a single descendant, then give the parent's mutations to the descendant, unless they
        conflict with the descendants own mutations. Also give the parent's clade annotations to the descendant,
        unless they conflict. Then prune the parent, and instead add this child to parent's parent."""
        parent = node_to_prune.parent
        parent.remove_child(node_to_prune)
        if len(parent.children) == 0:
            self.prune_node(parent)
        elif len(parent.children) == 1:
            child = parent.children[0]
            for mutation in parent.nuc_mutations:
                if mutation.one_indexed_position not in [
                        x.one_indexed_position for x in child.nuc_mutations
                ]:
                    child.nuc_mutations.append(mutation)
            if hasattr(parent, "clades"):
                for clade_type, clade_annotation in parent.clades.items():
                    if clade_type not in child.clades or child.clades[
                            clade_type] == "":
                        child.clades[clade_type] = clade_annotation
            grandparent = parent.parent
            parent.remove_child(child)
            if grandparent:
                grandparent.remove_child(parent)
                grandparent.add_child(child)

    def shear_tree(self, theshold=1000):
        """Consider each node. If at any point a child has fewer than 1/threshold proportion of the num_tips, then prune it"""
        for node in alive_it(list(self.tree.traverse_postorder())):
            if (node == self.tree.root):
                continue
            if len(node.children) > 1:
                biggest_child = max(node.children, key=lambda x: x.num_tips)
                for child in list(node.children):
                    if biggest_child.num_tips / child.num_tips > theshold:
                        self.prune_node(child)

    def create_mutation_like_objects_to_record_root_seq(self):
        """Hacky way of recording the root sequence"""
        ref_muts = []
        for i, character in enumerate(self.root_sequence):
            ref_muts.append(
                NucMutation(one_indexed_position=i + 1,
                            mut_nuc=character,
                            par_nuc="X"))

        return ref_muts

    def annotate_clades(self, clade_types):
        if clade_types:
            for i, node in alive_it(list(
                    enumerate(preorder_traversal(self.tree.root))),
                                    title="Annotating clades"):

                this_thing = self.data.metadata[i]
                node.clades = {
                    clade_types[index]: part
                    for index, part in enumerate(this_thing.clade_annotations)
                }

    def perform_aa_analysis(self):

        seq = str(self.genbank.seq)
        with alive_bar(self.tree.num_nodes(),
                       title="Annotating amino acids") as pbar:
            recursive_mutation_analysis(self.tree.root, {}, seq, self.cdses,
                                        pbar, self.nuc_to_codon)
        root_muts = self.create_mutation_like_objects_to_record_root_seq()
        self.tree.root.aa_muts = get_mutations(
            {},
            root_muts,
            seq,
            self.nuc_to_codon,
            disable_check_for_differences=True)
        self.tree.root.nuc_mutations = root_muts

    def load_genbank_file(self, genbank_file):
        self.genbank = SeqIO.read(genbank_file, "genbank")
        self.cdses = [x for x in self.genbank.features if x.type == "CDS"]
        # Assert that there are no compound locations and that all strands are positive,
        # and that all CDS features are a multiple of 3

        self.genes = get_genes_dict(self.cdses)

        by_everything = defaultdict(lambda: defaultdict(dict))
        total_lengths = {}

        for feature in self.cdses:

            gene_name = get_gene_name(feature)

            nucleotide_counter = 0
            for part in feature.location.parts:
                ranger = range(
                    part.start, part.end
                ) if part.strand == 1 else range(
                    part.end - 1, part.start - 1, -1
                )  #(honestly not sure why we need to subtract 1 here but we seem to?)

                for genome_position in ranger:

                    cur_codon_number = nucleotide_counter // 3
                    cur_pos_in_codon = nucleotide_counter % 3

                    by_everything[gene_name][cur_codon_number][
                        cur_pos_in_codon] = genome_position
                    nucleotide_counter += 1
            total_lengths[gene_name] = nucleotide_counter

        nuc_to_codon = defaultdict(list)

        for feat_name, codons in by_everything.items():
            for codon_index, codon_dict in codons.items():
                codon_obj = Codon(feat_name, codon_index, codon_dict,
                                  self.genes[feat_name].strand)

                assert len(codon_dict) % 3 == 0
                for k, v in codon_dict.items():
                    nuc_to_codon[v].append(codon_obj)

        self.nuc_to_codon = nuc_to_codon

    def convert_nuc_mutation(self, usher_mutation):
        new_mut = NucMutation(one_indexed_position=usher_mutation.position,
                              par_nuc=NUC_ENUM[usher_mutation.par_nuc],
                              mut_nuc=NUC_ENUM[usher_mutation.mut_nuc[0]])
        return new_mut

    def annotate_mutations(self):
        for i, node in alive_it(list(
                enumerate(preorder_traversal(self.tree.root))),
                                title="Annotating nuc muts"):
            node.nuc_mutations = [
                self.convert_nuc_mutation(x)
                for x in self.data.node_mutations[i].mutation
            ]

    def get_root_sequence(self):
        collected_mutations = {}
        for i, node in alive_it(list(
                enumerate(self.tree.root.traverse_postorder())),
                                title="Getting root sequence"):
            if node == self.tree.root:
                continue
            for mutation in node.nuc_mutations:
                collected_mutations[
                    mutation.one_indexed_position] = mutation.par_nuc
        self.root_sequence = list(str(self.genbank.seq))
        for i, character in enumerate(self.root_sequence):
            if i + 1 in collected_mutations:
                self.root_sequence[i] = collected_mutations[i + 1]
        self.root_sequence = "".join(self.root_sequence)

    def name_internal_nodes(self):
        for i, node in alive_it(list(
                enumerate(preorder_traversal_internal(self.tree.root))),
                                title="Naming internal nodes"):
            if not node.label:
                node.label = "node_" + str(i + 1)

    def set_branch_lengths(self):
        for node in alive_it(list(preorder_traversal(self.tree.root)),
                             title="Setting branch length"):
            node.edge_length = len(node.nuc_mutations)

    def expand_condensed_nodes(self):
        for node in alive_it(list(self.tree.traverse_leaves()),
                             title="Expanding condensed nodes"):

            if node.label and node.label in self.condensed_nodes_dict:
                assert len(node.nuc_mutations) == 0

                for new_node_label in self.condensed_nodes_dict[node.label]:
                    new_node = treeswift.Node(label=new_node_label)
                    new_node.nuc_mutations = node.nuc_mutations
                    if hasattr(node, "clades"):
                        new_node.clades = node.clades
                    node.parent.add_child(new_node)
                node.label = ""
                node.parent.remove_child(node)
            else:
                pass

    def get_condensed_nodes_dict(self, condensed_nodes_dict):
        output_dict = {}
        for condensed_node in alive_it(condensed_nodes_dict,
                                       title="Reading condensed nodes dict"):
            output_dict[
                condensed_node.node_name] = condensed_node.condensed_leaves
        return output_dict

    def assign_num_tips(self):
        for node in self.tree.traverse_postorder():
            if node.is_leaf():
                node.num_tips = 1
            else:
                node.num_tips = sum(child.num_tips for child in node.children)
