"""
Copyright CNRS/Inria/UniCA
Contributor(s): Eric Debreuve (eric.debreuve@cnrs.fr) since 2022
SEE COPYRIGHT NOTICE BELOW
"""

from __future__ import annotations

import typing as h

import matplotlib.pyplot as pypl  # noqa
import numpy as nmpy
import skimage.measure as msre
from babelplot.backend.specification.backend import backend_e
from babelplot.specification.dimension import dim_e
from babelplot.specification.plot import PlotsFromTemplate, plot_e, plot_type_h
from babelplot.type.base import backend_element_h, backend_plot_h
from babelplot.type.figure import figure_t as base_figure_t
from babelplot.type.frame import frame_t as base_frame_t
from babelplot.type.plot import plot_t as base_plot_t
from logger_36 import LOGGER
from matplotlib.artist import Artist as backend_plot_t  # noqa
from matplotlib.gridspec import GridSpec as grid_spec_t  # noqa
from matplotlib.markers import MarkerStyle as marker_style_t  # noqa
from matplotlib.patches import Polygon as polygon_t  # noqa
from matplotlib.pyplot import Axes as backend_frame_2d_t  # noqa
from matplotlib.pyplot import Figure as backend_figure_t  # noqa
from matplotlib.pyplot import figure as NewBackendFigure  # noqa
from mpl_toolkits.mplot3d import Axes3D as backend_frame_3d_t

import matplotlib as mlpl  # noqa

NAME = backend_e.MATPLOTLIB.value


array_t = nmpy.ndarray
backend_frame_h = backend_frame_2d_t | backend_frame_3d_t


def _NewPlot(
    frame: backend_frame_h,
    type_: plot_type_h | type(backend_plot_h),
    plot_function: h.Callable | None,
    *args,
    title: str = None,  # /!\ If _, then it is swallowed by kwargs!
    **kwargs,
) -> tuple[h.Any, h.Callable]:
    """"""
    if plot_function is None:
        # Next, priority is given to 2-D plots... which might be a problem if a 2-D and
        # 3-D frame have plot types with the same name. For example, scatter in 2-D and
        # 3-D.
        if hasattr(backend_frame_2d_t, type_):
            plot_function = getattr(backend_frame_2d_t, type_)
        elif hasattr(backend_frame_3d_t, type_):
            plot_function = getattr(backend_frame_3d_t, type_)
        else:
            raise ValueError(f"{type_}: Unknown {NAME} graph object.")

    return plot_function(frame, *args, **kwargs), plot_function


def _NewFrame(
    figure: backend_figure_t,
    _: int,
    __: int,
    *args,
    title: str = None,
    dim: dim_e = dim_e.XY,
    **kwargs,
) -> backend_frame_h:
    """"""
    if dim is dim_e.XY:
        output = figure.subplots(*args, **kwargs)
    elif dim is dim_e.XYZ:
        # See note below
        output = backend_frame_3d_t(figure, *args, auto_add_to_figure=False, **kwargs)
        figure.add_axes(output)
    else:
        raise NotImplementedError(f"{dim}: Dimension management not implemented yet")
    if title is not None:
        output.set_title(title)

    return output


def _AdjustLayout(figure: figure_t, /) -> None:
    """"""
    raw_figure = figure.backend

    if figure.title is not None:
        raw_figure.suptitle(figure.title)
    for frame in figure.frames:
        if frame.title is not None:
            frame.backend.set_title(frame.title)
        for plot in frame.plots:
            if plot.title is not None:
                plot.backend.set_label(plot.title)

    if figure.frames.__len__() < 2:
        return

    grid_spec = grid_spec_t(*figure.shape, figure=raw_figure)
    bottoms, tops, lefts, rights = grid_spec.get_grid_positions(raw_figure)

    for frame, (row, col) in zip(figure.frames, figure.locations):
        left, bottom, width, height = (
            lefts[col],
            bottoms[row],
            rights[col] - lefts[col],
            tops[row] - bottoms[row],
        )
        frame.backend.set_position((left, bottom, width, height))


def _Show(
    figure: figure_t,
    /,
) -> None:
    """"""
    raw = figure.backend

    raw.show()

    event_manager = raw.canvas
    event_manager.mpl_connect("close_event", lambda _: event_manager.stop_event_loop())
    event_manager.start_event_loop()


def _DefaultProperties(type_: h.Callable, /) -> dict[str, h.Any]:
    """"""
    name = type_.__name__
    properties = mlpl.rcParams.find_all(f"^{name}\\.")

    return {_key.replace(f"{name}.", ""): _vle for _key, _vle in properties.items()}


def _SetProperty(element: backend_element_h, name: str, value: h.Any, /) -> None:
    """"""
    if name == "marker":
        new_marker = marker_style_t(value)
        element.set_paths((new_marker.get_path(),))
    else:
        property_ = {name: value}
        try:
            pypl.setp(element, **property_)
        except AttributeError:
            LOGGER.error(
                f'Property "{name}": Invalid property for element of type "{type(element).__name__}"'
            )


def _Property(element: backend_element_h, name: str, /) -> h.Any:
    """"""
    try:
        output = pypl.getp(element, property=name)
    except AttributeError:
        output = None
        LOGGER.error(
            f'Property "{name}": Invalid property for element of type "{type(element).__name__}"'
        )

    return output


# noinspection PyTypeChecker
plot_t: base_plot_t = type(
    "plot_t",
    (base_plot_t,),
    {
        "BackendDefaultProperties": staticmethod(_DefaultProperties),
        "BackendSetProperty": staticmethod(_SetProperty),
        "BackendProperty": staticmethod(_Property),
    },
)
# noinspection PyTypeChecker
frame_t: base_frame_t = type(
    "frame_t",
    (base_frame_t,),
    {
        "plot_class": plot_t,
        "NewBackendPlot": staticmethod(_NewPlot),
        "BackendSetProperty": staticmethod(_SetProperty),
        "BackendProperty": staticmethod(_Property),
    },
)
# noinspection PyTypeChecker
figure_t: base_figure_t = type(
    "figure_t",
    (base_figure_t,),
    {
        "frame_class": frame_t,
        "NewBackendFigure": staticmethod(NewBackendFigure),
        "NewBackendFrame": staticmethod(_NewFrame),
        "AdjustLayout": _AdjustLayout,
        "BackendShow": _Show,
        "BackendSetProperty": staticmethod(_SetProperty),
        "BackendProperty": staticmethod(_Property),
    },
)


def _Polygon(
    frame: backend_frame_2d_t, xs: array_t, ys: array_t, *_, **kwargs
) -> polygon_t:
    """"""
    output = polygon_t(nmpy.vstack((xs, ys)).T, **kwargs)
    frame.add_patch(output)

    return output


def _Arrows2(frame: backend_frame_2d_t, *args, **kwargs) -> backend_plot_t:
    """"""
    if args.__len__() == 2:
        u, v = args
        x, y = u.shape
    else:
        x, y, u, v = args

    if isinstance(x, int):
        x, y = nmpy.meshgrid(range(x), range(y), indexing="ij")

    u = nmpy.asarray(u)
    v = nmpy.asarray(v)
    if u.ndim == 1:
        x = x.ravel()
        y = y.ravel()

    return frame.quiver(x, y, u, v, **kwargs)


def _Arrows3(frame: backend_frame_3d_t, *args, **kwargs) -> backend_plot_t:
    """"""
    if args.__len__() == 3:
        u, v, w = args
        x, y, z = u.shape
    else:
        x, y, z, u, v, w = args

    if isinstance(x, int):
        x, y, z = nmpy.meshgrid(range(x), range(y), range(z), indexing="ij")

    u = nmpy.asarray(u)
    v = nmpy.asarray(v)
    w = nmpy.asarray(w)
    if u.ndim == 1:
        x = x.ravel()
        y = y.ravel()
        z = z.ravel()

    if ((color := kwargs.get("color")) is not None) and isinstance(color, nmpy.ndarray):
        kwargs = kwargs.copy()
        kwargs["color"] = nmpy.vstack((color, nmpy.repeat(color, 2, axis=0)))

    return frame.quiver(x, y, z, u, v, w, **kwargs)


def _ElevationSurface(frame: backend_frame_2d_t, *args, **kwargs) -> backend_plot_t:
    """"""
    if args.__len__() == 1:
        elevation = args[0]
        x, y = nmpy.meshgrid(
            range(elevation.shape[0]), range(elevation.shape[1]), indexing="ij"
        )
    else:
        x, y, elevation = args

    return frame.plot_surface(x, y, elevation, **kwargs)


def _Isocontour(frame: backend_frame_2d_t, *args, **kwargs) -> backend_plot_t:
    """"""
    if args.__len__() == 2:
        values, value = args
        output = frame.contour(values, (value,), **kwargs)
    else:
        x, y, values, value = args
        output = frame.contour(x, y, values, (value,), **kwargs)
    if not frame.yaxis_inverted():
        frame.invert_yaxis()
        frame.xaxis.tick_top()

    return output


def _Isosurface(frame: backend_frame_2d_t, *args, **kwargs) -> backend_plot_t:
    """"""
    if "step_size" in kwargs:
        mc_kwargs = {"step_size": kwargs["step_size"]}
        kwargs = kwargs.copy()
        del kwargs["step_size"]
    else:
        mc_kwargs = {}

    vertices, triangles, *_ = msre.marching_cubes(*args, **mc_kwargs)

    return _Mesh(frame, triangles, vertices, **kwargs)


def _Mesh(
    frame: backend_frame_2d_t, triangles: array_t, vertices: array_t, *_, **kwargs
) -> backend_plot_t:
    """"""
    return frame.plot_trisurf(
        vertices[:, 0], vertices[:, 1], triangles, vertices[:, 2], **kwargs
    )


def _BarH(frame: backend_frame_2d_t, *args, **kwargs) -> backend_plot_t:
    """"""
    if args.__len__() == 1:
        counts = args[0]
        positions = range(counts.__len__())
    else:
        positions, counts = args

    return frame.barh(positions, counts, **kwargs)


def _BarV(frame: backend_frame_2d_t, *args, **kwargs) -> backend_plot_t:
    """"""
    if args.__len__() == 1:
        counts = args[0]
        positions = range(counts.__len__())
    else:
        positions, counts = args

    return frame.bar(positions, counts, **kwargs)


def _Bar3(frame: backend_frame_2d_t, *args, **kwargs) -> backend_plot_t:
    """"""
    if args.__len__() == 1:
        counts = nmpy.asarray(args[0])
        x, y = counts.shape
    else:
        x, y, counts = args
        counts = nmpy.asarray(counts)
    if isinstance(x, int):
        x, y = nmpy.meshgrid(range(x), range(y), indexing="ij")
    if counts.ndim == 1:
        x = x.ravel()
        y = y.ravel()

    width = kwargs.get("width", 0.8)
    depth = kwargs.get("depth", 0.8)
    offset = kwargs.get("offset", 0.0)
    kwargs = {
        _key: _vle
        for _key, _vle in kwargs.items()
        if _key not in ("width", "depth", "offset")
    }

    return frame.bar3d(x, y, offset, width, depth, counts, **kwargs)


def _Text2(frame: backend_frame_2d_t, text, x, y, *_, **kwargs) -> backend_plot_t:
    """"""
    return frame.text(x, y, text, **kwargs)


def _Text3(frame: backend_frame_2d_t, text, x, y, z, *_, **kwargs) -> backend_plot_t:
    """"""
    return frame.text(x, y, z, text, **kwargs)


PLOTS = PlotsFromTemplate()

PLOTS[plot_e.SCATTER][1] = backend_frame_2d_t.scatter
PLOTS[plot_e.POLYLINE][1] = backend_frame_2d_t.plot
PLOTS[plot_e.POLYGON][1] = _Polygon
PLOTS[plot_e.ARROWS][1] = _Arrows2
PLOTS[plot_e.ISOSET][1] = _Isocontour
PLOTS[plot_e.BARH][1] = _BarH
PLOTS[plot_e.BARV][1] = _BarV
PLOTS[plot_e.PIE][1] = backend_frame_2d_t.pie
PLOTS[plot_e.IMAGE][1] = backend_frame_2d_t.matshow
PLOTS[plot_e.TEXT][1] = _Text2

PLOTS[plot_e.SCATTER][2] = backend_frame_3d_t.scatter
PLOTS[plot_e.POLYLINE][2] = backend_frame_3d_t.plot
PLOTS[plot_e.ARROWS][2] = _Arrows3
PLOTS[plot_e.ELEVATION][2] = _ElevationSurface
PLOTS[plot_e.ISOSET][2] = _Isosurface
PLOTS[plot_e.MESH][2] = _Mesh
PLOTS[plot_e.BAR3][2] = _Bar3
PLOTS[plot_e.TEXT][2] = _Text3


TRANSLATIONS = {
    "color": "c",
    "color_edge": "edgecolors",
    "color_face": "facecolors",
    "color_max": "vmax",
    "color_min": "vmin",
    "color_scaling": "norm",
    "depth_shade": "depthshade",
    "plot_non_finite": "plotnonfinite",
    "size": "s",
    "width_edge": "linewidths",
    ("AddFrame", "azimuth"): "azim",
    ("AddFrame", "elevation"): "elev",
    (_Arrows2, "color"): "color",
    (_Arrows3, "color"): "color",
    (_Bar3, "color"): "color",
    (_BarH, "color"): "color",
    (_BarH, "offset"): "left",
    (_BarV, "color"): "color",
    (_BarV, "offset"): "bottom",
    (_ElevationSurface, "color_face"): "color",
    (_Isocontour, "color"): "colors",
    (_Isosurface, "color_face"): "color",
    (_Mesh, "color_face"): "color",
    (_Polygon, "color_edge"): "edgecolor",
    (_Polygon, "color_face"): "facecolor",
    (backend_frame_2d_t.pie, "color"): "colors",
    (backend_frame_2d_t.quiver, "color"): "color",
    (backend_frame_3d_t.quiver, "color"): "colors",
    (backend_frame_3d_t.scatter, 2): "zs",
}


# From: https://matplotlib.org/stable/api/prev_api_changes/api_changes_3.4.0.html
# *Axes3D automatically adding itself to Figure is deprecated*
#
# New Axes3D objects previously added themselves to figures when they were created,
# unlike all other Axes classes, which lead to them being added twice if
# fig.add_subplot(111, projection='3d') was called.
#
# This behavior is now deprecated and will warn. The new keyword argument
# auto_add_to_figure controls the behavior and can be used to suppress the warning. The
# default value will change to False in Matplotlib 3.5, and any non-False value will be
# an error in Matplotlib 3.6.
#
# In the future, Axes3D will need to be explicitly added to the figure
#
# fig = Figure()
# ax = Axes3d(fig)
# fig.add_axes(ax)
#
# as needs to be done for other axes.Axes subclasses. Or, a 3D projection can be made
# via:
#
# fig.add_subplot(projection='3d')


"""
COPYRIGHT NOTICE

This software is governed by the CeCILL  license under French law and
abiding by the rules of distribution of free software.  You can  use,
modify and/ or redistribute the software under the terms of the CeCILL
license as circulated by CEA, CNRS and INRIA at the following URL
"http://www.cecill.info".

As a counterpart to the access to the source code and  rights to copy,
modify and redistribute granted by the license, users are provided only
with a limited warranty  and the software's author,  the holder of the
economic rights,  and the successive licensors  have only  limited
liability.

In this respect, the user's attention is drawn to the risks associated
with loading,  using,  modifying and/or developing or reproducing the
software by the user in light of its specific status of free software,
that may mean  that it is complicated to manipulate,  and  that  also
therefore means  that it is reserved for developers  and  experienced
professionals having in-depth computer knowledge. Users are therefore
encouraged to load and test the software's suitability as regards their
requirements in conditions enabling the security of their systems and/or
data to be ensured and,  more generally, to use and operate it in the
same conditions as regards security.

The fact that you are presently reading this means that you have had
knowledge of the CeCILL license and that you accept its terms.

SEE LICENCE NOTICE: file README-LICENCE-utf8.txt at project source root.

This software is being developed by Eric Debreuve, a CNRS employee and
member of team Morpheme.
Team Morpheme is a joint team between Inria, CNRS, and UniCA.
It is hosted by the Centre Inria d'Université Côte d'Azur, Laboratory
I3S, and Laboratory iBV.

CNRS: https://www.cnrs.fr/index.php/en
Inria: https://www.inria.fr/en/
UniCA: https://univ-cotedazur.eu/
Centre Inria d'Université Côte d'Azur: https://www.inria.fr/en/centre/sophia/
I3S: https://www.i3s.unice.fr/en/
iBV: http://ibv.unice.fr/
Team Morpheme: https://team.inria.fr/morpheme/
"""
