#!python

import sys
import os
import argparse
import logging
import datetime
import time
from collections import namedtuple
import tempfile
import subprocess
import atexit
import shutil

from io import StringIO

from plenum.server.general_config import ubuntu_platform_config

logger = logging.getLogger()

ArchiveEntry = namedtuple('archive_entry', ['abs_fs_path', 'archive_path', 'include'])

ENV_COLLECTION_CMD = [
    ("python_version", ['python3', '-V']),
    ("pip_freeze", ['pip3', 'freeze']),
    ("app_install", ['apt', 'list', '--installed']),
    ("validator_info", ['validator-info', '-v', '--json']),
]

DEFAULT_ROOT_DIR = '/'


# ***************
# *  Command-line Argument Parsing
# ***************
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError(
            'Boolean value (yes, no, true, false, y, n, 1, or 0) expected.')


LOG_LEVEL_HELP = """Logging level.
                      [LOG-LEVEL]: notset, debug, info, warning, error, critical
                      Default: warning"""
levels = {
    'notset': logging.NOTSET,
    'debug': logging.DEBUG,
    'info': logging.INFO,
    'warning': logging.WARNING,
    'error': logging.ERROR,
    'critical': logging.CRITICAL
}


def log_level(v):
    if v.lower() in levels.keys():
        return levels[v.lower()]
    else:
        raise argparse.ArgumentTypeError(
            'Expected one of the following: {}.'.format(
                ', '.join(levels.keys())))


def program_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('-t', '--test', action='store_true',
                        default=False, help='Runs unit tests and exits.')

    parser.add_argument('-l', '--log-level', type=log_level, nargs='?',
                        const=logging.INFO, default=logging.INFO, help=LOG_LEVEL_HELP)

    parser.add_argument('-o', '--output-dir', default=os.getcwd(),
                        help='the directory where the captured file will be writen to. Default: CWD')
    parser.add_argument('-n', '--node-name',
                        help='The name of the node to be captured, required if the script finds more than one node.')
    parser.add_argument('-r', '--root-dir', default=DEFAULT_ROOT_DIR,
                        help='The root to look for node artifacts. Default: ' + DEFAULT_ROOT_DIR)

    parser.add_argument('-x', '--exclude-recording', action='store_true',
                        help='Will exclude recording data from capture.')
    parser.add_argument('-d', '--dry-run', action='store_true',
                        default=False, help='Will not create archive file but will collect files to be collected')

    return parser


def parse_args(argv=None, parser=program_args()):
    return parser.parse_args(args=argv)


def get_system_dirs():
    # TODO detect the system someday (help support windows)
    return ubuntu_platform_config


# ***************
# *  Capture File Collection
# ***************
def _log_capture(collection):
    for entry in collection:
        logger.info("%s found at %s" % (entry.archive_path, entry.abs_fs_path))
        for file in entry.include:
            logger.info("    %s" % file)


def _collect_files(dir_type, directory, archive_path, filter_fn, recursive=True):
    if not os.path.exists(directory):
        return None

    rtn = ArchiveEntry(directory, archive_path, [])

    try:
        for root, dirs, files in os.walk(directory):
            rel_path = os.path.relpath(root, directory)
            files = [os.path.join(rel_path, f) for f in files]
            files = filter(filter_fn, files)
            rtn.include.extend(files)
            if recursive is not True:
                break
    except FileNotFoundError:
        logger.warning("Unable to find a %s directory for pool! Looked for '%s'", dir_type, directory)

    return rtn


def _run_env_capture_cmd(output_file, cmd, entry):
    try:
        with open(output_file, 'w') as f:
            subprocess.run(cmd,
                           check=True,
                           stderr=subprocess.STDOUT,
                           stdout=f
                           )
    except (subprocess.CalledProcessError, FileNotFoundError) as e:
        logger.warning("Unable to capture env data: cmd '%s' - error: %s", ' '.join(cmd), str(e))
    finally:
        if os.path.exists(output_file) and os.path.getsize(output_file) > 0:
            entry.include.append(output_file)


def collect_env(args):
    logger.info("Collecting Env Files")
    d = tempfile.mkdtemp()
    TEMP_DIRS.append(d)
    rtn = ArchiveEntry(d, '/capture', [])
    for filename, cmd in ENV_COLLECTION_CMD:
        _run_env_capture_cmd(os.path.join(d, filename),
                             cmd,
                             rtn)
    logger.info("Collected %s files to captured", str(len(rtn.include)))
    return rtn


def collect_logs(args):
    log_dir = indy_log_dir(args)
    log_dir = os.path.join(log_dir, args.pool_name)

    def log_filter(f):
        return f.startswith(os.path.join('.', args.node_name))

    logger.info("Collecting Log Files")
    rtn = _collect_files('log',
                         log_dir,
                         '/log',
                         log_filter,
                         recursive=False)
    if rtn:
        logger.info("Collected %s files to captured", str(len(rtn.include)))
    else:
        logger.warning("Unable to collect files from Log")
    return rtn


def collect_config(args):
    d = indy_config_dir(args)

    def file_filter(_f):
        return True

    logger.info("Collecting Config Files")
    rtn = _collect_files('config', d, '/config', file_filter, recursive=False)

    if rtn:
        logger.info("Collected %s files to captured", str(len(rtn.include)))
    else:
        logger.warning("Unable to collect files from Config")

    return rtn


def collect_data(args):
    d = indy_app_dir(args)
    d = os.path.join(d, args.pool_name)

    def file_filter(f):
        if f.startswith('data/'):
            if f.startswith(os.path.join('data/', args.node_name + '/')) \
                    or f.startswith(os.path.join('data/', args.node_name + 'C/')):

                if 'recorder' in f:
                    return not args.exclude_recording
                return True
            else:
                return False

        return True

    logger.info("Collecting Data Files")
    rtn = _collect_files('application data', d, '/', file_filter, recursive=True)

    if rtn:
        logger.info("Collected %s files to captured", str(len(rtn.include)))
    else:
        logger.warning("Unable to collect files from Data")

    return rtn


def collect_plugins(args):
    d = indy_app_dir(args)
    d = os.path.join(d, 'plugins')

    def file_filter(f):
        return True

    logger.info("Collecting Plugins Files")
    rtn = _collect_files('plugin data', d, '/plugins', file_filter, recursive=True)

    if rtn:
        logger.info("Collected %s files to captured", str(len(rtn.include)))
    else:
        logger.warning("Unable to collect files from Plugins")

    return rtn


COLLECT_FN = [
    collect_env,
    collect_logs,
    collect_config,
    collect_data,
    collect_plugins,
]


def collect_capture_files(args):
    rtn = []
    logger.info("Collecting file to be captured")
    for fn in COLLECT_FN:
        val = fn(args)
        if val:
            rtn.append(val)
    logger.info("Collection complete")
    return rtn


# ***************
# *  Indy Directories
# ***************
def _find_dir(args, dir_location, name):
    dir_location = dir_location.lstrip("/")
    rtn = os.path.join(args.root_dir, dir_location)
    logger.debug("using %s dir: %s", name, rtn)
    return rtn


def indy_app_dir(args):
    return _find_dir(args, get_system_dirs().NODE_INFO_DIR, "application dir")


def indy_log_dir(args):
    return _find_dir(args, get_system_dirs().LOG_DIR, "logging dir")


def indy_config_dir(args):
    return _find_dir(args, '/etc/indy', "config dir")


# ***************
# *  Capture File Output File Name
# ***************
def capture_file_name(args):
    if (not hasattr(args, 'node_name') or not hasattr(args, 'pool_name')) or \
            (args.node_name is None or args.pool_name is None):
        logger.error("node_name or pool_name has not been specified, can not create capture file name. Object: %s",
                     str(args))
        raise Exception("Can not create capture file name!")

    rtn = "%s.%s.%s" % (args.node_name, args.pool_name, datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
    return rtn


# ***************
# *  Node Discovery
# ***************
def find_node_name(args):
    nodes = []

    def find_nodes_in_pool_dir(pool_name):
        data_dir = os.path.join(indy_app_dir(args), pool_name, "data")

        try:
            possible_nodes = sorted(os.listdir(data_dir))
        except FileNotFoundError:
            logger.warning("Did not find a data directory for pool '%s'! (Likely not an issue) Looked for '%s'",
                           pool_name, data_dir)
            return []

        possible_nodes = list(filter(is_dir_check(data_dir), possible_nodes))
        logger.debug("possible node names in %s: %s", pool_name, possible_nodes)
        node_names = set()
        while len(possible_nodes) is not 0:
            a_name = possible_nodes.pop(0)
            if a_name + 'C' in possible_nodes:
                possible_nodes.remove(a_name + 'C')
            node_names.add((a_name, pool_name))

        return node_names

    def is_dir_check(base_dir):
        def is_dir(d):
            return os.path.isdir(os.path.join(base_dir, d))

        return is_dir

    app_dir = indy_app_dir(args)
    try:
        pool_dirs = list(filter(is_dir_check(indy_app_dir(args)), os.listdir(app_dir)))
    except FileNotFoundError:
        logger.error("Unable to find indy application directory! Looked for '%s'", app_dir)
        raise

    for pool in pool_dirs:
        # ignore plugins dir - it is not a pool directory
        if "plugins" == pool:
            continue

        nodes.extend(find_nodes_in_pool_dir(pool))

    if args.node_name is None:
        if len(nodes) == 1:
            node_name, pool = nodes.pop()
            args.node_name = node_name
            args.pool_name = pool
        elif len(nodes) == 0:
            logger.error("No nodes where found. Check root dir?")
            sys.exit(-1)
        else:
            logger.error("Found more than one node and a node name was not specified! Found -- %s", nodes)
            sys.exit(-1)
    else:
        for node in nodes:
            if args.node_name == node[0]:
                args.pool_name = node[1]
                break
        else:
            logger.error("Node with name '%s' was not found, can not continue! Found -- %s", args.node_name, nodes)
            sys.exit(-1)

    logger.info("Capturing Node '%s' for pool '%s'", args.node_name, args.pool_name)


TEMP_DIRS = []


# Clean up anything that is created by this script (excluding the output file)
def clean_up():
    while TEMP_DIRS:
        shutil.rmtree(TEMP_DIRS.pop())

        # raise Exception("TEST")


# ***************
# *  Init Details
# ***************
def init(args):
    # print warning
    print("!" * 20 + " WARNING " + '!' * 21)
    print("!" * 2 + "  " + "This tool captures node state replay and diagnostic purposes (for development and QA)")
    print("!" * 2 + "  " + "By the nature of these needs, that includes sensitive data (including secret keys)")
    print("!" * 2 + "  " + "DO NOT USE THIS TOOL on nodes that need to be secure")
    print("!" * 50)
    time.sleep(1)

    logger.setLevel(args.log_level)
    logger.debug("args: %s", args)
    find_node_name(args)

    atexit.register(clean_up)


def build_archive_cmd(output_file, files):
    cmd = ['tar']
    flags = 'cfz'
    if logger.level == logging.DEBUG:
        flags += 'v'
    cmd.append(flags)
    cmd.append(output_file)
    cmd.append('--show-transformed-names')
    cmd.extend(["-C", "/"])

    for entry in files:
        fs_path_pattern = entry.abs_fs_path.strip('/')
        archive_path_pattern = entry.archive_path.lstrip('/')
        if archive_path_pattern:
            archive_path_pattern += '/'  # Add slash for directories otherwise leave blank

        if entry.include:
            cmd.append('--xform')
            cmd.append("s|%s/|%s|1" % (fs_path_pattern, archive_path_pattern))
            for file in entry.include:
                f_path = os.path.join(entry.abs_fs_path, file)
                f_path = os.path.abspath(f_path)
                f_path = f_path.lstrip('/')
                cmd.append(f_path)
        else:
            cmd.append('--xform')
            cmd.append("s|%s|%s|1" % (fs_path_pattern, archive_path_pattern))
            f_path = entry.abs_fs_path.strip('/')
            cmd.append(f_path + '/')  # Add slash to end for sed expression matching

    return cmd


# ***************
# *  Running the archive command
# ***************
def run_cmd(cmd):
    logger.info("Running archive command")
    try:
        subprocess.run(cmd, check=True)
        # os.system(' '.join(cmd))
    except Exception as e:
        logger.error("Unable to complete archive command: %s", str(e))
    else:
        logger.info("Archive command complete")


def do_capture(args, files):
    output_file = os.path.join(args.output_dir, capture_file_name(args) + '.tar.gz')
    logger.info("Writing archive to '%s'" % output_file)

    logger.info("Building archive command")
    cmd = build_archive_cmd(output_file, files)
    logger.debug("Archive command to be run:")
    logger.debug(' '.join(cmd))

    if args.dry_run:
        logger.info("Dry-run")
        logger.info("Collected following files:")
        _log_capture(files)
        return
    else:
        if args.root_dir == DEFAULT_ROOT_DIR and os.getuid() is not 0:
            logger.error("Must be run as super user (sudo) to capture all files.")
            raise Exception("Not Super User")

        run_cmd(cmd)


def main(args):
    try:
        init(args)
    except Exception:
        logger.error('Unable to initialize script')
        raise

    try:
        files = collect_capture_files(args)
    except Exception:
        logger.error('Unable to collect the files for this capture')
        raise

    try:
        do_capture(args, files)
    except Exception:
        logger.error('Unable to build capture archive')
        raise

    return 0


# ***************
# *  UNIT TESTS !!!!! (use -t to run them)
# ***************
def test():
    print("The 'unittest' module is not available!\nUnable to run tests!")
    return 0


try:
    import unittest

    def test(args, module='__main__'):
        t = unittest.main(argv=['capture_test'], module=module, exit=False, verbosity=10)
        return int(not t.result.wasSuccessful())

    def _touch(path, content):
        d = os.path.dirname(path)
        os.makedirs(d, exist_ok=True)
        with open(path, 'w') as f:
            f.write(content)

    class TestCapture(unittest.TestCase):
        @classmethod
        def setUpClass(cls):
            sys.stderr = StringIO()
            logger.setLevel(sys.maxsize)
            pass

        @classmethod
        def tearDownClass(cls):
            clean_up()

        @staticmethod
        def create_sample_dir(node_names=None, pool_name='sandbox'):
            d = tempfile.TemporaryDirectory()
            args = argparse.Namespace(root_dir=d.name)

            if not node_names:
                node_names = ["Node1"]

            app_dir = indy_app_dir(args)
            for name in node_names:
                node_data_dir = os.path.join(app_dir, pool_name, "data", name)
                client_data_dir = os.path.join(app_dir, pool_name, "data", name + "C")

                _touch(os.path.join(node_data_dir, 'ledger.db'), 'data')
                _touch(os.path.join(client_data_dir, 'actions.txt'), 'action')

                _touch(os.path.join(node_data_dir, 'recorder', name, 'msg.log'), 'received msgs')
                _touch(os.path.join(client_data_dir, 'recorder', name + 'C', 'msg.log'), 'received msgs')

                log_dir = indy_log_dir(args)
                _touch(os.path.join(log_dir, pool_name, name + '.log'), "Test Log")

            _touch(os.path.join(app_dir, 'plugins', 'plugin.txt'), "plugin data")

            config_dir = indy_config_dir(args)
            _touch(os.path.join(config_dir, 'indy.conf'), "Test config")

            return d

        def test_arg_log_level(self):
            for k, v in levels.items():
                test_args = parse_args(['-l', k])
                self.assertEqual(test_args.log_level, v)

            test_args = parse_args(['-l'])
            self.assertEqual(test_args.log_level, logging.INFO, msg='Invalid const level')
            test_args = parse_args([])
            self.assertEqual(test_args.log_level, logging.INFO, msg='Invalid default level')

        def test_arg_output_dir(self):
            with self.assertRaises(BaseException):
                parse_args(['-o'])

            test_args = parse_args(['-o=/tmp/'])
            self.assertEqual('/tmp/', test_args.output_dir)

        def test_arg_exclude_recording(self):
            test_args = parse_args(['-x'])
            self.assertEqual(True, test_args.exclude_recording)

        def test_arg_node_name(self):
            with self.assertRaises(BaseException):
                parse_args(['-n'])

            with self.assertRaises(BaseException):
                parse_args(['-n', 'Node1', 'Node2'])

            test_args = parse_args(['-n=Node1'])
            self.assertEqual('Node1', test_args.node_name)

        def test_arg_root_dir(self):
            test_args = parse_args([])
            self.assertEqual('/', test_args.root_dir)

            with self.assertRaises(BaseException):
                parse_args(['-r'])

            test_args = parse_args(['-r=/tmp/pytest234'])
            self.assertEqual('/tmp/pytest234', test_args.root_dir)

        def test_arg_dry_run(self):
            test_args = parse_args(['-d'])
            self.assertEqual(True, test_args.dry_run)

        def test_system_dirs(self):
            self.assertIsNotNone(get_system_dirs().LOG_DIR)

        def test_find_node_name(self):
            with self.create_sample_dir(node_names=['test1']) as d:
                test_args = parse_args(['-r', d])
                self.assertIsNone(test_args.node_name)
                find_node_name(test_args)
                self.assertEqual(test_args.node_name, 'test1')
                self.assertEqual(test_args.pool_name, 'sandbox')

                test_args = parse_args(['-r', d, '-n', 'test2'])
                with self.assertRaises(BaseException):
                    find_node_name(test_args)

            with self.create_sample_dir(node_names=['test1']) as d:
                os.makedirs(os.path.join(d, "var/lib/indy", "dkms", "data", 'test2'))
                os.makedirs(os.path.join(d, "var/lib/indy", "dkms", "data", 'test2C'))
                test_args = parse_args(['-r', d])
                self.assertIsNone(test_args.node_name)
                # There are multiple nodes but the args don't specify which one should be used
                # Should fail
                with self.assertRaises(SystemExit):
                    find_node_name(test_args)

                test_args = parse_args(['-r', d, '-n', 'test1'])
                self.assertEqual(test_args.node_name, 'test1')
                find_node_name(test_args)
                self.assertEqual(test_args.node_name, 'test1')
                self.assertEqual(test_args.pool_name, 'sandbox')

        def test_find_node_name_without_client_dir(self):
            with self.create_sample_dir(node_names=['test1']) as d:
                test_args = parse_args(['-r', d])
                self.assertIsNone(test_args.node_name)
                shutil.rmtree(os.path.join(indy_app_dir(test_args), 'sandbox', "data", 'test1' + "C"))
                find_node_name(test_args)
                self.assertEqual(test_args.node_name, 'test1')
                self.assertEqual(test_args.pool_name, 'sandbox')

                test_args = parse_args(['-r', d, '-n', 'test2'])
                with self.assertRaises(BaseException):
                    find_node_name(test_args)

            with self.create_sample_dir(node_names=['test1']) as d:
                os.makedirs(os.path.join(d, "var/lib/indy", "dkms", "data", 'test2'))
                test_args = parse_args(['-r', d])
                self.assertIsNone(test_args.node_name)
                # There are multiple nodes but the args don't specify which one should be used
                # Should fail
                with self.assertRaises(SystemExit):
                    find_node_name(test_args)

                test_args = parse_args(['-r', d, '-n', 'test1'])
                self.assertEqual(test_args.node_name, 'test1')
                find_node_name(test_args)
                self.assertEqual(test_args.node_name, 'test1')
                self.assertEqual(test_args.pool_name, 'sandbox')

        def test_find_node_name_not_found(self):
            test_args = parse_args(['-r', '/nonsense/help/'])
            with self.assertRaises(FileNotFoundError):
                find_node_name(test_args)

        def test_capture_file_name(self):
            test_args = parse_args(['-n', 'test1'])
            test_args.pool_name = "sandbox"
            file_name = capture_file_name(test_args)
            self.assertTrue(file_name.startswith("test1.sandbox"))

        def test_capture_file_name_bad(self):
            test_args = parse_args(['-n', 'test1'])
            # test_args.pool_name = "sandbox"
            with self.assertRaises(Exception):
                capture_file_name(test_args)

            test_args = parse_args([])
            with self.assertRaises(Exception):
                capture_file_name(test_args)
                # self.assertTrue(file_name.startswith("test1.sandbox"))

        def test_collect_capture_files(self):
            with self.create_sample_dir(node_names=['test1']) as d:
                test_args = parse_args(['-n', 'test1', '-r', d])
                test_args.pool_name = 'sandbox'
                collection = collect_capture_files(test_args)
                self.assertTrue(type(collection) is list)
                self.assertTrue(len(collection) is not 0)

        def test_collect_logs(self):
            with self.create_sample_dir(node_names=['test1']) as d:
                test_args = parse_args(['-n', 'test1', '-r', d])
                test_args.pool_name = 'sandbox'
                collection = collect_logs(test_args)

                self.assertTrue(type(collection.abs_fs_path) is str)
                self.assertTrue(type(collection.archive_path) is str)
                self.assertEqual(1, len(collection.include))
                self.assertEqual('./test1.log', collection.include[0])

        def test_collect_logs_no_dir(self):
            with self.create_sample_dir(node_names=['test1']) as d:
                shutil.rmtree(os.path.join(d, 'var/log/indy/sandbox'))
                test_args = parse_args(['-n', 'test1', '-r', d])
                test_args.pool_name = 'sandbox'
                collection = collect_logs(test_args)

                self.assertIsNone(collection)

        def test_collect_config(self):
            with self.create_sample_dir(node_names=['test1']) as d:
                test_args = parse_args(['-n', 'test1', '-r', d])
                test_args.pool_name = 'sandbox'
                collection = collect_config(test_args)

                self.assertTrue(type(collection.abs_fs_path) is str)
                self.assertTrue(type(collection.archive_path) is str)
                self.assertEqual(1, len(collection.include))
                self.assertEqual('./indy.conf', collection.include[0])

        def test_collect_data(self):
            with self.create_sample_dir(node_names=['test1']) as d:
                test_args = parse_args(['-n', 'test1', '-r', d])
                test_args.pool_name = 'sandbox'
                collection = collect_data(test_args)

                self.assertTrue(type(collection.abs_fs_path) is str)
                self.assertTrue(type(collection.archive_path) is str)
                self.assertEqual(4, len(collection.include))
                self.assertIn('data/test1/ledger.db', collection.include)

        def test_collect_data_no_client_dir(self):
            with self.create_sample_dir(node_names=['test1']) as d:
                test_args = parse_args(['-n', 'test1', '-r', d])
                test_args.pool_name = 'sandbox'
                shutil.rmtree(os.path.join(indy_app_dir(test_args), test_args.pool_name, "data", 'test1' + "C"))
                collection = collect_data(test_args)

                self.assertTrue(type(collection.abs_fs_path) is str)
                self.assertTrue(type(collection.archive_path) is str)
                self.assertEqual(2, len(collection.include))
                self.assertEqual('data/test1/ledger.db', collection.include[0])

        def test_collect_plugins(self):
            with self.create_sample_dir(node_names=['test1']) as d:
                test_args = parse_args(['-n', 'test1', '-r', d])
                test_args.pool_name = 'sandbox'
                collection = collect_plugins(test_args)

                self.assertTrue(type(collection.abs_fs_path) is str)
                self.assertTrue(type(collection.archive_path) is str)
                self.assertEqual(1, len(collection.include))
                self.assertEqual('./plugin.txt', collection.include[0])

        def test_collect_plugins_empty_dir(self):
            with self.create_sample_dir(node_names=['test1']) as d:
                test_args = parse_args(['-n', 'test1', '-r', d])
                test_args.pool_name = 'sandbox'
                os.remove(os.path.join(d, indy_app_dir(test_args), 'plugins', 'plugin.txt'))
                collection = collect_plugins(test_args)

                self.assertTrue(type(collection.abs_fs_path) is str)
                self.assertTrue(type(collection.archive_path) is str)
                self.assertEqual(0, len(collection.include))

        def test_collect_env(self):
            test_args = parse_args([])
            collection = collect_env(test_args)
            self.assertGreater(len(collection.include), 0)

        def test_build_archive_empty_plugin(self):
            with self.create_sample_dir(node_names=['test1']) as d:
                test_args = parse_args(['-n', 'test1', '-r', d])
                test_args.pool_name = 'sandbox'
                os.remove(os.path.join(d, indy_app_dir(test_args), 'plugins', 'plugin.txt'))
                collection = collect_plugins(test_args)
                cmd = build_archive_cmd("/test/test/", [collection])
                self.assertTrue(cmd[-1].endswith('plugins/'))

except ImportError:
    pass

if __name__ == '__main__':
    arguments = parse_args()

    if arguments.test:
        exit_code = test(arguments)
        sys.exit(exit_code)
    else:
        sys.exit(main(arguments))
