"""
=========================
The Core Population Model
=========================

This module contains tools for sampling and assigning core demographic
characteristics to simulants.

"""
from typing import Callable, Dict, Iterable, List

import numpy as np
import pandas as pd
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import PopulationView, SimulantData
from vivarium.framework.randomness import RandomnessStream

from vivarium_public_health import utilities
from vivarium_public_health.population.data_transformations import (
    assign_demographic_proportions,
    load_population_structure,
    rescale_binned_proportions,
    smooth_ages,
)


class BasePopulation:
    """Component for producing and aging simulants based on demographic data."""

    configuration_defaults = {
        "population": {
            "age_start": 0,
            "age_end": 125,
            "exit_age": None,
            "include_sex": "Both",  # Either Female, Male, or Both
        }
    }

    def __init__(self):
        self._sub_components = [AgeOutSimulants()]

    def __repr__(self) -> str:
        return "BasePopulation()"

    ##############
    # Properties #
    ##############

    @property
    def name(self) -> str:
        return "base_population"

    @property
    def sub_components(self) -> List:
        return self._sub_components

    @property
    def columns_created(self) -> List[str]:
        return ["age", "sex", "alive", "location", "entrance_time", "exit_time"]

    #################
    # Setup methods #
    #################

    # noinspection PyAttributeOutsideInit
    def setup(self, builder: Builder) -> None:
        self.config = builder.configuration.population
        self.key_columns = builder.configuration.randomness.key_columns
        if self.config.include_sex not in ["Male", "Female", "Both"]:
            raise ValueError(
                "Configuration key 'population.include_sex' must be one "
                "of ['Male', 'Female', 'Both']. "
                f"Provided value: {self.config.include_sex}."
            )

        source_population_structure = load_population_structure(builder)
        self.demographic_proportions = assign_demographic_proportions(
            source_population_structure,
            include_sex=self.config.include_sex,
        )

        self.randomness = self.get_randomness_streams(builder)
        self.register_simulants = builder.randomness.register_simulants
        self.population_view = self.get_population_view(builder)
        builder.population.initializes_simulants(
            self.on_initialize_simulants, creates_columns=self.columns_created
        )

        builder.event.register_listener("time_step", self.on_time_step, priority=8)

    @staticmethod
    def get_randomness_streams(builder: Builder) -> Dict[str, RandomnessStream]:
        return {
            "general_purpose": builder.randomness.get_stream("population_generation"),
            "bin_selection": builder.randomness.get_stream(
                "bin_selection", initializes_crn_attributes=True
            ),
            "age_smoothing": builder.randomness.get_stream(
                "age_smoothing", initializes_crn_attributes=True
            ),
            "age_smoothing_age_bounds": builder.randomness.get_stream(
                "age_smoothing_age_bounds", initializes_crn_attributes=True
            ),
        }

    def get_population_view(self, builder: Builder) -> PopulationView:
        return builder.population.get_view(self.columns_created)

    ########################
    # Event-driven methods #
    ########################

    # TODO: Move most of this docstring to an rst file.
    def on_initialize_simulants(self, pop_data: SimulantData) -> None:
        """Creates a population with fundamental demographic and simulation properties.

        When the simulation framework creates new simulants (essentially producing a new
        set of simulant ids) and this component is being used, the newly created simulants
        arrive here first and are assigned the demographic qualities 'age', 'sex',
        and 'location' in a way that is consistent with the demographic distributions
        represented by the population-level data.  Additionally, the simulants are assigned
        the simulation properties 'alive', 'entrance_time', and 'exit_time'.

        The 'alive' parameter is alive or dead.
        In general, most simulation components (except for those computing summary statistics)
        ignore simulants if they are not in the 'alive' category. The 'entrance_time' and
        'exit_time' categories simply mark when the simulant enters or leaves the simulation,
        respectively.  Here we are agnostic to the methods of entrance and exit (e.g., birth,
        migration, death, etc.) as these characteristics can be inferred from this column and
        other information about the simulant and the simulation parameters.

        """

        age_params = {
            "age_start": pop_data.user_data.get("age_start", self.config.age_start),
            "age_end": pop_data.user_data.get("age_end", self.config.age_end),
        }

        demographic_proportions = self.get_demographic_proportions_for_creation_time(
            self.demographic_proportions, pop_data.creation_time.year
        )

        self.population_view.update(
            generate_population(
                simulant_ids=pop_data.index,
                creation_time=pop_data.creation_time,
                step_size=pop_data.creation_window,
                age_params=age_params,
                demographic_proportions=demographic_proportions,
                randomness_streams=self.randomness,
                register_simulants=self.register_simulants,
                key_columns=self.key_columns,
            )
        )

    def on_time_step(self, event: Event) -> None:
        """Ages simulants each time step."""
        population = self.population_view.get(event.index, query="alive == 'alive'")
        population["age"] += utilities.to_years(event.step_size)
        self.population_view.update(population)

    ##################
    # Helper methods #
    ##################

    @staticmethod
    def get_demographic_proportions_for_creation_time(
        demographic_proportions, year: int
    ) -> pd.DataFrame:
        reference_years = sorted(set(demographic_proportions.year_start))
        ref_year_index = np.digitize(year, reference_years).item() - 1
        return demographic_proportions[
            demographic_proportions.year_start == reference_years[ref_year_index]
        ]


class AgeOutSimulants:
    """Component for handling aged-out simulants"""

    @property
    def name(self) -> str:
        return "age_out_simulants"

    # noinspection PyAttributeOutsideInit
    def setup(self, builder: Builder) -> None:
        if builder.configuration.population.exit_age is None:
            return
        self.config = builder.configuration.population
        self.population_view = builder.population.get_view(["age", "exit_time", "tracked"])
        builder.event.register_listener("time_step__cleanup", self.on_time_step_cleanup)

    def on_time_step_cleanup(self, event: Event) -> None:
        population = self.population_view.get(event.index)
        max_age = float(self.config.exit_age)
        pop = population[(population["age"] >= max_age) & population["tracked"]].copy()
        if len(pop) > 0:
            pop["tracked"] = pd.Series(False, index=pop.index)
            pop["exit_time"] = event.time
            self.population_view.update(pop)

    def __repr__(self) -> str:
        return "AgeOutSimulants()"


def generate_population(
    simulant_ids: pd.Index,
    creation_time: pd.Timestamp,
    step_size: pd.Timedelta,
    age_params: Dict[str, float],
    demographic_proportions: pd.DataFrame,
    randomness_streams: Dict[str, RandomnessStream],
    register_simulants: Callable[[pd.DataFrame], None],
    key_columns: Iterable[str] = ("entrance_time", "age"),
) -> pd.DataFrame:
    """Produces a random set of simulants sampled from the provided `population_data`.

    Parameters
    ----------
    simulant_ids
        Values to serve as the index in the newly generated simulant DataFrame.
    creation_time
        The simulation time when the simulants are created.
    age_params
        Dictionary with keys
            age_start : Start of an age range
            age_end : End of an age range

        The latter two keys can have values specified to generate simulants over an age range.
    demographic_proportions
        Table with columns 'age', 'age_start', 'age_end', 'sex', 'year',
        'location', 'population', 'P(sex, location, age| year)',
        'P(sex, location | age, year)'.
    randomness_streams
        Source of random number generation within the vivarium common random number framework.
    step_size
        The size of the initial time step.
    register_simulants
        A function to register the new simulants with the CRN framework.
    key_columns
        A list of key columns for random number generation.

    Returns
    -------
    pandas.DataFrame
        Table with columns
            'entrance_time'
                The `pandas.Timestamp` describing when the simulant entered
                the simulation. Set to `creation_time` for all simulants.
            'exit_time'
                The `pandas.Timestamp` describing when the simulant exited
                the simulation. Set initially to `pandas.NaT`.
            'alive'
                One of 'alive' or 'dead' indicating how the simulation
                interacts with the simulant.
            'age'
                The age of the simulant at the current time step.
            'location'
                The location indicating where the simulant resides.
            'sex'
                Either 'Male' or 'Female'.  The sex of the simulant.

    """
    simulants = pd.DataFrame(
        {
            "entrance_time": pd.Series(creation_time, index=simulant_ids),
            "exit_time": pd.Series(pd.NaT, index=simulant_ids),
            "alive": pd.Series("alive", index=simulant_ids),
        },
        index=simulant_ids,
    )
    age_start = float(age_params["age_start"])
    age_end = float(age_params["age_end"])
    if age_start == age_end:
        return _assign_demography_with_initial_age(
            simulants,
            demographic_proportions,
            age_start,
            step_size,
            randomness_streams,
            register_simulants,
        )
    else:  # age_params['age_start'] is not None and age_params['age_end'] is not None
        return _assign_demography_with_age_bounds(
            simulants,
            demographic_proportions,
            age_start,
            age_end,
            randomness_streams,
            register_simulants,
            key_columns,
        )


def _assign_demography_with_initial_age(
    simulants: pd.DataFrame,
    pop_data: pd.DataFrame,
    initial_age: float,
    step_size: pd.Timedelta,
    randomness_streams: Dict[str, RandomnessStream],
    register_simulants: Callable[[pd.DataFrame], None],
) -> pd.DataFrame:
    """Assigns age, sex, and location information to the provided simulants given a fixed age.

    Parameters
    ----------
    simulants
        Table that represents the new cohort of agents being added to the simulation.
    pop_data
        Table with columns 'age', 'age_start', 'age_end', 'sex', 'year',
        'location', 'population', 'P(sex, location, age| year)',
        'P(sex, location | age, year)'
    initial_age
        The age to assign the new simulants.
    randomness_streams
        Source of random number generation within the vivarium common random number framework.
    step_size
        The size of the initial time step.
    register_simulants
        A function to register the new simulants with the CRN framework.

    Returns
    -------
    pandas.DataFrame
        Table with same columns as `simulants` and with the additional
        columns 'age', 'sex',  and 'location'.
    """
    pop_data = pop_data[
        (pop_data.age_start <= initial_age) & (pop_data.age_end >= initial_age)
    ]

    if pop_data.empty:
        raise ValueError(
            "The age {} is not represented by the population data structure".format(
                initial_age
            )
        )

    age_fuzz = randomness_streams["age_smoothing"].get_draw(
        simulants.index
    ) * utilities.to_years(step_size)
    simulants["age"] = initial_age + age_fuzz
    register_simulants(simulants[["entrance_time", "age"]])

    # Assign a demographically accurate location and sex distribution.
    choices = pop_data.set_index(["sex", "location"])[
        "P(sex, location | age, year)"
    ].reset_index()
    decisions = randomness_streams["general_purpose"].choice(
        simulants.index, choices=choices.index, p=choices["P(sex, location | age, year)"]
    )

    simulants["sex"] = choices.loc[decisions, "sex"].values
    simulants["location"] = choices.loc[decisions, "location"].values

    return simulants


def _assign_demography_with_age_bounds(
    simulants: pd.DataFrame,
    pop_data: pd.DataFrame,
    age_start: float,
    age_end: float,
    randomness_streams: Dict[str, RandomnessStream],
    register_simulants: Callable[[pd.DataFrame], None],
    key_columns: Iterable[str] = ("entrance_time", "age"),
) -> pd.DataFrame:
    """Assigns an age, sex, and location to the provided simulants given a range of ages.

    Parameters
    ----------
    simulants
        Table that represents the new cohort of agents being added to the simulation.
    pop_data
        Table with columns 'age', 'age_start', 'age_end', 'sex', 'year',
        'location', 'population', 'P(sex, location, age| year)',
        'P(sex, location | age, year)'
    age_start, age_end
        The start and end of the age range of interest, respectively.
    randomness_streams
        Source of random number generation within the vivarium common random number framework.
    register_simulants
        A function to register the new simulants with the CRN framework.
    key_columns
        A list of key columns for random number generation.

    Returns
    -------
    pandas.DataFrame
        Table with same columns as `simulants` and with the additional columns
        'age', 'sex',  and 'location'.

    """
    pop_data = rescale_binned_proportions(pop_data, age_start, age_end)
    if pop_data.empty:
        raise ValueError(
            f"The age range ({age_start}, {age_end}) is not represented by the "
            f"population data structure."
        )

    # Assign a demographically accurate age, location, and sex distribution.
    sub_pop_data = pop_data[(pop_data.age_start >= age_start) & (pop_data.age_end <= age_end)]
    choices = sub_pop_data.set_index(["age", "sex", "location"])[
        "P(sex, location, age| year)"
    ].reset_index()
    decisions = randomness_streams["bin_selection"].choice(
        simulants.index, choices=choices.index, p=choices["P(sex, location, age| year)"]
    )
    simulants["age"] = choices.loc[decisions, "age"].values
    simulants["sex"] = choices.loc[decisions, "sex"].values
    simulants["location"] = choices.loc[decisions, "location"].values
    simulants = smooth_ages(
        simulants, pop_data, randomness_streams["age_smoothing_age_bounds"]
    )
    register_simulants(simulants[list(key_columns)])
    return simulants
