#!python

import argparse
from datetime import time
from etna_api import EtnaSession
import json
import logging
from logging.handlers import TimedRotatingFileHandler
import os
from panza.jobs import new_job_workspace, add_logger_handler, DataFetchingError
from panza.config import init_config
import pika
from pika.exceptions import AMQPError
from pprint import pformat
import shutil
import signal
import ssl
from typing import Dict, Any


def extract_info(body: Dict[str, Any]):
    info = {}
    _, module_id, _, activity_id = body["files_path"].split("/")
    info["module_id"] = int(module_id)
    info["activity_id"] = int(activity_id)
    info["group_id"] = body["result"]["group_id"]
    info["leader"] = body["result"]["leader"]
    info["stage"] = body["result"]["stage"]
    info["stage_end"] = body["tokens"]["stage_end"]
    info["request_date"] = body["date"]
    return info


def download_moulinette(destination_dir: str, info: Dict[str, Any]):
    logger.info("Downloading the moulinette from the intranet resources...")
    session = EtnaSession(
        username=config["intra_user"],
        password=config["intra_password"],
        request_retries=10,
        retry_on_statuses=(500, 502, 504)
    )

    intra_moulinette_dir = "resources/moulinette/"
    files = session.get_activity_stage_files_list(info["module_id"], info["activity_id"], info["stage"])

    files = map(lambda x: (x["rel_path"], x["rel_path"].rsplit(f"/stages/{info['stage']}/", maxsplit=1)[1]), files)
    files = filter(lambda x: x[1].startswith(intra_moulinette_dir), files)

    os.mkdir(destination_dir)

    for dist_path, local in files:
        local = os.path.relpath(local, intra_moulinette_dir)
        dirname = os.path.dirname(local)
        if not os.path.exists(f"{destination_dir}/{dirname}"):
            os.makedirs(f"{destination_dir}/{dirname}")
        data = session.download_file_from_activity(info["module_id"], info["activity_id"], dist_path)
        with open(f"{destination_dir}/{local}", "wb") as f:
            f.write(data)


def sanitize_filename(name: str) -> str:
    import string
    from functools import reduce
    allowed_chars = string.ascii_lowercase + string.digits
    name = map(lambda x: x if x in allowed_chars else '_', name.lower())
    return reduce(lambda acc, x: (acc + x) if x != '_' or (len(acc) > 0 and acc[-1] != '_') else acc, name, "")


def retry_job(job_name: str, body: Dict[str, Any]):
    try:
        if "retries_count" in body:
            body["retries_count"] += 1
            if body["retries_count"] > 4:
                logger.error(f"Dropping job {job_name}: too many retries")
                logger.error(f"Body was:")
                logger.error(pformat(body))
                return
        else:
            body["retries_count"] = 1
        logger.warning(f"Requesting a re-schedule of job {job_name}...")
        channel.basic_publish(
            exchange='moulinette',
            routing_key='quest_moulinette',
            body=json.dumps(body)
        )
    except AMQPError as e:
        logger.critical(f"Cannot request a re-schedule of the job: {e}")
        raise


def handle_job(ch, method, properties, body: bytes):
    body = json.loads(body.decode())
    logger.debug(f"Job received, with raw body: {body}")

    result = body["result"]
    info = extract_info(body)
    sanitized_stage_name = sanitize_filename(info['stage'])
    sanitized_req_date = sanitize_filename(info["request_date"])

    logger.info(
        f"Job received for {info['module_id']}-{info['activity_id']}-{sanitized_stage_name}, group {info['group_id']}"
    )

    job_name = f"{info['module_id']}-{info['activity_id']}-{sanitized_stage_name}-{info['group_id']}"
    moulinette_directory = \
        f"{downloads_dir}/{info['module_id']}-{info['activity_id']}-{sanitized_stage_name}-{sanitized_req_date}"

    if not os.path.exists(moulinette_directory):
        try:
            download_moulinette(moulinette_directory, info)
        except Exception as e:
            logger.warning(f"Cannot download the moulinette from the intranet: {e}")
            logger.warning(f"Aborting job {info['module_id']}-{info['activity_id']}-{sanitized_stage_name}, "
                           f"group {info['group_id']}")
            shutil.rmtree(moulinette_directory, ignore_errors=True)
            retry_job(job_name, body)
            return
    else:
        logger.info(f"Using cached moulinette from {info['request_date']}")

    environment_name = f"rocinante-{info['module_id']}-{info['activity_id']}-{sanitized_stage_name}"

    logger.info("Processing job...")

    try:
        with new_job_workspace(moulinette_directory, job_name) as workspace:
            job_status, job_feedback = workspace \
                .build_job_environment(environment_name) \
                .fetch_data(context=info) \
                .execute_job(context=info)
    except Exception as e:
        logger.warning(f"Cannot process job: {e}")
        logger.warning(f"Aborting job {info['module_id']}-{info['activity_id']}-{sanitized_stage_name}, "
                       f"group {info['group_id']}")
        # Only retry if data could not be fetched. TODO: find a better way to detect errors
        if isinstance(e, DataFetchingError):
            retry_job(job_name, body)
        return

    logger.debug(f"Job result: {job_status}, {job_feedback}")

    if job_status.is_success():
        logger.info("Job successfully processed")
    else:
        logger.warning(f"Job executor reported an error: {job_status.message}")

    feedback_messages = job_feedback["messages"]
    if len(feedback_messages) > 0:
        output = "KO:\n"
        for message in feedback_messages:
            output += message
        result["output"] = output
        result["status"] = 1
    else:
        result["output"] = "OK"
        result["status"] = 0

    logger.info("Sending job result...")

    try:
        channel.basic_publish(
            exchange='moulinette',
            routing_key='quest_result.import',
            body=json.dumps(result)
        )
    except AMQPError as e:
        logger.critical(f"Cannot publish job result: {e}")
        raise

    logger.info("Result successfully sent")


def handle_sigterm(signal_number, frame):
    logger.info(f"Received signal {signal.Signals(signal_number).name} ({signal_number}), cleaning up...")
    channel.stop_consuming()
    shutil.rmtree(root_dir)
    logger.info("Quitting.")
    exit(0)


arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
    "-c", "--config-file", type=str, required=True,
    help="the path to the configuration file"
)
arg_parser.add_argument(
    "-r", "--root-dir", type=str, default="/var/run/rocinante",
    help="the path of the directory to use as root directory"
)
arg_parser.add_argument(
    "--docker-bridge-ip", type=str, default="10.9.8.7/24",
    help="the range of addresses to use for Docker's bridge interface"
)
arg_parser.add_argument(
    "-l", "--log-dir", type=str, default="/var/log/rocinante",
    help="the path of the directory to use as log directory"
)
arg_parser.add_argument(
    "--debug", action='store_true',
    help="whether debug logs should be emitted"
)

args = arg_parser.parse_args()

try:
    with open(args.config_file, 'r') as config_file:
        config = json.load(config_file)
except (OSError, IOError) as e:
    print("cannot load configuration file from", args.config_file)
    exit(1)

rmq_config: Dict[str, Any] = config["rabbitmq"]

host = rmq_config["host"]
port = rmq_config["port"]
username = rmq_config["username"]
password = rmq_config["password"]
virtual_host = rmq_config["virtual_host"]

config: Dict[str, Any] = config["config"]
init_config({**config, "root_dir": args.root_dir, "docker_bridge_ip": args.docker_bridge_ip})

root_dir = args.root_dir
log_dir = args.log_dir
if os.path.exists(root_dir):
    print(f"cannot create root directory at {root_dir}: directory already exists")
    exit(1)

os.mkdir(root_dir)
downloads_dir = f"{root_dir}/downloads"
os.mkdir(downloads_dir)
if not os.path.exists(log_dir):
    os.mkdir(log_dir)

logger = logging.getLogger("rocinante")
logger.setLevel(logging.DEBUG if args.debug else logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
file_handler = TimedRotatingFileHandler(
    f"{log_dir}/rocinante.log",
    when='midnight',
    atTime=time(hour=2),
    backupCount=7
)
formatter = logging.Formatter('%(asctime)s [%(name)s] %(levelname)s: %(message)s')
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
logger.addHandler(console_handler)
logger.addHandler(file_handler)
add_logger_handler(console_handler)
add_logger_handler(file_handler)

credentials = pika.PlainCredentials(username, password)
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
context.verify_mode = ssl.CERT_NONE
params = pika.ConnectionParameters(
    host=host,
    port=port,
    credentials=credentials,
    virtual_host=virtual_host,
    ssl_options=pika.SSLOptions(context)
)

signal.signal(signal.SIGTERM, handle_sigterm)

try:
    logger.info("Connecting to RabbitMQ server...")
    connection = pika.BlockingConnection(params)
    channel = connection.channel()
    channel.queue_declare(queue="quest_moulinette", durable=True, passive=True)
    channel.exchange_declare(exchange='moulinette', exchange_type='topic', durable=True, passive=True)

    logger.info("Registering as consumer...")
    channel.basic_consume(queue="quest_moulinette", on_message_callback=handle_job, auto_ack=True)
    logger.info("Waiting for jobs")
    channel.start_consuming()
except AMQPError as e:
    logger.critical(f"unable to consume jobs from RabbitMQ: {str(e)}")
    exit(1)
