#!python

"""
Process a single run with straxen using pema
Copy of straxen/bin/straxer with some nice new features
"""

import argparse
import datetime
import logging
import time
import os
import os.path as osp
import platform
import psutil
import sys
import json
import pema
import gc


def parse_args():
    parser = argparse.ArgumentParser(
        description='Process a single run with pema',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        'run_id',
        metavar='RUN_ID',
        nargs='*',
        help="ID of the run to process; usually the run name.")
    parser.add_argument(
        '--context',
        default='xenon1t_dali',
        help="Name of pema context to use")
    parser.add_argument(
        '--target',
        nargs='*',
        default='event_info',
        help='Target final data type to produce')
    parser.add_argument(
        '--init_from_json', default='',
        help='Use a json file to use for the start of the context')
    parser.add_argument(
        '--config_from_json', default='',
        help='Use a json-file to load the config from')
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help='Start processing at raw_records, regardless of what data is available. '
             'Saving will ONLY occur to ./strax_data! If you already have the target'
             'data in ./strax_data, you need to delete it there first.')
    parser.add_argument(
        '--max_messages',
        default=4, type=int,
        help=("Size of strax's internal mailbox buffers. "
              "Lower to reduce memory usage, at increasing risk of deadlocks."))
    parser.add_argument(
        '--timeout',
        default=None, type=int,
        help="Strax' internal mailbox timeout in seconds")
    parser.add_argument(
        '--workers',
        default=1, type=int,
        help=("Number of worker threads/processes. "
              "Strax will multithread (1/plugin) even if you set this to 1."))
    parser.add_argument(
        '--notlazy',
        action='store_true',
        help='Forbid lazy single-core processing. Not recommended.')
    parser.add_argument(
        '--multiprocess',
        action='store_true',
        help="Allow multiprocessing.")
    parser.add_argument(
        '--shm',
        action='store_true',
        help="Allow passing data via /dev/shm when multiprocessing.")
    parser.add_argument(
        '--profile_to',
        default='',
        help="Filename to output profile information to. If ommitted,"
             "no profiling will occur.")
    parser.add_argument(
        '--profile_ram',
        action='store_true',
        help="Use memory_profiler for a more accurate measurement of the "
             "peak RAM usage of the process.")
    parser.add_argument(
        '--diagnose_sorting',
        action='store_true',
        help="Diagnose sorting problems during processing")
    parser.add_argument(
        '--debug',
        action='store_true',
        help="Enable debug logging to stdout")
    parser.add_argument(
        '--build_lowlevel',
        action='store_true',
        help='Build low-level data even if the context forbids it.')
    parser.add_argument(
        '--rechunk_rr',
        action='store_true',
        help='Rechunk the raw-records (especially useful when simulating data)')
    return parser.parse_args()


def main(args):
    logging.basicConfig(
        level=logging.DEBUG if args.debug else logging.INFO,
        format='%(asctime)s - %(threadName)s - %(name)s - %(levelname)s - %(message)s')

    logging.info(f"Starting processing of run {args.run_id} until {args.target}")
    logging.info(f"\tpython {platform.python_version()} at {sys.executable}")

    # These imports take a bit longer, so it's nicer
    # to do them after argparsing (so --help is fast)
    import strax
    import straxen
    straxen.print_versions(['strax', 'straxen', 'pema', 'wfsim'])

    if args.init_from_json != '':
        context_init = json_to_dict(args.init_from_json)
        logging.info(f'Overwriting context with {context_init}')
    else:
        context_init = {}

    st = getattr(pema.contexts, args.context)(**context_init)

    # ignore strax warnings
    st.set_context_config({'free_options': tuple(st.config.keys())})

    if args.diagnose_sorting:
        st.set_config(dict(diagnose_sorting=True))

    if args.config_from_json != '':
        conf = json_to_dict(args.config_from_json)
        logging.info(f'Overwriting config with {conf}')
        st.set_config(conf)

    st.context_config['allow_multiprocess'] = args.multiprocess
    st.context_config['allow_shm'] = args.shm
    st.context_config['allow_lazy'] = not (args.notlazy is True)

    if args.timeout is not None:
        st.context_config['timeout'] = args.timeout

    st.context_config['max_messages'] = args.max_messages

    if args.build_lowlevel:
        st.context_config['forbid_creation_of'] = tuple()
        if args.rechunk_rr:
            st._plugin_class_registry['raw_records'].rechunk_on_save = True
            st._plugin_class_registry['records'].rechunk_on_save = True
            st._plugin_class_registry['peaklets'].rechunk_on_save = True
    else:
        st.context_config['forbid_creation_of'] = straxen.DAQReader.provides

    process = psutil.Process(os.getpid())
    peak_ram = 0

    def get_results(run_id, target):
        kwargs = dict(
            run_id=run_id,
            targets=target,
            max_workers=int(args.workers))

        if args.profile_to:
            with strax.profile_threaded(args.profile_to+run_id+target):
                yield from st.get_iter(**kwargs)
        else:
            yield from st.get_iter(**kwargs)

    clock_start = time.time()
    run_ids = strax.to_str_tuple(args.run_id)
    targets = strax.to_str_tuple(args.target)
    tot_i = max(len(run_ids), 1) * max(len(targets), 1)
    mem_mb = None
    for run_i, run_id in enumerate(run_ids):
        for tar_i, target in enumerate(targets):
            if st.is_stored(run_id, target):
                print(f'{run_id}-{target} is stored')
                continue
            for i, d in enumerate(get_results(run_id, target)):
                mem_mb = process.memory_info().rss / 1e6
                peak_ram = max(mem_mb, peak_ram)

                if not len(d):
                    logging.info(f"Got chunk {i}, but it is empty! Using {mem_mb:.1f} MB RAM.")
                    continue
            n_done = run_i * tar_i + (tar_i + 1)
            eta = (time.time() - clock_start) / max(1 - (n_done / tot_i), 1e-9)
            if n_done == tot_i - 1:
                eta = 0
            print(f'Using {mem_mb:.1f} MB RAM. R{run_i}-T{tar_i}')

    logging.info(f"\npema_straxer is done! "
                 f"We took {time.time() - clock_start:.1f} seconds, "
                 f"peak RAM usage was around {peak_ram:.1f} MB.")
    logging.warning('processing ended')
    print('Processing ended, bye bye')


def list_to_tuple(items):
    if not isinstance(items, list):
        return items
    res = []
    for i in items:
        if isinstance(i, list):
            i = list_to_tuple(i)
        if isinstance(i, dict):
            i = dict_to_tuple(i)
        res.append(i)
    return tuple(res)


def dict_to_tuple(res):
    if not isinstance(res, dict):
        return res
    for k, v in res.copy().items():
        if isinstance(v, dict):
            v = dict_to_tuple(v)
        elif isinstance(v, list):
            v = list_to_tuple(v)
        res[k] = v
    return res


def json_to_dict(path:str):
    if not os.path.exists(path):
        raise FileNotFoundError(f'{path} does not exist')
    with open(path, mode='r') as f:
        res = json.loads(f.read())
    for k, v in res.copy().items():
        print(k, v)
        if isinstance(v, dict):
            res[k] = dict_to_tuple(v)
        elif isinstance(v, list):
            res[k] = list_to_tuple(v)

    return res


if __name__ == '__main__':
    args = parse_args()
    if args.profile_ram:
        from memory_profiler import memory_usage

        mem = memory_usage(proc=(main, [args], dict()))
        print(f"Memory profiler says peak RAM usage was: {max(mem):.1f} MB")
    else:
        sys.exit(main(args))
