#!python
import typer
import os
import json
import time
import datetime
import pathlib
from typing import List, Optional
from typing_extensions import Annotated
from yaspin import yaspin
from rich.console import Console
from rich.table import Table
import importlib.metadata

from quantagonia.cloud.https_client import HTTPSClient
from quantagonia.cloud.cloud_runner import CloudRunner
from quantagonia.cloud.enums import JobStatus
from quantagonia.cloud.solver_log import SolverLog
from quantagonia.parser.solution_parser import SolutionParser
from quantagonia.parameters import HybridSolverParameters
from quantagonia.enums import HybridSolverServers
from quantagonia.cloud.cloud_runner import build_specs_list


app = typer.Typer(help="Simple CLI client for Quantagonia's cloud-based HybridSolver.")
console = Console()

API_KEY = None  # filled in main()
SERVER = None  # filled in main()
POLL_FREQUENCY = 2

###
# Helpers that minimize code duplication
###

def _exit_with_error(error_msg) -> None:
    console.print(f"[red]Error:[/red] {error_msg}")
    exit(1)


def _print_warning(msg) -> None:
    console.print(f"[yellow]Warning:[/yellow] {msg}")


def _submit(
        client : HTTPSClient,
        problem_file : str,
        params_file : Optional[str],
        tag : str,
        context : str,
        quiet : bool,
        **cl_arguments) -> str:

    # check if file exists
    if not os.path.isfile(problem_file):
        _exit_with_error(f"File {problem_file} does not exist, exiting...")

    # check file extension
    _, _, extension = pathlib.Path(problem_file).name.partition(".")
    is_mip = (extension in ["mps", "lp", "mps.gz", "lp.gz"])
    is_qubo = (extension in ["qubo", "qubo.gz"])

    if not is_mip and not is_qubo:
        _exit_with_error(f"File {problem_file} is not in MIP or QUBO file format, exiting...")

    # first collect parameters given as command line arguments
    params_dict = {param: value for param, value in cl_arguments.items() if value not in [None, []]}

    # then overwrite with parameters given in parameter file
    if params_file:
        try:
            with open(params_file, "r") as f:
                given_parameters = json.loads(f.read())
        except FileNotFoundError:
            _exit_with_error(f"Parameter file {params_file} does not exist, exiting...")

        for param, value in given_parameters.items():
            params_dict[param] = value

    # create empty params object
    params = HybridSolverParameters()

    # set collected options
    for param, value in params_dict.items():
        # get setter method from the parameter object
        setter = getattr(params, f'set_{param}')
        setter(value)

    # build specs from params and problem file
    specs = build_specs_list([params], [problem_file], {})

    # start solving job
    if quiet:
        try:
            job_id = client.submit_job(
                problem_files=[problem_file], specs=[specs], tag=tag, context=context)

            print(f"Submitted job with ID: {job_id}")
        except Exception as e:
            _exit_with_error(f"Failed to submit job: {e}")

    else:
        has_error = False
        with yaspin(text=f"Submitting job to the HybridSolver...", color="yellow") as spinner:
            try:
                job_id = client.submit_job(problem_files=[problem_file], specs=specs, tag=tag, context=context)

                spinner.text = f"Submitted job with ID: {job_id}"
                spinner.ok("✅")
            except Exception as e:
                spinner.text = f"Failed to submit job: {e}"
                spinner.fail("❌")
                has_error = True
        # handle exit outside of spinner
        if has_error:
            exit(1)

    return job_id


def _status(job_id : str, item : int, client : HTTPSClient, spinner : yaspin) -> bool:

    try:
        status = client.get_current_status(job_id)
        status = JobStatus(status[item])

        spinner.text = f"Status: {status.value}"

        if status == JobStatus.created:
          spinner.ok("⏳")
        elif status == JobStatus.running:
          spinner.ok("💻")
        elif status in [JobStatus.finished, JobStatus.success]:
          spinner.ok("✅")
        else:
          spinner.ok("❌")
    except Exception as e:
        spinner.text = f"Failed to retrieve status of job {job_id}"
        spinner.fail("❌")
        return False

    return True


def _monitor_job(job_id : str, client : HTTPSClient, solver_log = SolverLog):
    """Monitor given job, i.e., print logs, final status etc."""

    # we want the exit to be outside the with statement such that the spinner closes
    has_error = False
    # if monitor is set, keep running and follow logs
    with yaspin(text=f"Processing job {job_id}...", color="yellow") as spinner:

        try:
            status = JobStatus.created

            while status in [JobStatus.created, JobStatus.running]:
                status = client.get_current_status(job_id)
                status = JobStatus(status[0])

                logs = client.get_current_log(job_id)
                with spinner.hidden():
                    solver_log.update_log(logs[0])
                time.sleep(POLL_FREQUENCY)

            spinner.text = f"Status: {status.value}"

            if status in [JobStatus.finished, JobStatus.success]:
                spinner.ok("✅")
            else:
                spinner.fail("❌")
        except Exception as e:
            spinner.text = f"Failed to retrieve status for job {job_id}"
            spinner.fail("❌")
            has_error = True

    if has_error:
        exit(1)


def _retrieve_billing_time(job_id : str, client : HTTPSClient):
    with yaspin(text=f"Retrieving billing time for job {job_id}...", color="yellow") as spinner:

        try:
            _, time_billed = client.get_results(job_id)
            spinner.text = f"Minutes billed: {time_billed}"
            spinner.ok("✅")
        except Exception as e:
            spinner.text = "Failed to retrieve billing time."
            spinner.fail("❌")


###
# CLI commands
###
@app.command()
def solve(
        problem_file : str = typer.Argument(help="Path to optimization problem file."),
        params_file : str = typer.Option(None, help="Path to parameter file. If specified, override other options."),
        relative_gap : float = typer.Option(None, help="Stopping criterion: relative gap"),
        absolute_gap : float = typer.Option(None, help="Stopping criterion: absolute gap"),
        timelimit : int = typer.Option(None, help="Stopping criterion: runtime"),
        as_qubo : bool = typer.Option(None, help="Try to solve MIP as MIP and as QUBO"),
        as_qubo_only : bool = typer.Option(None, help="Try to solve MIP only as QUBO"),
        presolve : bool = typer.Option(None, help="Enable (default) or disable presolve"),
        heuristics_only : bool = typer.Option(None, help="Only apply primal heuristics and then terminate (QUBO only)"),
        objective_limit : float = typer.Option(None, help="Stopping criterion: objective value"),
        quantum_heuristic : Annotated[Optional[List[str]],typer.Option(help="Adds given quantum heuristic to the heuristics pool")] = None,
        context : str = typer.Option("", help="Billing context to run the job in."),
        tag : str = typer.Option("", help="Tag to identify the job later."),
        quiet : bool = typer.Option(False, help="Disable interactive output and only show final logs")):
    """
    Solve the optimization problem specified in the given file (.mps/.qubo, either plain or .gz)
    and actively monitor the progress.
    Outputs solver logs and the time billed.
    This command is equivalent to a 'submit' command followed by a 'monitor'.

    Args:
        problem_file (str): Path to the optimization problem file.
        params_file (str, optional): Path to the parameters file.
        relative_gap (float, optional): Stopping criterion: relative gap.
        absolute_gap (float, optional): Stopping criterion: absolute gap.
        timelimit (int, optional): Stopping criterion: runtime.
        as_qubo (bool): Attempts to solve a MIP as MIP and QUBO in parallel.
        as_qubo_only (bool):  = typer.Option(None, help="TryAttempts to solve MIP only as QUBO.
        presolve (bool): Enables (default) or disable presolve.
        heuristics_only (bool): Only apply primal heuristics and then terminate (QUBO only).
        objective_limit (float): Stopping criterion: objective value.
        quantum_heuristic (str): Adds given quantum heuristic to the heuristics pool. Can be used multiple times.
        tag (str, optional): Tag to identify the job later. Defaults to "".
        quiet (bool, optional): Disable interactive output and only show final logs. Defaults to False.
    """

    runner = CloudRunner(api_key = API_KEY, server = SERVER)
    client = runner._https_client

    # submit job
    job_id = _submit(
        client, problem_file, params_file, tag, context, quiet,
        absolute_gap=absolute_gap, relative_gap=relative_gap, time_limit=timelimit, as_qubo=as_qubo,
        as_qubo_only=as_qubo_only, presolve=presolve, heuristics_only=heuristics_only,
        quantum_heuristics=quantum_heuristic)

    solver_log = SolverLog()
    solver_log.next_time_add_new_line()

    if quiet:
        try:
            status = JobStatus.created
            while status in [JobStatus.created, JobStatus.running]:
                status = client.get_current_status(job_id)
                status = JobStatus(status[0])
                time.sleep(POLL_FREQUENCY)

            logs = client.get_current_log(job_id)
            solver_log.update_log(logs[0])

        except Exception as e:
            _exit_with_error("Failed to retrieve status")

        print(f"\nFinished job with status {status.value}")

    else:
        # monitor job, i.e., print logs, final status etc.
        _monitor_job(job_id, client, solver_log)

        # get time billed
        _retrieve_billing_time(job_id, client)


@app.command()
def submit(
        problem_file : str = typer.Argument(help="Path to optimization problem file."),
        params_file : str = typer.Option(None, help="Path to parameter file. If specified, override other options."),
        relative_gap : float = typer.Option(None, help="Stopping criterion: relative gap"),
        absolute_gap : float = typer.Option(None, help="Stopping criterion: absolute gap"),
        timelimit : int = typer.Option(None, help="Stopping criterion: runtime"),
        as_qubo : bool = typer.Option(None, help="Try to solve MIP as MIP and as QUBO"),
        as_qubo_only : bool = typer.Option(None, help="Try to solve MIP only as QUBO"),
        presolve : bool = typer.Option(None, help="Enable (default) or disable presolve"),
        heuristics_only : bool = typer.Option(None, help="Only apply primal heuristics and then terminate (QUBO only)"),
        objective_limit : float = typer.Option(None, help="Stopping criterion: objective value"),
        quantum_heuristic : Annotated[Optional[List[str]],typer.Option(help="Adds given quantum heuristic to the heuristics pool")] = None,
        context : str = typer.Option("", help="Billing context to run the job in."),
        tag : str = typer.Option("", help="Tag to identify the job later."),
        quiet : bool = typer.Option(False, help="Disable interactive output and only show final logs")):
    """
    Submit the given optimization problem in a non-blocking way.
    Use 'status', 'logs', and 'solution' commands to get results.

    Args:
        problem_file (str): Path to the optimization problem file.
        params_file (str, optional): Path to the parameters file.
        relative_gap (float, optional): Stopping criterion: relative gap.
        absolute_gap (float, optional): Stopping criterion: absolute gap.
        timelimit (int, optional): Stopping criterion: runtime.
        as_qubo (bool): Attempts to solve a MIP as MIP and QUBO in parallel.
        as_qubo_only (bool):  = typer.Option(None, help="TryAttempts to solve MIP only as QUBO.
        presolve (bool): Enables (default) or disable presolve.
        heuristics_only (bool): Only apply primal heuristics and then terminate (QUBO only).
        objective_limit (float): Stopping criterion: objective value.
        quantum_heuristic (str): Adds given quantum heuristic to the heuristics pool. Can be used multiple times.
        tag (str, optional): Tag to identify the job later. Defaults to "".
        quiet (bool, optional): Disable interactive output and only show final logs. Defaults to False.

    Returns:
        job_id (any): Job ID associated with the submitted job.
    """

    runner = CloudRunner(api_key = API_KEY, server = SERVER)
    client = runner._https_client

    # submit job
    return _submit(
        client, problem_file, params_file, tag, context, quiet,
        absolute_gap=absolute_gap, relative_gap=relative_gap, time_limit=timelimit, as_qubo=as_qubo,
        as_qubo_only=as_qubo_only, presolve=presolve, heuristics_only=heuristics_only,
        quantum_heuristics=quantum_heuristic)


@app.command()
def logs(job_id : str, item : int = 0):
    """
    Print the current logs of the given job.
    For batched jobs, the optional parameter 'item' selects the item of the batch.

    Args:
        job_id (str): The ID of the job to retrieve logs for.
        item (int, optional): The index of the batch item. Defaults to 0.
    """

    runner = CloudRunner(api_key = API_KEY, server = SERVER)
    client = runner._https_client

    # single-shot log dumping of all available lines
    has_error = False
    with yaspin(text=f"Retrieving current logs for job {job_id}...", color="yellow") as spinner:
        try:
            logs = client.get_current_log(job_id)
            spinner.text = f"Retrieved logs:"
            spinner.ok("✅")

            with spinner.hidden():
                for line in logs[item].split("\n"):
                    print(line)
        except Exception as e:
            spinner.text = f"Failed to retrieve logs"
            spinner.fail("❌")
            has_error = True

    if has_error:
        exit(1)


@app.command()
def monitor(job_id : str, item : int = 0):
    """
    Resumes monitoring the progress (i.e., logs) of a given job, e.g., after a 'submit' command.
    For batched jobs, the optional parameter 'item' selects the item of the batch.

    Args:
        job_id (str): The ID of the job to monitor.
        item (int, optional): The index of the batch item. Defaults to 0.
    """

    runner = CloudRunner(api_key = API_KEY, server = SERVER)
    client = runner._https_client
    solver_log = SolverLog()

    # monitor job, i.e., print logs, final status etc.
    _monitor_job(job_id, client, solver_log)

    # get time billed
    _retrieve_billing_time(job_id, client)


@app.command()
def status(job_id : str, item : int = 0):
    """
    Retrieves the status of a given job: CREATED, RUNNING, FINISHED, SUCCESS, TIMEOUT, TERMINATED, or ERROR.
    For batched jobs, the optional parameter 'item' selects the item of the batch.

    Args:
        job_id (str): The ID of the job to retrieve status for.
        item (int, optional): The index of the batch item. Defaults to 0.
    """

    with yaspin(text=f"Retrieving status for job {job_id}...", color="yellow") as spinner:
        runner = CloudRunner(api_key = API_KEY, server = SERVER)
        client = runner._https_client
        _status(job_id, item, client, spinner)


@app.command()
def solution(job_id : str, item : int = 0):
    """
    Display the solution vector for a given job if its computation completed with success.
    For batched jobs, the optional parameter 'item' selects the item of the batch.

    Args:
        job_id (str): The ID of the job to retrieve solution for.
        item (int, optional): The index of the batch item. Defaults to 0.
    """

    with yaspin(text=f"Retrieving solution for job {job_id}...", color="yellow") as spinner:
        runner = CloudRunner(api_key = API_KEY, server = SERVER)
        client = runner._https_client

    try:
        res, _ = client.get_results(job_id)
        res = res[item]

        spinner.text = f"Retrieved solution:"
        spinner.ok("✅")

        print(SolutionParser.parse(res["solution_file"]))

    except Exception as e:
        spinner.text = "Failed to retrieve solution."
        spinner.fail("❌")


@app.command()
def time_billed(job_id : str):
    """
    Output the time billed for a particular job in minutes.

    Args:
        job_id (str): The ID of the job to retrieve billing time for.
    """

    runner = CloudRunner(api_key = API_KEY, server = SERVER)
    client = runner._https_client

    _retrieve_billing_time(job_id, client)


@app.command()
def list(n : int = typer.Option(10, help="Maximum number of jobs to display")):
    """
    Shows a list of the user's n latest jobs with some basic information.

    Args:
        n (int, optional): Maximum number of jobs to display. Defaults to 10.
    """

    res = None
    with yaspin(text=f"Retrieving latest {n} jobs for given API key...", color="yellow") as spinner:
        runner = CloudRunner(api_key = API_KEY, server = SERVER)
        client = runner._https_client

        try:
            res = client.get_jobs(n)

            spinner.text = f"You have {len(res['running'])} active and {len(res['old'])} finished jobs" + \
                (":" if len(res["running"]) + len(res["old"]) > 0 else ".")
            spinner.ok("✅")

            # output jobs in a nice, tabular format
            def jobs2table(jobs, title=""):
                tbl = Table(title=f"[bold]{title}[/bold]", title_justify="left")
                tbl.add_column()
                tbl.add_column("Job ID")
                tbl.add_column("Size", justify="right")
                tbl.add_column("Tag"),
                tbl.add_column("Type(s)")
                tbl.add_column("Filename(s)")
                tbl.add_column("Created")
                tbl.add_column("Time billed", justify="right")

                for job in jobs:
                    bs = int(job["batch_size"])
                    dt = datetime.datetime.fromtimestamp(int(job["created"]))
                    status = ""
                    if bool(job["finished"]) and bool(job["successful"]):
                        status = "[green]✔[/green]"
                    elif bool(job["finished"]) and not bool(job["successful"]):
                        status = "[red]✗[/red]"

                    tbl.add_row(
                        status,
                        job["job_id"],
                        f"{bs}",
                        "---" if job["tag"] == "" else job["tag"],
                        job["first_type"] + (f" (+ {bs - 1})" if bs > 1 else ""),
                        job["first_filename"] + (f" (+ {bs - 1})" if bs > 1 else ""),
                        dt.strftime("%d.%m.%Y %H:%M:%S"),
                        job["time_billed"])
                console.print(tbl)

            if len(res["running"]) > 0:
                print("")
                jobs2table(res["running"], title="Active jobs")

            if len(res["old"]) > 0:
                print("")
                jobs2table(res["old"], title="Finished jobs")

        except Exception as e:
            spinner.text = "Failed to retrieve list of jobs."
            spinner.fail("❌")


@app.command()
def cancel(job_id : str):
    """
    Cancel a job that is currently running.

    Args:
        job_id (str): The ID of the job to cancel.
    """

    with yaspin(text=f"Canceling job {job_id}...", color="yellow") as spinner:
        runner = CloudRunner(api_key = API_KEY, server = SERVER)
        client = runner._https_client

        try:
            client.interrupt_job(job_id)

            spinner.text = f"Job canceled"
            spinner.ok("✅")
        except Exception as e:
            spinner.text = f"Failed to cancel job"
            spinner.fail("❌")


@app.command()
def api_key():
    """
    Prints the API key set through QUANTAGONIA_API_KEY.
    """
    console.print(f"Your API Key: [green]{API_KEY}[/green]")


@app.command()
def version():
    """
    Prints the version of this Python package.
    """
    __version__ = importlib.metadata.version("quantagonia")
    print(__version__)


if __name__ == "__main__":

    if "QUANTAGONIA_API_KEY" not in os.environ:
        _exit_with_error(
            "Quantagonia API Key not found. Please set the 'QUANTAGONIA_API_KEY' environment variable.")
    API_KEY = os.environ["QUANTAGONIA_API_KEY"]

    # internal users can set a server different to PROD through an env variable
    SERVER = HybridSolverServers[os.environ.get("QUANTAGONIA_SERVER", "PROD").upper()]
    if SERVER != HybridSolverServers.PROD:
        _print_warning(f"Job is submitted to {SERVER.name} environment.")

    app()
