#!python

"""Run deep cash experiment."""

import click

from metalearn import experiment
from metalearn.data_types import ExperimentType


DEFAULT = experiment.get_default_parameters("METALEARN_REINFORCE")


@click.group()
def cli():
    pass


@cli.group("run")
def run():
    pass


@run.command("experiment")
@click.argument("datasets", nargs=-1)
@click.option("--output_fp", default=DEFAULT["output_fp"])
@click.option("--input_size", default=DEFAULT["input_size"])
@click.option("--hidden_size", default=DEFAULT["hidden_size"])
@click.option("--output_size", default=DEFAULT["output_size"])
@click.option("--n_layers", default=DEFAULT["n_layers"])
@click.option("--dropout_rate", default=DEFAULT["dropout_rate"])
@click.option("--gamma", default=DEFAULT["gamma"], type=float)
@click.option("--entropy_coef", default=DEFAULT["entropy_coef"],
              type=float)
@click.option("--entropy_coef_anneal_by",
              default=DEFAULT["entropy_coef_anneal_by"],
              type=float)
@click.option("--n_episodes", default=DEFAULT["n_episodes"])
@click.option("--n_iter", default=DEFAULT["n_iter"])
@click.option("--learning_rate",
              default=DEFAULT["learning_rate"],
              type=float)
@click.option("--env_sources", "-e",
              default=DEFAULT["env_sources"],
              type=click.Choice(["SKLEARN", "OPEN_ML", "KAGGLE"]),
              multiple=True)
@click.option("--target_types", "-t",
              default=DEFAULT["target_types"],
              type=click.Choice(["BINARY", "MULTICLASS", "REGRESSION"]),
              multiple=True)
@click.option("--error_reward", default=DEFAULT["error_reward"],
              type=float)
@click.option("--per_framework_time_limit",
              default=DEFAULT["per_framework_time_limit"])
@click.option("--per_framework_memory_limit",
              default=DEFAULT["per_framework_memory_limit"])
@click.option("--metric_logger",
              default=DEFAULT["metric_logger"])
@click.option("--fit_verbose", default=DEFAULT["fit_verbose"])
@click.option("--controller_seed",
              default=DEFAULT["controller_seed"])
@click.option("--task_environment_seed",
              default=DEFAULT["task_environment_seed"])
def run_experiment(*args, **kwargs):
    """Run deep cash experiment with single configuration."""
    # TODO: use the following SO solution to implement list option
    # https://stackoverflow.com/questions/47631914/how-to-pass-several-list-of-arguments-to-click-option
    experiment.run_experiment(*args, **kwargs)


@run.command("from-config")
@click.argument("config_fp")
def from_config(config_fp):
    print("running experiment from config %s" % config_fp)
    config = experiment.read_config(config_fp)
    # write config to output directory
    experiment.write_config(
        config, config.parameters["output_fp"], fname="experiment_config.yml")
    print("\nparameters:")
    for k, v in config.parameters.items():
        print("\t%s: %s" % (k, v))
    if "experiment_type" not in config._fields:
        print("experiment_type not specified in config. Assuming experiment "
              "type is METALEARN_REINFORCE.")
    experiment_fn = experiment.get_experiment_fn(
        config._asdict().get("experiment_type", "METALEARN_REINFORCE"))
    experiment_fn(**config.parameters)


@cli.group("create")
def create():
    pass


@create.command("config")
@click.argument("name")
@click.argument("dir-path")
@click.option("--experiment-type",
              type=click.Choice([x.name for x in ExperimentType]),
              default="METALEARN_REINFORCE")
@click.option("--description", default="")
@click.option("--floyd", is_flag=True)
def create_config(name, dir_path, experiment_type, description, floyd):
    """Creates experiment configuration file."""
    custom_params = {}
    if floyd:
        custom_params.update({
            "output_fp": "/output",
            "metric_logger": "floyd",
            "fit_verbose": 0,
        })
    experiment.write_config(
        experiment.create_config(
            name, experiment_type, description, **custom_params),
        dir_path)


if __name__ == "__main__":
    cli()
