import os
import sys
import pathlib
import shutil
import json

import pandas as pd
import pyranges as pr
import snakemake.utils

import capcruncher.pipeline.utils
from capcruncher.utils import convert_bed_to_pr

import importlib.util
from typing import Literal
import itertools


snakemake.utils.min_version('7.19.1')


configfile: "capcruncher_config.yml"
container: "library://asmith151/capcruncher/capcruncher:latest"


# Pipeline set-up
capcruncher.pipeline.utils.format_config_dict(config)

## Get fastq files
FASTQ_SAMPLES = capcruncher.pipeline.utils.FastqSamples.from_files(
    list(pathlib.Path(".").glob("*.fastq*"))
)

## Convert FASTQ files to design matrix
if os.path.exists(config["analysis"].get("design", None)):
    DESIGN = pd.read_table(
        config["analysis"]["design"], sep=r"\s+|,|\t", engine="python"
    )
else:
    DESIGN = FASTQ_SAMPLES.design

## Read viewpoints
VIEWPOINTS = config["analysis"]["viewpoints"]
VIEWPOINT_NAMES = convert_bed_to_pr(VIEWPOINTS).df.Name.drop_duplicates().tolist()


N_SAMPLES = DESIGN["sample"].nunique()
ANALYSIS_METHOD = config["analysis"].get("method", "capture")
BIN_SIZES = capcruncher.pipeline.utils.get_bin_sizes(config)
HIGH_NUMBER_OF_VIEWPOINTS = capcruncher.pipeline.utils.has_high_viewpoint_number(
    VIEWPOINTS, config
)
IGNORE_MULTIPLE_FRAGMENTS_PER_VIEWPOINT = config["analysis"].get(
    "ignore_multiple_fragments_per_viewpoint", False
)

# Details
SUMMARY_METHODS = [
    m
    for m in re.split(r"[,;\s+]", config["compare"].get("summary_methods", "mean,"))
    if m
]


## Optional
AGGREGATE_SAMPLES = DESIGN["sample"].nunique() > 1
COMPARE_SAMPLES = DESIGN["condition"].nunique() > 1
PERFORM_DIFFERENTIAL_ANALYSIS = (
    (config["differential"]["contrast"] in DESIGN.columns)
    and COMPARE_SAMPLES
    and (ANALYSIS_METHOD in ["capture", "tri"])
)
PERFORM_PLOTTING = capcruncher.pipeline.utils.can_perform_plotting(config)
PERFORM_BINNING = capcruncher.pipeline.utils.can_perform_binning(config)

## Pipeline variables
ASSAY = config["analysis"]["method"]
SAMPLE_NAMES = FASTQ_SAMPLES.sample_names_all

# CLEANUP = "full" if config["analysis"].get("cleanup", False) else "partial"
CLEANUP = False


include: "rules/digest.smk"
include: "rules/fastq.smk"
include: "rules/qc.smk"
include: "rules/align.smk"
include: "rules/annotate.smk"
include: "rules/filter.smk"
include: "rules/pileup.smk"
include: "rules/compare.smk"
include: "rules/statistics.smk"
include: "rules/visualise.smk"


wildcard_constraints:
    sample="|".join(SAMPLE_NAMES),
    part=r"\d+",
    viewpoint="|".join(VIEWPOINT_NAMES),
    combined="|".join(["flashed", "pe"]),


rule all:
    input:
        qc=rules.multiqc_report.output[0],
        report="capcruncher_output/results/capcruncher_report.html",
        pileups=capcruncher.pipeline.utils.get_pileups(
            assay=ASSAY,
            design=DESIGN,
            samples_aggregate=AGGREGATE_SAMPLES,
            samples_compare=COMPARE_SAMPLES,
            sample_names=SAMPLE_NAMES,
            summary_methods=SUMMARY_METHODS,
            viewpoints=VIEWPOINT_NAMES,
        ),
        counts=expand(
            "capcruncher_output/results/{sample}/{sample}.hdf5",
            sample=SAMPLE_NAMES,
        ),
        hub=rules.create_ucsc_hub.output[0]
        if ANALYSIS_METHOD in ["capture", "tri"]
        else [],
        differential=expand(
            "capcruncher_output/results/differential/{viewpoint}",
            viewpoint=VIEWPOINT_NAMES,
        )
        if PERFORM_DIFFERENTIAL_ANALYSIS
        else [],
        plots=expand(
            "capcruncher_output/results/figures/{viewpoint}.pdf",
            viewpoint=VIEWPOINT_NAMES,
        )
        if PERFORM_PLOTTING
        else [],


onerror:
    log_out = "capcruncher_error.log"
    shutil.copyfile(log, log_out)
    print(
        f"An error occurred. Please check the log file {log_out} for more information."
    )


onsuccess:
    log_out = "capcruncher.log"
    shutil.copyfile(log, log_out)
    print(f"Pipeline completed successfully. See {log_out} for more information.")

    if CLEANUP == "full":
        shutil.rmtree("capcruncher_output/interim/")

    elif CLEANUP == "partial" and pathlib.Path("capcruncher_output/interim/").exists():
        import subprocess

        files_to_remove = []
        # Split files
        for sample_name in SAMPLE_NAMES:
            split_dir = f"capcruncher_output/interim/fastq/split/{sample_name}"
            flashed_dir = f"capcruncher_output/interim/fastq/flashed/{sample_name}"
            rebalanced_dir = (
                f"capcruncher_output/interim/fastq/rebalanced/{sample_name}"
            )
            files_to_remove.extend(
                pathlib.Path(split_dir).glob(f"{sample_name}_part*.fastq.gz")
            )
            files_to_remove.extend(
                pathlib.Path(flashed_dir).glob(f"{sample_name}_part*.fastq.gz")
            )
            for combined in ["flashed", "pe"]:
                files_to_remove.extend(
                    (pathlib.Path(rebalanced_dir) / combined).glob(
                        f"{sample_name}_part*.fastq.gz"
                    )
                )
        for f in files_to_remove:
            f.unlink()
