#!/usr/bin/env python
# Copyright Salient Predictions 2024

"""Command line interface for the Salient SDK."""

import argparse

import xarray as xr

from .__init__ import __version__
from .constants import get_model_version
from .data_timeseries_api import data_timeseries
from .downscale_api import downscale
from .forecast_timeseries_api import forecast_timeseries
from .location import Location
from .login_api import get_verify_ssl, login


def main() -> None:
    """Command line interface for the Salient SDK."""
    parser = _get_parser()

    # Convert arguments into action -------
    args = parser.parse_args()
    args_dict = vars(args)
    cmd = args_dict.pop("command")

    args_dict = _login_from_arg(args_dict)
    args_dict = _location_from_arg(args_dict)

    # Dispatch to the appropriate function based on the command
    if cmd == "forecast_timeseries":
        file_name = forecast_timeseries(**args_dict)
    elif cmd == "downscale":
        file_name = downscale(**args_dict)
    elif cmd == "data_timeseries":
        file_name = data_timeseries(**args_dict)
    elif cmd == "version":
        file_name = __version__
        args.verbose = False
    elif cmd == "login":
        file_name = args_dict["session"]
    else:
        file_name = None
        parser.print_help()

    if file_name is None:
        pass
    elif args.verbose and isinstance(file_name, str):
        print(xr.open_dataset(file_name))
    else:
        print(file_name)


def _location_from_arg(arg: dict) -> dict:
    if all(key in arg for key in ["latitude", "longitude", "location_file", "shapefile"]):
        arg["loc"] = Location(
            lat=arg.pop("latitude"),
            lon=arg.pop("longitude"),
            location_file=arg.pop("location_file"),
            shapefile=arg.pop("shapefile"),
        )
    return arg


def _login_from_arg(arg: dict) -> dict:
    if all(key in arg for key in ["username", "password"]):
        arg["session"] = login(
            username=arg.pop("username"),
            password=arg.pop("password"),
            verify=arg.pop("verify"),
            verbose=arg["verbose"],  # don't pop, used elsewhere
        )
        # login() may change this as a side effect
        arg["verify"] = get_verify_ssl()
    return arg


def _get_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="salientsdk command line interface")
    subparsers = parser.add_subparsers(dest="command")

    # version command ---------
    version_parser = subparsers.add_parser("version", help="Print the Salient SDK version")

    # login command ----------
    login_parser = _add_common_args(
        subparsers.add_parser("login", help="Login to the Salient API"),
        [],
        location_args=False,
        login_args=True,
    )

    # forecast_timeseries command ------------
    forecast_parser = _add_common_args(
        subparsers.add_parser("forecast_timeseries", help="Run the forecast_timeseries function"),
        ["date", "debias", "version", "force"],
    )
    forecast_parser.add_argument("-fld", "--field", type=str, default="anom")
    forecast_parser.add_argument("-fmt", "--format", type=str, default="nc")
    forecast_parser.add_argument("-mdl", "--model", type=str, default="blend")
    forecast_parser.add_argument("-ref", "--reference_clim", type=str, default="30_yr")
    forecast_parser.add_argument("--timescale", type=str, default="all")
    forecast_parser.add_argument("-var", "--variable", type=str, default="temp")

    # downscale command -----------
    downscale_parser = _add_common_args(
        subparsers.add_parser("downscale", help="Run the downscale function"),
        ["date", "debias", "version", "force"],
    )
    downscale_parser.add_argument("-var", "--variables", type=str, default="temp,precip")
    downscale_parser.add_argument("--members", type=int, default=50)
    # These are currently the only available values.  Not necessary to expose
    # these as options from the command line.
    # argparser.add_argument("--format", type=str, default="nc")
    # argparser.add_argument("--frequency", type=str, default="daily")

    # data_timeseries command --------------
    data_parser = _add_common_args(
        subparsers.add_parser(
            "data_timeseries", help="Get historical ERA5 from the data_timeseries API"
        ),
        ["debias", "force"],
    )
    data_parser.add_argument("-var", "--variable", type=str, default="temp")
    data_parser.add_argument("-fld", "--field", type=str, default="anom")
    data_parser.add_argument("--start", type=str, default="1950-01-01")
    data_parser.add_argument("--end", type=str, default="-today")
    data_parser.add_argument("--format", type=str, default="nc")
    data_parser.add_argument("--frequency", type=str, default="daily")
    return parser


def _add_common_args(
    argparser: argparse.ArgumentParser,
    args: list[str] = [],
    location_args: bool = True,
    login_args: bool = True,
) -> argparse.ArgumentParser:
    """Add standard arguments to a subparser.

    Args:
        argparser (argparse.ArgumentParser): The subparser to add arguments to.
        args (list[str]): Additional standard/shared arguments to add to the parser.
        location_args (bool): If True (default), add standard location arguments.
        login_args (bool): If True (default), add username/password arguments

    Generate an argument parser for the Location class with consistent arguments
    that can be used across multiple `main()` functions and the command line.
    """
    if login_args:
        argparser.add_argument(
            "-u", "--username", type=str, default="username", help="Salient-issued user name"
        )
        argparser.add_argument(
            "-p", "--password", type=str, default="password", help="Salient-issued password"
        )

        verify_group = argparser.add_mutually_exclusive_group(required=False)
        verify_group.add_argument(
            "--verify",
            dest="verify",
            action="store_true",
            help="Force verification of SSL certificates.",
        )
        verify_group.add_argument(
            "--noverify",
            dest="verify",
            action="store_false",
            help="Disable verification of SSL certificates.",
        )
        argparser.set_defaults(verify=None)

        verbosity_group = argparser.add_mutually_exclusive_group(required=False)
        verbosity_group.add_argument(
            "--verbose",
            dest="verbose",
            action="store_true",
            help="Print status messages (default behavior)",
        )
        verbosity_group.add_argument(
            "--quiet", dest="verbose", action="store_false", help="Suppress status messages"
        )
        argparser.set_defaults(verbose=True)

    if location_args:
        argparser.add_argument(
            "-lat",
            "--latitude",
            type=float,
            default=None,
            help="Decimal latitude (also requires longitude)",
        )
        argparser.add_argument(
            "-lon",
            "--longitude",
            type=float,
            default=None,
            help="Decimal longitude (also requires latitude)",
        )
        argparser.add_argument("-loc", "--location_file", type=str, default=None)
        argparser.add_argument("-shp", "--shapefile", type=str, default=None)

    if "debias" in args:
        argparser.add_argument(
            "--debias",
            action="store_true",
            help="Debias to observation stations (default is no debiasing)",
        )

    if "version" in args:
        argparser.add_argument(
            "-ver", "--version", type=str, default=get_model_version(), help="Model version to use"
        )

    if "force" in args:
        argparser.add_argument(
            "--force", action="store_true", help="Overwrite existing files (default is to cache)"
        )

    if "date" in args:
        argparser.add_argument("--date", type=str, default="-today")

    return argparser


if __name__ == "__main__":
    main()
