# -*- coding: utf-8 -*-
"""
Components that do not require PyWake nor Floris.

@author: ricriv
"""

# %% Import.

from functools import partial

import jax
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from scipy.interpolate import PchipInterpolator

# %% Classes and functions.


class PchipInterpolatorWithExtrap(PchipInterpolator):
    r"""PCHIP 1-D monotonic cubic interpolation, with bounded output.

    This class is the same as PchipInterpolator from Scipy. The difference is that
    it provides a new argument ``y_extrapolate``, which sets the extrapolation value.

    Parameters
    ----------
    x : ndarray, shape (npoints, )
        1D input grid.
    y : ndarray, shape (npoints, ny)
        2D output grid.
    y_extrapolate : array_like, shape (ny)
        1D output used for extrapolation.
    """

    def __init__(self, x, y, y_extrapolate):
        super().__init__(x=x, y=y, axis=0, extrapolate=False)
        self.y_extrapolate = y_extrapolate

    def __call__(self, x):
        """
        Evaluate the interpolator.

        Parameters
        ----------
        x : (N, M, ...) ndarray
            Input.

        Returns
        -------
        yp : (N, M, ..., ny) ndarray
            Output.

        """
        # Predict output.
        yp = super().__call__(x)
        # Replace the NaN in each output with the correct extrapolation value.
        for j in range(yp.shape[-1]):
            jnp.nan_to_num(yp[..., j], copy=False, nan=self.y_extrapolate[j])
        return yp


def load_ccblade_performance(file):
    """
    Load performance data generated by CCBlade.

    Parameters
    ----------
    file : str
        File path.

    Returns
    -------
    df : pandas DataFrame
        Table with the performance data.

    """
    df = pd.read_table(file, index_col=0)
    df.index.name = "Wind speed (m/s)"  # Remove #
    df.columns = [s.strip() for s in df.columns]
    return df


@partial(jax.jit, inline=True)
def _rot_x(x):
    """
    Rotation tensor around x axis.

    Parameters
    ----------
    x : float
        Angle in radian.

    Returns
    -------
    (3, 3) float ndarray
        Rotation tensor.

    Notes
    -----
    See `Wikipedia <https://en.wikipedia.org/wiki/Rotation_matrix#Basic_3D_rotations>`__.

    """
    s = jnp.sin(x)
    c = jnp.cos(x)
    # fmt: off
    return jnp.array([[1.0, 0.0, 0.0],
                      [0.0,   c,  -s],
                      [0.0,   s,   c]])
    # fmt: on


@partial(jax.jit, inline=True)
def _rot_y(x):
    """
    Rotation tensor around y axis.

    Parameters
    ----------
    x : float
        Angle in radian.

    Returns
    -------
    (3, 3) float ndarray
        Rotation tensor.

    Notes
    -----
    See `Wikipedia <https://en.wikipedia.org/wiki/Rotation_matrix#Basic_3D_rotations>`__.

    """
    s = jnp.sin(x)
    c = jnp.cos(x)
    # fmt: off
    return jnp.array([[c,   0.0,   s],
                      [0.0, 1.0, 0.0],
                      [-s,  0.0,   c]])
    # fmt: on


@partial(jax.jit, inline=True)
def _rot_z(x):
    """
    Rotation tensor around z axis.

    Parameters
    ----------
    x : float
        Angle in radian.

    Returns
    -------
    (3, 3) float ndarray
        Rotation tensor.

    Notes
    -----
    See `Wikipedia <https://en.wikipedia.org/wiki/Rotation_matrix#Basic_3D_rotations>`__.

    """
    s = jnp.sin(x)
    c = jnp.cos(x)
    # fmt: off
    return jnp.array([[c,    -s, 0.0],
                      [s,     c, 0.0],
                      [0.0, 0.0, 1.0]])
    # fmt: on


@jax.jit
def make_rectangular_grid(y, z):
    """
    Make a rectangular grid. The grid is in the vertical plane for 0 yaw and tilt.

    This class is meant to work with PyWake, and therefore it uses the same reference frame.
    Usually, the grid will be generated with the rotor center in (0, 0).

    The reference frame of the output grid is:

        - x is downwind and set to 0.
        - y is crosswind.
        - z is up.

    Parameters
    ----------
    y : (N, ) array_like
        1D array with the y coordinate (crosswind).
    z : (M, ) array_like
        1D array with the z coordinate (up).

    Returns
    -------
    xyz : (N, M, 3) ndarray
        x, y and z coordinates of the grid points.

    """
    y_2d, z_2d = jnp.meshgrid(y, z, indexing="ij")
    # For yaw = tilt = 0 the grid is in the vertical plane.
    x_2d = jnp.zeros_like(y_2d)
    return jnp.stack((x_2d, y_2d, z_2d), axis=2)


@partial(jax.jit, static_argnames=["degrees"])
def make_polar_grid(radius, azimuth, degrees=False):
    """
    Make a regular grid using polar coordinates. The grid is in the vertical plane for 0 yaw and tilt.

    This class is meant to work with PyWake, and therefore it uses the same reference frame.
    Usually, the grid will be generated with the rotor center in (0, 0).

    The reference frame of the output grid is:

        - x is downwind and set to 0.
        - y is crosswind.
        - z is up.

    Parameters
    ----------
    radius : (N, ) array_like
        Array of radius. The rotor center is at 0 and grows towards the blade tip.
        For example, `radius = jnp.linspace(0.0, 100.0, 10)`
    azimuth : (M, ) array_like
        Array of azimuth angles in radians, typically [0, 2*pi).
        0 is horizontal (crosswind) and grows clockwise looking downwind.
        For example, `azimuth = jnp.linspace(0.0, 2.0*jnp.pi, 20, endpoint=False)`.
        Due to periodicity, it is best to skip 2*pi, and thus set `endpoint=False`.
    degrees : bool, optional
        If `True`, then the given angles are assumed to be in degrees. Default is `False`, which means radian.

    Returns
    -------
    xyz : (N, M, 3) ndarray
        x, y and z coordinates of the grid points.

    """
    # This is more efficient than:
    #    radius_2d, azimuth_2d = jnp.meshgrid(radius, azimuth, indexing="ij")
    radius_2d = jnp.broadcast_to(radius[:, jnp.newaxis], (radius.size, azimuth.size))
    azimuth_2d = jnp.broadcast_to(azimuth[jnp.newaxis, :], (radius.size, azimuth.size))
    if degrees:
        azimuth_2d = jnp.deg2rad(azimuth_2d)
    y_2d = radius_2d * jnp.cos(azimuth_2d)
    z_2d = radius_2d * jnp.sin(azimuth_2d)
    # For yaw = tilt = 0 the grid is in the vertical plane.
    x_2d = jnp.zeros_like(y_2d)
    return jnp.stack((x_2d, y_2d, z_2d), axis=2)


@partial(jax.jit, inline=True)
def _cross(a, b):
    """
    Return the cross product of two vectors.

    Parameters
    ----------
    a : (3,) array_like
        First vector.
    a : (3,) array_like
        Second vector.

    Returns
    -------
    v : (3,) ndarray
        Cross product of a and b.

    """
    # Adapted from _cross3 in
    # https://github.com/scipy/scipy/blob/main/scipy/spatial/transform/_rotation.pyx
    # This function is much simpler than the one in numpy and JAX, and consequently has less overhead.
    return jnp.array(
        [
            a[1] * b[2] - a[2] * b[1],
            a[2] * b[0] - a[0] * b[2],
            a[0] * b[1] - a[1] * b[0],
        ]
    )


@partial(jax.jit, inline=True)
def _compose_quat_single(p, q):
    # Adapted from _compose_quat_single in
    # https://github.com/scipy/scipy/blob/main/scipy/spatial/transform/_rotation.pyx

    cross = _cross(p[:3], q[:3])

    return jnp.array(
        [
            p[3] * q[0] + q[3] * p[0] + cross[0],
            p[3] * q[1] + q[3] * p[1] + cross[1],
            p[3] * q[2] + q[3] * p[2] + cross[2],
            p[3] * q[3] - p[0] * q[0] - p[1] * q[1] - p[2] * q[2],
        ]
    )


@jax.jit
def _compose_quat(p, q):
    """Compose quaternions."""
    # Adapted from _compose_quat in
    # https://github.com/scipy/scipy/blob/main/scipy/spatial/transform/_rotation.pyx
    n = q.shape[0] if p.shape[0] == 1 else p.shape[0]

    product = jnp.zeros((n, 4))

    # dealing with broadcasting
    if p.shape[0] == 1:
        for ind in range(n):
            product = product.at[ind, :].set(_compose_quat_single(p[0, :], q[ind, :]))
    elif q.shape[0] == 1:
        for ind in range(n):
            product = product.at[ind, :].set(_compose_quat_single(p[ind, :], q[0, :]))
    else:
        for ind in range(n):
            product = product.at[ind, :].set(_compose_quat_single(p[ind, :], q[ind, :]))

    return product


@jax.jit
def _make_elementary_quat(axis, angles):
    """
    Convert from axis-angle representation to quaternion using the Hamilton representation.

    Parameters
    ----------
    axis : int
        Rotation axes. Must be 0, 1 or 2.
        0 = x, 1 = y and 2 = z.
    angles : (N, 3) ndarray
        Rotation angles in rad.

    Returns
    -------
    quat : (N, 4) ndarray
        Quaternions.

    """
    # Adapted from _make_elementary_quat in
    # https://github.com/scipy/scipy/blob/main/scipy/spatial/transform/_rotation.pyx
    # See formula at https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Definition
    quat = jnp.zeros((angles.shape[0], 4))
    half_angle = angles / 2
    quat = quat.at[:, axis].set(jnp.sin(half_angle))
    quat = quat.at[:, 3].set(jnp.cos(half_angle))

    return quat


@partial(jax.jit, static_argnames=["intrinsic"])
def _elementary_quat_compose(seq, angles, intrinsic):
    """Compose a sequence of rotations using quaternions."""
    # Adapted from _elementary_quat_compose in
    # https://github.com/scipy/scipy/blob/main/scipy/spatial/transform/_rotation.pyx

    result = _make_elementary_quat(seq[0], angles[:, 0])
    seq_len = len(seq)

    for idx in range(1, seq_len):
        if intrinsic:
            result = _compose_quat(
                result, _make_elementary_quat(seq[idx], angles[:, idx])
            )
        else:
            result = _compose_quat(
                _make_elementary_quat(seq[idx], angles[:, idx]), result
            )
    return result


@partial(jax.jit, inline=True)
def _quat_as_matrix(quat):
    """
    Represent a quaternion as a rotation tensor.

    Parameters
    ----------
    quat : (4,) array_like
        Quaternion.

    Returns
    -------
    matrix : (3, 3) ndarray
        Rotation tensor.

    """
    # Adapted from as_matrix in
    # https://github.com/scipy/scipy/blob/main/scipy/spatial/transform/_rotation.pyx
    # Formula from the wiki
    # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Rotation_matrices
    x, y, z, w = quat

    x2 = x * x
    y2 = y * y
    z2 = z * z
    w2 = w * w

    xy = x * y
    zw = z * w
    xz = x * z
    yw = y * w
    yz = y * z
    xw = x * w

    return jnp.array(
        [
            [x2 - y2 - z2 + w2, 2 * (xy - zw), 2 * (xz + yw)],
            [2 * (xy + zw), -x2 + y2 - z2 + w2, 2 * (yz - xw)],
            [2 * (xz - yw), 2 * (yz + xw), -x2 - y2 + z2 + w2],
        ]
    )


@partial(jax.jit, static_argnames=["degrees"])
def rotate_grid(grid, yaw=0.0, tilt=0.0, degrees=False):
    """
    Rotate in yaw and tilt the grid generated by `rectangular_grid` or `polar_grid`.

    This class is meant to work with PyWake, and therefore it uses the same reference frame.
    The PyWake reference frame convention is illustrated in
    `R. Riva et al., Incorporation of floater rotation and displacement in a static wind farm simulator <https://iopscience.iop.org/article/10.1088/1742-6596/2767/6/062019>`_
    and described hereafter.

     - Cartesian reference frame.
     - Origin in the rotor center.
     - x axis horizontal, positive downwind.
     - y axis horizontal, positive left when looking downwind. It is the crosswind.
     - z axis vertical, positive up.
     - yaw angle is measured starting from y, positive counter-clockwise when looking down.
     - tilt angle is measured starting from z, positive clockwise when looking crosswind.
       That is, tilt causes the top of the rotor to move downwind.

    It is assumed that the turbine is bottom-fixed. Therefore, the order of rotations is first yaw around the vertical axis (z),
    and then tilt.

    Parameters
    ----------
    grid : (N, M, 3) ndarray
        x, y and z coordinates of the grid points,
        typically generated by `make_rectangular_grid` or `make_polar_grid`.
    yaw : float, optional
        Rotor yaw angle. The default is 0.
    tilt : float, optional
        Rotor tilt angle. The default is 0.
    degrees : bool, optional
        If `True`, then the given angles are assumed to be in degrees. Default is `False`, which means radian.

    Returns
    -------
    grid_rotated : (N, M, 3) ndarray
        x, y and z coordinates of the grid points, after rotation by yaw and tilt.

    """
    angles = jnp.array([[tilt, yaw]])

    if degrees:
        angles = jnp.deg2rad(angles)

    # Compute quaternion for extrinsic rotation, i.e. global frame.
    quat = _elementary_quat_compose((1, 2), angles, intrinsic=False)[0, :]

    # Get equivalent rotation tensor.
    matrix = _quat_as_matrix(quat)

    # Arrange the grid such that each column is a different point, while the rows contain x, y, z.
    # Then, do the rotation.
    xyz = matrix @ grid.transpose(2, 0, 1).reshape(3, -1)

    # Return the grid in the original shape.
    return xyz.T.reshape(grid.shape).astype(grid.dtype)


def plot_grid(grid, close_grid=False, fig=None):
    """
    Make a 3D plot with the grid generated by `rectangular_grid` or `polar_grid`.

    Parameters
    ----------
    grid : (N, M, 3) ndarray
        x, y and z coordinates of the grid points,
    close_grid : bool
        If `True` the first azimuth value (column) is repeated to allow closing the mesh.
        This only makes sense for polar grids, and therefore the default is `False`.
    fig : matplotlib Figure, optional
        Figure to plot the grid. The default is `None`, in which case a new one will be created.

    Returns
    -------
    ax :  matplotlib Axes
        Matplotlib axes.

    """
    x_2d = grid[:, :, 0]
    y_2d = grid[:, :, 1]
    z_2d = grid[:, :, 2]
    # Make the data periodic?
    if close_grid:
        x = jnp.concatenate((x_2d, x_2d[:, [0]]), axis=1)
        y = jnp.concatenate((y_2d, y_2d[:, [0]]), axis=1)
        z = jnp.concatenate((z_2d, z_2d[:, [0]]), axis=1)
    else:
        x = x_2d
        y = y_2d
        z = z_2d

    # Plot.
    if fig is None:
        fig = plt.figure()
    ax = fig.add_subplot(projection="3d")
    ax.set_proj_type("ortho")
    ax.set_aspect("equal")
    range_x = jnp.maximum(jnp.ptp(x), 2.0)
    range_y = jnp.maximum(jnp.ptp(y), 2.0)
    range_z = jnp.maximum(jnp.ptp(z), 2.0)
    ax.set_box_aspect((range_x, range_y, range_z))
    ax.set_xlabel("x (downwind)")
    ax.set_ylabel("y (crosswind)")
    ax.set_zlabel("z (up)")
    ax.invert_xaxis()
    ax.invert_yaxis()
    ax.scatter(x_2d.ravel(), y_2d.ravel(), z_2d.ravel())
    ax.plot_wireframe(x, y, z)
    return ax


def preallocate_ilktn(
    wt, wd=None, ws=None, time=None, dtype=np.float32, data=None, **names
):
    """
    Preallocate a xarray `DataArray` to store a quantity dependent on inflow and names. Typically used to store loads and sector-average.

    The shape of the quantity depends if time is used or not.

        - Without time the shape is `(turbine, wind direction, wind speed, name_0, name_1, ...)`.
        - With time the shape is `(turbine, time, name_0, name_1, ...)`. If provided, wind direction and wind speed are saved as time-dependent coordinates.

    Parameters
    ----------
    wt : list_like
        Wind turbines.
    wd : array_like, optional
        Wind directions. Required if `time` is not `None`.
    ws : array_like, optional
        Wind speeds. Required if `time` is not `None`.
    time : array_like, optional
        Time. If provided, `wd` and `ws` become time-dependent coordinates.
    dtype : data-type, optional
        The desired data-type for the result. The default is numpy single precision,
        which should be enough for all outputs. The properties of each type can
        be checked with `np.finfo(np.float32(1.0))`.
        The only supported types are floating points from Numpy and JAX.
        Ignored if `data` is not `None`.
    data : array_like, optional
        Data to be used, with compatible shape.
        If `None` (default) the resulting array is filled with NaN.
    names : dict of list_like
        Additional coordinate names.

    Returns
    -------
    quantity : xarray.DataArray
        Quantity for each turbine, inflow condition and name.

    """
    if time is None:
        coords = {
            "wt": jnp.asarray(wt),
            "wd": jnp.asarray(wd),
            "ws": jnp.asarray(ws),
            **names,
        }
        if data is None:
            if any((dtype is np.float16, dtype is np.float32, dtype is np.float64)):
                data_ = np.full(
                    (
                        len(wt),
                        len(wd),
                        len(ws),
                        *[len(x) for x in names.values()],
                    ),
                    np.nan,
                    dtype=dtype,
                )
            elif any(
                (dtype is jnp.float16, dtype is jnp.float32, dtype is jnp.float64)
            ):
                data_ = jnp.full(
                    (
                        len(wt),
                        len(wd),
                        len(ws),
                        *[len(x) for x in names.values()],
                    ),
                    jnp.nan,
                    dtype=dtype,
                )
            else:
                raise ValueError("Unsupported dtype.")
        else:
            data_ = data
        quantity = xr.DataArray(
            data=data_,
            coords=coords,
            dims=list(coords.keys()),
        )
    else:  # time is not None
        # Set the independent coordinates: turbine, time and names.
        time_ = jnp.asarray(time)
        coords = {
            "wt": jnp.asarray(wt),
            "time": time_,
            **names,
        }
        dims = list(coords)

        # Set the dependent coordinates: wind direction and wind speed.
        if wd is not None:
            wd_ = jnp.asarray(wd)
            if wd_.size == 1:
                wd_ = jnp.broadcast_to(wd_, time_.shape)
            coords["wd"] = (["time"], wd_)
        if ws is not None:
            ws_ = jnp.asarray(ws)
            if ws_.size == 1:
                ws_ = jnp.broadcast_to(ws_, time_.shape)
            coords["ws"] = (["time"], ws_)

        # Allocate the quantity.
        if data is None:
            if any((dtype is np.float16, dtype is np.float32, dtype is np.float64)):
                data_ = np.full(
                    (
                        len(wt),
                        time_.size,
                        *[len(x) for x in names.values()],
                    ),
                    np.nan,
                    dtype=dtype,
                )
            elif any(
                (dtype is jnp.float16, dtype is jnp.float32, dtype is jnp.float64)
            ):
                data_ = jnp.full(
                    (
                        len(wt),
                        time_.size,
                        *[len(x) for x in names.values()],
                    ),
                    jnp.nan,
                    dtype=dtype,
                )
            else:
                raise ValueError("Unsupported dtype.")
        else:
            data_ = data
        quantity = xr.DataArray(
            data=data_,
            coords=coords,
            dims=dims,
        )
    return quantity


def _get_sensor_names(surrogates):
    """
    Get the sensor names from a dictionary of surrogate models.

    When a surrogate model has only 1 output, then its name is used directly.
    If instead the surrogate has multiple outputs, then a set of names in the
    form `surrogate_name.out_0, surrogate_name.out_1, ...` is created., where
    `out_i` are the surrogate `output_names`.

    Parameters
    ----------
    surrogates : dict of surrogates_interface.surrogates.SurrogateModel
        Dictionary of surrogate models.

    Returns
    -------
    names : list of str
        Names for all output sensors.

    """
    names = []
    for key, val in surrogates.items():
        if val.n_outputs == 1:
            names.append(key)
        else:
            names.extend([f"{key}.{out}" for out in val.output_names])
    return names


def plot_flow_map(flow_map, wd=None, ws=None, wt=None, time=None, quantity="WS_eff"):
    """
    Plot the effective wind speed and Turbulence Intensity over each rotor for any inflow condition.

    Parameters
    ----------
    flow_map : xarray DataSet
        Flow map generated by `compute_flow_map`. It contains the effective wind speed,
        effective turbulence intensity and corresponding grid points for each turbine and flow case.
    wt : int, (I) array_like, optional
        Wind turbines. Must be a subset of the one contained in `flow_map`.
        The default is `None`, which means to use all available wind turbines.
    wd : float, (L) array_like, optional
        Wind direction, in deg. Must be a subset of the one contained in `flow_map`.
        The default is `None`, which means to use all available wind directions.
        It is ignored if `time` is not `None`.
    ws : float, (K) array_like, optional
        Wind speed. Must be a subset of the one contained in `flow_map`.
        The default is `None`, which means to use all available wind speeds.
        It is ignored if `time` is not `None`.
    time : float, (Time) array_like, optional
        Time. Must be a subset of the one contained in `flow_map`.
        The default is `None`, which means to use all available time instants.
    quantity : str, optional
        Quantity to plot. Can be `"WS_eff"` or `"TI_eff"`. The default is `"WS_eff"`.

    Returns
    -------
    None.

    """
    if quantity == "WS_eff":
        cmap = "Blues_r"
        label = "Effective wind speed [m/s]"
    elif quantity == "TI_eff":
        cmap = "Oranges"
        label = "Effective Turbulence Intensity [-]"
    else:
        raise ValueError("quantity must be 'WS_eff' or 'TI_eff'.")

    wt_ = flow_map["wt"].values if wt is None else np.atleast_1d(wt)
    if "time" in flow_map.dims:
        time_ = flow_map["time"].values if time is None else np.atleast_1d(time)
        for wt_i in wt_:
            for time_t in time_:
                fig = plt.figure()
                fig.suptitle(f"Turbine {wt_i}, Time {time_t}")
                ax = fig.add_subplot(projection="3d")
                ax.set_proj_type("ortho")
                ax.set_aspect("equal")
                x_grid = (
                    flow_map["grid"]
                    .loc[{"wt": wt_i, "time": time_t, "axis": "x"}]
                    .values
                )
                y_grid = (
                    flow_map["grid"]
                    .loc[{"wt": wt_i, "time": time_t, "axis": "y"}]
                    .values
                )
                z_grid = (
                    flow_map["grid"]
                    .loc[{"wt": wt_i, "time": time_t, "axis": "z"}]
                    .values
                )
                ptp_x = np.ptp(x_grid)
                ptp_y = np.ptp(y_grid)
                ptp_z = np.ptp(z_grid)
                ptp_max = np.max(np.array([ptp_x, ptp_y, ptp_z]))
                ptp_min = 0.1 * ptp_max
                range_x = np.maximum(ptp_x, ptp_min)
                range_y = np.maximum(ptp_y, ptp_min)
                range_z = np.maximum(ptp_z, ptp_min)
                ax.set_box_aspect((range_x, range_y, range_z))
                ax.set_xlabel("x (east)")
                ax.set_ylabel("y (north)")
                ax.set_zlabel("z (up)")
                # ax.invert_xaxis()
                # ax.invert_yaxis()
                patch_ws = ax.scatter(
                    x_grid.ravel(),
                    y_grid.ravel(),
                    z_grid.ravel(),
                    c=flow_map["flow"]
                    .loc[{"wt": wt_i, "time": time_t, "quantity": quantity}]
                    .values,
                    cmap=cmap,
                )
                plt.colorbar(patch_ws, label=label, ax=ax)
                plt.tight_layout()

    else:  # "time" not in flow_map.dims
        wd_ = flow_map["wd"].values if wd is None else np.atleast_1d(wd)
        ws_ = flow_map["ws"].values if ws is None else np.atleast_1d(ws)
        for wt_i in wt_:
            for wd_l in wd_:
                for ws_k in ws_:

                    fig = plt.figure()
                    fig.suptitle(
                        f"Turbine {wt_i}, Wind direction {wd_l}, Wind speed {ws_k}"
                    )
                    ax = fig.add_subplot(projection="3d")
                    ax.set_proj_type("ortho")
                    ax.set_aspect("equal")
                    x_grid = (
                        flow_map["grid"]
                        .loc[{"wt": wt_i, "wd": wd_l, "ws": ws_k, "axis": "x"}]
                        .values
                    )
                    y_grid = (
                        flow_map["grid"]
                        .loc[{"wt": wt_i, "wd": wd_l, "ws": ws_k, "axis": "y"}]
                        .values
                    )
                    z_grid = (
                        flow_map["grid"]
                        .loc[{"wt": wt_i, "wd": wd_l, "ws": ws_k, "axis": "z"}]
                        .values
                    )
                    ptp_x = np.ptp(x_grid)
                    ptp_y = np.ptp(y_grid)
                    ptp_z = np.ptp(z_grid)
                    ptp_max = np.max(np.array([ptp_x, ptp_y, ptp_z]))
                    ptp_min = 0.1 * ptp_max
                    range_x = np.maximum(ptp_x, ptp_min)
                    range_y = np.maximum(ptp_y, ptp_min)
                    range_z = np.maximum(ptp_z, ptp_min)
                    ax.set_box_aspect((range_x, range_y, range_z))
                    ax.set_xlabel("x (east)")
                    ax.set_ylabel("y (north)")
                    ax.set_zlabel("z (up)")
                    # ax.invert_xaxis()
                    # ax.invert_yaxis()
                    patch_ws = ax.scatter(
                        x_grid.ravel(),
                        y_grid.ravel(),
                        z_grid.ravel(),
                        c=flow_map["flow"]
                        .loc[{"wt": wt_i, "wd": wd_l, "ws": ws_k, "quantity": quantity}]
                        .values,
                        cmap=cmap,
                    )
                    plt.colorbar(patch_ws, label=label, ax=ax)
                    plt.tight_layout()


def compute_sector_average(
    sim,
    radius,
    n_azimuth_per_sector,
    look="downwind",
    align_in_yaw=True,
    align_in_tilt=True,
    axial_wind=False,
    wt=None,
    wd=None,
    ws=None,
    time=None,
    dtype=jnp.float32,
    **kwargs,
):
    r"""
    Compute the sector-averaged effective wind speed and effective turbulence intensity, assuming 4 sectors.

    Each sector spans 90 deg in azimuth and they are oriented as left, up, right and down, as described in
    `A. Guilloré, F. Campagnolo and C. L. Bottasso, A control-oriented load surrogate model based on sector-averaged inflow quantities: capturing damage for unwaked, waked, wake-steering and curtailed wind turbines. Presented at TORQUE 2024. <https://doi.org/10.1088/1742-6596/2767/3/032019>`_

    The result can be visualized via `wind_farm_loads.tool_agnostic.plot_sector_average()`.

    The sector average is computed as

    .. math::
      V_{\mathrm{avg}}
        = \frac{4}{\pi R^2}
           \int_{0}^{R}
            \int_{\theta_{\mathrm{start}}}^{\theta_{end}}
             V(r, \theta) r \mathrm{d}r \mathrm{d}\theta

    where:
        - :math:`V` is the quantity to be integrated (effective wind speed or turbulence intensity);
        - :math:`r` the radius and :math:`R` the rotor radius;
        - :math:`\theta` is the azimuth, :math:`\theta_{\mathrm{start}}` is where the sector starts
          and :math:`\theta_{\mathrm{end}}` where it ends.

    The integral is numerically computed with the trapezoidal method.

    Parameters
    ----------
    sim : py_wake SimulationResult or floris FlorisModel
        Either:

         - Floris model. Must follow a call to `run()`.
         - Simulation result computed by PyWake. Must follow a call to the wind farm model.

    radius : int or (Grid,) or (Grid, Turbine_type) ndarray
        Radius grid in [0, 1]. Can be:

          - Integer: number of grid points along an equally-spaced array.
          - 1D array: non-dimensional radius array, applied to all turbine types.
          - 2D array: non-dimensional radius array, 1 column per turbine type.

        In all cases, the rotor radii are obtained from `sim`.
    n_azimuth_per_sector : int
        Number of points in the azimuth grid to cover one 90 deg sector.
        The azimuth step is `90.0 / (n_azimuth_per_sector - 1)` [deg].
    look : str, optional
        The left and right sectors are determined by an observer that can look
        `"upwind"` or `"downwind"`. The default is `"downwind"`.
    align_in_yaw : bool, optional
        If `True` (default) the grid is aligned in yaw with the rotor plane.
    align_in_tilt : bool, optional
        If `True` (default) the grid is aligned in tilt with the rotor plane.
    axial_wind : bool, optional
        If `True` the axial wind speed and TI are returned. The default is `False`.
    wt : int, (I) array_like, optional
        Wind turbines. Must be a subset of the one contained in `sim`.
        The default is `None`, which means to use all available wind turbines.
    wd : float, (L) array_like, optional
        Wind direction, in deg. Must be a subset of the one contained in `sim`.
        The default is `None`, which means to use all available wind directions.
    ws : float, (K) array_like, optional
        Wind speed. Must be a subset of the one contained in `sim`.
        The default is `None`, which means to use all available wind speeds.
    time : float, (Time) array_like, optional
        Time. Must be a subset of the one contained in `sim`.
        The default is `None`, which means to use all available time instants.
    dtype : data-type, optional
        The desired data-type for the result. The default is single precision,
        which should be enough for all outputs. The properties of each type can
        be checked with `np.finfo(np.float32(1.0))`.
    kwargs : dict
        Additional keyword arguments are passed to `compute_flow_map` and are tool-specific.
        Possible values are:

            - `use_single_precision` : bool. If `True`, the PyWake flow map is computed in single precision.

    Returns
    -------
    sa : xarray DataArray
        Sector-averaged wind speed, effective turbulence intensity for each turbine and flow case.

    """
    # Part of this function that depends on the tool.
    if type(sim).__name__ == "FlorisModel":
        # Get all rotor diameters.
        rotor_diameter = sim.core.farm.rotor_diameters[0, :]  # First findex.
        # Only 1 turbine type is supported for now.
        rotor_diameter = np.unique(rotor_diameter)
        turbine_type = np.zeros((sim.core.farm.n_turbines,), dtype=int)
        if rotor_diameter.size > 1:
            raise NotImplementedError(
                "Multiple turbine types are not yet supported for Floris."
            )
        # In this context, we only care about the diameter to determine the type.
        # An alternative would be:
        #   len(sim.core.farm.turbine_type)
        # But that might create more types than intended.
        from wind_farm_loads.floris import compute_flow_map
    elif type(sim).__name__ == "SimulationResult":  # PyWake
        rotor_diameter = sim.windFarmModel.windTurbines._diameters
        turbine_type = sim["type"].values
        from wind_farm_loads.py_wake import compute_flow_map
    else:
        raise TypeError(
            f"Input 'sim' must be a floris FlorisModel or a py-wake SimulationResult. Received {type(sim).__name__}."
        )

    # The tool-agnostic part starts here.
    rotor_radius = 0.5 * jnp.astype(rotor_diameter, dtype)
    n_type = rotor_radius.size

    # Make the azimuth grid. We cover 360 deg starting from -45 deg,
    # which is the beginning of the left sector.
    # We also add 1 more point at the end for easier slicing later on.
    i_sector_size = n_azimuth_per_sector - 1
    n_azimuth_rotor = 4 * i_sector_size + 1
    dazimuth = 90.0 / i_sector_size  # Azimuth step [deg].
    azimuth_grid_deg = jnp.linspace(
        -45.0, 360.0 - 45.0 + dazimuth, n_azimuth_rotor + 1, endpoint=True, dtype=dtype
    )
    azimuth_grid_rad = jnp.deg2rad(azimuth_grid_deg)

    # Make radius grid for each turbine type.
    if isinstance(radius, int):
        n_radius = radius
        # Make equally-spaced, non-dimensional, radius grid.
        radius_grid = jnp.linspace(0.0, 1.0, n_radius, dtype=dtype)
        # Make the grid dimensional, by multiplying by the rotors radius.
        # axis 0: grid points
        # axis 1: turbine types.
        radius_grid = radius_grid[:, jnp.newaxis] * rotor_radius[jnp.newaxis, :]

    elif isinstance(radius, (np.ndarray, jax.Array)):
        if radius.ndim == 1:
            n_radius = radius.size
            # Make the grid dimensional, by multiplying by the rotors radius.
            # axis 0: grid points
            # axis 1: turbine types.
            radius_grid = (
                jnp.astype(radius, dtype)[:, jnp.newaxis] * rotor_radius[jnp.newaxis, :]
            )
        elif radius.ndim == 2:
            n_radius = radius.shape[0]
            # Make the grid dimensional, by multiplying by the rotors radius.
            # axis 0: grid points
            # axis 1: turbine types.
            radius_grid = jnp.astype(radius, dtype) * rotor_radius[jnp.newaxis, :]

    # Make the polar grid for each turbine type.
    grid = jnp.zeros((n_radius, azimuth_grid_rad.size, 3, n_type), dtype=dtype)
    for i in range(n_type):
        grid = grid.at[:, :, :, i].set(
            make_polar_grid(radius_grid[:, i], azimuth_grid_rad, degrees=False)
        )

    # Get the flow map.
    flow_map = compute_flow_map(
        sim,
        grid,
        align_in_yaw=align_in_yaw,
        align_in_tilt=align_in_tilt,
        axial_wind=axial_wind,
        wt=wt,
        wd=wd,
        ws=ws,
        time=time,
        dtype=dtype,
        save_grid=False,
        **kwargs,
    )

    # Convert flow map to a JAX array.
    # When wd and ws are present, the axes ordered as
    #   0   1   2         3       4        5
    #  wt, wd, ws, quantity, radius, azimuth
    # When time is present, axes 1 and 2 are replaced by time.
    # The important point here is that we are integrating over the last 2 axes.
    flow_map_data = jnp.asarray(flow_map["flow"].values)

    # Get the radius grid for all turbines.
    # Shape: (wt, Radius)
    radius_grid_all_turbine = radius_grid[:, turbine_type].T

    # Area of 1 sector for all turbines.
    # Shape: (wt, )
    sector_area = np.pi * radius_grid_all_turbine[:, -1] ** 2 / 4.0

    # Preallocate an array to store the sector-averaged wind speed and turbulence intensity.
    # Assume 4 sectors: up, right, down and left.
    # Broadcast radius grid and area to compute the integral.
    # The radius is broadcasted to all dimensions of the flow map except for the last, which is the azimuth.
    # The area is broadcasted to all dimensions of the flow map except the last 2 (radius and azimuth), plus one more for the 4 sectors.
    if "time" in flow_map.dims:
        sa_data = jnp.full(
            (
                flow_map["wt"].size,
                flow_map["time"].size,
                2,  # WS and TI.
                4,  # sectors.
            ),
            jnp.nan,
            dtype=dtype,
        )

        radius_broadcasted = jnp.broadcast_to(
            radius_grid_all_turbine[:, np.newaxis, np.newaxis, :],
            flow_map_data.shape[:-1],
        )
        area_broadcasted = jnp.broadcast_to(
            sector_area[:, np.newaxis, np.newaxis, np.newaxis],
            (*flow_map_data.shape[:-2], 4),
        )
    else:
        sa_data = jnp.full(
            (
                flow_map["wt"].size,
                flow_map["wd"].size,
                flow_map["ws"].size,
                2,  # WS and TI.
                4,  # sectors.
            ),
            jnp.nan,
            dtype=dtype,
        )
        radius_broadcasted = jnp.broadcast_to(
            radius_grid_all_turbine[:, np.newaxis, np.newaxis, np.newaxis, :],
            flow_map_data.shape[:-1],
        )
        area_broadcasted = jnp.broadcast_to(
            sector_area[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis],
            (*flow_map_data.shape[:-2], 4),
        )

    # Set sector indices for easy labelling.
    i_up = 0
    i_right = 1
    i_down = 2
    i_left = 3

    # Compute the sector average for the up and down sectors.
    # The double integral is divided by the area later.
    sa_data = sa_data.at[..., i_up].set(
        jnp.trapezoid(
            radius_broadcasted
            * jnp.trapezoid(
                flow_map_data[..., i_sector_size : 2 * i_sector_size + 1],
                x=azimuth_grid_rad[i_sector_size : 2 * i_sector_size + 1],
            ),
            x=radius_broadcasted,
        )
    )
    sa_data = sa_data.at[..., i_down].set(
        jnp.trapezoid(
            radius_broadcasted
            * jnp.trapezoid(
                flow_map_data[..., 3 * i_sector_size : -1],
                x=azimuth_grid_rad[3 * i_sector_size : -1],
            ),
            x=radius_broadcasted,
        )
    )

    # Compute the sector average for the left and right sectors.
    # The double integral is divided by the area later.
    if look == "upwind":
        sa_data = sa_data.at[..., i_left].set(
            jnp.trapezoid(
                radius_broadcasted
                * jnp.trapezoid(
                    flow_map_data[..., 0 : i_sector_size + 1],
                    x=azimuth_grid_rad[0 : i_sector_size + 1],
                ),
                x=radius_broadcasted,
            )
        )
        sa_data = sa_data.at[..., i_right].set(
            jnp.trapezoid(
                radius_broadcasted
                * jnp.trapezoid(
                    flow_map_data[..., 2 * i_sector_size : 3 * i_sector_size + 1],
                    x=azimuth_grid_rad[2 * i_sector_size : 3 * i_sector_size + 1],
                ),
                x=radius_broadcasted,
            )
        )

    elif look == "downwind":
        sa_data = sa_data.at[..., i_right].set(
            jnp.trapezoid(
                radius_broadcasted
                * jnp.trapezoid(
                    flow_map_data[..., 0 : i_sector_size + 1],
                    x=azimuth_grid_rad[0 : i_sector_size + 1],
                ),
                x=radius_broadcasted,
            )
        )
        sa_data = sa_data.at[..., i_left].set(
            jnp.trapezoid(
                radius_broadcasted
                * jnp.trapezoid(
                    flow_map_data[..., 2 * i_sector_size : 3 * i_sector_size + 1],
                    x=azimuth_grid_rad[2 * i_sector_size : 3 * i_sector_size + 1],
                ),
                x=radius_broadcasted,
            )
        )

    else:
        raise ValueError("Parameter look must be 'upwind' or 'downwind'")

    # Divide the double integral by the sector area to obtain the average.
    sa_data = sa_data.at[...].divide(area_broadcasted)

    # Floris does not need JAX.
    if type(sim).__name__ == "FlorisModel":
        sa_data = np.asarray(sa_data)

    # Store the sector average into a xarray DataArray.
    sa = preallocate_ilktn(
        wt=flow_map["wt"].values,
        wd=flow_map["wd"].values,
        ws=flow_map["ws"].values,
        time=flow_map["time"].values if "time" in flow_map.dims else None,
        data=sa_data,
        quantity=["WS_eff", "TI_eff"],
        direction=["up", "right", "down", "left"],
    )

    return sa


def plot_sector_average(
    sector_average, wd=None, ws=None, wt=None, time=None, quantity="WS_eff"
):
    """
    Plot the sector-average effective wind speed and Turbulence Intensity for each rotor and inflow condition.

    Parameters
    ----------
    sector_average : xarray DataArray
        Sector average generated by `compute_sector_average`. It contains the effective wind speed,
        effective turbulence intensity and farm layout for each turbine and flow case.
    wt : int, (I) array_like, optional
        Wind turbines. Must be a subset of the one contained in `flow_map`.
        The default is `None`, which means to use all available wind turbines.
    wd : float, (L) array_like, optional
        Wind direction, in deg. Must be a subset of the one contained in `flow_map`.
        The default is `None`, which means to use all available wind directions.
        It is ignored if `time` is not `None`.
    ws : float, (K) array_like, optional
        Wind speed. Must be a subset of the one contained in `flow_map`.
        The default is `None`, which means to use all available wind speeds.
        It is ignored if `time` is not `None`.
    time : float, (Time) array_like, optional
        Time. Must be a subset of the one contained in `sim_res`.
        The default is `None`, which means to use all available time instants.
    quantity : str, optional
        Quantity to plot. Must be `"WS_eff"` or `"TI_eff"`. The default is `"WS_eff"`.

    Returns
    -------
    None.

    """
    if quantity == "WS_eff":
        cmap = plt.get_cmap("Blues_r")
        label = "Effective wind speed [m/s]"
        format_str = ".2f"
    elif quantity == "TI_eff":
        cmap = plt.get_cmap("Oranges")
        label = "Effective Turbulence Intensity [-]"
        format_str = ".2%"
    else:
        raise ValueError("Parameter quantity must be 'WS_eff' or 'TI_eff'.")

    wt_ = sector_average["wt"].values if wt is None else np.atleast_1d(wt)
    if "time" in sector_average.dims:
        time_ = sector_average["time"].values if time is None else np.atleast_1d(time)
        for wt_i in wt_:
            for time_t in time_:
                fig, ax = plt.subplots()
                fig.suptitle(f"Turbine {wt_i}, Time {time_t}")
                data = sector_average.loc[
                    {"wt": wt_i, "time": time_t, "quantity": quantity}
                ].values
                normalizer = mpl.colors.Normalize(
                    vmin=0.9 * data.min(), vmax=1.1 * data.max()
                )
                colors = cmap(normalizer(data))

                def get_data_str(_):
                    """Get data for the pie chart."""
                    # This function will be called 4 times (once per sector).
                    # Remember the last index and increment it by 1 to get the current sector.
                    try:
                        get_data_str.i_sector += 1
                    except AttributeError:
                        get_data_str.i_sector = 0
                    return f"{data[get_data_str.i_sector]:{format_str}}"

                # To draw the sectors we use a pie chart.
                # Unfortunately, this means that the visualization is 2D.
                ax.pie(
                    (1, 1, 1, 1),
                    explode=None,
                    labels=sector_average["direction"].values,
                    colors=colors,
                    autopct=get_data_str,
                    startangle=135,
                    counterclock=False,
                    wedgeprops={"edgecolor": "k"},
                )
                plt.colorbar(
                    mpl.cm.ScalarMappable(norm=normalizer, cmap=cmap),
                    label=label,
                    ax=ax,
                )
                plt.tight_layout()

    else:  # "time" not in sector_average.dims
        wd_ = sector_average["wd"].values if wd is None else np.atleast_1d(wd)
        ws_ = sector_average["ws"].values if ws is None else np.atleast_1d(ws)
        for wt_i in wt_:
            for wd_l in wd_:
                for ws_k in ws_:

                    fig, ax = plt.subplots()
                    fig.suptitle(
                        f"Turbine {wt_i}, Wind direction {wd_l}, Wind speed {ws_k}"
                    )
                    data = sector_average.loc[
                        {"wt": wt_i, "wd": wd_l, "ws": ws_k, "quantity": quantity}
                    ].values
                    normalizer = mpl.colors.Normalize(
                        vmin=0.9 * data.min(), vmax=1.1 * data.max()
                    )
                    colors = cmap(normalizer(data))

                    def get_data_str(_):
                        """Get data for the pie chart."""
                        # This function will be called 4 times (once per sector).
                        # Remember the last index and increment it by 1 to get the current sector.
                        try:
                            get_data_str.i_sector += 1
                        except AttributeError:
                            get_data_str.i_sector = 0
                        return f"{data[get_data_str.i_sector]:{format_str}}"

                    # To draw the sectors we use a pie chart.
                    # Unfortunately, this means that the visualization is 2D.
                    ax.pie(
                        (1, 1, 1, 1),
                        explode=None,
                        labels=sector_average["direction"].values,
                        colors=colors,
                        autopct=get_data_str,
                        startangle=135,
                        counterclock=False,
                        wedgeprops={"edgecolor": "k"},
                    )
                    plt.colorbar(
                        mpl.cm.ScalarMappable(norm=normalizer, cmap=cmap),
                        label=label,
                        ax=ax,
                    )
                    plt.tight_layout()


def _predict_loads_pod(
    surrogates,
    flow_map,
    expand_shape_fun,
    *additional_inputs,
    dtype=jnp.float32,
    ti_in_percent=True,
):
    """Tool-agnostic part of `predict_loads_pod()`."""
    # Preallocate JAX array to store the loads.
    name_sensors = _get_sensor_names(surrogates)
    if "time" in flow_map.dims:
        loads_data = jnp.full(
            (
                flow_map["wt"].size,
                flow_map["time"].size,
                len(name_sensors),
            ),
            jnp.nan,
            dtype=dtype,
        )
    else:
        loads_data = jnp.full(
            (
                flow_map["wt"].size,
                flow_map["wd"].size,
                flow_map["ws"].size,
                len(name_sensors),
            ),
            jnp.nan,
            dtype=dtype,
        )

    # Multiply the turbulence intensity by 100?
    if ti_in_percent:
        ti = flow_map["flow"].loc[{"quantity": "TI_eff"}].values * 100.0
    else:
        ti = flow_map["flow"].loc[{"quantity": "TI_eff"}].values

    # Ensure that the additional inputs have shape (wt, wd, ws) or (wt, time), then reshape to column array.
    # We take the sizes from the flow map, rather than sim_res, because the flow map might contain a subset of it.
    if "time" in flow_map.dims:
        theta = [
            expand_shape_fun(th, flow_map["wt"].size, flow_map["time"].size)
            .astype(dtype)
            .reshape(-1, 1)
            for th in additional_inputs
        ]
    else:
        theta = [
            expand_shape_fun(
                th,
                flow_map["wt"].size,
                flow_map["wd"].size,
                flow_map["ws"].size,
            )
            .astype(dtype)
            .reshape(-1, 1)
            for th in additional_inputs
        ]

    # Compose input for load surrogate.
    # We want points on the rows and features on the columns, hence the reshape.
    # Points are: wt, wd and ws or wt and time.
    # Features are: WS, TI and additional inputs.
    if "time" in flow_map.dims:
        shape_flow = (
            flow_map["wt"].size * flow_map["time"].size,
            1  # 1 because each quantity (WS and TI) is kept separate.
            * flow_map["q0"].size
            * flow_map["q1"].size,
        )
        shape_load = (flow_map["wt"].size, flow_map["time"].size)
    else:
        shape_flow = (
            flow_map["wt"].size * flow_map["wd"].size * flow_map["ws"].size,
            1  # 1 because each quantity (WS and TI) is kept separate.
            * flow_map["q0"].size
            * flow_map["q1"].size,
        )
        shape_load = (flow_map["wt"].size, flow_map["wd"].size, flow_map["ws"].size)
    x = jnp.concatenate(
        (
            flow_map["flow"]
            .loc[{"quantity": "WS_eff"}]
            .values.astype(dtype)
            .reshape(shape_flow),
            ti.astype(dtype).reshape(shape_flow),
            *theta,
        ),
        axis=1,
    )

    # The surrogates_interface package does not yet support JAX, therefore we must convert to Numpy.
    x = np.asarray(x)

    # Predict loads.
    for i, sensor in zip(range(len(surrogates)), list(surrogates.keys())):
        loads_data = loads_data.at[..., i].set(
            surrogates[sensor].predict_output(x).reshape(shape_load)
        )

    # Floris does not need JAX.
    if expand_shape_fun.__name__ == "_findex_to_ilk":
        loads_data = np.asarray(loads_data)

    # Store the loads into a xarray DataArray.
    loads = preallocate_ilktn(
        wt=flow_map["wt"].values,
        wd=flow_map["wd"].values,
        ws=flow_map["ws"].values,
        time=flow_map["time"].values if "time" in flow_map.dims else None,
        name=name_sensors,
        data=loads_data,
    )

    return loads


def _predict_loads_sector_average(
    surrogates,
    sector_average,
    expand_shape_fun,
    *additional_inputs,
    dtype=jnp.float32,
    ti_in_percent=True,
):
    """Tool-agnostic part of `predict_loads_sector_average()`."""
    # Preallocate JAX array to store the loads.
    name_sensors = _get_sensor_names(surrogates)
    if "time" in sector_average.dims:
        loads_data = jnp.full(
            (
                sector_average["wt"].size,
                sector_average["time"].size,
                len(name_sensors),
            ),
            jnp.nan,
            dtype=dtype,
        )
    else:
        loads_data = jnp.full(
            (
                sector_average["wt"].size,
                sector_average["wd"].size,
                sector_average["ws"].size,
                len(name_sensors),
            ),
            jnp.nan,
            dtype=dtype,
        )

    # Multiply the turbulence intensity by 100?
    if ti_in_percent:
        ti = sector_average.loc[{"quantity": "TI_eff"}].values * 100.0
    else:
        ti = sector_average.loc[{"quantity": "TI_eff"}].values

    # Ensure that the additional inputs have shape (wt, wd, ws) or (wt, time), then reshape to column array.
    # We take the sizes from the sector average, rather than sim_res, because the sector average might contain a subset of it.
    if "time" in sector_average.dims:
        theta = [
            expand_shape_fun(
                th,
                sector_average["wt"].size,
                sector_average["time"].size,
            )
            .astype(dtype)
            .reshape(-1, 1)
            for th in additional_inputs
        ]
        shape_load = (
            sector_average["wt"].size,
            sector_average["time"].size,
        )
    else:
        theta = [
            expand_shape_fun(
                th,
                sector_average["wt"].size,
                sector_average["wd"].size,
                sector_average["ws"].size,
            )
            .astype(dtype)
            .reshape(-1, 1)
            for th in additional_inputs
        ]
        shape_load = (
            sector_average["wt"].size,
            sector_average["wd"].size,
            sector_average["ws"].size,
        )

    # Compose input for load surrogate.
    # We want points on the rows and features on the columns, hence the reshape.
    # Points are: wt, wd and ws or wt and time.
    # Features are: sector-averaged WS, sector-averaged TI and additional inputs.
    # Assumes that the order of the sectors matches the one of the surrogate.
    # In this case, the sectors must be ordered as: up, right, down and left.
    n_sectors = sector_average["direction"].size
    x = jnp.concatenate(
        (
            sector_average.loc[{"quantity": "WS_eff"}]
            .values.astype(dtype)
            .reshape((-1, n_sectors)),
            ti.astype(dtype).reshape((-1, n_sectors)),
            *theta,
        ),
        axis=1,
    )

    # The surrogates_interface package does not yet support JAX, therefore we must convert to Numpy.
    x = np.asarray(x)

    # Predict loads.
    for i, sensor in zip(range(len(surrogates)), list(surrogates.keys())):
        loads_data = loads_data.at[..., i].set(
            surrogates[sensor].predict_output(x).reshape(shape_load)
        )

    # Floris does not need JAX.
    if expand_shape_fun.__name__ == "_findex_to_ilk":
        loads_data = np.asarray(loads_data)

    # Store the loads into a xarray DataArray.
    loads = preallocate_ilktn(
        wt=sector_average["wt"].values,
        wd=sector_average["wd"].values,
        ws=sector_average["ws"].values,
        time=sector_average["time"].values if "time" in sector_average.dims else None,
        name=name_sensors,
        data=loads_data,
    )

    return loads


def compute_lifetime_del(
    probability, loads, material, factor=20.0 * 365.0 * 24.0 * 3600.0 * 1e-8
):
    r"""
    Compute Lifetime Damage Equivalent Loads.

    The Lifetime DEL for each turbine, load channel and Wöhler exponent is computed as

    .. math::
      \mathrm{LDEL}
          = \left(
                \phi \int_{-\pi}^{+\pi} \int_{V_{\text{cut-in}}}^{V_{\text{cut-out}}} p(V, \theta) \mathrm{DEL}(V, \theta)^m \mathrm{d}V \mathrm{d}\theta
            \right)^{1/m}

    where:
        - :math:`V` is the wind speed;
        - :math:`\theta` is the wind direction;
        - :math:`\mathrm{DEL}` is the Damage Equivalent Load for a given turbine and Wöhler exponent;
        - :math:`p(V, \theta)` is the probability of the flow case;
        - :math:`\phi` is the number of seconds in 20 years divided by the estimated number of cycles;
        - :math:`m` is the Wöhler exponent.

    The integral is numerically computed with the trapezoidal method.

    Parameters
    ----------
    probability : (L, K) xarray DataArray
        Probability for each wind direction (l) and wind speed (k).
        The dimensions must be `"wd"` and `"ws"`.
    loads : (I, L, K, N) xarray DataArray or JAX ArrayImpl
        Damage Equivalent Loads for each turbine (i), wind direction (l), wind speed (k) and output channel (n).
        If it is a `DataArray` then one dimension must be names `"wt"`.
    material : (N) xarray DataArray
        Wöhler exponent for each output channel (n).
        The dimension must be `"name"`.
    factor : float, optional
        Number of seconds in 20 years divided by the estimated number of cycles.
        The default is :math:`20 \cdot 365 \cdot 24 \cdot 3600 / 10^8`.

    Raises
    ------
    ValueError
        If the probability is not normalized to 1. That is, it must be

        .. math::
          \int_{-\pi}^{+\pi} \int_{0}^{+\infty} p(V, \theta) dV d\theta = 1

    Returns
    -------
    ldel : (I, N) xarray DataArray or JAX ArrayImpl
        Lifetime Damage Equivalent Loads for each turbine (i) and output channel (n).
        If it is a `DataArray` then the dimensions are `"wt"` and `"name"`.

    """
    # Test that the probability integrates to 1.
    if probability.dtype is np.float32 or probability.dtype is jnp.float32:
        atol = 1e-6
    else:  # double precision.
        atol = 1e-14

    tot_p = float(probability.integrate("wd").integrate("ws"))
    assert np.isclose(tot_p, 1.0, atol=atol), "The probability must integrate to 1."

    # Carry out the numerical integration with the trapezoidal method using JAX.
    if isinstance(loads, xr.DataArray):
        loads_ = loads.values
    else:  # JAX array.
        loads_ = loads

    probability_ = jnp.broadcast_to(
        probability.values[jnp.newaxis, :, :, jnp.newaxis], loads.shape
    )
    material_ = jnp.broadcast_to(
        material.values[jnp.newaxis, jnp.newaxis, jnp.newaxis, :], loads.shape
    )
    ldel = probability_ * loads_**material_  # Shape: (wt, wd, ws, name)
    ldel = jnp.trapezoid(
        jnp.trapezoid(ldel, x=probability["wd"].values, axis=1),
        x=probability["ws"].values,
        axis=1,
    )  # Shape: (wt, name)
    ldel = (factor * ldel) ** (1.0 / material.values)  # Shape: (wt, name)

    # Convert back to xarray.
    if isinstance(loads, xr.DataArray):
        return xr.DataArray(
            data=ldel,
            coords={
                "wt": loads["wt"],
                "name": material["name"],
            },
        )
    else:  # JAX array.
        return ldel
