import pytest

import ray
from ray.exceptions import RayError
from ray._private.test_utils import wait_for_condition
from ray import serve
from ray.serve.constants import REPLICA_HEALTH_CHECK_UNHEALTHY_THRESHOLD


class Counter:
    def __init__(self):
        self._count = 0

    def get(self):
        return self._count

    def inc(self):
        self._count += 1
        return self._count

    def reset(self):
        self._count = 0


@serve.deployment(_health_check_period_s=1, _health_check_timeout_s=1)
class Patient:
    def __init__(self):
        self.healthy = True
        self.should_hang = False

    def check_health(self):
        if self.should_hang:
            import time

            time.sleep(10000)
        elif not self.healthy:
            raise Exception("intended to fail")

    def __call__(self, *args):
        return ray.get_runtime_context().current_actor

    def set_should_fail(self):
        self.healthy = False
        return ray.get_runtime_context().current_actor

    def set_should_hang(self):
        self.should_hang = True
        return ray.get_runtime_context().current_actor


def check_new_actor_started(handle, original_actors):
    if not isinstance(original_actors, set):
        original_actors = {original_actors._actor_id}
    try:
        return ray.get(handle.remote())._actor_id not in original_actors
    except RayError:
        return False


@pytest.mark.parametrize("use_class", [True, False])
def test_no_user_defined_method(serve_instance, use_class):
    """Check the default behavior when an actor crashes."""

    if use_class:

        @serve.deployment
        class A:
            def __call__(self, *args):
                return ray.get_runtime_context().current_actor

    else:

        @serve.deployment
        def A(*args):
            return ray.get_runtime_context().current_actor

    A.deploy()
    h = A.get_handle()
    actor = ray.get(h.remote())
    ray.kill(actor)

    # This would time out if we wait for multiple health check failures.
    wait_for_condition(check_new_actor_started, handle=h, original_actors=actor)


def test_user_defined_method_fails(serve_instance):
    Patient.deploy()
    h = Patient.get_handle()
    actor = ray.get(h.remote())
    ray.get(h.set_should_fail.remote())

    wait_for_condition(check_new_actor_started, handle=h, original_actors=actor)
    ray.get([h.remote() for _ in range(100)])


def test_user_defined_method_hangs(serve_instance):
    Patient.deploy()
    h = Patient.get_handle()
    actor = ray.get(h.remote())
    ray.get(h.set_should_hang.remote())

    wait_for_condition(check_new_actor_started, handle=h, original_actors=actor)
    ray.get([h.remote() for _ in range(100)])


def test_multiple_replicas(serve_instance):
    Patient.options(num_replicas=2).deploy()
    h = Patient.get_handle()
    actors = {a._actor_id for a in ray.get([h.remote() for _ in range(100)])}
    assert len(actors) == 2

    ray.get(h.set_should_fail.remote())

    wait_for_condition(check_new_actor_started, handle=h, original_actors=actors)

    new_actors = {a._actor_id for a in ray.get([h.remote() for _ in range(100)])}
    assert len(new_actors) == 2
    assert len(new_actors.intersection(actors)) == 1


def test_inherit_healthcheck(serve_instance):
    class Parent:
        def __init__(self):
            self.should_fail = False

        def check_health(self):
            if self.should_fail:
                raise Exception("intended to fail")

        def set_should_fail(self):
            self.should_fail = True

    @serve.deployment(_health_check_period_s=1)
    class Child(Parent):
        def __call__(self, *args):
            return ray.get_runtime_context().current_actor

    Child.deploy()
    h = Child.get_handle()
    actors = {ray.get(h.remote())._actor_id for _ in range(100)}
    assert len(actors) == 1

    ray.get(h.set_should_fail.remote())
    wait_for_condition(check_new_actor_started, handle=h, original_actors=actors)


def test_nonconsecutive_failures(serve_instance):
    counter = ray.remote(Counter).remote()

    # Test that a health check failing every other call isn't marked unhealthy.
    @serve.deployment(_health_check_period_s=0.1)
    class FlakyHealthCheck:
        def check_health(self):
            curr_count = ray.get(counter.inc.remote())
            if curr_count % 2 == 0:
                raise Exception("Ah! I had evens!")

        def __call__(self, *args):
            return ray.get_runtime_context().current_actor

    FlakyHealthCheck.deploy()
    h = FlakyHealthCheck.get_handle()
    a1 = ray.get(h.remote())

    # Wait for 10 health check periods, should never get marked unhealthy.
    wait_for_condition(lambda: ray.get(counter.get.remote()) > 10)
    assert ray.get(h.remote())._actor_id == a1._actor_id


def test_consecutive_failures(serve_instance):
    # Test that the health check must fail N times before being marked unhealthy.

    counter = ray.remote(Counter).remote()

    @serve.deployment(_health_check_period_s=1)
    class ChronicallyUnhealthy:
        def __init__(self):
            self._actor_id = ray.get_runtime_context().current_actor._actor_id
            self._should_fail = False

        def check_health(self):
            if self._should_fail:
                ray.get(counter.inc.remote())
                raise Exception("intended to fail")

        def set_should_fail(self):
            self._should_fail = True
            return self._actor_id

        def __call__(self, *args):
            return self._actor_id

    ChronicallyUnhealthy.deploy()
    h = ChronicallyUnhealthy.get_handle()

    def check_fails_3_times():
        original_actor_id = ray.get(h.set_should_fail.remote())

        # Wait until a new actor is started.
        wait_for_condition(lambda: ray.get(h.remote()) != original_actor_id)

        # Check that the health check failed N times before replica was killed.
        assert ray.get(counter.get.remote()) == REPLICA_HEALTH_CHECK_UNHEALTHY_THRESHOLD

    # Run the check twice to see that the counter gets reset after a
    # replica is killed.
    check_fails_3_times()
    ray.get(counter.reset.remote())
    check_fails_3_times()


if __name__ == "__main__":
    import sys

    sys.exit(pytest.main(["-v", "-s", __file__]))
