import pandas as pd
import pysam
from celescope.tools import utils
from celescope.tools.step import Step, s_common
from collections import defaultdict
from celescope.__init__ import HELP_DICT


def count_pathseq(args):
    with Count_pathseq(args, display_title="Count") as runner:
        runner.run()


class Count_pathseq(Step):
    def __init__(self, args, display_title=None):
        Step.__init__(self, args, display_title=display_title)
        # match
        self.match_dict = utils.parse_match_dir(args.match_dir)
        self.match_barcode = self.match_dict["match_barcode"]

        # out
        self.raw_matrix_file = f"{self.outdir}/{self.sample}_raw_UMI_matrix.tsv.gz"
        self.outs = [self.raw_matrix_file]

        # metrics
        self.total_pathseq_reads = 0
        self.total_YP_reads = 0

    def pathseq_score_to_dict(self):
        """
        Returns: {taxID: genus_name}
        {
        '1703962': 'Rhizobium',
        '1720346': 'Rhizobium',
        }
        """
        pathseq_score = open(self.args.pathseq_score_file, "r")
        genus_dict = {}
        set_for_genera = set()
        for each_line in pathseq_score:
            each_line = each_line.rstrip("\n")
            each_line_list = each_line.split("\t")
            level = each_line_list[2]
            tax = each_line_list[3]
            if level == "genus":
                set_for_genera.add(tax)
            if "|" in each_line_list[1]:
                name_string_list = each_line_list[1].split("|")
                for n in range(len(name_string_list)):
                    pointer = -n - 1  ## -1,-2,-3
                    if "_" not in name_string_list[pointer]:
                        name = name_string_list[pointer]
                        break
                    if "unclassified" in name_string_list[pointer]:
                        name = name_string_list[pointer]
                        break
                id = each_line_list[0]
                genus_dict[id] = name
        print("len(dict_for_genus) = ", len(genus_dict))
        return genus_dict

    def parse_pathseq_bam(self):
        """
        YP: taxIDs
        AS: Alignment score generated by pathseq
        """
        pathseq_bam = pysam.AlignmentFile(
            self.args.pathseq_bam_file, "rb", threads=self.thread
        )
        read_dict = {}
        for segment in pathseq_bam:
            self.total_pathseq_reads += 1
            if segment.has_tag("YP"):
                self.total_YP_reads += 1
                read_name, YP, AS = (
                    segment.query_name,
                    segment.get_tag("YP"),
                    int(segment.get_tag("AS")),
                )
                read_dict[read_name] = {
                    "YP": YP,
                    "AS": AS,
                }
        return read_dict

    def parse_unmap_bam(self):
        unmap_bam = pysam.AlignmentFile(self.args.unmap_bam_file, threads=self.thread)
        unmap_read_dict = {}
        for segment in unmap_bam:
            if segment.has_tag("CB") and segment.has_tag("UB"):
                read_name, CB, UB = (
                    segment.query_name,
                    segment.get_tag("CB"),
                    segment.get_tag("UB"),
                )
                unmap_read_dict[read_name] = {
                    "CB": CB,
                    "UB": UB,
                }
        return unmap_read_dict

    def select_highest_AS(self, read_dict, unmap_read_dict):
        """
        collapse read with same (CB,UB), select the one with highest AS
        """
        read_highest_AS = {}
        for read_name in read_dict:
            CB, UB = unmap_read_dict[read_name]["CB"], unmap_read_dict[read_name]["UB"]
            if (CB, UB) not in read_highest_AS or read_dict[read_name][
                "AS"
            ] > read_highest_AS[(CB, UB)]["AS"]:
                read_highest_AS[(CB, UB)] = read_dict[read_name]
        return read_highest_AS

    def select_unique_YP(self, read_highest_AS, read_dict, genus_dict):
        cb_genus_umi = defaultdict(lambda: defaultdict(int))
        for CB, UB in read_highest_AS:
            YP = read_highest_AS[(CB, UB)]["YP"]
            genus = set([genus_dict[id] for id in YP.split(",")])
            if len(genus) == 1:
                genus = genus.pop()
                cb_genus_umi[CB][genus] += 1
        return cb_genus_umi

    def write_raw_matrix(self, cb_genus_umi):
        """
        rows are genus and columns are cb
        """
        df_umi = pd.DataFrame.from_dict(cb_genus_umi, orient="index").transpose()
        df_raw = pd.DataFrame()
        for cb in self.match_barcode:
            if cb in df_umi.columns:
                df_raw[cb] = df_umi[cb]
            else:
                df_raw[cb] = 0
        df_raw.fillna(0, inplace=True)
        df_raw = df_raw.astype(int)
        non_zero_rows = df_raw.index[df_raw.sum(axis=1) > 0]
        df_raw = df_raw.loc[non_zero_rows, :]
        df_raw.to_csv(self.raw_matrix_file, sep="\t")

    def run(self):
        dict_for_genus = self.pathseq_score_to_dict()
        read_dict = self.parse_pathseq_bam()
        unmap_read_dict = self.parse_unmap_bam()
        read_highest_AS = self.select_highest_AS(read_dict, unmap_read_dict)
        cb_genus_umi = self.select_unique_YP(read_highest_AS, read_dict, dict_for_genus)
        self.write_raw_matrix(cb_genus_umi)


def get_opts_count_pathseq(parser, sub_program):
    if sub_program:
        parser.add_argument(
            "--pathseq_bam_file", help="pathseq bam file", required=True
        )
        parser.add_argument(
            "--pathseq_score_file", help="pathseq score file", required=True
        )
        parser.add_argument("--unmap_bam_file", help="unmap bam file", required=True)
        parser.add_argument("--match_dir", help=HELP_DICT["match_dir"], required=True)
        s_common(parser)
    return parser
