#!/usr/bin/env python3

import sys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
import argparse


class profiler_data:
    """ reads data in csv format and stores in numpy arrays """

    def __init__(self):
        self.rank = None
        self.thread_id = None
        self.event_name = None
        self.start_t = None
        self.end_t = None
        self.delta_t = None
        self.depth = None

        self.event_name_u = None
        self.thread_id_u = None

        self.n_threads = None
        self.n_ranks = None
        self.max_runtime = None

        self.mem_rank = None
        self.mem_time = None
        self.mem_use = None

        self.verbose = True

    def subset(self, r):
        """ return a new instance with data for rank r """

        subset_data = profiler_data()

        # subset the event profiler data
        ii = np.where(self.rank == r)[0]

        subset_data.rank = r

        subset_data.thread_id = self.thread_id[ii]
        subset_data.thread_id_u = set(subset_data.thread_id)
        subset_data.n_threads = len(subset_data.thread_id_u)

        subset_data.event_name = self.event_name[ii]
        subset_data.event_name_u = set(subset_data.event_name)
        subset_data.n_events = len(subset_data.event_name_u)

        subset_data.start_t = self.start_t[ii]
        subset_data.end_t = self.end_t[ii]
        subset_data.delta_t = self.delta_t[ii]
        subset_data.depth = self.depth[ii]

        subset_data.max_runtime = self.get_max_runtime()
        subset_data.n_ranks = self.get_num_ranks()

        if self.verbose:
            sys.stderr.write('On rank %d : %d total events, ' \
                '%d unique events, %d threads, thread_ids=%s\n'%( \
                r, len(subset_data.event_name), subset_data.n_events,
                subset_data.n_threads, str(subset_data.thread_id_u)))

        # subset the memory profiler data
        if self.mem_rank is None:
            return subset_data

        ii = np.where(self.mem_rank == r)[0]

        subset_data.mem_time = self.mem_time[ii]
        subset_data.mem_use = self.mem_use[ii]

        return subset_data

    def initialize(self, prof_file_name, mem_file_name=None):
        f = open(prof_file_name,'r')
        lines = f.readlines()
        f.close()

        rank = []
        thread_id = []
        event_name = []
        start_t = []
        end_t = []
        delta_t = []
        depth = []
        for line in lines:
            if line[0] == '#':
                continue
            fields = line.split(',')
            rank.append(int(fields[0]))
            thread_id.append(int(fields[1], 16))
            event_name.append(fields[2])
            start_t.append(float(fields[3]))
            end_t.append(float(fields[4]))
            delta_t.append(float(fields[5]))
            depth.append(int(fields[6]))

        self.rank = np.array(rank)
        self.thread_id = np.array(thread_id)
        self.event_name = np.array(event_name)
        self.start_t = np.array(start_t)
        self.end_t = np.array(end_t)
        self.delta_t = np.array(delta_t)
        self.depth = np.array(depth)

        if mem_file_name is None:
            return

        f = open(mem_file_name,'r')
        lines = f.readlines()
        f.close()

        mem_rank = []
        mem_time = []
        mem_use = []
        for line in lines:
            if line[0] == '#':
                continue
            fields = line.split(',')
            mem_rank.append(int(fields[0]))
            mem_time.append(float(fields[1]))
            mem_use.append(float(fields[2]))

        self.mem_rank = np.array(mem_rank)
        self.mem_time = np.array(mem_time)
        self.mem_use = np.array(mem_use)


    def get_num_threads(self, r=None):
        if r is None:
            tids = set(self.thread_id)
            return len(tids)
        else:
            if self.n_threads is None:
                ii = np.where(self.rank == r)
                self.thread_ids_u = set(self.thread_id)
                self.n_threads = len(self.thread_ids_u)
            return self.n_threads

    def get_num_ranks(self):
        if self.n_ranks is None:
            self.n_ranks = np.max(self.rank) + 1
        return self.n_ranks

    def get_max_runtime(self):
        if self.max_runtime is None:
            self.max_runtime = np.max(self.delta_t)
        return self.max_runtime






class display_rank_events:
    def __init__(self):
        self.data = None
        self.axes = None
        self.collection = None
        self.patches = None
        self.verbose = False

    def initialize(self, data, ax, title):
        """ renders the data to axes ax and sets up the click handlers
            to print event data when the plot is clicked on """

        self.data = data

        # sort thread ids by number of events
        # they are plotted from lowest to highest number of events
        tmp = []
        for tid in data.thread_id_u:
            nev = len(np.where(data.thread_id == tid)[0])
            tmp.append((tid,nev))
        tmp = sorted(tmp, key=lambda x : x[1])
        thread_ids_us = []
        for tup in tmp:
            thread_ids_us.append(tup[0])

        # set plot spacing and rectangle height
        dy = 1./(data.n_threads - 1)
        yc = []
        yl = []
        thread_map = {}
        i  = 0
        for tid in thread_ids_us:
            thread_map[tid] = i
            yc.append(i*dy + dy/2.)
            yl.append(i)
            i += 1

        # get unique events
        tmp = list(data.event_name_u)
        tmp.sort()
        i = 0
        event_map = {}
        n_events = len(tmp)
        for enm in tmp:
            if enm not in event_map.keys():
                event_map[enm] = i
                i += 1

        # shift time scale to start at 0 seconds
        t0 = np.min(data.start_t)
        data.start_t -= t0
        data.end_t -= t0

        # make color maps
        cmap_names = ['Purples', 'Blues', 'Greens', \
             'Oranges', 'Reds', 'Greys']

        cmap_n = int(n_events // len(cmap_names) + 1)

        cmap = []
        for cmap_name in cmap_names:
            cmap.append(plt.get_cmap(cmap_name, 1.5*cmap_n))

        # in a patch collection all rectangles have
        # the same zorder, hence we need a collection
        # per zorder
        self.patches = {}
        self.patch_ids = {}
        max_depth = np.max(data.depth)
        i = 0
        while i <= max_depth:
            self.patches[i] = []
            self.patch_ids[i] = []
            i += 1

        # create a rectangle for each event. rectangles are
        # drawn at height coresponding to the thread id
        # start and width are given by the event time
        if not self.verbose:
            sys.stderr.write('%d : '%(data.rank))
        i = 0
        ne = len(data.event_name)
        while i < ne:
            tid = thread_map[data.thread_id[i]]
            eid = event_map[data.event_name[i]]

            enm = data.event_name[i]
            et0 = data.start_t[i]
            edt =  data.delta_t[i]

            if self.verbose:
                sys.stderr.write('%srank=%d event_index=%d thread=%d ' \
                    'event_id=%d event_name=%s rect = %s\n' % (
                    title, data.rank, i, tid, eid, enm, str([et0,
                    tid*dy, edt, dy])))
            else:
                if (ne < 78) or int(i % (ne // 78)) == 0:
                    sys.stderr.write('=')
                    sys.stderr.flush()

            cid = int(eid % cmap_n)
            cm = cmap[eid // cmap_n]
            fc = cm(cid)

            if ' thread_pool ' in enm:
                fa = 0.5
                ec = 'b'
                ew = 3
                fill = True
                pz = 2
            else:
                fa = 0.9
                ec = 'k'
                ew = 1
                fill = True
                pz = 3

            dep = data.depth[i]
            df = dep*(0.5 - 0.125)/float(max_depth)

            rect = Rectangle((et0, tid*dy+df), edt, dy-2.*df,
                             alpha=1, facecolor=fc, edgecolor=ec,
                             linewidth=ew, zorder=dep+1, fill=fill,
                             label=enm)

            #ax.add_patch(rect)
            self.patches[dep].append(rect)
            self.patch_ids[dep].append(i)

            i += 1

        if not self.verbose:
            sys.stderr.write('> OK!\n')

        self.axes = ax

        self.collection = {}
        for d,p in self.patches.items():
            pc = PatchCollection(p, match_original=True, zorder=d+1)
            self.axes.add_collection(pc)
            self.collection[d] = pc

        # add memory profiler data
        if self.data.mem_use is not None:

            # put data into the same coordinate system as events
            self.data.mem_time -= t0

            mmx = np.max(self.data.mem_use)
            self.data.mem_use = self.data.mem_use/mmx * n_threads*dy

            plt.plot(self.data.mem_time, self.data.mem_use, 'r--', linewidth=3)

            plt.title('%sMPI rank %d event times, max RSS %0.1f MiB'%(
                      title, data.rank, mmx/(2.**10)), fontweight='bold')
        else:
            plt.title('%sMPI rank %d event times'%(title, data.rank),
                      fontweight='bold')

        pfx = 0.01
        pfy = 0.10
        t1 = data.max_runtime
        plt.xlim(0-pfx*t1, t1*(1+pfx))
        plt.ylim(0-pfy*dy, (n_threads+pfy)*dy)
        plt.yticks(yc, yl)
        plt.ylabel('thread', fontweight='bold')
        plt.xlabel('seconds', fontweight='bold')
        self.axes.set_axisbelow(True)
        plt.grid(True)



    def connect(self):
        for d,c in self.collection.items():
            self.cidpress = c.figure.canvas.mpl_connect(
                'button_press_event', self.on_press)

    def on_press(self, event):
        if event.inaxes != self.axes:
            return

        n_hits = 0

        for d,c in self.collection.items():
            hit, ids = c.contains(event)

            if hit:
                n_hits += 1

                for i in ids['ind']:
                    ii = self.patch_ids[d][i]
                    #rect = self.patches[i]
                    sys.stderr.write(
                        'rank=%d i=%d ii=%d event_name=%s thread=%d start=%g' \
                        ' end=%g duration=%g depth=%d\n'%(
                        self.data.rank, i, ii, self.data.event_name[ii],
                        self.data.thread_id[ii], self.data.start_t[ii],
                        self.data.end_t[ii], self.data.delta_t[ii],
                        self.data.depth[ii]))

        if n_hits > 0:
            sys.stderr.write('\n')

    def disconnect(self):
        self.collection.figure.canvas.mpl_disconnect(self.cidpress)



parser = argparse.ArgumentParser(prog='teca_profile_explorer')

parser.add_argument('-e', '--event_file', required=True, type=str,
                    help='path to a TECA profiler event data file')

parser.add_argument('-m', '--mem_file', required=False, type=str, default=None,
                    help='path to a TECA profiler memory data file')

parser.add_argument('-r', '--ranks', nargs='+', required=True, type=int,
                    help='path to a TECA profiler data file')

parser.add_argument('-o', '--out_prefix', required=False, type=str, default='',
                    help='a string that is preprended to output file names')

parser.add_argument('-t', '--title', required=False, type=str, default='',
                    help='a string that is preprended to plot titles')

parser.add_argument('-x', '--xlim', required=False, type=float, default=-1.,
                    help='set the high x axis limmit used in plots')

parser.add_argument('-v', '--verbose', required=False, type=int,
                    default=0, help='set the verbosity level')

args = parser.parse_args()


data = profiler_data()
data.verbose = args.verbose > 0
data.initialize(args.event_file, args.mem_file)

dres = []
for rank in args.ranks:

    # get this rank's data
    subset = data.subset(rank)

    # size the plot
    n_threads = subset.n_threads
    fig_width = 12
    fig_height = 0.625*(n_threads + 2)
    fig = plt.figure(figsize=(fig_width, fig_height))

    # draw the plot & connect the event hander
    ax = plt.gca()
    dre = display_rank_events()
    dre.verbose = args.verbose > 1
    dre.initialize(subset, ax, args.title)
    dre.connect()
    dres.append(dre)

    plt.subplots_adjust(bottom=0.2)

    if args.xlim > 0.:
        x0,x1 = plt.xlim()
        plt.xlim(x0, args.xlim)

    plt.savefig('%srank_profile_data_%d.png'%(
                args.out_prefix,rank),
                dpi=200)

plt.show()

