from typing import List, Optional, Set

from deprecation import deprecated
from gemd.enumeration.base_enumeration import BaseEnumeration

from citrine._rest.resource import Resource
from citrine._serialization import properties as _properties
from citrine.informatics.data_sources import DataSource
from citrine.informatics.descriptors import Descriptor
from citrine.informatics.predictors import PredictorNode

__all__ = ['AutoMLPredictor', 'AutoMLEstimator']


class AutoMLEstimator(BaseEnumeration):
    """[ALPHA] Algorithms to be used during AutoML model selection.

    * LINEAR corresponds to a linear regression estimator
        (valid for single-task regression problems)
    * RANDOM_FOREST corresponds to a random forest estimator
        (valid for single-task and multi-task regression and classification problems)
    * GAUSSIAN_PROCESS corresponds to a Gaussian process estimator
        (valid for single-task regression and classification problems)
    * SUPPORT_VECTOR_MACHINE corresponds to an support machine estimator
        (valid for single-task classification problems)
    * ALL combines all estimator choices (valid for all learning tasks)
    """

    LINEAR = "LINEAR"
    RANDOM_FOREST = "RANDOM_FOREST"
    GAUSSIAN_PROCESS = "GAUSSIAN_PROCESS"
    SUPPORT_VECTOR_MACHINE = "SUPPORT_VECTOR_MACHINE"
    ALL = "ALL"


class AutoMLPredictor(Resource["AutoMLPredictor"], PredictorNode):
    """A predictor interface that builds a single ML model.

    The model uses the set of inputs to predict the output(s).
    Only one machine learning model is built.

    Parameters
    ----------
    name: str
        name of the configuration
    description: str
        the description of the predictor
    inputs: list[Descriptor]
        Descriptors that represent inputs to the model
    outputs: list[Descriptor]
        Descriptors that represents the output(s) of the model.
    estimators: Optional[Set[AutoMLEstimator]]
        Set of estimators to consider during during AutoML model selection.
        If None is provided, defaults to AutoMLEstimator.RANDOM_FOREST.
    training_data: Optional[List[DataSource]] (deprecated)
        Sources of training data. Each can be either a CSV or an GEM Table. Candidates from
        multiple data sources will be combined into a flattened list and de-duplicated by uid and
        identifiers. De-duplication is performed if a uid or identifier is shared between two or
        more rows. The content of a de-duplicated row will contain the union of data across all
        rows that share the same uid or at least 1 identifier. Training data is unnecessary if the
        predictor is part of a graph that includes all training data required by this predictor.

    """

    inputs = _properties.List(_properties.Object(Descriptor), 'inputs')
    outputs = _properties.List(_properties.Object(Descriptor), 'outputs')
    estimators = _properties.Set(
        _properties.Enumeration(AutoMLEstimator),
        'estimators',
        default={AutoMLEstimator.RANDOM_FOREST}
    )
    _training_data = _properties.List(
        _properties.Object(DataSource),
        'training_data',
        default=[]
    )

    typ = _properties.String('type', default='AutoML', deserializable=False)

    def __init__(self,
                 name: str,
                 *,
                 description: str,
                 outputs: List[Descriptor],
                 inputs: List[Descriptor],
                 estimators: Optional[Set[AutoMLEstimator]] = None,
                 training_data: Optional[List[DataSource]] = None):
        self.name: str = name
        self.description: str = description
        self.inputs: List[Descriptor] = inputs
        self.estimators: Set[AutoMLEstimator] = estimators or {AutoMLEstimator.RANDOM_FOREST}
        self.outputs = outputs
        # self.training_data: List[DataSource] = training_data or []
        if training_data:
            self.training_data: List[DataSource] = training_data

    @property
    @deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
                details="Training data must be accessed through the top-level GraphPredictor.'")
    def training_data(self):
        """[DEPRECATED] Retrieve training data associated with this node."""
        return self._training_data

    @training_data.setter
    @deprecated(deprecated_in="3.5.0", removed_in="4.0.0",
                details="Training data should only be added to the top-level GraphPredictor.'")
    def training_data(self, value):
        self._training_data = value

    def __str__(self):
        return '<AutoMLPredictor {!r}>'.format(self.name)
