"""
         Inverse problem with two forward models that share a common parameter
----------------------------------------------------------------------------------------
                       ---> Additive model prediction error <---
----------------------------------------------------------------------------------------
The first model equation is y(x) = a * x + b with a, b being the model parameters and
the second model equation is y(x) = alpha * x**2 + b where alpha is an additional model
parameter, and b is the same model parameter as in the first model equation. Both
forward models have the same additive error model with a normal zero-mean distribution
where the standard deviation is to be inferred. The problem is approach with a maximum
likelihood estimation.
"""

# standard library imports
import unittest
import os

# third party imports
import numpy as np
import matplotlib.pyplot as plt

# local imports (problem definition)
from probeye.definition.inverse_problem import InverseProblem
from probeye.definition.forward_model import ForwardModelBase
from probeye.definition.distribution import Normal, Weibull
from probeye.definition.sensor import Sensor
from probeye.definition.likelihood_model import GaussianLikelihoodModel

# local imports (knowledge graph)
from probeye.ontology.knowledge_graph_export import (
    export_knowledge_graph_including_results,
)

# local imports (testing related)
from probeye.inference.scipy.solver import MaxLikelihoodSolver
from tests.integration_tests.subroutines import run_inference_engines


class TestProblem(unittest.TestCase):
    def test_two_models(
        self,
        n_steps: int = 200,
        n_initial_steps: int = 100,
        n_walkers: int = 20,
        plot: bool = False,
        show_progress: bool = False,
        write_to_graph: bool = True,
        run_scipy: bool = True,
        run_emcee: bool = False,  # intentionally False for faster test-runs
        run_dynesty: bool = False,  # intentionally False for faster test-runs
    ):
        """
        Integration test for the problem described at the top of this file.

        Parameters
        ----------
        n_steps
            Number of steps (samples) to run. Note that the default number is rather low
            just so the test does not take too long.
        n_initial_steps
            Number of steps for initial (burn-in) sampling.
        n_walkers
            Number of walkers used by the estimator.
        plot
            If True, the data and the posterior distributions are plotted. This is
            deactivated by default, so that the test does not stop until the generated
            plots are closed.
        show_progress
            If True, progress-bars will be shown, if available.
        write_to_graph
            Triggers the export of the solver results to a given knowledge graph.
        run_scipy
            If True, the problem is solved with scipy (maximum likelihood est).
            Otherwise, no maximum likelihood estimate is derived.
        run_emcee
            If True, the problem is solved with the emcee solver. Otherwise, the emcee
            solver will not be used.
        run_dynesty
            If True, the problem is solved with the dynesty solver. Otherwise, the
            dynesty solver will not be used.
        """

        # ============================================================================ #
        #                              Set numeric values                              #
        # ============================================================================ #

        # 'true' value of a, and its normal prior parameters
        a_true = 2.5
        mean_a = 2.0
        std_a = 1.0

        # 'true' value of b, and its normal prior parameters
        b_true = 1.7
        mean_b = 1.0
        std_b = 1.0

        # 'true' value of alpha, and its normal prior parameters
        alpha_true = 0.7
        mean_alpha = 2.0
        std_alpha = 1.0

        # 'true' value of sigma, and its Weibull prior parameters
        sigma_true = 0.15
        scale_sigma = 0.2
        shape_sigma = 5.0

        # the number of generated experiment_names and seed for random numbers
        n_tests = 100
        seed = 1

        # ============================================================================ #
        #                         Define the Inference Problem                         #
        # ============================================================================ #

        # initialize the inverse problem with a useful name
        problem = InverseProblem("Two forward models with a shared parameter (AME)")

        # add all parameters to the problem
        problem.add_parameter(
            name="a",
            info="Slope of the graph in linear model",
            tex="$a$ (linear)",
            prior=Normal(mean=mean_a, std=std_a),
        )
        problem.add_parameter(
            name="alpha",
            info="Factor of quadratic term",
            tex=r"$\alpha$ (quad.)",
            prior=Normal(mean=mean_alpha, std=std_alpha),
        )
        problem.add_parameter(
            name="b",
            info="Intersection of graph with y-axis",
            tex="$b$ (shared)",
            prior=Normal(mean=mean_b, std=std_b),
        )
        problem.add_parameter(
            name="sigma",
            domain="(0, +oo)",
            tex=r"$\sigma$ (likelihood)",
            info="Standard deviation of zero-mean additive model error",
            prior=Weibull(scale=scale_sigma, shape=shape_sigma),
        )

        # ============================================================================ #
        #                    Add test data to the Inference Problem                    #
        # ============================================================================ #

        # data generation for linear model
        np.random.seed(seed)
        x_test = np.linspace(0.0, 1.0, n_tests)
        y_true_linear = a_true * x_test + b_true
        y_test_linear = np.random.normal(loc=y_true_linear, scale=sigma_true)

        # add the experimental data
        problem.add_experiment(
            name=f"TestSeries_linear",
            sensor_data={
                "x": x_test,
                "y": y_test_linear,
            },
        )

        # data generation for quadratic model
        y_true_quadratic = alpha_true * x_test**2 + b_true
        y_test_quadratic = np.random.normal(loc=y_true_quadratic, scale=sigma_true)

        # add the experimental data
        problem.add_experiment(
            name=f"TestSeries_quadratic",
            sensor_data={
                "x": x_test,
                "y": y_test_quadratic,
            },
        )

        # plot the true and noisy data
        if plot:
            plt.scatter(
                x_test,
                y_test_linear,
                label="measured data (linear)",
                s=10,
                c="red",
                zorder=10,
            )
            plt.plot(x_test, y_true_linear, label="true (linear)", c="black")
            plt.scatter(
                x_test,
                y_test_quadratic,
                s=10,
                c="orange",
                zorder=10,
                label="measured data (quadratic)",
            )
            plt.plot(x_test, y_true_quadratic, label="true (quadratic)", c="blue")
            plt.xlabel("x")
            plt.ylabel("y")
            plt.legend()
            plt.tight_layout()
            plt.draw()  # does not stop execution

        # ============================================================================ #
        #                          Define the Forward Models                           #
        # ============================================================================ #

        class LinearModel(ForwardModelBase):
            def interface(self):
                self.parameters = ["a", "b"]
                self.input_sensors = Sensor("x")
                self.output_sensors = Sensor("y", std_model="sigma")

            def response(self, inp: dict) -> dict:
                x = inp["x"]
                a = inp["a"]
                b = inp["b"]
                return {"y": a * x + b}

        class QuadraticModel(ForwardModelBase):
            def interface(self):
                self.parameters = ["alpha", {"b": "beta"}]
                self.input_sensors = Sensor("x")
                self.output_sensors = Sensor("y", std_model="sigma")

            def response(self, inp: dict) -> dict:
                x = inp["x"]
                alpha = inp["alpha"]
                beta = inp["beta"]
                return {"y": alpha * x**2 + beta}

        # add the forward model to the problem
        linear_model = LinearModel("LinearModel")
        problem.add_forward_model(linear_model, experiments="TestSeries_linear")
        quadratic_model = QuadraticModel("QuadraticModel")
        problem.add_forward_model(quadratic_model, experiments="TestSeries_quadratic")

        # ============================================================================ #
        #                           Add likelihood model(s)                            #
        # ============================================================================ #

        problem.add_likelihood_model(
            GaussianLikelihoodModel(
                experiment_name="TestSeries_linear",
                model_error="additive",
            )
        )
        problem.add_likelihood_model(
            GaussianLikelihoodModel(
                experiment_name="TestSeries_quadratic",
                model_error="additive",
            )
        )

        # give problem overview
        problem.info()

        # ============================================================================ #
        #                    Solve problem with inference engine(s)                    #
        # ============================================================================ #

        # this routine is imported from another script because it it used by all
        # integration tests in the same way
        true_values = {
            "a": a_true,
            "alpha": alpha_true,
            "b": b_true,
            "sigma": sigma_true,
        }
        run_inference_engines(
            problem,
            true_values=true_values,
            n_steps=n_steps,
            n_initial_steps=n_initial_steps,
            n_walkers=n_walkers,
            plot=plot,
            show_progress=show_progress,
            run_scipy=False,  # is called below separately
            run_emcee=run_emcee,
            run_dynesty=run_dynesty,
        )

        # the ScipySolver is called separately here to test the knowledge graph export
        # routine that directly includes the results in the graph
        if run_scipy:
            scipy_solver = MaxLikelihoodSolver(problem, show_progress=show_progress)
            inference_data = scipy_solver.run(true_values=true_values)
            if write_to_graph:
                dir_path = os.path.dirname(__file__)
                basename_owl = os.path.basename(__file__).split(".")[0] + ".owl"
                knowledge_graph_file = os.path.join(dir_path, basename_owl)
                export_knowledge_graph_including_results(
                    scipy_solver.problem,
                    inference_data,
                    knowledge_graph_file,
                    data_dir=dir_path,
                )


if __name__ == "__main__":
    unittest.main()
