from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import logging
import os
import time
import tensorflow as tf

import ray

logger = logging.getLogger(__name__)


def fetch(oids):
    if ray.global_state.use_raylet:
        local_sched_client = ray.worker.global_worker.local_scheduler_client
        for o in oids:
            ray_obj_id = ray.ObjectID(o)
            local_sched_client.reconstruct_objects([ray_obj_id], True)
    else:
        for o in oids:
            plasma_id = ray.pyarrow.plasma.ObjectID(o)
            ray.worker.global_worker.plasma_client.fetch([plasma_id])


def run_timeline(sess, ops, feed_dict=None, write_timeline=False, name=""):
    feed_dict = feed_dict or {}
    if write_timeline:
        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()
        fetches = sess.run(
            ops,
            options=run_options,
            run_metadata=run_metadata,
            feed_dict=feed_dict)
        trace = Timeline(step_stats=run_metadata.step_stats)
        outf = "timeline-{}-{}.json".format(name, os.getpid())
        trace_file = open(outf, "w")
        logger.info("wrote tf timeline to", os.path.abspath(outf))
        trace_file.write(trace.generate_chrome_trace_format())
    else:
        fetches = sess.run(ops, feed_dict=feed_dict)
    return fetches


class Timeline(object):
    def __init__(self, tid):
        self.events = []
        self.offset = 0
        self.start_time = self.time()
        self.tid = tid

    def patch_ray(self):
        orig_log = ray.worker.log

        def custom_log(event_type, kind, *args, **kwargs):
            orig_log(event_type, kind, *args, **kwargs)
            if kind == ray.worker.LOG_SPAN_START:
                self.start(event_type)
            elif kind == ray.worker.LOG_SPAN_END:
                self.end(event_type)
            elif kind == ray.worker.LOG_SPAN_POINT:
                self.event(event_type)

        ray.worker.log = custom_log

    def time(self):
        return time.time() + self.offset

    def reset(self):
        self.events = []
        self.start_time = self.time()

    def start(self, name):
        self.events.append((self.tid, "B", name, self.time()))

    def end(self, name):
        self.events.append((self.tid, "E", name, self.time()))

    def event(self, name):
        now = self.time()
        self.events.append((self.tid, "B", name, now))
        self.events.append((self.tid, "E", name, now + .0001))

    def merge(self, other):
        if other.start_time < self.start_time:
            self.start_time = other.start_time
        self.events.extend(other.events)
        self.events.sort(key=lambda e: e[3])

    def chrome_trace_format(self, filename):
        out = []
        for tid, ph, name, t in self.events:
            ts = int((t - self.start_time) * 1000000)
            out.append({
                "name": name,
                "tid": tid,
                "pid": tid,
                "ph": ph,
                "ts": ts,
            })
        with open(filename, "w") as f:
            f.write(json.dumps(out))
        logger.info("Wrote chrome timeline to", filename)


if __name__ == "__main__":
    a = Timeline(1)
    b = Timeline(2)
    a.start("hi")
    time.sleep(.1)
    b.start("bye")
    a.start("hi3")
    time.sleep(.1)
    a.end("hi3")
    b.end("bye")
    time.sleep(.1)
    a.end("hi")
    b.start("b1")
    b.end("b1")
    a.merge(b)
    a.chrome_trace_format("test.json")
