# -*- coding: utf-8 -*-
"""
Components that require Floris.

@author: ricriv
"""

# %% Import.

import sys
from copy import deepcopy

import numpy as np
import xarray as xr

from wind_farm_loads.tool_agnostic import (
    _get_sensor_names,
    _preallocate_ilktn,
    _predict_loads_pod,
    _predict_loads_sector_average,
    rotate_grid,
)

# %% Utils.


def monkey_patch_floris_solver():
    """
    Replace Floris solver with the ones that are modified to provide no deficit around the current turbine. Must be called before importing Floris.

    Returns
    -------
    None.

    """
    import wind_farm_loads._floris_solver

    sys.modules["floris.core.solver"] = wind_farm_loads._floris_solver
    import floris  # noqa: F401


def _findex_to_ilk(v, I, L, K):
    """
    Convert a variable from shape `(findex, I)` to `(I, L, K)`.

    Assumes that the flow conditions have been generated as

    .. code-block:: python

       wd, ws = np.meshgrid(wd_ambient, ws_ambient, indexing="ij")
       wind_directions = wd.ravel()
       wind_speeds = ws.ravel()

    Parameters
    ----------
    v : float, array_like
        Input variable that needs to be converted to `(I, L, K)` shape.
    I : int
        Number of turbines.
    L : int
        Number of wind directions.
    K : int
        Number of wind speeds.


    Returns
    -------
    r : (I, L, K) ndarray
        Input variable converted to `(I, L, K)` shape.

    """
    return v.T.reshape(I, L, K)


# %% Functions to extract the inflow.


def compute_flow_map(
    fmodel,
    x_grid,
    y_grid,
    z_grid,
    align_in_yaw=True,
    align_in_tilt=True,
    axial_wind=False,
    wt=None,
    wd=None,
    ws=None,
    time=None,
    dtype=np.float32,
    save_grid=False,
):
    r"""
    Compute the effective wind speed and Turbulence Intensity over a rotor.

    This function receives a grid, and then rotates it by the wind direction. Optionally,
    the grid is also rotated by the yaw and tilt of each turbine to align it with the rotor plane.
    Finally, the grid is translated to each rotor center and the flow map is computed.

    Parameters
    ----------
    fmodel : FlorisModel
        Floris model. Must follow a call to `run()`.
    x_grid : (N, M) or (N, M, Type) ndarray
        x coordinate (downwind) of the grid points, before rotation by yaw and tilt. Should be 0.
        Typically generated by `make_rectangular_grid` or `make_polar_grid`.
        The first 2 dimensions cover the rotor, while the last is over the turbine types.
        If the user passes a 2D array, the grid is assumed to be the same for all turbine types.
    y_grid : (N, M) or (N, M, Type) ndarray
        List of y coordinate (crosswind) of the grid points, before rotation by yaw and tilt.
        Typically generated by `make_rectangular_grid` or `make_polar_grid`.
        If the user passes a 2D array, the grid is assumed to be the same for all turbine types.
    z_grid : (N, M) or (N, M, Type) ndarray
        List of z coordinate (up) of the grid points, before rotation by yaw and tilt.
        Typically generated by `make_rectangular_grid` or `make_polar_grid`.
        If the user passes a 2D array, the grid is assumed to be the same for all turbine types.
    align_in_yaw : bool, optional
        If `True` (default) the grid is aligned in yaw with the rotor plane.
    align_in_tilt : bool, optional
        If `True` (default) the grid is aligned in tilt with the rotor plane.
    axial_wind : bool, optional
        If `True` the axial wind speed and TI are returned. That is, the downstream wind speed computed by PyWake
        is multiplied by :math:`\cos(\mathrm{yaw}) \cos(\mathrm{tilt})`. The default is `False`.
    wt : int, (I) array_like, optional
        Wind turbines. Must be a subset of the one contained in `sim_res`.
        The default is `None`, which means to use all available wind turbines.
    wd : float, (L) array_like, optional
        Wind direction, in deg. Must be a subset of the one contained in `sim_res`.
        The default is `None`, which means to use all available wind directions.
    ws : float, (K) array_like, optional
        Wind speed. Must be a subset of the one contained in `sim_res`.
        The default is `None`, which means to use all available wind speeds.
    time : float, (Time) array_like, optional
        Time. Must be a subset of the one contained in `sim_res`.
        The default is `None`, which means to use all available time instants.
    dtype : data-type, optional
        The desired data-type for the result. The default is single precision,
        which should be enough for all outputs. The properties of each type can
        be checked with `np.finfo(np.float32(1.0))`.
    save_grid : bool, optional
        If `True` the grid will be saved for all inflow conditions. Since this comes at a significant
        memory cost, it is recommended to switch it on only for debug purposes.
        The default is `False`.

    Returns
    -------
    flow_map : xarray DataSet
        Effective wind speed, effective turbulence intensity and corresponding grid points
        for each turbine and flow case.

    """
    n_types = len(fmodel.configuration["farm"]["turbine_type"])

    # The grid must be a numpy array with 3 dimensions.
    # The first 2 cover the rotor, while the last is over the types.
    # This implies that all turbine types must have the same number of grid points.
    # If the user passes a 2D array, the grid is assumed to be the same for all types.
    if x_grid.ndim == 2 and y_grid.ndim == 2 and z_grid.ndim == 2:
        x_grid_t = np.broadcast_to(
            x_grid[:, :, np.newaxis], (x_grid.shape[0], x_grid.shape[1], n_types)
        )
        y_grid_t = np.broadcast_to(
            y_grid[:, :, np.newaxis], (y_grid.shape[0], y_grid.shape[1], n_types)
        )
        z_grid_t = np.broadcast_to(
            z_grid[:, :, np.newaxis], (z_grid.shape[0], z_grid.shape[1], n_types)
        )

    elif x_grid.ndim == 3 and y_grid.ndim == 3 and z_grid.ndim == 3:
        x_grid_t = x_grid
        y_grid_t = y_grid
        z_grid_t = z_grid
        # Check that there is 1 grid per turbine type.
        if x_grid_t.shape[2] != n_types:
            raise ValueError(
                f"{x_grid_t.shape[2]} grid types provided, but {n_types} were expected."
            )
    else:
        raise ValueError("The grid must be a 2D or 3D array.")

    # Map wind direction and speed to findex.
    # unique() is needed because fmodel should contain all combinations of wind direction and speed.
    wd_ = np.unique(fmodel.core.flow_field.wind_directions)
    ws_ = np.unique(fmodel.core.flow_field.wind_speeds)
    L = wd_.size
    K = ws_.size
    assert (
        fmodel.core.flow_field.n_findex == L * K
    ), f"There are {fmodel.core.flow_field.n_findex} flow cases, but I have found {L} wind directions and {K} wind speeds."
    findex_table = xr.DataArray(
        data=np.arange(fmodel.core.flow_field.n_findex).reshape(L, K),
        coords={
            "wd": wd_,
            "ws": ws_,
        },
    )

    # The default value of wt, wd, ws and time is the one in fmodel.
    I = fmodel.core.farm.layout_x.size
    wt_ = np.arange(I)
    if wt is not None:
        # Test that the input for wt is contained in fmodel.
        if not np.all(np.isin(np.atleast_1d(wt), wt_)):
            raise ValueError("Input 'wt' must be contained in the Floris model.")
        wt_ = np.atleast_1d(wt)

    if wd is not None:
        wd_ = np.atleast_1d(wd)
        # Test that the input for wd is contained in fmodel.
        if not np.all(np.isin(wd_, fmodel.core.flow_field.wind_directions)):
            raise ValueError("Input 'wd' must be contained in the Floris model.")

    if ws is not None:
        ws_ = np.atleast_1d(ws)
        # Test that the input for ws is contained in fmodel.
        if not np.all(np.isin(ws_, fmodel.core.flow_field.wind_speeds)):
            raise ValueError("Input 'ws' must be contained in the Floris model.")

    # Convert yaw and tilt to arrays.
    # If time is not present the result has shape (I, L, K), i.e. (turbines, wind directions, wind speeds).
    # Instead, if time is present, the result has shape (I, Time), i.e. (turbines, time).
    # These arrays are contained in fmodel, therefore all turbines, directions and speeds and times must be used.
    # In Floris yaw and tilt have shape (fmodel.core.flow_field.n_findex, I).
    # Transpose it to bring it to (I, L*K) and then reshape to (I, L, K).
    # Here we are assuming that the flow conditions have been generated as
    #    wd, ws = np.meshgrid(wd_ambient, ws_ambient, indexing="ij")
    #    wind_directions = wd.ravel()
    #    wind_speeds = ws.ravel()
    if align_in_yaw:
        yaw_turbines_ = fmodel.core.farm.yaw_angles
    else:
        yaw_turbines_ = 0.0
    yaw_turbines_ = _findex_to_ilk(yaw_turbines_, I, L, K)

    # Set tilt angle.
    if align_in_tilt:
        # Compute tilt angle for the effective wind speed at each findex and turbine.
        tilt_turbines_ = fmodel.core.farm.calculate_tilt_for_eff_velocities(
            fmodel.turbine_average_velocities
        )
        tilt_turbines_ = tilt_turbines_
    else:
        tilt_turbines_ = 0.0
    tilt_turbines_ = _findex_to_ilk(tilt_turbines_, I, L, K)

    # Conveniently access turbine position.
    x_turbines_ = fmodel.core.farm.layout_x  # Shape: (I, )
    y_turbines_ = fmodel.core.farm.layout_y  # Shape: (I, )
    z_turbines_ = fmodel.core.farm.hub_heights  # Shape: (n_findex, I)

    # Preallocate DataSet for effective wind speed, turbulence intensity and grid points.
    # In the flow map returned by this function wt, wd and ws, or time, are placed first, followed by the quantity and grid dimensions.
    # This order enables vectorization in predict_loads_pod().
    # Each turbine type is allowed to have a different grid, but all grids must have the same number of points.
    # The grid dimensions are labeled q0 and q1 because they might either be y and z or radius and azimuth.
    xr_dict = {}
    xr_dict["flow"] = xr.DataArray(
        data=np.full(
            (
                wt_.size,
                wd_.size,
                ws_.size,
                2,  # Effective wind speed and TI.
                x_grid_t.shape[0],
                x_grid_t.shape[1],
            ),
            np.nan,
            dtype=dtype,
        ),
        coords={
            "wt": wt_,
            "wd": wd_,
            "ws": ws_,
            "quantity": ["WS_eff", "TI_eff"],
        },
        dims=["wt", "wd", "ws", "quantity", "q0", "q1"],
    )

    if save_grid:
        xr_dict["grid"] = xr.DataArray(
            data=np.full(
                (
                    wt_.size,
                    wd_.size,
                    ws_.size,
                    3,  # x, y, z
                    x_grid_t.shape[0],
                    x_grid_t.shape[1],
                ),
                np.nan,
                dtype=dtype,
            ),
            coords={
                "wt": wt_,
                "wd": wd_,
                "ws": ws_,
                "axis": ["x", "y", "z"],
            },
            dims=["wt", "wd", "ws", "axis", "q0", "q1"],
        )
    ds = xr.Dataset(xr_dict)

    # Convert all angles from deg to rad.
    wd_rad = np.deg2rad(wd_)
    yaw_turbines_ = np.deg2rad(yaw_turbines_)
    tilt_turbines_ = np.deg2rad(tilt_turbines_)

    cos_yaw_cos_tilt = np.cos(yaw_turbines_) * np.cos(tilt_turbines_)

    angle_ref = np.deg2rad(90.0)

    # Make a local copy of fmodel.
    fmodel_ = deepcopy(fmodel)

    # Loop over the turbines.
    for i in wt_:
        # Get type of current turbine.
        # The type is just a string in fmodel.configuration["farm"]["turbine_type"]. How do we convert it to int?
        i_type = 0
        # Loop over wind directions.
        for l in range(wd_.size):
            # Loop over wind speeds.
            for k in range(ws_.size):
                # Get flow index.
                findex = int(findex_table[l, k])
                # Convert grid from downwind-crosswind-z to east-north-z.
                # While doing that, also rotate by yaw and tilt.
                # This can be done because the order of rotations is first yaw and then tilt.
                # It will NOT work for a floating turbine.
                # We rely on this function to create new arrays, so that the following
                # translation will not affect the original ones.
                # The formula for the yaw is taken from py_wake.wind_turbines._wind_turbines.WindTurbines.plot_xy()
                x_grid_, y_grid_, z_grid_ = rotate_grid(
                    x_grid_t[:, :, i_type],
                    y_grid_t[:, :, i_type],
                    z_grid_t[:, :, i_type],
                    yaw=angle_ref - wd_rad[l] + yaw_turbines_[i, l, k],  # [rad]
                    tilt=-tilt_turbines_[i, l, k],  # [rad]
                    degrees=False,
                )

                # Move grid to rotor center. The turbine position is in east-north-z coordinates.
                x_grid_ += x_turbines_[i]
                y_grid_ += y_turbines_[i]
                z_grid_ += z_turbines_[findex, i]

                ilk = {"wt": wt_[i], "wd": wd_[l], "ws": ws_[k]}
                if save_grid:
                    ds["grid"].loc[{**ilk, "axis": "x"}] = x_grid_
                    ds["grid"].loc[{**ilk, "axis": "y"}] = y_grid_
                    ds["grid"].loc[{**ilk, "axis": "z"}] = z_grid_
                # Set the current flow condition and run the model.
                # fmodel_.reset_operation()
                fmodel_.set(
                    wind_directions=[wd_[l]],
                    wind_speeds=[ws_[k]],
                    turbulence_intensities=[
                        fmodel.core.flow_field.turbulence_intensities[findex]
                    ],
                    yaw_angles=fmodel.core.farm.yaw_angles[[findex], :],
                )
                fmodel_.run()
                # Compute wind speed over rotor disk.
                # There is only 1 flow index on the rows. Points are on the columns.
                flow_map_ws = fmodel_.sample_flow_at_points(
                    x_grid_.ravel(), y_grid_.ravel(), z_grid_.ravel()
                )[0, :]
                ds["flow"].loc[{**ilk, "quantity": "WS_eff"}] = flow_map_ws.reshape(
                    x_grid_.shape
                )

                # I cannot find an equivalent function for the TI field.
                # Get the TI at the hub for each turbine, and assume that it is uniform over the rotor disk.
                # First dimension is findex, second is turbine. third and fourth are grid points (there is only 1).
                # Alternatively, we could adopt the ghost turbine approach from calculate_horizontal_plane_with_turbines(),
                # which means to move a new ghost turbine everywhere on the grid and get the TI there.
                turbine_ti = fmodel_.get_turbine_TIs()[0, :, 0, 0]
                ds["flow"].loc[{**ilk, "quantity": "TI_eff"}] = turbine_ti[i]

    # Project wind speed.
    if axial_wind:
        ds["flow"] *= cos_yaw_cos_tilt[:, :, :, np.newaxis, np.newaxis, np.newaxis]

    return ds


# %% Functions to evaluate the loads.


def predict_loads_rotor_average(
    surrogates, fmodel, *additional_inputs, dtype=np.float32, ti_in_percent=True
):
    r"""
    Evaluate the load surrogate models based on rotor-averaged wind speed and turbulence intensity. Additional (control) inputs are supported as well.

    Each load surrogate is evaluated as

    .. math::
      y = f(\mathrm{WS}, \mathrm{TI}, \boldsymbol{\theta})

    where :math:`\mathrm{WS}` is the rotor-averaged wind speed, :math:`\mathrm{TI}` is the rotor-averaged turbulence intensity and
    :math:`\boldsymbol{\theta}` are the additional inputs (typically, control parameters). The surrogates are evaluated
    for all turbines and ambient inflow conditions.

    The load database has been described in
    `Guilloré, A., Campagnolo, F. & Bottasso, C. L. (2024). A control-oriented load surrogate model based on sector-averaged inflow quantities: capturing damage for unwaked, waked, wake-steering and curtailed wind turbines <https://doi.org/10.1088/1742-6596/2767/3/032019>`_
    where it was proposed to include the controller set point by adding the yaw, pitch and rotor speed.
    This function has been developed using the surrogate models trained by Hari, which are based on the database provided by TUM.

    Parameters
    ----------
    surrogates : dict of surrogates_interface.surrogates.SurrogateModel
        Dictionary containing surrogate models. The keys will be used as sensor names.
    fmodel : FlorisModel
        Floris model. Must follow a call to `run()`.
    additional_inputs : list of (findex, I) ndarray
        Additional inputs to evaluate the load surrogate models.
        Must be coherent with the Floris model. Typical additional inputs are:

            - Yaw, pitch and rotor speed.
            - Yaw and curtailment level.

        It is the user responsibility to pass the inputs in the order required by the surrogates, and to use the correct units.
    dtype : data-type, optional
        The desired data-type for the result. The default is single precision,
        which should be enough for all outputs. The properties of each type can
        be checked with `np.finfo(np.float32(1.0))`.
    ti_in_percent : bool
        If `True` (default) the turbulence intensity is multiplied by 100 before evaluating the surrogates.

    Returns
    -------
    loads : xarray.DataArray
        Loads for each turbine, ambient inflow condition and sensor.

    """
    I = fmodel.core.farm.layout_x.size
    wt_ = np.arange(I)

    # Map wind direction and speed to findex.
    # unique() is needed because fmodel should contain all combinations of wind direction and speed.
    wd_ = np.unique(fmodel.core.flow_field.wind_directions)
    ws_ = np.unique(fmodel.core.flow_field.wind_speeds)
    L = wd_.size
    K = ws_.size
    assert (
        fmodel.core.flow_field.n_findex == L * K
    ), f"There are {fmodel.core.flow_field.n_findex} flow cases, but I have found {L} wind directions and {K} wind speeds."

    # Preallocate a DataArray for the results.
    loads = _preallocate_ilktn(
        wt=wt_,
        wd=wd_,
        ws=ws_,
        name=_get_sensor_names(surrogates),
        dtype=dtype,
    )

    # Fix wind speed shape.
    ws_eff = (
        _findex_to_ilk(fmodel.turbine_average_velocities, I, L, K).ravel().astype(dtype)
    )

    # Multiply the turbulence intensity by 100?
    if ti_in_percent:
        ti = fmodel.get_turbine_TIs() * 100.0
    else:
        ti = fmodel.get_turbine_TIs()
    ti = _findex_to_ilk(ti, I, L, K).ravel().astype(dtype)

    # Ensure that the additional inputs have shape (wt, wd, ws).
    theta = [
        _findex_to_ilk(
            th,
            I,
            L,
            K,
        )
        .ravel()
        .astype(dtype)
        for th in additional_inputs
    ]

    # Compose input for load surrogate.
    x = np.column_stack(
        (
            ws_eff,  # [m/s]
            ti,
            *theta,
        )
    )

    # Loop over the surrogate models and evaluate them.
    for sensor in surrogates.keys():
        loads.loc[{"name": sensor}] = (
            surrogates[sensor].predict_output(x).reshape(I, L, K)
        )
    return loads


def predict_loads_pod(
    surrogates,
    flow_map,
    *additional_inputs,
    dtype=np.float32,
    ti_in_percent=True,
):
    r"""
    Evaluate the load surrogate models based on Proper Orthogonal Decomposition of wind speed and turbulence intensity. Additional (control) inputs are supported as well.

    Each load surrogate is evaluated as

    .. math::
      y = f(\mathrm{WS}, \mathrm{TI}, \boldsymbol{\theta})

    where :math:`\mathrm{WS}` is the wind speed over the grid used to generated the POD basis, :math:`\mathrm{TI}` is
    the turbulence intensity over the grid used to generated the POD basis and :math:`\boldsymbol{\theta}` are the
    additional inputs (typically, control parameters). The surrogates are evaluated for all turbines and ambient inflow conditions.

    The load database has been described in
    `Guilloré, A., Campagnolo, F. & Bottasso, C. L. (2024). A control-oriented load surrogate model based on sector-averaged inflow quantities: capturing damage for unwaked, waked, wake-steering and curtailed wind turbines <https://doi.org/10.1088/1742-6596/2767/3/032019>`_
    where it was proposed to include the controller set point by adding the yaw, pitch and rotor speed.
    This function has been developed using the surrogate models trained by Hari, which are based on the database provided by TUM.

    Parameters
    ----------
    surrogates : dict of surrogates_interface.surrogates.SurrogateModel
        Dictionary containing surrogate models. The keys will be used as sensor names.
    flow_map : xarray DataSet
        Effective wind speed, effective turbulence intensity and corresponding grid points
        for each turbine and flow case. Generated by `compute_flow_map()`.
    additional_inputs : list of ndarray
        Additional inputs to evaluate the load surrogate models.
        Must be coherent with the flow map. PyWake rules are applied to broadcast each additional
        input to shape `(wt, wd, ws)` or `(wt, time)`. Typical additional inputs are:

            - Yaw, pitch and rotor speed.
            - Yaw and curtailment level.

        It is the user responsibility to pass the inputs in the order required by the surrogates, and to use the correct units.
    dtype : data-type, optional
        The desired data-type for the result. The default is single precision,
        which should be enough for all outputs. The properties of each type can
        be checked with `np.finfo(np.float32(1.0))`.
    ti_in_percent : bool
        If `True` (default) the turbulence intensity is multiplied by 100 before evaluating the surrogates.

    Returns
    -------
    loads : xarray.DataArray
        Loads for each turbine, ambient inflow condition and sensor.

    """
    return _predict_loads_pod(
        surrogates,
        flow_map,
        _findex_to_ilk,
        *additional_inputs,
        dtype=dtype,
        ti_in_percent=ti_in_percent,
    )


def predict_loads_sector_average(
    surrogates,
    sector_average,
    *additional_inputs,
    dtype=np.float32,
    ti_in_percent=True,
):
    r"""
    Evaluate the load surrogate models based on sector average of wind speed and turbulence intensity. Additional (control) inputs are supported as well.

    Each load surrogate is evaluated as

    .. math::
      y = f(\mathrm{WS}, \mathrm{TI}, \boldsymbol{\theta})

    where :math:`\mathrm{WS}` is the sector-averaged wind speed, :math:`\mathrm{TI}` is the sector-averaged
    turbulence intensity and :math:`\boldsymbol{\theta}` are the additional inputs (typically, control parameters).
    The surrogates are evaluated for all turbines and ambient inflow conditions.

    The load database has been described in
    `Guilloré, A., Campagnolo, F. & Bottasso, C. L. (2024). A control-oriented load surrogate model based on sector-averaged inflow quantities: capturing damage for unwaked, waked, wake-steering and curtailed wind turbines <https://doi.org/10.1088/1742-6596/2767/3/032019>`_
    where it was proposed to include the controller set point by adding the yaw, pitch and rotor speed.
    This function has been developed using the surrogate models trained by Hari, which are based on the database provided by TUM.

    Parameters
    ----------
    surrogates : dict of surrogates_interface.surrogates.SurrogateModel
        Dictionary containing surrogate models. The keys will be used as sensor names.
    sector_average : xarray DataArray
        Sector average of effective wind speed and effective turbulence intensity
        for each turbine and flow case. Generated by `compute_sector_average()`.
    additional_inputs : list of ndarray
        Additional inputs to evaluate the load surrogate models.
        Must be coherent with the sector average.

            - Yaw, pitch and rotor speed.
            - Yaw and curtailment level.

        It is the user responsibility to pass the inputs in the order required by the surrogates, and to use the correct units.
    dtype : data-type, optional
        The desired data-type for the result. The default is single precision,
        which should be enough for all outputs. The properties of each type can
        be checked with `np.finfo(np.float32(1.0))`.
    ti_in_percent : bool
        If `True` (default) the turbulence intensity is multiplied by 100 before evaluating the surrogates.

    Returns
    -------
    loads : xarray.DataArray
        Loads for each turbine, ambient inflow condition and sensor.

    """
    return _predict_loads_sector_average(
        surrogates,
        sector_average,
        _findex_to_ilk,
        *additional_inputs,
        dtype=dtype,
        ti_in_percent=ti_in_percent,
    )
