#!python
"""
Convenience script for gathering metrics data

"""
import argparse
import logging
import os
import shutil
from abc import ABC, abstractmethod
from collections import namedtuple
from datetime import datetime, timedelta

from indy_common.config_util import getConfig
from plenum.common.constants import KeyValueStorageType
from plenum.common.metrics_collector import MetricsName
from plenum.common.metrics_stats import load_metrics_from_kv_store
from storage.helper import initKeyValueStorage

logging.root.handlers = []
logger = logging.getLogger()
logger.propagate = False
logger.disabled = True


def read_args():
    parser = argparse.ArgumentParser(description="Process metrics")

    parser.add_argument('--data_dir', required=False, help="Path to metrics data, replaces a")
    parser.add_argument('--node_name', required=False, help="Node's name")
    parser.add_argument('--network', required=False, help="Network name to read metrics from")
    parser.add_argument('--output', required=False, default=None, help="Output CSV file name with metrics")
    parser.add_argument('--ts_fmt', required=False, default="%Y-%m-%d %H:%M:%S", help="Output timestamp format")

    parser.add_argument(
        '--min_ts',
        required=False,
        default=None,
        help="Process all metrics starting from given timestamp (beginning by default)")
    parser.add_argument(
        '--max_ts',
        required=False,
        default=None,
        help="Process all metrics till given timestamp (end by default)")
    parser.add_argument(
        '--step',
        type=int,
        required=False,
        default=60,
        help="Build timelog with given timestep, seconds (60 by default)"
    )

    return parser.parse_args()


def get_data_dir(node_name=None, network=None):
    config = getConfig()
    _network = network if network else config.NETWORK_NAME
    data_dir = os.path.join(config.LEDGER_DIR, _network, 'data')

    if not os.path.exists(data_dir):
        print("No such file or directory: {}".format(data_dir))
        print("Please check, that network: '{}' was used ".format(_network))
        exit()

    if not node_name:
        dirs = os.listdir(data_dir)
        if len(dirs) == 0:
            print("Node's 'data' folder not found: {}".format(data_dir))
            exit()
        node_name = dirs[0]

    return os.path.join(data_dir, node_name)


def make_copy_of_data(data_dir):
    read_copy_data_dir = data_dir + '-read-copy'
    if os.path.exists(read_copy_data_dir):
        shutil.rmtree(read_copy_data_dir)
    shutil.copytree(data_dir, read_copy_data_dir)
    return read_copy_data_dir


class Column(ABC):
    Context = namedtuple('Context', 'args stats')
    Data = namedtuple('Data', 'ts frame')

    def __init__(self, title):
        self.title = title

    @abstractmethod
    def value(self, context: Context, data: Data):
        pass


class Timestamp(Column):
    def __init__(self):
        super().__init__("timestamp")

    def value(self, context, data):
        return data.ts.replace(microsecond=0).strftime(context.args.ts_fmt)


class MetricsValue(Column):
    def __init__(self, name: MetricsName, param: str = 'avg'):
        super().__init__(self._gen_title(name, param))
        self._name = name
        self._param = param

    def value(self, context, data):
        acc = data.frame.get(self._name)
        result = getattr(acc, self._param)
        if not result:
            return 0
        if self._param in ['sum', 'count']:
            result /= context.stats.timestep.total_seconds()
        return result

    @staticmethod
    def _gen_title(name: MetricsName, param: str):
        title = str(name).split('.')[-1].lower()
        if param == 'sum':
            return "{}_per_sec".format(title)
        if param == 'count':
            return "{}_count_per_sec".format(title)
        return "{}_{}".format(param, title)


def process_storage(storage, args):
    fmt = "%Y-%m-%d %H:%M:%S"
    min_ts = datetime.strptime(args.min_ts, fmt) if args.min_ts else None
    max_ts = datetime.strptime(args.max_ts, fmt) if args.max_ts else None
    stats = load_metrics_from_kv_store(storage, min_ts, max_ts, timedelta(seconds=args.step))

    print("Start time: {}".format(stats.min_ts))
    print("End time: {}".format(stats.max_ts))
    print("Duration: {}".format(stats.max_ts - stats.min_ts))
    print("")

    total = stats.total

    print("Number of messages processed in one looper run:")
    print("   Node:   {}".format(total.get(MetricsName.NODE_STACK_MESSAGES_PROCESSED).to_str()))
    print("   Client: {}".format(total.get(MetricsName.CLIENT_STACK_MESSAGES_PROCESSED).to_str()))
    print("")

    print("Seconds passed between looper runs:")
    print("   {}".format(total.get(MetricsName.LOOPER_RUN_TIME_SPENT).to_str()))
    print("")

    print("Number of messages in one transport batch:")
    print("   {}".format(total.get(MetricsName.TRANSPORT_BATCH_SIZE).to_str()))
    print("")

    print("Node message size, bytes:")
    print("   Outgoing: {}".format(total.get(MetricsName.OUTGOING_NODE_MESSAGE_SIZE).to_str()))
    print("   Incoming: {}".format(total.get(MetricsName.INCOMING_NODE_MESSAGE_SIZE).to_str()))
    print("")

    print("Client message size, bytes:")
    print("   Outgoing: {}".format(total.get(MetricsName.OUTGOING_CLIENT_MESSAGE_SIZE).to_str()))
    print("   Incoming: {}".format(total.get(MetricsName.INCOMING_CLIENT_MESSAGE_SIZE).to_str()))
    print("")

    print("Number of requests in one 3PC batch:")
    print("   Created: {}".format(total.get(MetricsName.THREE_PC_BATCH_SIZE).to_str()))
    print("   Ordered: {}".format(total.get(MetricsName.ORDERED_BATCH_SIZE).to_str()))
    print("")

    print("Time spent on write request processing, ms:")
    print("   {}".format(total.get(MetricsName.REQUEST_PROCESSING_TIME).
                         to_str(show_sums=False, stats_mul=1000.0, stats_unit="ms")))
    print("")

    print("Monitor throughput, TPS")
    print("   Master: {}".format(total.get(MetricsName.MONITOR_AVG_THROUGHPUT).to_str(show_sums=False)))
    print("   Backup: {}".format(total.get(MetricsName.BACKUP_MONITOR_AVG_THROUGHPUT).to_str(show_sums=False)))
    print("")

    print("Monitor latency, seconds")
    print("   Master: {}".format(total.get(MetricsName.MONITOR_AVG_LATENCY).to_str(show_sums=False)))
    print("   Backup: {}".format(total.get(MetricsName.BACKUP_MONITOR_AVG_LATENCY).to_str(show_sums=False)))
    print("")

    print("RAM info:")
    print("   Available, Mb: {}".format(
        total.get(MetricsName.AVAILABLE_RAM_SIZE).to_str(show_sums=False, stats_mul=1.0 / 1024 / 1024)))
    print("   Node RSS, Mb: {}".format(
        total.get(MetricsName.NODE_RSS_SIZE).to_str(show_sums=False, stats_mul=1.0 / 1024 / 1024)))
    print("   Node VMS, Mb: {}".format(
        total.get(MetricsName.NODE_VMS_SIZE).to_str(show_sums=False, stats_mul=1.0 / 1024 / 1024)))
    print("")

    print("Additional statistics:")
    three_pc = total.get(MetricsName.ORDERED_BATCH_SIZE)
    node_in = total.get(MetricsName.INCOMING_NODE_MESSAGE_SIZE)
    node_out = total.get(MetricsName.OUTGOING_NODE_MESSAGE_SIZE)
    client_in = total.get(MetricsName.INCOMING_CLIENT_MESSAGE_SIZE)
    client_out = total.get(MetricsName.OUTGOING_CLIENT_MESSAGE_SIZE)
    node_traffic = node_in.sum + node_out.sum
    client_traffic = (client_in.sum + client_out.sum)
    print("   Client incoming/outgoing: {:.2f} messages, {:.2f} traffic"
          .format(client_in.count / client_out.count, client_in.sum / client_out.sum if client_out.sum > 0 else 0))
    print("   Node incoming/outgoing traffic: {:.2f}".format(node_in.sum / node_out.sum if node_out.sum > 0 else 0))
    print("   Node/client traffic: {:.2f}".format(node_traffic / client_traffic))
    if three_pc.count > 0:
        print("   Node traffic per batch: {:.2f}".format(node_traffic / three_pc.count))
        print("   Node traffic per request: {:.2f}".format(node_traffic / three_pc.sum))
    print("")

    print("Profiling info:")
    for m in MetricsName:
        if m > 30000:
            break
        if m < MetricsName.NODE_PROD_TIME and \
                m not in {MetricsName.REQUEST_PROCESSING_TIME,
                          MetricsName.BACKUP_REQUEST_PROCESSING_TIME,
                          MetricsName.GC_GEN0_TIME,
                          MetricsName.GC_GEN1_TIME,
                          MetricsName.GC_GEN2_TIME}:
            continue
        acc = total.get(m)
        print("   {} : {}".format(str(m).split('.')[-1],
                                  acc.to_str(sums_unit='seconds', stats_unit='ms', stats_mul=1000)))

    if args.output is not None:
        columns = [Timestamp()]
        for m in MetricsName:
            if m < 30000:
                columns.append(MetricsValue(m, 'sum'))
                columns.append(MetricsValue(m, 'count'))
                columns.append(MetricsValue(m, 'min'))
                columns.append(MetricsValue(m, 'lo'))
                columns.append(MetricsValue(m, 'avg'))
                columns.append(MetricsValue(m, 'hi'))
            columns.append(MetricsValue(m, 'max'))

        with open(args.output, 'w') as f:
            f.write(",".join(c.title for c in columns))
            f.write("\n")
            context = Column.Context(args=args, stats=stats)
            for ts, frame in sorted(stats.frames(), key=lambda v: v[0]):
                data = Column.Data(ts=ts, frame=frame)
                f.write(",".join(str(c.value(context, data)) for c in columns))
                f.write("\n")


def detect_storage_type(path):
    if os.path.isfile('{}.bin'.format(path)):
        return KeyValueStorageType.BinaryFile
    if any(f.endswith('.ldb') for f in os.listdir(path)):
        return KeyValueStorageType.Leveldb
    return KeyValueStorageType.Rocksdb


if __name__ == '__main__':
    args = read_args()

    if args.data_dir is not None:
        storage_type = detect_storage_type(args.data_dir)
        location, name = os.path.split(args.data_dir)
        storage = initKeyValueStorage(storage_type, location, name)
        process_storage(storage, args)
        exit()

    data_dir = get_data_dir(args.node_name, args.network)
    data_is_copied = False
    try:
        config = getConfig()
        if config.METRICS_KV_STORAGE != KeyValueStorageType.Rocksdb:
            data_dir = make_copy_of_data(data_dir)
            data_is_copied = True

        storage = initKeyValueStorage(
            config.METRICS_KV_STORAGE,
            data_dir,
            config.METRICS_KV_DB_NAME)

        process_storage(storage, args)
    finally:
        if data_is_copied:
            shutil.rmtree(data_dir)
