from __future__ import annotations

from pathlib import Path
from snapatac2_scooby._snapatac2 import AnnData, AnnDataSet
import snapatac2_scooby._snapatac2 as _snapatac2
import logging
from snapatac2_scooby.genome import Genome

def macs3(
    adata: AnnData | AnnDataSet,
    *,
    groupby: str | list[str] | None = None,
    qvalue: float = 0.05,
    call_broad_peaks: bool = False,
    broad_cutoff: float = 0.1,
    replicate: str | list[str] | None = None,
    replicate_qvalue: float | None = None,
    max_frag_size: int | None = None,
    selections: set[str] | None = None,
    nolambda: bool = False,
    shift: int = -100,
    extsize: int = 200,
    min_len: int | None = None,
    blacklist: Path | None = None,
    key_added: str = 'macs3',
    tempdir: Path | None = None,
    inplace: bool = True,
    n_jobs: int = 8,
) -> dict[str, 'polars.DataFrame'] | None:
    """ Call peaks using MACS3.

    Parameters
    ----------
    adata
        The (annotated) data matrix of shape `n_obs` x `n_vars`.
        Rows correspond to cells and columns to regions.
    groupby
        Group the cells before peak calling. If a `str`, groups are obtained from
        `.obs[groupby]`. If None, peaks will be called for all cells.
    qvalue
        qvalue cutoff used in MACS3.
    call_broad_peaks
        If True, MACS3 will call broad peaks. The broad peak calling process
        utilizes two distinct cutoffs to discern broader, weaker peaks (`broad_cutoff`)
        and narrower, stronger peaks (`qvalue`), which are subsequently nested to
        provide a detailed peak landscape. To conceptualize "nested" peaks, picture
        a gene structure housing regions analogous to exons (strong peaks) and
        introns coupled with UTRs (weak peaks). Please note that, if you only want to
        call "broader" peak and not interested in the nested peak structure, please
        simply use `qvalue` with weaker cutoff instead of using `call_broad_peaks` option.
    broad_cutoff
        qvalue cutoff used in MACS3 for calling broad peaks.
    replicate
        Replicate information. If provided, reproducible peaks will be called
        for each group.
    replicate_qvalue
        qvalue cutoff used in MACS3 for calling peaks in replicates.
        This parameter is only used when `replicate` is provided.
        Typically this parameter is used to call peaks in replicates with a more lenient cutoff.
        If not provided, `qvalue` will be used.
    max_frag_size
        Maximum fragment size. If provided, fragments with sizes larger than
        `max_frag_size` will be not be used in peak calling.
        This is used in ATAC-seq data to remove fragments that are not 
        from nucleosome-free regions.
        You can use :func:`~snapatac2_scooby.pl.frag_size_distr` to choose a proper value for
        this parameter.
    selections
        Call peaks for the selected groups only.
    nolambda
        If True, macs3 will use the background lambda as local lambda.
        This means macs3 will not consider the local bias at peak candidate regions.
    shift
        The shift size in MACS.
    extsize
        The extension size in MACS.
    min_len
        The minimum length of a called peak. If None, it is set to `extsize`.
    blacklist
        Path to the blacklist file in BED format. If provided, regions in the blacklist will be
        removed.
    key_added
        `.uns` key under which to add the peak information.
    tempdir
        If provided, a temporary directory will be created in the directory.
        Otherwise, a temporary directory will be created in the system default temporary directory.
    inplace
        Whether to store the result inplace.
    n_jobs
        Number of processes to use for peak calling.

    Returns
    -------
    dict[str, 'polars.DataFrame'] | None
        If `inplace=True` it stores the result in `adata.uns[`key_added`]`.
        Otherwise, it returns the result as dataframes.

    See Also
    --------
    merge_peaks
    """
    from MACS3.Signal.PeakDetect import PeakDetect
    from math import log
    import tempfile

    if isinstance(groupby, str):
        groupby = list(adata.obs[groupby])
    if replicate is not None and isinstance(replicate, str):
        replicate = list(adata.obs[replicate])

    # MACS3 options
    options = type('MACS3_OPT', (), {})()
    options.info = lambda _: None
    options.debug = lambda _: None
    options.warn = logging.warn
    options.name = "MACS3"
    options.bdg_treat = 't'
    options.bdg_control = 'c'
    options.cutoff_analysis = False
    options.cutoff_analysis_file = 'a'
    options.store_bdg = False
    options.do_SPMR = False
    options.trackline = False
    options.log_pvalue = None
    options.log_qvalue = log(qvalue, 10) * -1
    options.PE_MODE = False

    options.gsize = adata.uns['reference_sequences']['reference_seq_length'].sum()    # Estimated genome size
    options.maxgap = 30    # The maximum allowed gap between two nearby regions to be merged
    options.minlen = extsize if min_len is None else min_len
    options.shift = shift
    options.nolambda = nolambda
    options.smalllocal = 1000
    options.largelocal = 10000
    options.call_summits = False if call_broad_peaks else True
    options.broad = call_broad_peaks
    if options.broad:
        options.log_broadcutoff = log(broad_cutoff, 10) * -1

    options.fecutoff = 1.0
    options.d = extsize
    options.scanwindow = 2 * options.d

    if groupby is None:
        peaks = _snapatac2.call_peaks_bulk(adata, options, max_frag_size)
        if inplace:
            adata.uns[key_added + "_pseudobulk"] = peaks.to_pandas() if not adata.isbacked else peaks
            return
        else:
            return peaks

    with tempfile.TemporaryDirectory(dir=tempdir) as tmpdirname:
        logging.info("Exporting fragments...")
        fragments = _snapatac2.export_tags(adata, tmpdirname, groupby, replicate, max_frag_size, selections)

        def _call_peaks(tags):
            import tempfile
            tempfile.tempdir = tmpdirname  # Overwrite the default tempdir in MACS3
            merged, reps = _snapatac2.create_fwtrack_obj(tags)
            options.log_qvalue = log(qvalue, 10) * -1
            logging.getLogger().setLevel(logging.CRITICAL + 1) # temporarily disable logging
            peakdetect = PeakDetect(treat=merged, opt=options)
            peakdetect.call_peaks()
            peakdetect.peaks.filter_fc(fc_low = options.fecutoff)
            merged = peakdetect.peaks

            others = []
            if replicate_qvalue is not None:
                options.log_qvalue = log(replicate_qvalue, 10) * -1
            for x in reps:
                peakdetect = PeakDetect(treat=x, opt=options)
                peakdetect.call_peaks()
                peakdetect.peaks.filter_fc(fc_low = options.fecutoff)
                others.append(peakdetect.peaks)
            
            logging.getLogger().setLevel(logging.INFO) # enable logging
            return _snapatac2.find_reproducible_peaks(merged, others, blacklist)

        logging.info("Calling peaks...")
        if n_jobs == 1:
            peaks = [_call_peaks(x) for x in fragments.values()]
        else:
            peaks = _par_map(_call_peaks, [(x,) for x in fragments.values()], n_jobs)
        peaks = {k: v for k, v in zip(fragments.keys(), peaks)}
        if inplace:
            if adata.isbacked:
                adata.uns[key_added] = peaks
            else:
                adata.uns[key_added] = {k: v.to_pandas() for k, v in peaks.items()}
        else:
            return peaks

def merge_peaks(
    peaks: dict[str, 'polars.DataFrame'],
    chrom_sizes: dict[str, int] | Genome,
    half_width: int = 250,
) -> 'polars.DataFrame':
    """Merge peaks from different groups.

    Merge peaks from different groups. It is typically used to merge
    results from :func:`~snapatac2_scooby.tools.macs3`.

    This function initially expands the summits of identified peaks by `half_width`
    on both sides. Following this expansion, it addresses the issue of overlapping
    peaks through an iterative process. The procedure begins by prioritizing the
    most significant peak, determined by the smallest p-value. This peak is retained,
    and any peak that overlaps with it is excluded. Subsequently, the same method
    is applied to the next most significant peak. This iteration continues until
    all peaks have been evaluated, resulting in a final list of non-overlapping
    peaks, each with a fixed width determined by the initial extension.

    Parameters
    ----------
    peaks
        Peak information from different groups.
    chrom_sizes
        Chromosome sizes. If a :class:`~snapatac2_scooby.genome.Genome` is provided,
        chromosome sizes will be obtained from the genome.
    half_width
        Half width of the merged peaks.

    Returns
    -------
    'polars.DataFrame'
        A dataframe with merged peaks.

    See Also
    --------
    macs3
    """
    import pandas as pd
    import polars as pl
    chrom_sizes = chrom_sizes.chrom_sizes if isinstance(chrom_sizes, Genome) else chrom_sizes
    peaks = { k: pl.from_pandas(v) if isinstance(v, pd.DataFrame) else v for k, v in peaks.items()}
    return _snapatac2.py_merge_peaks(peaks, chrom_sizes, half_width)

def _par_map(mapper, args, nprocs):
    import time
    from multiprocess import get_context
    from tqdm import tqdm

    with get_context("spawn").Pool(nprocs) as pool:
        procs = set(pool._pool)
        jobs = [(i, pool.apply_async(mapper, x)) for i, x in enumerate(args)]
        results = []
        with tqdm(total=len(jobs)) as pbar:
            while len(jobs) > 0:
                if any(map(lambda p: not p.is_alive(), procs)):
                    raise RuntimeError("Some worker process has died unexpectedly.")

                remaining = []
                for i, job in jobs:
                    if job.ready():
                        results.append((i, job.get()))
                        pbar.update(1)
                    else:
                        remaining.append((i, job))
                jobs = remaining
                time.sleep(0.5)
        return [x for _,x in sorted(results, key=lambda x: x[0])]