#!/usr/bin/env python3
# Copyright 2015-2020 CERN for the benefit of the ATLAS collaboration.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors:
# - Fernando Lopez <fernando.e.lopez@gmail.com>, 2015
# - Vincent Garonne <vincent.garonne@cern.ch>, 2018
# - Hannes Hansen <hannes.jakob.hansen@cern.ch>, 2018
# - Joaquin Bogado <jbogado@linti.unlp.edu.ar>, 2019
# - Martin Barisits <martin.barisits@cern.ch>, 2019
# - Benedikt Ziemons <benedikt.ziemons@cern.ch>, 2020
#
# PY3K COMPATIBLE

import argparse
import logging
import logging.handlers
import os
import signal
import sys
import textwrap
import time
from datetime import datetime
from functools import partial
from multiprocessing import Queue, Process, Event, Pipe

import rucio.common.config as config
import rucio.common.dumper as dumper
import rucio.daemons.auditor
from rucio.client.rseclient import RSEClient


def setup_pipe_logger(pipe, loglevel):
    logger = logging.getLogger('auditor')
    logger.setLevel(loglevel)
    handler = dumper.LogPipeHandler(pipe)
    logger.addHandler(handler)

    formatter = logging.Formatter(
        "%(asctime)s  %(name)-22s  %(levelname)-8s %(message)s"
    )
    handler.setFormatter(formatter)
    return logger


def main(args):
    RETRY_AFTER = 60 * 60 * 24 * 14  # Two weeks

    nprocs = args.nprocs
    assert nprocs >= 1
    if args.rses is None:
        rses_gen = RSEClient().list_rses()
    else:
        rses_gen = RSEClient().list_rses(args.rses)

    rses = [entry['rse'] for entry in rses_gen]
    assert len(rses) > 0

    procs = []
    queue = Queue()
    retry = Queue()
    terminate = Event()
    logpipes = []

    loglevel = logging.getLevelName(config.config_get('common', 'loglevel'))

    mainlogr, mainlogw = Pipe(duplex=False)
    logpipes.append(mainlogr)
    logger = setup_pipe_logger(mainlogw, loglevel)

    assert config.config_has_section('auditor')
    cache_dir = config.config_get('auditor', 'cache')
    results_dir = config.config_get('auditor', 'results')

    logfilename = os.path.join(config.config_get('common', 'logdir'), 'auditor.log')
    logger.info('Starting auditor')

    def termhandler(sign, trace):
        logger.error('Main process received signal %d, terminating child processes', sign)
        terminate.set()
        for proc in procs:
            proc.join()

    signal.signal(signal.SIGTERM, termhandler)

    for n in range(nprocs):
        logpiper, logpipew = Pipe(duplex=False)
        p = Process(
            target=partial(
                rucio.daemons.auditor.check,
                queue,
                retry,
                terminate,
                logpipew,
                cache_dir,
                results_dir,
                args.keep_dumps,
                args.delta,
            ),
            name='auditor-worker'
        )
        p.start()
        procs.append(p)
        logpipes.append(logpiper)

    p = Process(
        target=partial(
            rucio.daemons.auditor.activity_logger,
            logpipes,
            logfilename,
            terminate
        ),
        name='auditor-logger'
    )
    p.start()
    procs.append(p)

    last_run_month = None  # Don't check more than once per month. FIXME: Save on DB or file...

    try:
        while all(p.is_alive() for p in procs):
            while last_run_month == datetime.now().month:
                time.sleep(60 * 60 * 24)

            for rse in rses:
                queue.put((rse, 1))

            time.sleep(RETRY_AFTER)

            # Avoid infinite loop if an alternative check() implementation doesn't
            # decrement the number of attemps and keeps pushing failed checks.
            tmp_list = []
            while not retry.empty():
                tmp_list.append(retry.get())

            for each in tmp_list:
                queue.put(each)

    except:
        logging.error('Main process failed: %s', sys.exc_info()[0])

    terminate.set()
    for proc in procs:
        proc.join()


def get_parser():
    """
    Returns the argparse parser.
    """
    parser = argparse.ArgumentParser(description="The auditor daemon is the one responsable for the detection of inconsistencies on storage, i.e.: dark data discovery.",
                                     formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument(
        '--nprocs',
        help='Number subprocess, each subprocess check a fraction of the DDM '
             'Endpoints in sequence (default: 1).',
        default=1,
        type=int,
    )
    parser.add_argument(
        '--rses',
        help='RSEs to check specified as a RSE expression, defaults to check '
             'all the RSEs known to Rucio (default: check all RSEs).',
        default=None,
        type=str,
    )
    parser.add_argument(
        '--keep-dumps',
        help='Keep RSE and Rucio Replica Dumps on cache '
             '(default: False).',
        action='store_true',
    )
    parser.add_argument(
        '--delta',
        help='How many days older/newer than the RSE dump must the Rucio replica dumps be '
             '(default: 3).',
        default=3,
        type=int,
    )
    parser.epilog = textwrap.dedent('''
        examples:
            # Check all RSEs using only 1 subprocess
            %(prog)s

            # Check all SCRATCHDISKs with 4 subprocesses
            %(prog)s --nprocs 4 --rses "type=SCRATCHDISK"

            # Check all Tier 2 DATADISKs, except "BLUE_DATADISK" and "RED_DATADISK"
            %(prog)s --rses "tier=1&type=DATADISK\(BLUE_DATADISK|RED_DATADISK)"
    ''')  # NOQA: W605
    return parser


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    main(args)
