from __future__ import annotations

from abc import ABC, abstractmethod

import optuna
import pandas as pd

from dataclr._typing import DataSplits
from dataclr.metrics import Metric
from dataclr.models import BaseModel
from dataclr.results import Result, ResultPerformance

optuna.logging.set_verbosity(optuna.logging.WARNING)


class Method(ABC):
    """
    A base class for feature selection methods.

    This class defines the structure for methods that integrate with a machine learning
    model to select the best features from a dataset. It includes the main functionality
    for fitting a model and returning results.

    Attributes:
        model (:class:`~dataclr.models.BaseModel`): The machine learning model used for
                                             feature evaluation.
        metric (:data:`~dataclr.metrics.Metric`): The metric used to assess feature
                                                importance.
        n_results (int): The number of top features or results to select.
        total_combinations (int): The total number of feature combinations evaluated.
        seed (int) : Number determining the randomness.
    """

    def __init__(
        self, model: BaseModel, metric: Metric, n_results: int, seed: int = 42
    ) -> None:
        self.model: BaseModel = model
        self.metric: Metric = metric
        self.n_results: int = n_results
        self.seed = seed

        self.total_combinations: int = 0

    def fit_transform(
        self,
        X_train: pd.DataFrame,
        X_test: pd.DataFrame,
        y_train: pd.Series,
        y_test: pd.Series,
    ) -> list[Result]:
        """
        Fits the model using the training data and returns results based on the model.

        This method combines the functionality of `fit` and `transform` to perform both
        steps in sequence.

        Args:
            X_train (pd.DataFrame): The training features.
            X_test (pd.DataFrame): The test features.
            y_train (pd.Series): The training target variable.
            y_test (pd.Series): The test target variable.

        Returns:
            list[Result]: A list of results generated by the transformation.
                          Returns an empty list if fitting the model fails.
        """
        try:
            self.fit(X_train, y_train)
        except ValueError:
            return []

        return self.transform(X_train, X_test, y_train, y_test)

    def fit(self, X_train: pd.DataFrame, y_train: pd.Series) -> Method:
        """
        Fits the model using the provided training data.

        This method is intended to be implemented by child classes to define specific
        fitting logic.

        Args:
            X_train (pd.DataFrame): The training features.
            y_train (pd.Series): The training target variable.

        Returns:
            Method: The instance of the class itself after fitting.

        Raises:
            NotImplementedError: If the method is not implemented in a subclass.
        """
        pass

    @abstractmethod
    def transform(
        self,
        X_train: pd.DataFrame,
        X_test: pd.DataFrame,
        y_train: pd.Series,
        y_test: pd.Series,
    ) -> list[Result]:
        """
        Returns results based on the fitted model.

        This method is intended to be implemented by child classes to define specific
        transformation logic.

        Args:
            X_train (pd.DataFrame): The training features.
            X_test (pd.DataFrame): The test features.
            y_train (pd.Series): The training target variable.
            y_test (pd.Series): The test target variable.

        Returns:
            list[Result]: A list of results generated by the transformation.

        Raises:
            NotImplementedError: If the method is not implemented in a subclass.
        """
        pass

    @abstractmethod
    def _get_results(
        self,
        data_splits: DataSplits,
        cached_performance: dict[str, ResultPerformance],
    ) -> list[Result]:
        pass
