import numpy as np
from sklearn.base import RegressorMixin
from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_array
from sklearn.utils.validation import check_is_fitted

from skgrf import grf
from skgrf.ensemble.base import BaseGRFForest
from skgrf.tree.local_linear_regressor import GRFTreeLocalLinearRegressor
from skgrf.utils.validation import check_sample_weight


class GRFForestLocalLinearRegressor(BaseGRFForest, RegressorMixin):
    r"""GRF Local Linear Regression implementation for sci-kit learn.

    Provides a sklearn regressor interface to the GRF C++ library using Cython.

    .. warning::

        Because the training dataset is required for prediction, the training dataset
        is recorded onto the estimator instance. This means that serializing this
        estimator will result in a file at least as large as the serialized training
        dataset.

    :param int n_estimators: The number of tree regressors to train
    :param bool ll_split_weight_penalty: Use a covariance ridge penalty if using local
        linear splits.
    :param float ll_split_lambda: Ridge penalty for splitting.
    :param list(int) ll_split_variables: Linear correction variables for splitting. Uses
        all variables if not specified.
    :param float ll_split_cutoff: Leaf size after which the overall beta is used. If
        unspecified, default is sqrt of num samples. Passing 0 means no cutoff.
    :param bool equalize_cluster_weights: Weight the samples such that clusters have
        equally weight. If ``False``, larger clusters will have more weight. If
        ``True``, the number of samples drawn from each cluster is equal to the size of
        the smallest cluster. If ``True``, sample weights should not be passed on
        fitting.
    :param float sample_fraction: Fraction of samples used in each tree. If
        ``ci_group_size`` > 1, the max allowed fraction is 0.5
    :param int mtry: The number of features to split on each node. The default is
        ``sqrt(p) + 20`` where ``p`` is the number of features.
    :param int min_node_size: The minimum number of observations in each tree leaf.
    :param bool honesty: Use honest splitting (subsample splitting).
    :param float honesty_fraction: The fraction of data used for subsample splitting.
    :param bool honesty_prune_leaves: Prune estimation sample tree such that no leaves
        are empty. If ``False``, trees with empty leaves are skipped.
    :param float alpha: The maximum imbalance of a split.
    :param float imbalance_penalty: Penalty applied to imbalanced splits.
    :param int ci_group_size: The quantity of trees grown on each subsample. At least 2
        is required to provide confidence intervals.
    :param int n_jobs: The number of threads. Default is number of CPU cores.
    :param int seed: Random seed value.
    :param bool enable_tree_details: When ``True``, perform additional calculations
        for detailing the underlying decision trees. Must be enabled for ``estimators_``
        and ``get_estimator`` to work. Very slow.

    :ivar list estimators\_: A list of tree objects from the forest.
    :ivar int n_features_in\_: The number of features (columns) from the fit input
        ``X``.
    :ivar dict grf_forest\_: The returned result object from calling C++ grf.
    :ivar int mtry\_: The ``mtry`` value determined by validation.
    :ivar int outcome_index\_: The index of the grf train matrix holding the outcomes.
    :ivar list clusters\_: The cluster labels determined from the fit input ``cluster``.
    :ivar int n_clusters\_: The number of unique cluster labels from the fit input
        ``cluster``.
    :ivar array2d train\_: The ``X,y`` concatenated train matrix passed to grf.
    :ivar str criterion: The criterion used for splitting: ``mse``
    """

    def __init__(
        self,
        n_estimators=100,
        ll_split_weight_penalty=False,
        ll_split_lambda=0.1,
        ll_split_variables=None,
        ll_split_cutoff=None,
        equalize_cluster_weights=False,
        sample_fraction=0.5,
        mtry=None,
        min_node_size=5,
        honesty=True,
        honesty_fraction=0.5,
        honesty_prune_leaves=True,
        alpha=0.05,
        imbalance_penalty=0,
        ci_group_size=2,
        n_jobs=-1,
        seed=42,
        enable_tree_details=False,
    ):
        self.n_estimators = n_estimators
        self.ll_split_weight_penalty = ll_split_weight_penalty
        self.ll_split_lambda = ll_split_lambda
        self.ll_split_variables = ll_split_variables
        self.ll_split_cutoff = ll_split_cutoff
        self.equalize_cluster_weights = equalize_cluster_weights
        self.sample_fraction = sample_fraction
        self.mtry = mtry
        self.min_node_size = min_node_size
        self.honesty = honesty
        self.honesty_fraction = honesty_fraction
        self.honesty_prune_leaves = honesty_prune_leaves
        self.alpha = alpha
        self.imbalance_penalty = imbalance_penalty
        self.ci_group_size = ci_group_size
        self.n_jobs = n_jobs
        self.seed = seed
        self.enable_tree_details = enable_tree_details

    @property
    def criterion(self):
        return "mse"

    @property
    def estimators_(self):
        try:
            check_is_fitted(self)
        except NotFittedError:
            raise AttributeError(
                f"{self.__class__.__name__} object has no attribute 'estimators_'"
            ) from None
        if not self.enable_tree_details:
            raise ValueError("enable_tree_details must be True prior to training")
        return [
            GRFTreeLocalLinearRegressor.from_forest(self, idx=idx)
            for idx in range(self.n_estimators)
        ]

    def get_estimator(self, idx):
        """Extract a single estimator tree from the forest.

        :param int idx: The index of the tree to extract.
        """
        check_is_fitted(self)
        if not self.enable_tree_details:
            raise ValueError("enable_tree_details must be True prior to training")
        return GRFTreeLocalLinearRegressor.from_forest(self, idx=idx)

    def fit(self, X, y, sample_weight=None, cluster=None):
        """Fit the grf forest using training data.

        :param array2d X: training input features
        :param array1d y: training input targets
        :param array1d sample_weight: optional weights for input samples
        :param array1d cluster: optional cluster assignments for input samples
        """
        X, y = self._validate_data(X, y)
        self._check_num_samples(X)
        self._check_n_features(X, reset=True)

        self._check_sample_fraction()
        self._check_alpha()

        sample_weight, use_sample_weight = check_sample_weight(sample_weight, X)

        cluster = self._check_cluster(X=X, cluster=cluster)
        self.samples_per_cluster_ = self._check_equalize_cluster_weights(
            cluster=cluster, sample_weight=sample_weight
        )
        self.mtry_ = self._check_mtry(X=X)

        train_matrix = self._create_train_matrices(X, y, sample_weight=sample_weight)
        self.train_ = train_matrix

        if self.ll_split_variables is None:
            self.ll_split_variables_ = list(range(X.shape[1]))
        else:
            self.ll_split_variables_ = self.ll_split_variables

        # calculate overall beta
        if self.ll_split_cutoff is None:
            self.ll_split_cutoff_ = int(X.shape[0] ** 0.5)
        else:
            self.ll_split_cutoff_ = self.ll_split_cutoff

        if self.ll_split_cutoff_ > 0:
            J = np.eye(X.shape[1] + 1)
            J[0, 0] = 0
            D = np.concatenate([np.ones((X.shape[0], 1)), X], axis=1)
            self.overall_beta_ = (
                np.linalg.solve(
                    D.T @ D + self.ll_split_lambda * J, np.eye(X.shape[1] + 1)
                )
                @ D.T
                @ y
            )
        else:
            self.overall_beta_ = np.empty((0,), dtype=float, order="F")

        self.grf_forest_ = grf.ll_regression_train(
            np.asfortranarray(train_matrix.astype("float64")),
            self.outcome_index_,
            self.sample_weight_index_,
            self.ll_split_lambda,
            self.ll_split_weight_penalty,
            self.ll_split_variables_,
            self.ll_split_cutoff_,
            self.overall_beta_,
            use_sample_weight,
            self.mtry_,
            self.n_estimators,  # num_trees
            self.min_node_size,
            self.sample_fraction,
            self.honesty,
            self.honesty_fraction,
            self.honesty_prune_leaves,
            self.ci_group_size,
            self.alpha,
            self.imbalance_penalty,
            cluster,
            self.samples_per_cluster_,
            self._get_num_threads(),  # num_threads,
            self.seed,
        )
        self._ensure_ptr()

        if self.enable_tree_details:
            sample_weight = (
                sample_weight if sample_weight is not None else np.ones(len(X))
            )
            self._set_node_values(y, sample_weight)
            self._set_n_classes()

        return self

    def predict(self, X):
        """Predict regression target for X.

        :param array2d X: prediction input features
        """
        return np.atleast_1d(np.squeeze(np.array(self._predict(X)["predictions"])))

    def _predict(self, X, estimate_variance=False):
        check_is_fitted(self)
        X = check_array(X)
        self._check_n_features(X, reset=False)
        self._ensure_ptr()

        result = grf.ll_regression_predict(
            self.grf_forest_cpp_,
            np.asfortranarray(self.train_.astype("float64")),  # test_matrix
            self.outcome_index_,
            np.asfortranarray(X.astype("float64")),  # test_matrix
            [self.ll_split_lambda],  # ll_lambda
            self.ll_split_weight_penalty,  # ll_weight_penalty
            self.ll_split_variables_,  # linear_correction_variables
            self._get_num_threads(),
            estimate_variance,  # estimate variance
        )
        return result

    def _more_tags(self):
        return {
            "requires_y": True,
            "_xfail_checks": {
                "check_sample_weights_invariance": "zero sample_weight is not equivalent to removing samples",
            },
        }
