import csv
import os

import pandas as pd

from .utils import index_file, run_command


class QCMetrics:
    # Order marders in the QC metric operations
    # Genotype metrics -> variant metrics -> sample metrics
    # Therefore they cannot all be computed at the same time

    # Metrics are first computed using PLINK to generate the required metrics files
    # Then the data is filtered using either bcftools or PLINK
    plink_vcf_flags = [
        "--recode",
        "vcf-iid",  # Uses only sample id for header
        "bgz",
        "--output-chr",  # Outputs chromsomes with chr##
        "chrMT",
        "--real-ref-alleles",  # Needed to code the correct allele as reference
    ]

    attrs_created = ["sc_file", "hwe_file", "vc_file"]

    def __init__(
        self, file, tempdir, hwe=None, variant_call_rate=None, sample_call_rate=None
    ):
        self.file = file
        self.tempdir = tempdir

        self.variant_call_rate = variant_call_rate
        self.hwe = hwe
        self.sample_call_rate = sample_call_rate

    @staticmethod
    def _parse_metrics_file(file):
        compiled_data = []
        with open(file) as f:
            csv_reader = csv.reader(f, delimiter=" ")
            for line in csv_reader:
                vals = [val for val in line if len(val) != 0]
                compiled_data.append(vals)
        df = pd.DataFrame(compiled_data[1:], columns=compiled_data[0]).apply(
            pd.to_numeric, errors="ignore"
        )
        if "lmiss" in file or "hwe" in file:
            return pd.concat(
                [
                    pd.DataFrame(
                        df["SNP"].str.split("_").to_list(),
                        columns=["ID", "CHROM", "POS", "REF", "ALT"],
                    ),
                    df.iloc[:, 2:],
                ],
                axis=1,
            )
        return df.rename(columns={"IID": "SAMPLE_ID"}).drop(columns="FID")

    def genotype_qc(self):
        raise NotImplementedError("Genotype level QC has not been implemented yet.")

    def variant_qc(self):
        if self.variant_call_rate:
            self.vc_file = self.filter_on_variant_call_rate()
        if self.hwe:
            self.hwe_file = self.filter_on_hwe()

    def sample_qc(self):
        self.sc_file = self.filter_on_sample_call_rate()

    @staticmethod
    def _compute_base_qc(file, output_prefix):
        # Computes QC metrics. Need to be recomputed after each sequential metric
        run_command(
            [
                "plink",
                "--vcf",
                file,
                "--hardy",
                "midp",
                "--missing",
                "--out",
                output_prefix,
            ]
        )

    def _compute_variant_call_rate(self, file_input):
        output_prefix = os.path.join(self.tempdir, "vcr")
        self._compute_base_qc(file=file_input, output_prefix=output_prefix)
        vcr = (
            self._parse_metrics_file(f"{output_prefix}.lmiss")
            .assign(VARIANT_CALL_RATE=lambda x: 1 - x["F_MISS"])
            .drop(columns="F_MISS")
        )
        vcr.to_csv(
            os.path.join(self.tempdir, "variant_call_rate.tsv"), sep="\t", index=False
        )

    def _compute_hwe(self, file_input):
        output_prefix = os.path.join(self.tempdir, "hwe")
        self._compute_base_qc(file=file_input, output_prefix=output_prefix)
        hwe = self._parse_metrics_file(f"{output_prefix}.hwe").drop(
            columns=["TEST", "A1", "A2"]
        )
        hwe.to_csv(os.path.join(self.tempdir, "hwe.tsv"), sep="\t", index=False)

    def _compute_sample_call_rate(self, file_input):
        output_prefix = os.path.join(self.tempdir, "scr")
        self._compute_base_qc(file=file_input, output_prefix=output_prefix)
        scr = (
            self._parse_metrics_file(f"{output_prefix}.imiss")
            .assign(SAMPLE_CALL_RATE=lambda x: 1 - x["F_MISS"])
            .drop(columns="F_MISS")
        )
        scr.to_csv(
            os.path.join(self.tempdir, "sample_call_rate.tsv"), sep="\t", index=False
        )

    @index_file
    def filter_on_variant_call_rate(self):
        file_input = self.output_file
        # Compute metrics file
        self._compute_variant_call_rate(file_input=file_input)

        # Filter on variant call rate
        vc_qc_file = os.path.join(self.tempdir, "vc_qced.vcf.gz")
        run_command(
            f"bcftools view --include 'F_MISSING < {1-self.variant_call_rate}' -Oz -o {vc_qc_file} {self.file}",
            shell=True,
        )
        return vc_qc_file

    @index_file
    def filter_on_hwe(self):
        # Computes HWE using mid p-value correction
        file_input = self.output_file

        # Compute metrics file
        self._compute_hwe(file_input=file_input)
        hwe_qc_file = os.path.join(self.tempdir, "hwe_qced")
        run_command(
            [
                "plink",
                "--vcf",
                file_input,
                "--hwe",
                str(self.hwe),
                "midp",
                "--out",
                hwe_qc_file,
            ]
            + self.plink_vcf_flags
        )
        return f"{hwe_qc_file}.vcf.gz"

    @index_file
    def filter_on_sample_call_rate(self):
        file_input = self.output_file
        self._compute_sample_call_rate(file_input=file_input)

        sc_qc_file = os.path.join(self.tempdir, "sc_qced")
        run_command(
            [
                "plink",
                "--vcf",
                file_input,
                "--mind",
                str(1 - self.sample_call_rate),
                "--out",
                sc_qc_file,
            ]
            + self.plink_vcf_flags
        )
        return f"{sc_qc_file}.vcf.gz"

    @property
    def output_file(self):
        for attr in self.attrs_created:
            try:
                file = getattr(self, attr)
            except AttributeError:
                pass
            else:
                return file
        return self.file

    @property
    def files_created(self):
        files_created = []
        for file in self.attrs_created:
            try:
                attr = getattr(self, file)
            except AttributeError:
                pass
            else:
                files_created.append(attr)
        return files_created
