#!python

from datetime import time
import json
import logging
from logging.handlers import TimedRotatingFileHandler
import os
from panza.jobs import new_job_workspace, DataFetchingError
from panza.config import init_config
from pika.exceptions import AMQPError
from pprint import pformat
import shutil
import signal
import ssl
from typing import Any, Callable, Dict

from rocinante.cli import parse_arguments
from rocinante.config import ConfigurationLoadError, load_config
from rocinante.log_exporter import LogExporter
from rocinante.logging import init_logging
from rocinante.rabbitmq import get_blocking_connection
from rocinante.driver import Driver
from rocinante.drivers.intra import IntraValidationDriver


def send_result(ch, result: Dict[str, Any], routing_key: str):
    logger.info("Sending job result...")
    try:
        ch.basic_publish(
            exchange="moulinette",
            routing_key=routing_key,
            body=json.dumps(result)
        )
    except AMQPError as e:
        logger.critical(f"Cannot publish job result: {e}")
        raise
    logger.info("Result successfully sent")


def export_logs(ch, info: Dict[str, Any], status: str, retries_count: int, message: str = None):
    try:
        log_exporter.send(
            ch,
            info=info,
            retries_count=retries_count,
            status=status,
            message=message
        )
    except AMQPError as e:
        logger.error(f"Cannot export logs: {e}")
        raise
    logger.info("Result successfully exported")


def retry_job(ch, job_name: str, routing_key: str, body: Dict[str, Any], info: Dict[str, Any]):
    try:
        if "retries_count" in body:
            body["retries_count"] += 1
            if body["retries_count"] > 4:
                logger.error(f"Dropping job {job_name}: too many retries")
                logger.error(f"Body was:")
                logger.error(pformat(body))
                export_logs(
                    ch,
                    info=info,
                    retries_count=body["retries_count"],
                    status="dropped"
                )
                return
        else:
            body["retries_count"] = 1
        logger.warning(f"Requesting a re-schedule of job {job_name}...")
        ch.basic_publish(
            exchange='moulinette',
            routing_key=routing_key,
            body=json.dumps(body)
        )
    except AMQPError as e:
        logger.critical(f"Cannot request a re-schedule of the job: {e}")
        raise


def process_job(
        driver: Driver,
        body: Dict[str, Any],
        reply: Callable[[Dict[str, Any], str, Dict[str, Any]], None],
        retry: Callable[[str, Dict[str, Any]], None]
):
    """
    Callback used to process a job and reply with result

    :param driver:          the driver to use for this job
    :param body:            the job body (parsed as JSON from the data received from the queue)
    :param reply:           the function to call to reply with the job's result
    :param retry:           the function to call to request a re-schedule of the job
    """
    try:
        info = driver.extract_job_information(body)
    except ValueError as e:
        logger.warning(f"Unable to extract job information: {e}")
        logger.warning("Dropping invalid job")
        return

    job_name = info["job_name"]
    logger.info(f"Job identified as {job_name}")

    try:
        moulinette_directory = driver.retrieve_moulinette(info)
    except Exception as e:
        logger.warning(f"Cannot retrieve the moulinette: {e}")
        logger.warning(f"Aborting job {job_name}")
        retry(job_name, info)
        return

    logger.info("Processing job...")

    try:
        with new_job_workspace(job_files_dir=moulinette_directory, job_name=job_name) as workspace:
            environment_name = info["job_environment"]

            job_status, job_feedback = workspace \
                .build_job_environment(environment_name) \
                .fetch_data(context=info) \
                .execute_job(context=info)
    except Exception as e:
        logger.warning(f"Cannot process job: {e}")
        logger.warning(f"Aborting job {job_name}")
        # Only retry if data could not be fetched. TODO: find a better way to detect errors
        if isinstance(e, DataFetchingError):
            retry(job_name, info)
        return

    logger.debug(f"Job result: {job_status}, {job_feedback}")

    if job_status.is_success():
        logger.info("Job successfully processed")
        status = "success" if len(job_feedback["messages"]) == 0 else "failure"
    else:
        logger.warning(f"Job executor reported an error: {job_status.message}")
        status = "error"

    result = driver.format_result(body, job_feedback)
    reply(result, status, info)


def handle_job(input_queue_name: str, driver: Driver, result_routing_key: str, ch, method, properties, body: bytes):
    def reply(result: Dict[str, Any], status: str, info: Dict[str, Any]):
        send_result(ch, result, result_routing_key)
        export_logs(
            ch,
            info=info,
            retries_count=parsed_body.get("retries_count") or 0,
            message=result["output"],  # TODO: find a way to abstract over that entry or to specify it explicitly
            status=status
        )

    def retry(job_name: str, info: Dict[str, Any]):
        retry_job(ch, job_name, method.routing_key, parsed_body, info)

    log_exporter.start_recording()
    parsed_body = json.loads(body.decode())
    logger.debug(f"Job received from queue '{input_queue_name}', with JSON body: {parsed_body}")
    process_job(driver, parsed_body, reply, retry)


def handle_sigterm(signal_number, frame):
    logger.info(f"Received signal {signal.Signals(signal_number).name} ({signal_number}), cleaning up...")
    channel.stop_consuming()
    shutil.rmtree(root_dir)
    logger.info("Quitting.")
    exit(0)


args = parse_arguments()

try:
    config = load_config(args.config_file)
except ConfigurationLoadError as e:
    print(e)
    exit(1)

rabbitmq_config: Dict[str, Any] = config["rabbitmq"]

config: Dict[str, Any] = config["config"]
init_config({**config, "root_dir": args.root_dir, "docker_bridge_ip": args.docker_bridge_ip})

root_dir = args.root_dir
if os.path.exists(root_dir):
    print(f"cannot create root directory at {root_dir}: directory already exists")
    exit(1)
os.mkdir(root_dir)
drivers_dir = f"{root_dir}/drivers"
os.mkdir(drivers_dir)
cache_dir = f"{root_dir}/fetcher_cache"
os.mkdir(cache_dir)

log_dir = args.log_dir
if not os.path.exists(log_dir):
    os.mkdir(log_dir)

console_handler = logging.StreamHandler()
file_handler = TimedRotatingFileHandler(
    f"{log_dir}/rocinante.log",
    when='midnight',
    atTime=time(hour=2),
    backupCount=7
)
log_exporter = LogExporter()
logger = init_logging([console_handler, file_handler, log_exporter.log_handler], debug=args.debug is True)

drivers = {
    "intra": IntraValidationDriver.create(logger, f"{drivers_dir}/intra", config),
}

# TODO: Create directories *before* actually instantiating drivers
for driver in drivers.keys():
    os.makedirs(f"{drivers_dir}/{driver}", exist_ok=True)

context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
context.verify_mode = ssl.CERT_NONE

signal.signal(signal.SIGTERM, handle_sigterm)

try:
    logger.info("Connecting to RabbitMQ server...")
    connection = get_blocking_connection(
        username=rabbitmq_config["username"],
        password=rabbitmq_config["password"],
        host=rabbitmq_config["host"],
        port=rabbitmq_config["port"],
        virtual_host=rabbitmq_config["virtual_host"],
        ssl_context=context
    )
    channel = connection.channel()
    channel.exchange_declare(exchange='moulinette', exchange_type='topic', durable=True, passive=True)

    for queue_config in args.queue_config:
        if queue_config.driver_name not in drivers:
            logger.warning(f"Unknown driver '{queue_config.driver_name}', ignoring it.")
            continue

        logger.info(
            f"Registering driver '{queue_config.driver_name}' as consumer for queue '{queue_config.queue_name}'..."
        )

        channel.queue_declare(queue=queue_config.queue_name, durable=True, passive=True)
        channel.basic_consume(
            queue=queue_config.queue_name,
            on_message_callback=lambda *args, **kwargs: handle_job(
                queue_config.queue_name,
                drivers[queue_config.driver_name],
                queue_config.result_routing_key,
                *args,
                **kwargs,
            ),
            auto_ack=True
        )

    logger.info("Waiting for jobs")
    channel.start_consuming()
except AMQPError as e:
    logger.critical(f"Unable to consume jobs from RabbitMQ: {str(e)}")
    exit(1)
