# Copyright [2024] Expedia, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=unused-argument
# pylint: disable=invalid-name
# pylint: disable=too-many-ancestors
# pylint: disable=no-member
from typing import List, Optional

import pyspark.sql.functions as F
import tensorflow as tf
from pyspark import keyword_only
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.sql import DataFrame
from pyspark.sql.types import DataType, DoubleType, FloatType

from kamae.spark.params import SingleInputSingleOutputParams
from kamae.spark.utils import single_input_single_output_scalar_transform
from kamae.tensorflow.layers import RoundLayer

from .base import BaseTransformer


class RoundParams(Params):
    """
    Mixin class containing roundType parameter needed for rounding transform layers.
    """

    roundType = Param(
        Params._dummy(),
        "roundType",
        "Round type to use in round transform, one of 'floor', 'ceil' or 'round'.",
        typeConverter=TypeConverters.toString,
    )

    def setRoundType(self, value: str) -> "RoundParams":
        """
        Sets the roundType parameter.

        :param value: Rounding type to use in round transform,
        one of 'floor', 'ceil' or 'round'.
        :returns: Instance of class mixed in.
        """
        if value not in ["floor", "ceil", "round"]:
            raise ValueError("roundType must be one of 'floor', 'ceil' or 'round'")
        return self._set(roundType=value)

    def getRoundType(self) -> str:
        """
        Gets the roundType parameter.

        :returns: Rounding type to use in round transform,
        one of 'floor', 'ceil' or 'round'.
        """
        return self.getOrDefault(self.roundType)


class RoundTransformer(
    BaseTransformer,
    SingleInputSingleOutputParams,
    RoundParams,
):
    """
    Round Spark Transformer for use in Spark pipelines.
    This transformer rounds the input column to the nearest integer using the
    specified rounding type.
    """

    @keyword_only
    def __init__(
        self,
        inputCol: Optional[str] = None,
        outputCol: Optional[str] = None,
        inputDtype: Optional[str] = None,
        outputDtype: Optional[str] = None,
        layerName: Optional[str] = None,
        roundType: str = "round",
    ) -> None:
        """
        Initializes an RoundTransformer transformer.

        :param inputCol: Input column name.
        :param outputCol: Output column name.
        :param inputDtype: Input data type to cast input column to before
        transforming.
        :param outputDtype: Output data type to cast the output column to after
        transforming.
        :param layerName: Name of the layer. Used as the name of the tensorflow layer
        in the keras model. If not set, we use the uid of the Spark transformer.
        :param roundType: Rounding type to use in round transform,
        one of 'floor', 'ceil' or 'round'. Defaults to 'round'.
        :returns: None - class instantiated.
        """
        super().__init__()
        kwargs = self._input_kwargs
        self._setDefault(roundType="round")
        self.setParams(**kwargs)

    @property
    def compatible_dtypes(self) -> Optional[List[DataType]]:
        """
        List of compatible data types for the layer.
        If the computation can be performed on any data type, return None.

        :returns: List of compatible data types for the layer.
        """
        return [FloatType(), DoubleType()]

    def _transform(self, dataset: DataFrame) -> DataFrame:
        """
        Transforms the input dataset. Creates a new column with name `outputCol`,
        which applies the rounding operation to the input column.

        :param dataset: Pyspark dataframe to transform.
        :returns: Transformed pyspark dataframe.
        """
        func_dict = {
            "floor": F.floor,
            "ceil": F.ceil,
            "round": F.round,
        }

        input_datatype = self.get_column_datatype(
            dataset=dataset, column_name=self.getInputCol()
        )
        output_col = single_input_single_output_scalar_transform(
            input_col=F.col(self.getInputCol()),
            input_col_datatype=input_datatype,
            func=lambda x: func_dict[self.getRoundType()](x),
        )
        return dataset.withColumn(self.getOutputCol(), output_col)

    def get_tf_layer(self) -> tf.keras.layers.Layer:
        """
        Gets the tensorflow layer for the round transformer.

        :returns: Tensorflow keras layer with name equal to the layerName parameter that
         performs a rounding operation.
        """
        return RoundLayer(
            name=self.getLayerName(),
            input_dtype=self.getInputTFDtype(),
            output_dtype=self.getOutputTFDtype(),
            round_type=self.getRoundType(),
        )
