"""
Reads in a ``Cycling_Information`` json representation and uses it in
conjunction with input files generated by
``ep_bolfi.kadi_tools.gitt_preprocessing`` to fit any PyBaMM model to it
via EP-BOLFI.
"""

from ast import literal_eval
from contextlib import redirect_stdout
from copy import deepcopy
import json
from os import linesep
from os.path import isfile
import xmlhelpy


@xmlhelpy.command(
    name='python -m ep_bolfi.kadi_tools.gitt_parameterization',
    version='${VERSION}'
)
@xmlhelpy.option(
    'input-record',
    char='r',
    param_type=xmlhelpy.Integer,
    required=True,
    description=(
        "Persistent record identifier of the record with the optimizer input."
    )
)
@xmlhelpy.option(
    'input-file',
    char='f',
    param_type=xmlhelpy.String,
    required=True,
    description="File name of the optimizer input file."
)
@xmlhelpy.option(
    'parameters-file',
    char='m',
    param_type=xmlhelpy.String,
    required=True,
    description=(
        "File name of the model parameters. It must be a Python file and "
        "contain the following global variables:"
        + linesep + linesep
        + " - parameters: The dictionary of parameters to pass on to the "
        "solver. May be a ep_bolfi.utility.preprocessing.SubstitutionDict."
        + linesep + linesep
        + " - unknowns: The dictionary of unknown parameters. Instead of "
        "single values as in 'parameters', input 2-tuples with their lower "
        "and upper bounds, e.g. from literature."
        + linesep + linesep
        + "It may contain the additional following global variables:"
        + linesep + linesep
        + " - transform_unknowns: Dictionary for transforming unknowns before "
        "inferring them via a normal distribution. Either give as 2-tuples "
        "with the first entry being the back-transformation and the second "
        "one being the transformation. For convenience, putting 'log' gives "
        "log-normal distributions."
        + linesep + linesep
        + " - negative_SOC_from_cell_SOC: A callable, used for OCV "
        + "subtraction."
        + linesep + linesep
        + " - positive_SOC_from_cell_SOC: A callable, used for OCV "
        + "subtraction."
        + linesep + linesep
        + " - uncertainties: The dictionary of parameter uncertainties. Used "
        "for scrambling them in the simulation samples. Give them as tuples: "
        "the first entry is the name of the distribution in scipy.stats, and "
        "the following are its parameters. Example: ('norm', mean, std)."
    )
)
@xmlhelpy.option(
    'output-record',
    char='o',
    param_type=xmlhelpy.Integer,
    required=True,
    description=(
        "Persistent record identifier of the record to store the output of "
        "the optimizer in."
    )
)
@xmlhelpy.option(
    'convergence-mode',
    char='c',
    param_type=xmlhelpy.Bool,
    default=False,
    description=(
        "If set to True, the features will be treated differently. The goal "
        "is better behaviour in cases where the initial unknowns' 95% "
        "confidence boundaries are already pretty close to the best possible "
        "fit. The method is to mostly disable EP, as it is allowing BOLFI to "
        "treat the features one-by-one, introducing the possibility that one "
        "feature will shrink the confidence boundaries enough such that the "
        "next feature is not contained within these boundaries anymore. "
        "Instead, BOLFI will fit the vector of all features at once."
    )
)
@xmlhelpy.option(
    'seed',
    char='s',
    param_type=xmlhelpy.Integer,
    default=None,
    description="Seed for RNG. Set to a number for unvarying results."
)
@xmlhelpy.option(
    'overwrite',
    char='w',
    default=False,
    param_type=xmlhelpy.Bool,
    description=(
        "Whether or not an already existing file by the same name in the "
        "record gets overwritten."
    )
)
def gitt_parameterization(
    input_record,
    input_file,
    parameters_file,
    output_record,
    seed,
    convergence_mode,
    overwrite
):
    """Please refer to the --help output of this file."""
    from ep_bolfi import EP_BOLFI
    from ep_bolfi.kadi_tools.gitt_preprocessing import (
        gitt_feature_names, gitt_features
    )
    from ep_bolfi.models.solversetup import (
        simulation_setup, spectral_mesh_pts_and_method
    )
    from ep_bolfi.utility.preprocessing import (
        calculate_desired_voltage, SubstitutionDict
    )
    from kadi_apy.lib.core import KadiManager, Record
    from numpy import array
    from numpy.random import RandomState
    from pybamm.expression_tree.exceptions import SolverError
    from pybamm.models.full_battery_models import lithium_ion
    import scipy

    manager = KadiManager()

    file_prefix = parameters_file.split(".")[0]
    if not isfile(input_file) or not isfile("local_parameter_file.py"):
        input_record_handle = Record(manager, id=input_record, create=False)
    if not isfile(input_file):
        input_id = input_record_handle.get_file_id(input_file)
        input_record_handle.download_file(input_id, input_file)
    if not isfile("local_parameter_file.py"):
        parameters_id = input_record_handle.get_file_id(parameters_file)
        input_record_handle.download_file(
            parameters_id, "local_parameter_file.py"
        )
    from local_parameter_file import parameters, transform_unknowns, unknowns

    with open(input_file, 'r') as f:
        input_data = json.load(f)
    list_of_feature_indices = input_data['list_of_feature_indices']
    # experiment_features = input_data['experiment_features']
    initial_socs = input_data['initial_socs']
    current_input = input_data['current_input']
    overpotential = input_data['overpotential']
    three_electrode = input_data['three_electrode']
    dimensionless_reference_electrode_location = (
        input_data['dimensionless_reference_electrode_location']
    )
    sqrt_cutoff = input_data['sqrt_cutoff']
    sqrt_start = input_data['sqrt_start']
    exp_cutoff = input_data['exp_cutoff']
    white_noise = input_data['white_noise']
    uncertainties = input_data['uncertainties']
    model_name = input_data['model_name']
    discretization = input_data['discretization']
    optimizer_settings = input_data['optimizer_settings']
    experiment_data = input_data['experiment_data']

    # If no extra arguments are given, this will be just the model name.
    # Else, the second entry will be the comma-denoted list of arguments.
    model_components = model_name.split("(")
    model_prefix = model_components[0]
    model_args = []
    model_kwargs = {}
    if len(model_components) > 1:
        for argument in model_components[1][:-1].split(","):
            # Remove whitespace.
            argument = argument.strip()
            if "=" in argument:  # Keyword argument.
                kwarg = argument.split("=")
                model_kwargs[kwarg[0]] = literal_eval(kwarg[1])
            elif argument == "":
                continue
            else:  # Positional argument.
                model_args.append(literal_eval(argument))
    model_instance = getattr(lithium_ion, model_prefix)(
        *model_args, **model_kwargs
    )

    # Apply the correct initial SOC.
    for electrode in ["negative", "positive"]:
        soc_value = initial_socs[
            "Initial concentration in "
            + electrode
            + " electrode [mol.m-3]"
        ]
        if soc_value is not None:
            parameters[
                "Initial concentration in "
                + electrode
                + " electrode [mol.m-3]"
            ] = soc_value

    # When using SubstitutionDict, additional parameters may be variable.
    if isinstance(parameters, SubstitutionDict):
        solver_free_parameters = set(
            list(unknowns.keys())
            + list(uncertainties.keys())
            + list(parameters.dependent_variables(unknowns.keys()))
            + list(parameters.dependent_variables(uncertainties.keys()))
        )
    else:
        solver_free_parameters = set(
            list(unknowns.keys())
            + list(uncertainties.keys())
        )

    solver, callback = simulation_setup(
        model_instance,
        current_input,
        parameters,
        *spectral_mesh_pts_and_method(**discretization),
        free_parameters=solver_free_parameters,
        verbose=False,
        # logging_file=file_prefix + '_evaluations.log',
    )

    white_noise_generator = scipy.stats.norm(0, white_noise)
    white_noise_generator.random_state = RandomState(seed=seed + 1)
    parameter_noise_rng = {}
    for i, (p_name, (s_name, *args)) in enumerate(uncertainties.items()):
        parameter_noise_rng[p_name] = getattr(scipy.stats, s_name)(*args)
        parameter_noise_rng[p_name].random_state = RandomState(
            seed=seed + i + 2
        )

    t0 = experiment_data[0][0][0]

    # Will be set to True if PyBaMM errors out due to function variables
    # being used in the variable solver input. In that case, the
    # simulator will have to be built each trial, which is slower.
    function_variable = False
    try:
        test_parameters = {
            key: parameters[key] for key in solver_free_parameters
        }
        solver(
            calc_esoh=False,
            inputs=test_parameters,
            callbacks=callback
        )
    except AttributeError:  # 'function' object has no attribute 'shape'
        function_variable = True

    def gitt_simulator(trial_parameters):
        # Apply the parameter noise.
        for p_name, p_rng in parameter_noise_rng.items():
            trial_parameters[p_name] = p_rng.rvs(size=1)[0]
        full_parameters = deepcopy(trial_parameters)
        # Fail silently if the simulation did not work.
        try:
            if not function_variable:
                nonlocal solver, callback
                # Remove all parameters that are fixed.
                trial_parameters = {
                    key: trial_parameters[key]
                    for key in solver_free_parameters
                }
                solution = solver(
                    calc_esoh=False,
                    inputs=trial_parameters,
                    callbacks=callback
                )
            else:
                solver, callback = simulation_setup(
                    model_instance,
                    current_input,
                    trial_parameters,
                    *spectral_mesh_pts_and_method(**discretization),
                    verbose=False,
                    # logging_file=file_prefix + '_evaluations.log',
                )
                solution = solver(calc_esoh=False, callbacks=callback)
        except SolverError:
            return [[[0.0 for _ in n] for n in e] for e in experiment_data]
        if solution is None:
            return [[[0.0 for _ in n] for n in e] for e in experiment_data]
        if len(solution.cycles) < len(current_input):
            return [[[0.0 for _ in n] for n in e] for e in experiment_data]
        sim_data = [[], []]
        for cycle in solution.cycles:
            t_eval = t0 + cycle["Time [h]"].entries * 3600.0
            u_eval = calculate_desired_voltage(
                cycle,
                t_eval,
                1.0,  # voltage_scale
                overpotential,
                three_electrode,
                dimensionless_reference_electrode_location,
                full_parameters,
            )
            # For some reason, sometimes U goes on longer than t.
            if len(u_eval) < len(t_eval):
                t_eval = t_eval[:len(u_eval)]
            if len(t_eval) < len(u_eval):
                u_eval = u_eval[:len(t_eval)]
            sim_data[0].append(t_eval)
            # Apply the measurement noise.
            sim_data[1].append(
                u_eval + white_noise_generator.rvs(size=len(sim_data[0][-1]))
            )
        return sim_data

    def features(dataset):
        current_features = gitt_features(
            dataset,
            list_of_feature_indices,
            sqrt_cutoff,
            sqrt_start,
            exp_cutoff
        )
        if convergence_mode:
            return [array(current_features)]
        else:
            return current_features

    def feature_names(index):
        if convergence_mode:
            return "all features"
        total_index = list_of_feature_indices[index]
        return (
            gitt_feature_names[total_index % 5]
            + "(segment #"
            + str(total_index // 5)
            + ")"
        )

    estimator = EP_BOLFI(
        [gitt_simulator],
        [experiment_data],
        [features],
        parameters,
        free_parameters_boundaries=unknowns,
        transform_parameters=transform_unknowns,
        display_current_feature=[feature_names],
    )

    if 'seed' in optimizer_settings.keys():
        ep_bolfi_seed = optimizer_settings.pop('seed')
    else:
        ep_bolfi_seed = seed

    with open(file_prefix + '_optimization.log', 'w') as f:
        with redirect_stdout(f):
            estimator.run(seed=ep_bolfi_seed, **optimizer_settings)

    with open(file_prefix + '_parameterization.json', 'w') as f:
        f.write(estimator.result_to_json(seed=ep_bolfi_seed))

    with open(file_prefix + '_evaluations.json', 'w') as f:
        f.write(estimator.log_to_json())

    output_record_handle = Record(manager, id=output_record, create=False)
    # output_record_handle.upload_file(
    #     file_prefix + '_evaluations.log', force=overwrite
    # )
    output_record_handle.upload_file(
        file_prefix + '_optimization.log', force=overwrite
    )
    output_record_handle.upload_file(
        file_prefix + '_parameterization.json', force=overwrite
    )
    output_record_handle.upload_file(
        file_prefix + '_evaluations.json', force=overwrite
    )


if __name__ == '__main__':
    gitt_parameterization()
