import numpy as np
from scipy.spatial.ckdtree import cKDTree


class DecisionMaking:

    def __init__(self, normalize=True, ideal_point=None, nadir_point=None) -> None:
        super().__init__()
        self.normalize = normalize
        self.ideal_point, self.nadir_point = ideal_point, nadir_point

    def do(self, F, *args, **kwargs):
        return self._do(F, *args, **kwargs)


def normalize(F, ideal_point=None, nadir_point=None, estimate_bounds_if_none=True, return_bounds=False):
    N = F.astype(np.float)

    if estimate_bounds_if_none:
        if ideal_point is None:
            ideal_point = np.min(F, axis=0)
        if nadir_point is None:
            nadir_point = np.max(F, axis=0)

    if ideal_point is not None:
        N -= ideal_point

    if nadir_point is not None:

        # calculate the norm for each objective
        norm = nadir_point - ideal_point

        # check if normalization makes sense
        if np.any(norm < 1e-8):
            raise Exception("Normalization failed because the range between the ideal and nadir point is not "
                            "large enough.")

        N /= norm

    else:
        norm = np.ones(F.shape[1])

    if return_bounds:
        return N, norm, ideal_point, nadir_point
    else:
        return N


class NeighborFinder:

    def __init__(self, N,
                 epsilon=0.125,
                 n_neighbors=None,
                 n_min_neigbors=None,
                 consider_2d=True):

        super().__init__()
        self.N = N
        self.consider_2d = consider_2d

        _, n_dim = N.shape

        # at least find dimensionality times two neighbors - if enabled
        if n_min_neigbors == "auto":
            self.n_min_neigbors = 2 * n_dim

        # disable the minimum neighbor variable
        else:
            self.n_min_neigbors = np.inf

        # either choose epsilon
        self.epsilon = epsilon

        # if none choose the number of neighbors
        self.n_neighbors = n_neighbors

        if self.N.shape[1] == 1:
            raise Exception("At least 2 objectives must be provided.")

        elif self.consider_2d and self.N.shape[1] == 2:
            self.min, self.max = N.min(), N.max()
            self.rank = np.argsort(N[:, 0])
            self.pos_in_rank = np.argsort(self.rank)

        else:
            self.tree = cKDTree(N)

    def find(self, i):

        if self.consider_2d and self.N.shape[1] == 2:
            neighbours = []

            pos = self.pos_in_rank[i]
            if pos > 0:
                neighbours.append(self.rank[pos - 1])
            if pos < len(self.N) - 1:
                neighbours.append(self.rank[pos + 1])

        else:

            # for each neighbour in a specific radius of that solution
            if self.epsilon is not None:
                neighbours = self.tree.query_ball_point([self.N[i]], self.epsilon).tolist()[0]
            elif self.n_neighbors is not None:
                neighbours = self.tree.query([self.N[i]], k=self.n_neighbors + 1)[1].tolist()[0]
            else:
                raise Exception("Either define epsilon or number of neighbors.")

            # in case n_min_neigbors is enabled
            if len(neighbours) < self.n_min_neigbors:
                neighbours = self.tree.query([self.N[i]], k=self.n_min_neigbors + 1)[1].tolist()[0]

        return neighbours


def find_outliers_upper_tail(mu):

    # remove values that are nan
    I = np.where(np.logical_and(np.logical_not(np.isnan(mu)), np.logical_not(np.isinf(mu))))[0]
    mu = mu[I]

    # calculate mean and sigma
    mean, sigma = mu.mean(), mu.std()

    # calculate the deviation in terms of sigma
    deviation = (mu - mean) / sigma

    # 2 * sigma is considered as an outlier
    S = I[np.where(deviation >= 2)[0]]

    if len(S) == 0 and deviation.max() > 1:
        S = I[[np.argmax(mu)]]

    return S if len(S) > 0 else None
