from typing import Callable, Iterable

import numpy as np

from .em import EMRun
from ..core.logs import logger
from ..core.mu import (calc_sum_arcsine_distance,
                       calc_mean_arcsine_distance,
                       calc_pearson)

NOCONV = 0


def get_common_k(runs: list[EMRun]):
    """ Find the number of clusters (k) from among EM clustering runs.
    If there are multiple ks, then raise a ValueError. """
    ks: list[int] = sorted({run.k for run in runs})
    if len(ks) != 1:
        raise ValueError(f"Expected 1 unique number of clusters, but got {ks}")
    return ks[0]


def sort_runs(runs: list[EMRun]):
    """ Sort the runs of EM clustering by decreasing likelihood so that
    the run with the best (largest) likelihood comes first. """
    # Verify that every run has the same k; otherwise, the likelihood is
    # not directly comparable between runs.
    if runs:
        get_common_k(runs)
    return sorted(runs, key=lambda run: run.log_like, reverse=True)


class EMRunsK(object):
    """ One or more EM runs with the same number of clusters. """

    def __init__(self,
                 runs: list[EMRun],
                 max_pearson_run: float,
                 min_marcd_run: float,
                 max_jackpot_quotient: float,
                 max_loglike_vs_best: float,
                 min_pearson_vs_best: float,
                 max_marcd_vs_best: float):
        if not runs:
            raise ValueError(f"{self} got no EM runs")
        # Sort the runs from largest to smallest likelihood.
        runs = sort_runs(runs)
        # Flag runs that fail to meet the filters.
        self.max_pearson_run = max_pearson_run
        self.min_marcd_run = min_marcd_run
        self.max_jackpot_quotient = max_jackpot_quotient
        # Set the criteria for whether this number of clusters passes.
        self.max_loglike_vs_best = max_loglike_vs_best
        self.min_pearson_vs_best = min_pearson_vs_best
        self.max_marcd_vs_best = max_marcd_vs_best
        # Check whether each run shows signs of being overclustered.
        # To select only the valid runs, use "not" with the opposite of
        # the desired inequality because runs with just one cluster will
        # produce NaN values, which should always compare as True here.
        self.run_not_overclustered = np.array(
            [not (run.max_pearson > max_pearson_run
                  or run.min_marcd < min_marcd_run)
             for run in runs]
        )
        # Check whether each run shows signs of being underclustered.
        self.run_not_underclustered = np.array(
            [not (run.jackpot_quotient > max_jackpot_quotient)
             for run in runs]
        )
        # Select the best run.
        self.best = runs[self.best_index()]
        # Number of runs.
        self.n_runs_total = len(runs)
        # Number of clusters (K).
        self.k = get_common_k(runs)
        # Number of iterations until convergenge for each run.
        self.converged = np.array([run.iter if run.converged else NOCONV
                                   for run in runs])
        # Log-likelihood of each run.
        self.log_likes = np.array([run.log_like for run in runs])
        # BIC of each run.
        self.bics = np.array([run.bic for run in runs])
        # Jackpotting quotient of each run.
        self.jackpot_quotients = np.array([run.jackpot_quotient
                                           for run in runs])
        # Minimum MARCD between any two clusters in each run.
        self.min_marcds = np.array([run.min_marcd for run in runs])
        # Maximum correlation between any two clusters in each run.
        self.max_pearsons = np.array([run.max_pearson for run in runs])
        # MARCD between each run and the best run.
        self.marcds_vs_best = np.array(
            [calc_mean_arcsine_distance_clusters(run.mus.values,
                                                 self.best.mus.values)
             for run in runs]
        )
        # Correlation between each run and the best run.
        self.pearsons_vs_best = np.array(
            [calc_mean_pearson_clusters(run.mus.values,
                                        self.best.mus.values)
             for run in runs]
        )

    def run_passing(self, allow_underclustered: bool = False):
        """ Whether each run passed the filters. """
        if allow_underclustered:
            return self.run_not_overclustered
        return self.run_not_overclustered & self.run_not_underclustered

    def n_runs_passing(self, **kwargs):
        """ Number of runs passing the filters. """
        return int(np.count_nonzero(self.run_passing(**kwargs)))

    def get_valid_index(self, i: int | list[int] | np.ndarray, **kwargs):
        """ Index(es) of valid run number(s) `i`. """
        return np.flatnonzero(self.run_passing(**kwargs))[i]

    def best_index(self, **kwargs) -> int:
        """ Index of the best valid run. """
        try:
            # The best run is the valid run with the largest likelihood.
            return self.get_valid_index(0, **kwargs)
        except IndexError:
            # If no runs are valid, then use the best invalid run.
            logger.warning(f"{self} got no EM runs that passed all filters")
            return 0

    def subopt_indexes(self, **kwargs):
        """ Indexes of the valid suboptimal runs. """
        return self.get_valid_index(np.arange(1, self.n_runs_passing(**kwargs)),
                                    **kwargs)

    def loglike_vs_best(self, **kwargs):
        """ Log likelihood difference between the best and second-best
        runs. """
        try:
            index1, index2 = self.get_valid_index([0, 1], **kwargs)
            return float(self.log_likes[index1] - self.log_likes[index2])
        except IndexError:
            return np.nan

    def pearson_vs_best(self, **kwargs):
        """ Maximum Pearson correlation between the best run and any
        other run. """
        try:
            return float(np.max(
                self.pearsons_vs_best[self.subopt_indexes(**kwargs)]
            ))
        except ValueError:
            return np.nan

    def marcd_vs_best(self, **kwargs):
        """ Minimum MARCD between the best run and any other run. """
        try:
            return float(np.min(
                self.marcds_vs_best[self.subopt_indexes(**kwargs)]
            ))
        except ValueError:
            return np.nan

    def _n_min_runs_passing(self, **kwargs):
        n_runs_passing = self.n_runs_passing(**kwargs)
        min_runs_passing = min(self.n_runs_total, 2)
        return n_runs_passing, min_runs_passing

    def enough_runs_passing(self, **kwargs):
        """ Whether enough runs passed. """
        n_runs_passing, min_runs_passing = self._n_min_runs_passing(**kwargs)
        return n_runs_passing >= min_runs_passing

    def passing(self, **kwargs):
        """ Whether this number of clusters passes the filters. """
        n_runs_passing, min_runs_passing = self._n_min_runs_passing(**kwargs)
        if n_runs_passing < min_runs_passing:
            logger.detail(f"{self} did not pass: {n_runs_passing} runs passed, "
                          f"but needed {min_runs_passing}")
            return False
        # Make sure that if any attribute is NaN, the run will still be
        # able to pass; this can be done by requiring each inequality
        # to be True in order to not pass (since < and > will be False
        # if one side is NaN).
        loglike_vs_best = self.loglike_vs_best(**kwargs)
        if loglike_vs_best > self.max_loglike_vs_best > 0.:
            logger.detail(f"{self} did not pass: difference between 1st/2nd "
                          f"log likelihoods is {loglike_vs_best}, but needed "
                          f"to be ≤ {self.max_loglike_vs_best}")
            return False
        pearson_vs_best = self.pearson_vs_best(**kwargs)
        if pearson_vs_best < self.min_pearson_vs_best:
            logger.detail(f"{self} did not pass: Pearson correlation between "
                          f"best run and any other run is {pearson_vs_best}, "
                          f"but needed to be ≥ {self.min_pearson_vs_best}")
            return False
        marcd_vs_best = self.marcd_vs_best(**kwargs)
        if marcd_vs_best > self.max_marcd_vs_best:
            logger.detail(f"{self} did not pass: MARCD between best run and "
                          f"any other run is {marcd_vs_best}, but needed to "
                          f"be ≤ {self.max_marcd_vs_best}")
            return False
        logger.detail(f"{self} passed all filters using {kwargs}")
        return True

    def summarize(self, **kwargs):
        """ Summarize the results of the runs. """
        lines = [f"EM runs for K={self.k}",
                 "\nPARAMETERS\n"]
        for attr in ["max_pearson_run",
                     "min_marcd_run",
                     "max_jackpot_quotient",
                     "max_loglike_vs_best",
                     "min_pearson_vs_best",
                     "max_marcd_vs_best"]:
            lines.append(f"{attr} = {getattr(self, attr)}")
        lines.append("\nRUNS\n")
        for attr in ["n_runs_total",
                     "converged",
                     "log_likes",
                     "bics",
                     "jackpot_quotients",
                     "min_marcds",
                     "max_pearsons",
                     "marcds_vs_best",
                     "pearsons_vs_best",
                     "run_not_overclustered",
                     "run_not_underclustered"]:
            lines.append(f"{attr} = {getattr(self, attr)}")
        lines.append("\nPASSING\n")
        for attr in ["run_passing",
                     "n_runs_passing",
                     "best_index",
                     "loglike_vs_best",
                     "pearson_vs_best",
                     "marcd_vs_best",
                     "enough_runs_passing",
                     "passing"]:
            func = getattr(self, attr)
            lines.append(f"{attr} = {func(**kwargs)}")
        return "\n".join(lines)


def find_best_k(ks: Iterable[EMRunsK], **kwargs):
    """ Find the best number of clusters. """
    # Sort the runs by increasing numbers of clusters.
    ks = sorted(ks, key=lambda runs: runs.k)
    if not ks:
        logger.warning("No numbers of clusters exist")
        return 0
    # Select only the numbers of clusters that pass the filters.
    # For the largest number of clusters, underclustering can be allowed
    # to permit this number to be identified as the best number so far;
    # for all other numbers, use all filters.
    ks = [runs for runs in ks[:-1] if runs.passing()] + (
        [ks[-1]] if ks[-1].passing(**kwargs) else []
    )
    if not ks:
        logger.warning("No numbers of clusters pass the filters")
        return 0
    # Of the remaining numbers of clusters, find the number that gives
    # the smallest BIC.
    ks = sorted(ks, key=lambda runs: runs.best.bic)
    # Return that number of clusters.
    return ks[0].k


def _compare_groups(func: Callable, mus1: np.ndarray, mus2: np.ndarray):
    """ Compare two groups of clusters using a comparison function and
    return a matrix of the results. """
    n1, k1 = mus1.shape
    n2, k2 = mus2.shape
    if n1 != n2:
        raise ValueError(
            f"Numbers of positions in groups 1 ({n1}) and 2 ({n2}) differ"
        )
    return np.array([[func(mus1[:, cluster1], mus2[:, cluster2])
                      for cluster2 in range(k2)]
                     for cluster1 in range(k1)]).reshape((k1, k2))


def assign_clusterings(mus1: np.ndarray, mus2: np.ndarray):
    """ Optimally assign clusters from two groups to each other. """
    n1, k1 = mus1.shape
    n2, k2 = mus2.shape
    if n1 != n2:
        raise ValueError(
            f"Numbers of positions in groups 1 ({n1}) and 2 ({n2}) differ"
        )
    if k1 != k2:
        raise ValueError(
            f"Numbers of clusters in groups 1 ({k1}) and 2 ({k2}) differ"
        )
    if n1 >= 1:
        costs = _compare_groups(calc_sum_arcsine_distance, mus1, mus2)
        assert costs.shape == (k1, k2)
        from scipy.optimize import linear_sum_assignment
        rows, cols = linear_sum_assignment(costs)
    else:
        # If n1 == 0, then the costs matrix will contain NaN, which will
        # cause linear_sum_assignment to raise an error.
        rows = np.arange(k1)
        cols = np.arange(k1)
    assert np.array_equal(rows, np.arange(k1))
    assert rows.shape == cols.shape
    return rows, cols


def calc_mean_arcsine_distance_clusters(mus1: np.ndarray, mus2: np.ndarray):
    """ Mean MARCD between the clusters. """
    assignment = assign_clusterings(mus1, mus2)
    marcds = _compare_groups(calc_mean_arcsine_distance, mus1, mus2)
    return float(np.mean(marcds[assignment]))


def calc_mean_pearson_clusters(mus1: np.ndarray, mus2: np.ndarray):
    """ Mean Pearson correlation between the clusters. """
    assignment = assign_clusterings(mus1, mus2)
    pearsons = _compare_groups(calc_pearson, mus1, mus2)
    return float(np.mean(pearsons[assignment]))
