# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple, Tuple

import jax
import jax.numpy as jnp

from blackjax.types import PyTree


class SMCInfo(NamedTuple):
    """Additional information on the tempered SMC step.

    weights: jnp.ndarray
        The weights after the MCMC pass.
    proposals: PyTree
        The particles that were proposed by the MCMC pass.
    ancestors: jnp.ndarray
        The index of the particles proposed by the MCMC pass that were selected
        by the resampling step.
    log_likelihood_increment: float
        The log-likelihood increment due to the current step of the SMC algorithm.

    """

    weights: jnp.ndarray
    proposals: PyTree
    ancestors: jnp.ndarray
    log_likelihood_increment: float


def kernel(
    mcmc_kernel_factory: Callable,
    mcmc_state_generator: Callable,
    resampling_fn: Callable,
    num_mcmc_iterations: int,
):
    """Build a generic SMC kernel.

    In Feynman-Kac equivalent terms, the algo goes roughly as follows:

    ```
        M_t = mcmc_kernel_factory(logdensity_fn)
        for i in range(num_mcmc_iterations):
            x_t^i = M_t(..., x_t^i)
        G_t = log_weights_fn
        log_weights = G_t(x_t)
        idx = resample(log_weights)
        x_t = x_t[idx]
    ```


    Parameters
    ----------
    mcmc_kernel_factory: Callable
        A function of the Markov potential that returns a mcmc_kernel.
    mcmc_state_generator: Callable
        A function that creates a new mcmc state from a position and a logdensity function.
    resampling_fn: Callable
        A function that resamples the particles generated by the MCMC kernel,
        based of previously computed weights.
    num_mcmc_iterations: int
        Number of iterations of the MCMC kernel

    Returns
    -------
    A kernel that takes a PRNGKey, a set of particles, the log-likehood of the
    distribution and the Feynman-Kac potential at time `t`. The kernel returns
    a new set of particles.

    """

    def one_step(
        rng_key: jnp.ndarray,
        particles: PyTree,
        logdensity_fn: Callable,
        log_weight_fn: Callable,
    ) -> Tuple[PyTree, SMCInfo]:
        """Take one step with the SMC kernel.

        Parameters
        ----------
        rng_key: DeviceArray[int],
            JAX PRNGKey for randomness.
        particles: PyTree
            Current particles sample of the SMC algorithm.
        logdensity_fn: Callable
            Log probability function we wish to sample from.
        log_weight_fn: Callable
            A function that represents the Feynman-Kac log potential at time t.

        Returns
        -------
        particles: PyTree,
            The updated set of particles.
        info: SMCInfo,
            Additional information on the SMC step

        """
        num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
        scan_key, resampling_key = jax.random.split(rng_key, 2)

        # First advance the particles using the MCMC kernel
        mcmc_kernel = mcmc_kernel_factory(logdensity_fn)

        def mcmc_body_fn(curr_particles, curr_key):
            keys = jax.random.split(curr_key, num_particles)
            new_particles, _ = jax.vmap(mcmc_kernel, in_axes=(0, 0))(
                keys, curr_particles
            )
            return new_particles, None

        mcmc_state = jax.vmap(mcmc_state_generator, in_axes=(0, None))(
            particles, logdensity_fn
        )
        keys = jax.random.split(scan_key, num_mcmc_iterations)
        proposed_states, _ = jax.lax.scan(mcmc_body_fn, mcmc_state, keys)
        proposed_particles = proposed_states.position

        # Resample the particles depending on their respective weights
        log_weights = jax.vmap(log_weight_fn, in_axes=(0,))(proposed_particles)
        weights, log_likelihood_increment = _normalize(log_weights)
        resampling_index = resampling_fn(weights, resampling_key)
        particles = jax.tree_map(lambda x: x[resampling_index], proposed_particles)

        info = SMCInfo(
            weights, proposed_particles, resampling_index, log_likelihood_increment
        )
        return particles, info

    return one_step


def _normalize(log_weights):
    """Normalize the weight and compute the log-likelihood increment."""
    n = log_weights.shape[0]
    max_logw = jnp.max(log_weights)
    w = jnp.exp(log_weights - max_logw)
    w_mean = w.mean()

    log_likelihood_increment = jnp.log(w_mean) + max_logw

    w = w / (n * w_mean)
    return w, log_likelihood_increment
