#!python

from enum import Enum
import json
import os
from pprint import pformat
import shutil
import signal
import ssl
from typing import Any, Callable, Dict

from panza import new_job_workspace
from panza.backends.docker import DockerWithAdditionalDaemonBackend
from panza.cache import Cache
from panza.errors import DataFetchingError
from pika.exceptions import AMQPError

from rocinante.config import ConfigurationLoadError, RabbitMQConfiguration, load_config
from rocinante.driver import Driver
from rocinante.drivers.intra import IntraValidationDriver
from rocinante.drivers.la_mancha import LaManchaDriver
from rocinante.errors import RetryableError
from rocinante.logging import init_logging, logger_for_driver
from rocinante.rabbitmq import RecoverableBlockingConsumerPublisher, build_connection_parameters
from rocinante.utils import make_credentials_context, sanitize_for_filename


class JobStatus(Enum):
    PASSED = 0
    RETRY_REQUESTED = 1
    DROPPED = 2
    RETRIED_TOO_MANY_TIMES = 3
    UNCOMPLETED = 4


def send_result(publisher, result: Dict[str, Any], routing_key: str) -> JobStatus:
    logger.info("Sending job result...")
    try:
        publisher.basic_publish(
            exchange="moulinette",
            routing_key=routing_key,
            body=json.dumps(result)
        )
    except AMQPError as e:
        logger.critical(f"Cannot publish job result: {e}")
        raise
    logger.info("Result successfully sent")
    return JobStatus.PASSED


def retry_job(publisher, job_name: str, routing_key: str, body: Dict[str, Any]) -> JobStatus:
    remaining_retries = body.get("remaining_retries", config.max_job_retries)
    if remaining_retries <= 0:
        logger.error(f"Dropping job {job_name}: too many retries")
        logger.error(f"Body was:")
        logger.error(pformat(body))
        return JobStatus.RETRIED_TOO_MANY_TIMES

    body["remaining_retries"] = remaining_retries - 1
    logger.warning(f"Requesting a re-schedule of job {job_name}...")

    try:
        publisher.basic_publish(
            exchange='moulinette',
            routing_key=routing_key,
            body=json.dumps(body)
        )
        return JobStatus.RETRY_REQUESTED
    except AMQPError as e:
        logger.critical(f"Cannot request a re-schedule of the job: {e}")
        raise


def process_job(
        driver: Driver,
        body: Dict[str, Any],
        reply: Callable[[Dict[str, Any]], JobStatus],
        retry: Callable[[str], JobStatus]
) -> JobStatus:
    """
    Callback used to process a job and reply with result

    :param driver:          the driver to use for this job
    :param body:            the job body (parsed as JSON from the data received from the queue)
    :param reply:           the function to call to reply with the job's result
    :param retry:           the function to call to request a re-schedule of the job
    """
    job_name = None
    try:
        info = driver.extract_job_information(body)
        job_name = info["job_name"]
        logger.info(f"Job identified as {job_name}")
        logger.info(f"Job information: {json.dumps(info)}")

        moulinette_directory = driver.retrieve_moulinette(info)

        logger.info("Processing job...")

        job_root = f"{config.root_directory}/panza/{job_name}"

        with new_job_workspace(backend, with_files=moulinette_directory, with_root=job_root) as workspace:
            environment_name = info["job_environment"]
            cache_entry_name = f"{job_name}_{sanitize_for_filename(info['request_date'])}"
            credentials_context = make_credentials_context(config.credentials)

            handle = workspace \
                .build_job_environment(environment_name) \
                .fetch_data(context={**info, **credentials_context}, cache=cache, cache_entry=cache_entry_name)
            job_result = handle.execute_job(
                context={**info, **credentials_context},
                environment_tag=environment_name,
                timeout=config.job_timeout
            )

        result = driver.format_result(body, handle.blueprint, job_result)
        logger.info(f"Job successfully processed, result is: {json.dumps(result)}")
        return reply(result)

    except (RetryableError, DataFetchingError) as e:
        logger.warning(f"Cannot process job: {e}")
        logger.warning(f"Aborting job {job_name or '<unnamed>'}")
        return retry(job_name or '<unnamed>')

    except Exception as e:
        logger.warning(f"Cannot process job: {e}")
        logger.warning(f"Dropping job {job_name or '<unnamed>'}")
        return JobStatus.DROPPED


def handle_job(input_queue_name: str, driver: Driver, ch, method, properties, body: bytes):
    def reply(result: Dict[str, Any]) -> JobStatus:
        return send_result(consumer_publisher, result, result_routing_key)

    def retry(job_name: str) -> JobStatus:
        return retry_job(consumer_publisher, job_name, method.routing_key, parsed_body)

    parsed_body = json.loads(body.decode())
    logger.info(f"Job received from queue '{input_queue_name}', bound to driver '{type(driver).__name__}'")
    logger.debug(f"Job JSON body: {parsed_body}")

    result_routing_key = parsed_body["result"]["routing.key"]
    status = JobStatus.UNCOMPLETED
    try:
        status = process_job(driver, parsed_body, reply, retry)
    finally:
        logger.info(f"Finished processing job, status: {status.name}")


def configure_publisher_channel(channel):
    channel.exchange_declare(exchange='moulinette', exchange_type='topic', durable=True, passive=True)


def configure_consumer_channel(channel):
    channel.exchange_declare(exchange='moulinette', exchange_type='topic', durable=True, passive=True)

    for queue_config in config.queues:
        if queue_config.driver not in drivers:
            logger.warning(f"Unknown driver '{queue_config.driver}', ignoring it.")
            continue

        logger.info(f"Registering driver '{queue_config.driver}' as consumer for queue '{queue_config.name}'...")

        channel.queue_declare(queue=queue_config.name, durable=True, passive=True)
        channel.basic_consume(
            queue=queue_config.name,
            on_message_callback=lambda *args, **kwargs: handle_job(
                queue_config.name,
                drivers[queue_config.driver],
                *args,
                **kwargs,
            ),
            auto_ack=True
        )
    logger.info("Successfully registered as consumer for each specified queue")


def get_consumer_publisher(rabbitmq_config: RabbitMQConfiguration) -> RecoverableBlockingConsumerPublisher:
    ssl_context = None
    if rabbitmq_config.use_ssl:
        ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
        ssl_context.verify_mode = ssl.CERT_NONE

    params = build_connection_parameters(
        username=rabbitmq_config.username,
        password=rabbitmq_config.password,
        host=rabbitmq_config.host,
        port=rabbitmq_config.port,
        virtual_host=rabbitmq_config.virtual_host,
        ssl_context=ssl_context,
    )

    return RecoverableBlockingConsumerPublisher(
        params,
        configure_consumer_channel=configure_consumer_channel,
        configure_publisher_channel=configure_publisher_channel,
        max_reconnection_retries=8,
        backoff_factor=0.1,
    )


def cleanup():
    shutil.rmtree(config.root_directory)


def handle_sigterm(signal_number, frame):
    logger.info(f"Received signal {signal.Signals(signal_number).name} ({signal_number}), cleaning up...")
    consumer_publisher.stop_consuming()
    cleanup()
    logger.info("Quitting.")
    exit(0)


try:
    config = load_config()
except ConfigurationLoadError as e:
    print(e)
    exit(1)

backend = DockerWithAdditionalDaemonBackend(config.additional_docker_daemon)

if os.path.exists(config.root_directory):
    print(f"cannot create root directory at {config.root_directory}: directory already exists")
    exit(1)
os.makedirs(config.root_directory)
drivers_dir = f"{config.root_directory}/drivers"
os.mkdir(drivers_dir)
cache_dir = f"{config.root_directory}/fetcher_cache"
cache = Cache.create_at(cache_dir, max_entries=config.cache.max_entries)

os.makedirs(config.log_directory, exist_ok=True)

logger = init_logging(config.log_directory, debug=config.debug is True)

drivers = {
    "intra": IntraValidationDriver.create,
    "la_mancha": LaManchaDriver.create,
}

for driver in drivers.keys():
    os.makedirs(f"{drivers_dir}/{driver}", exist_ok=True)
    drivers[driver] = drivers[driver](logger_for_driver(driver, debug=config.debug), f"{drivers_dir}/{driver}", config)

signal.signal(signal.SIGTERM, handle_sigterm)

try:
    logger.debug(f"Connecting to RabbitMQ at {config.rabbitmq.host}:{config.rabbitmq.port}...")
    consumer_publisher = get_consumer_publisher(config.rabbitmq)

    logger.info(f"Successfully connected to {config.rabbitmq.host}:{config.rabbitmq.port}")
    consumer_publisher.start_consuming()
except AMQPError as e:
    logger.critical(f"Unable to consume jobs from RabbitMQ: {str(e)}")
    cleanup()
    exit(1)
