# Copyright [2024] Expedia, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional

import tensorflow as tf

import kamae
from kamae.tensorflow.typing import Tensor
from kamae.tensorflow.utils import (
    enforce_single_tensor_input,
    unix_timestamp_to_datetime,
)

from .base import BaseLayer


@tf.keras.utils.register_keras_serializable(package=kamae.__name__)
class UnixTimestampToDateTimeLayer(BaseLayer):
    """
    Returns the date in yyyy-MM-dd HH:mm:ss.SSS format from a Unix timestamp.
    If `include_time` is set to `False`, the output will be in yyyy-MM-dd format.
    """

    def __init__(
        self,
        name=None,
        input_dtype: str = None,
        output_dtype: str = None,
        unit: str = "s",
        include_time: bool = True,
        **kwargs,
    ) -> None:
        """
        Initialises an instance of the UnixTimestampToDateTime layer.

        :param name: Name of the layer. Defaults to `None`.
        :param input_dtype: The dtype to cast the input to. Defaults to `None`.
        :param output_dtype: The dtype to cast the output to. Defaults to `None`.
        :param unit: Unit of the timestamp. Can be `milliseconds` (or `ms`)
        or `seconds` (or `s`). Defaults to `s`.
        :param include_time: Whether to include the time in the output.
        Defaults to `True`.
        """
        super().__init__(
            name=name, input_dtype=input_dtype, output_dtype=output_dtype, **kwargs
        )
        if unit not in ["milliseconds", "seconds", "ms", "s"]:
            raise ValueError(
                """Unit must be one of ["milliseconds", "seconds", "ms", "s"]"""
            )
        if unit == "milliseconds":
            unit = "ms"
        if unit == "seconds":
            unit = "s"
        self.unit = unit
        self.include_time = include_time

    @property
    def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]:
        """
        Returns the compatible dtypes of the layer. Returns `None` as the layer
        only returns the current date as a string. It does not transform any input.

        :returns: The compatible dtypes of the layer.
        """
        return [
            tf.float64,
            tf.int64,
        ]

    @enforce_single_tensor_input
    def _call(self, inputs: Tensor, **kwargs) -> tf.Tensor:
        """
        Returns the datetime in yyyy-MM-dd HH:mm:ss.SSS format if `include_time` is
        set to `True`. Otherwise, returns the date in yyyy-MM-dd format.

        Decorated with `@enforce_single_tensor_input` to ensure that
        the input is a single tensor. Raises an error if multiple tensors are passed
        in as an iterable.

        :param inputs: Input tensor to determine the shape of the output tensor.
        :returns: Datetime in either yyyy-MM-dd HH:mm:ss.SSS or yyyy-MM-dd format.
        """
        # Timestamp needs to be in float64 for unix_timestamp_to_datetime
        timestamp_in_seconds = (
            self._cast(inputs, cast_dtype="float64")
            if self.unit == "s"
            else tf.math.divide_no_nan(self._cast(inputs, cast_dtype="float64"), 1000.0)
        )
        outputs = unix_timestamp_to_datetime(
            timestamp_in_seconds, include_time=self.include_time
        )
        return outputs

    def get_config(self) -> Dict[str, Any]:
        """
        Gets the configuration of the UnixTimestampToDateTime layer.
        Used for saving and loading from a model.

        Specifically sets the `unit` and `include_time` parameters in the config.

        :returns: Dictionary of the configuration of the layer.
        """
        config = super().get_config()
        config.update(
            {
                "unit": self.unit,
                "include_time": self.include_time,
            }
        )
        return config
