from typing import TYPE_CHECKING, Callable, Dict, List, Union

import numpy as np

from ray.air.util.data_batch_conversion import BatchFormat
from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray
from ray.data.preprocessor import Preprocessor
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    import torch


@PublicAPI(stability="alpha")
class TorchVisionPreprocessor(Preprocessor):
    """Apply a `TorchVision transform <https://pytorch.org/vision/stable/transforms.html>`_
    to image columns.

    Examples:
        >>> import ray
        >>> dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images")
        >>> dataset  # doctest: +ellipsis
        Dataset(num_blocks=..., num_rows=..., schema={image: ArrowTensorType(shape=(..., 3), dtype=float)})

        Torch models expect inputs of shape :math:`(B, C, H, W)` in the range
        :math:`[0.0, 1.0]`. To convert images to this format, add ``ToTensor`` to your
        preprocessing pipeline.

        >>> from torchvision import transforms
        >>> from ray.data.preprocessors import TorchVisionPreprocessor
        >>> transform = transforms.Compose([
        ...     transforms.ToTensor(),
        ...     transforms.Resize((224, 224)),
        ... ])
        >>> preprocessor = TorchVisionPreprocessor(["image"], transform=transform)
        >>> preprocessor.transform(dataset)  # doctest: +ellipsis
        Dataset(num_blocks=..., num_rows=..., schema={image: ArrowTensorType(shape=(3, 224, 224), dtype=float)})

        For better performance, set ``batched`` to ``True`` and replace ``ToTensor``
        with a batch-supporting ``Lambda``.

        >>> def to_tensor(batch: np.ndarray) -> torch.Tensor:
        ...     tensor = torch.as_tensor(batch, dtype=torch.float)
        ...     # (B, H, W, C) -> (B, C, H, W)
        ...     tensor = tensor.permute(0, 3, 1, 2).contiguous()
        ...     # [0., 255.] -> [0., 1.]
        ...     tensor = tensor.div(255)
        ...     return tensor
        >>> transform = transforms.Compose([
        ...     transforms.Lambda(to_tensor),
        ...     transforms.Resize((224, 224))
        ... ])
        >>> preprocessor = TorchVisionPreprocessor(
        ...     ["image"], transform=transform, batched=True
        ... )
        >>> preprocessor.transform(dataset)  # doctest: +ellipsis
        Dataset(num_blocks=..., num_rows=..., schema={image: ArrowTensorType(shape=(3, 224, 224), dtype=float)})

    Args:
        columns: The columns to apply the TorchVision transform to.
        transform: The TorchVision transform you want to apply. This transform should
            accept a ``np.ndarray`` or ``torch.Tensor`` as input and return a
            ``torch.Tensor`` as output.
        batched: If ``True``, apply ``transform`` to batches of shape
            :math:`(B, H, W, C)`. Otherwise, apply ``transform`` to individual images.
    """  # noqa: E501

    _is_fittable = False

    def __init__(
        self,
        columns: List[str],
        transform: Callable[[Union["np.ndarray", "torch.Tensor"]], "torch.Tensor"],
        batched: bool = False,
    ):
        self._columns = columns
        self._torchvision_transform = transform
        self._batched = batched

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}(columns={self._columns}, "
            f"transform={self._torchvision_transform!r})"
        )

    def _transform_numpy(
        self, np_data: Union["np.ndarray", Dict[str, "np.ndarray"]]
    ) -> Union["np.ndarray", Dict[str, "np.ndarray"]]:
        import torch
        from ray.air._internal.torch_utils import convert_ndarray_to_torch_tensor

        def apply_torchvision_transform(array: np.ndarray) -> np.ndarray:
            try:
                tensor = convert_ndarray_to_torch_tensor(array)
                output = self._torchvision_transform(tensor)
            except TypeError:
                # Transforms like `ToTensor` expect a `np.ndarray` as input.
                output = self._torchvision_transform(array)

            if not isinstance(output, torch.Tensor):
                raise ValueError(
                    "`TorchVisionPreprocessor` expected your transform to return a "
                    "`torch.Tensor`, but your transform returned a "
                    f"`{type(output).__name__}` instead."
                )

            return output.numpy()

        def transform_batch(batch: np.ndarray) -> np.ndarray:
            if self._batched:
                return apply_torchvision_transform(batch)
            return _create_possibly_ragged_ndarray(
                [apply_torchvision_transform(array) for array in batch]
            )

        if isinstance(np_data, dict):
            outputs = np_data
            for column in self._columns:
                outputs[column] = transform_batch(np_data[column])
        else:
            outputs = transform_batch(np_data)

        return outputs

    def preferred_batch_format(cls) -> BatchFormat:
        return BatchFormat.NUMPY
