# Copyright CNRS/Inria/UCA
# Contributor(s): Eric Debreuve (since 2019)
#
# eric.debreuve@cnrs.fr
#
# This software is governed by the CeCILL  license under French law and
# abiding by the rules of distribution of free software.  You can  use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and  rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty  and the software's author,  the holder of the
# economic rights,  and the successive licensors  have only  limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using,  modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean  that it is complicated to manipulate,  and  that  also
# therefore means  that it is reserved for developers  and  experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and,  more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.

import glob
import sys as sstm
from collections import defaultdict as default_dict_t
from csv import reader as csv_reader_t
from pathlib import Path as path_t
from typing import Callable, Optional, Sequence, Tuple

import imageio as mgio
import numpy as nmpy

import daccuracy.brick.csv_io as csio
import daccuracy.brick.image as imge
from daccuracy.brick.csv_io import row_transform_h


array_t = nmpy.ndarray

img_shape_h = Tuple[int, ...]
gt_loading_fct_h = Callable[
    [path_t, img_shape_h, Sequence[int], row_transform_h], array_t
]


# See at the end of module
GT_LOADING_FOR_EXTENSION = default_dict_t(lambda: _GroundTruthFromImage)
DN_LOADING_FOR_EXTENSION = default_dict_t(lambda: _DetectionFromImage)


def GroundTruthForDetection(
    detection_name: str,  # Without extension
    detection_shape: img_shape_h,
    ground_truth_path: path_t,
    ground_truth_folder: path_t,
    ground_truth: Optional[array_t],
    coordinate_idc: Optional[Sequence[int]],
    row_transform: Optional[row_transform_h],
    mode: str,
    /,
) -> Tuple[Optional[array_t], Optional[path_t]]:
    """"""
    if mode == "one-to-one":
        ground_truth_path = None
        pattern = ground_truth_folder / (detection_name + ".*")
        for path in glob.iglob(pattern):
            ground_truth_path = path
            break
        if ground_truth_path is None:
            ground_truth = None
        else:
            gt_loading_fct = GT_LOADING_FOR_EXTENSION[ground_truth_path.suffix.lower()]
            ground_truth = gt_loading_fct(
                ground_truth_path, detection_shape, coordinate_idc, row_transform
            )
    elif ground_truth is None:  # mode = 'one-to-many'
        gt_loading_fct = GT_LOADING_FOR_EXTENSION[ground_truth_path.suffix.lower()]
        ground_truth = gt_loading_fct(
            ground_truth_path, detection_shape, coordinate_idc, row_transform
        )

    return ground_truth, ground_truth_path


def _GroundTruthFromImage(
    path: path_t,
    _: img_shape_h,
    __: Optional[Sequence[int]],
    ___: Optional[row_transform_h],
    /,
) -> Optional[array_t]:
    """"""
    return _ImageFromPath(
        path, _ImageFromImagePath, None, "image or unreadable by imageio"
    )


def _GroundTruthFromNumpy(
    path: path_t,
    _: img_shape_h,
    __: Optional[Sequence[int]],
    ___: Optional[row_transform_h],
    /,
) -> Optional[array_t]:
    """"""
    return _ImageFromPath(path, _ImageFromNumpyPath, None, "Numpy file or unreadable")


def _GroundTruthFromCSV(
    path: path_t,
    shape: img_shape_h,
    coordinate_idc: Optional[Sequence[int]],
    row_transform: Optional[row_transform_h],
    /,
) -> Optional[array_t]:
    """"""
    output = nmpy.zeros(shape, dtype=nmpy.uint64)

    # Leave this here since the symmetrization transform must be defined for each image (shape[0])
    if row_transform is None:
        row_transform = lambda f_idx: csio.SymmetrizedRow(f_idx, float(shape[0]))

    try:
        with open(path) as csv_accessor:
            csv_reader = csv_reader_t(csv_accessor)
            # Do not enumerate csv_reader below since some rows might be dropped
            label = 1
            for line in csv_reader:
                coordinates = csio.CSVLineToCoords(line, coordinate_idc, row_transform)
                if coordinates is not None:
                    if coordinates.__len__() != output.ndim:
                        print(
                            f"{coordinates.__len__()} != {output.ndim}: Mismatch between (i) CSV coordinates "
                            f"and (ii) detection dimension for {path}"
                        )
                        output = None
                        break
                    if any(_elm < 0 for _elm in coordinates) or nmpy.any(
                        nmpy.greater_equal(coordinates, output.shape)
                    ):
                        expected = (f"0<= . <= {_sze - 1}" for _sze in output.shape)
                        expected = ", ".join(expected)
                        print(
                            f"{coordinates}: CSV coordinates out of bound for detection {path}; Expected={expected}"
                        )
                        output = None
                        break
                    if output[coordinates] > 0:
                        print(
                            f"{path}: Multiple GTs at same position (due to rounding or duplicates)"
                        )
                        output = None
                        break
                    output[coordinates] = label
                    label += 1
    except BaseException as exc:
        print(f"{path}: Error while reading or unreadable\n({exc})", file=sstm.stderr)
        output = None

    return output


def _DetectionFromImage(
    path: path_t, dn_shifts: Optional[Sequence[int]], /
) -> Optional[array_t]:
    """"""
    return _ImageFromPath(
        path, _ImageFromImagePath, dn_shifts, "image or unreadable by imageio"
    )


def _DetectionFromNumpy(
    path: path_t, dn_shifts: Optional[Sequence[int]], /
) -> Optional[array_t]:
    """"""
    return _ImageFromPath(
        path, _ImageFromNumpyPath, dn_shifts, "Numpy file or unreadable"
    )


def _ImageFromPath(
    path: path_t,
    LoadingFunction: Callable[[path_t], array_t],
    dn_shifts: Optional[Sequence[int]],
    message: str,
    /,
) -> Optional[array_t]:
    """"""
    try:
        output = LoadingFunction(path)
        if dn_shifts is not None:
            output = imge.ShiftedVersion(output, dn_shifts)

        is_valid, issues = LabeledImageIsValid(output)
        if not is_valid:
            print(
                f"{path}: Incorrectly labeled image:\n    {issues}",
                file=sstm.stderr,
            )
            output = None
    except BaseException as exc:
        print(
            f"{path}: Not a valid {message}\n({exc})",
            file=sstm.stderr,
        )
        output = None

    return output


def _ImageFromImagePath(path: path_t, /) -> array_t:
    """"""
    return mgio.imread(str(path))


def _ImageFromNumpyPath(path: path_t, /) -> array_t:
    """"""
    output = nmpy.load(str(path))

    if hasattr(output, "keys"):
        first_key = tuple(output.keys())[0]
        output = output[first_key]

    return output


def LabeledImageIsValid(image: array_t, /) -> Tuple[bool, Optional[str]]:
    """"""
    unique_values = nmpy.unique(image)
    expected_values = range(nmpy.amax(image) + 1)

    is_valid = (unique_values.__len__() > 1) and nmpy.array_equal(
        unique_values, expected_values
    )

    if is_valid:
        issues = None
    elif unique_values.__len__() == 1:
        issues = f"Only one value present in image: {unique_values[0]}; Expected=at least 0 and 1"
    else:
        if unique_values[0] > 0:
            issues = ["0?"]  # Zero is missing
        else:
            issues = [str(unique_values[0])]
        for v_m_1_idx, label in enumerate(unique_values[1:]):
            previous = unique_values[v_m_1_idx]
            label_as_str = str(label)
            if label == previous:
                issues.append("=" + label_as_str)
            elif label > previous + 1:
                issues.extend(("...?", label_as_str))
            else:
                issues.append(label_as_str)
        issues = ", ".join(issues)

    return is_valid, issues


def WithFixedDimensions(
    ground_truth: array_t, detection: array_t, /
) -> Tuple[Optional[array_t], Optional[array_t]]:
    """"""
    if ground_truth.ndim == 3:
        ground_truth = _AsOneGrayChannelOrNone(ground_truth)
    else:
        detection = _AsOneGrayChannelOrNone(detection)

    return ground_truth, detection


def _AsOneGrayChannelOrNone(image: array_t, /) -> Optional[array_t]:
    """"""
    if (
        (3 <= image.shape[2] <= 4)
        and nmpy.array_equal(image[..., 0], image[..., 1])
        and nmpy.array_equal(image[..., 0], image[..., 2])
    ):
        if (image.shape[2] == 3) or nmpy.all(image[..., 3] == image[0, 0, 3]):
            return image[..., 0]

    return None


GT_LOADING_FOR_EXTENSION |= {
    ".npy": _GroundTruthFromNumpy,
    ".npz": _GroundTruthFromNumpy,
    ".csv": _GroundTruthFromCSV,
}
DN_LOADING_FOR_EXTENSION |= {".npy": _DetectionFromNumpy, ".npz": _DetectionFromNumpy}
