from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os
import redis
import subprocess
import sys
import tempfile
import time

import ray

EVENT_KEY = "RAY_MULTI_NODE_TEST_KEY"
"""This key is used internally within this file for coordinating drivers."""


def _wait_for_nodes_to_join(num_nodes, timeout=20):
    """Wait until the nodes have joined the cluster.

    This will wait until exactly num_nodes have joined the cluster and each
    node has a local scheduler and a plasma manager.

    Args:
        num_nodes: The number of nodes to wait for.
        timeout: The amount of time in seconds to wait before failing.

    Raises:
        Exception: An exception is raised if too many nodes join the cluster or
            if the timeout expires while we are waiting.
    """
    start_time = time.time()
    while time.time() - start_time < timeout:
        client_table = ray.global_state.client_table()
        num_ready_nodes = len(client_table)
        if num_ready_nodes == num_nodes:
            # Check that for each node, a local scheduler and a plasma manager
            # are present.
            # In raylet mode, this is a list of map.
            # The GCS info will appear as a whole instead of part by part.
            return
        if num_ready_nodes > num_nodes:
            # Too many nodes have joined. Something must be wrong.
            raise Exception("{} nodes have joined the cluster, but we were "
                            "expecting {} nodes.".format(
                                num_ready_nodes, num_nodes))
        time.sleep(0.1)

    # If we get here then we timed out.
    raise Exception("Timed out while waiting for {} nodes to join. Only {} "
                    "nodes have joined so far.".format(num_ready_nodes,
                                                       num_nodes))


def _broadcast_event(event_name, redis_address, data=None):
    """Broadcast an event.

    This is used to synchronize drivers for the multi-node tests.

    Args:
        event_name: The name of the event to wait for.
        redis_address: The address of the Redis server to use for
            synchronization.
        data: Extra data to include in the broadcast (this will be returned by
            the corresponding _wait_for_event call). This data must be json
            serializable.
    """
    redis_host, redis_port = redis_address.split(":")
    redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
    payload = json.dumps((event_name, data))
    redis_client.rpush(EVENT_KEY, payload)


def _wait_for_event(event_name, redis_address, extra_buffer=0):
    """Block until an event has been broadcast.

    This is used to synchronize drivers for the multi-node tests.

    Args:
        event_name: The name of the event to wait for.
        redis_address: The address of the Redis server to use for
            synchronization.
        extra_buffer: An amount of time in seconds to wait after the event.

    Returns:
        The data that was passed into the corresponding _broadcast_event call.
    """
    redis_host, redis_port = redis_address.split(":")
    redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
    while True:
        event_infos = redis_client.lrange(EVENT_KEY, 0, -1)
        events = {}
        for event_info in event_infos:
            name, data = json.loads(event_info)
            if name in events:
                raise Exception("The same event {} was broadcast twice."
                                .format(name))
            events[name] = data
        if event_name in events:
            # Potentially sleep a little longer and then return the event data.
            time.sleep(extra_buffer)
            return events[event_name]
        time.sleep(0.1)


def _pid_alive(pid):
    """Check if the process with this PID is alive or not.

    Args:
        pid: The pid to check.

    Returns:
        This returns false if the process is dead. Otherwise, it returns true.
    """
    try:
        os.kill(pid, 0)
        return True
    except OSError:
        return False


def wait_for_pid_to_exit(pid, timeout=20):
    start_time = time.time()
    while time.time() - start_time < timeout:
        if not _pid_alive(pid):
            return
        time.sleep(0.1)
    raise Exception("Timed out while waiting for process to exit.")


def run_and_get_output(command):
    with tempfile.NamedTemporaryFile() as tmp:
        p = subprocess.Popen(command, stdout=tmp, stderr=tmp)
        if p.wait() != 0:
            raise RuntimeError("ray start did not terminate properly")
        with open(tmp.name, 'r') as f:
            result = f.readlines()
            return "\n".join(result)


def run_string_as_driver(driver_script):
    """Run a driver as a separate process.

    Args:
        driver_script: A string to run as a Python script.

    Returns:
        The script's output.
    """
    # Save the driver script as a file so we can call it using subprocess.
    with tempfile.NamedTemporaryFile() as f:
        f.write(driver_script.encode("ascii"))
        f.flush()
        out = ray.utils.decode(
            subprocess.check_output([sys.executable, f.name]))
    return out


def run_string_as_driver_nonblocking(driver_script):
    """Start a driver as a separate process and return immediately.

    Args:
        driver_script: A string to run as a Python script.

    Returns:
        A handle to the driver process.
    """
    # Save the driver script as a file so we can call it using subprocess. We
    # do not delete this file because if we do then it may get removed before
    # the Python process tries to run it.
    with tempfile.NamedTemporaryFile(delete=False) as f:
        f.write(driver_script.encode("ascii"))
        f.flush()
        return subprocess.Popen(
            [sys.executable, f.name], stdout=subprocess.PIPE)
