import copy
from unittest.mock import Mock, patch

import pytest
from click.exceptions import ClickException

import ray.tests.aws.utils.helpers as helpers
import ray.tests.aws.utils.stubs as stubs
from ray.autoscaler._private.aws.config import (
    DEFAULT_AMI,
    _configure_subnet,
    _get_subnets_or_die,
    bootstrap_aws,
    log_to_cli,
)
from ray.autoscaler._private.aws.node_provider import AWSNodeProvider
from ray.autoscaler._private.providers import _get_node_provider
from ray.tests.aws.utils.constants import (
    AUX_SG,
    AUX_SUBNET,
    CUSTOM_IN_BOUND_RULES,
    DEFAULT_CLUSTER_NAME,
    DEFAULT_INSTANCE_PROFILE,
    DEFAULT_KEY_PAIR,
    DEFAULT_LT,
    DEFAULT_SG,
    DEFAULT_SG_AUX_SUBNET,
    DEFAULT_SG_DUAL_GROUP_RULES,
    DEFAULT_SG_WITH_NAME,
    DEFAULT_SG_WITH_NAME_AND_RULES,
    DEFAULT_SG_WITH_RULES,
    DEFAULT_SG_WITH_RULES_AUX_SUBNET,
    DEFAULT_SUBNET,
)


def test_use_subnets_in_only_one_vpc(iam_client_stub, ec2_client_stub):
    """
    This test validates that when bootstrap_aws populates the SubnetIds field,
    all of the subnets used belong to the same VPC, and that a SecurityGroup
    in that VPC is correctly configured.

    Also validates that head IAM role is correctly filled.
    """
    stubs.configure_iam_role_default(iam_client_stub)
    stubs.configure_key_pair_default(ec2_client_stub)

    # Add a response with a thousand subnets all in different VPCs.
    # After filtering, only subnet in one particular VPC should remain.
    # Thus SubnetIds for each available node type should end up as
    # being length-one lists after the bootstrap_config.
    stubs.describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub)

    # describe the subnet in use while determining its vpc
    stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
    # given no existing security groups within the VPC...
    stubs.describe_no_security_groups(ec2_client_stub)
    # expect to create a security group on the VPC
    stubs.create_sg_echo(ec2_client_stub, DEFAULT_SG)
    # expect new security group details to be retrieved after creation
    stubs.describe_sgs_on_vpc(
        ec2_client_stub,
        [DEFAULT_SUBNET["VpcId"]],
        [DEFAULT_SG],
    )

    # given no existing default security group inbound rules...
    # expect to authorize all default inbound rules
    stubs.authorize_sg_ingress(
        ec2_client_stub,
        DEFAULT_SG_WITH_RULES,
    )

    # expect another call to describe the above security group while checking
    # a second time if it has ip_permissions set ("if not sg.ip_permissions")
    stubs.describe_an_sg_2(
        ec2_client_stub,
        DEFAULT_SG_WITH_RULES,
    )

    # given our mocks and an example config file as input...
    # expect the config to be loaded, validated, and bootstrapped successfully
    config = helpers.bootstrap_aws_example_config_file("example-full.yaml")
    _get_subnets_or_die.cache_clear()

    # We've filtered down to only one subnet id -- only one of the thousand
    # subnets generated by ec2.subnets.all() belongs to the right VPC.
    for node_type in config["available_node_types"].values():
        node_config = node_type["node_config"]
        assert node_config["SubnetIds"] == [DEFAULT_SUBNET["SubnetId"]]
        assert node_config["SecurityGroupIds"] == [DEFAULT_SG["GroupId"]]


@pytest.mark.parametrize(
    "correct_az",
    [True, False],
)
def test_create_sg_different_vpc_same_rules(
    iam_client_stub, ec2_client_stub, correct_az: bool
):
    # use default stubs to skip ahead to security group configuration
    stubs.skip_to_configure_sg(ec2_client_stub, iam_client_stub)

    default_subnet = copy.deepcopy(DEFAULT_SUBNET)
    if not correct_az:
        default_subnet["AvailabilityZone"] = "us-west-2b"

    # given head and worker nodes with custom subnets defined...
    # expect to second describe the head subnet ID
    stubs.describe_subnets_echo(ec2_client_stub, [default_subnet])
    # expect to first describe the worker subnet ID
    stubs.describe_subnets_echo(ec2_client_stub, [AUX_SUBNET])
    # given no existing security groups within the VPC...
    stubs.describe_no_security_groups(ec2_client_stub)
    # expect to first create a security group on the worker node VPC
    stubs.create_sg_echo(ec2_client_stub, DEFAULT_SG_AUX_SUBNET)
    # expect new worker security group details to be retrieved after creation
    stubs.describe_sgs_on_vpc(
        ec2_client_stub,
        [AUX_SUBNET["VpcId"]],
        [DEFAULT_SG_AUX_SUBNET],
    )
    # expect to second create a security group on the head node VPC
    stubs.create_sg_echo(ec2_client_stub, DEFAULT_SG)
    # expect new head security group details to be retrieved after creation
    stubs.describe_sgs_on_vpc(
        ec2_client_stub,
        [DEFAULT_SUBNET["VpcId"]],
        [DEFAULT_SG],
    )

    # given no existing default head security group inbound rules...
    # expect to authorize all default head inbound rules
    stubs.authorize_sg_ingress(
        ec2_client_stub,
        DEFAULT_SG_DUAL_GROUP_RULES,
    )
    # given no existing default worker security group inbound rules...
    # expect to authorize all default worker inbound rules
    stubs.authorize_sg_ingress(
        ec2_client_stub,
        DEFAULT_SG_WITH_RULES_AUX_SUBNET,
    )

    # given our mocks and an example config file as input...
    # expect the config to be loaded, validated, and bootstrapped successfully
    error = None
    try:
        config = helpers.bootstrap_aws_example_config_file("example-subnets.yaml")
    except ClickException as e:
        error = e

    _get_subnets_or_die.cache_clear()

    if not correct_az:
        assert isinstance(error, ClickException), "Did not get a ClickException!"
        iam_client_stub._queue.clear()
        ec2_client_stub._queue.clear()
        return

    # expect the bootstrapped config to show different head and worker security
    # groups residing on different subnets
    for node_type_key, node_type in config["available_node_types"].items():
        node_config = node_type["node_config"]
        security_group_ids = node_config["SecurityGroupIds"]
        subnet_ids = node_config["SubnetIds"]
        if node_type_key == config["head_node_type"]:
            assert security_group_ids == [DEFAULT_SG["GroupId"]]
            assert subnet_ids == [DEFAULT_SUBNET["SubnetId"]]
        else:
            assert security_group_ids == [AUX_SG["GroupId"]]
            assert subnet_ids == [AUX_SUBNET["SubnetId"]]

    # expect no pending responses left in IAM or EC2 client stub queues
    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_create_sg_with_custom_inbound_rules_and_name(iam_client_stub, ec2_client_stub):
    # use default stubs to skip ahead to security group configuration
    stubs.skip_to_configure_sg(ec2_client_stub, iam_client_stub)

    # expect to describe the head subnet ID
    stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
    # given no existing security groups within the VPC...
    stubs.describe_no_security_groups(ec2_client_stub)
    # expect to create a security group on the head node VPC
    stubs.create_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME)
    # expect new head security group details to be retrieved after creation
    stubs.describe_sgs_on_vpc(
        ec2_client_stub,
        [DEFAULT_SUBNET["VpcId"]],
        [DEFAULT_SG_WITH_NAME],
    )

    # given custom existing default head security group inbound rules...
    # expect to authorize both default and custom inbound rules
    stubs.authorize_sg_ingress(
        ec2_client_stub,
        DEFAULT_SG_WITH_NAME_AND_RULES,
    )

    # given the prior modification to the head security group...
    # expect the next read of a head security group property to reload it
    stubs.describe_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME_AND_RULES)

    _get_subnets_or_die.cache_clear()
    # given our mocks and an example config file as input...
    # expect the config to be loaded, validated, and bootstrapped successfully
    config = helpers.bootstrap_aws_example_config_file("example-security-group.yaml")

    # expect the bootstrapped config to have the custom security group...
    # name and in bound rules
    assert (
        config["provider"]["security_group"]["GroupName"]
        == DEFAULT_SG_WITH_NAME_AND_RULES["GroupName"]
    )
    assert (
        config["provider"]["security_group"]["IpPermissions"] == CUSTOM_IN_BOUND_RULES
    )

    # expect no pending responses left in IAM or EC2 client stub queues
    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_subnet_given_head_and_worker_sg(iam_client_stub, ec2_client_stub):
    stubs.configure_iam_role_default(iam_client_stub)
    stubs.configure_key_pair_default(ec2_client_stub)

    # list a security group and a thousand subnets in different vpcs
    stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG)
    stubs.describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub)

    config = helpers.bootstrap_aws_example_config_file(
        "example-head-and-worker-security-group.yaml"
    )

    # check that just the single subnet in the right vpc is filled
    for node_type in config["available_node_types"].values():
        node_config = node_type["node_config"]
        assert node_config["SubnetIds"] == [DEFAULT_SUBNET["SubnetId"]]

    # expect no pending responses left in IAM or EC2 client stub queues
    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


# Parametrize across multiple regions, since default AMI is different in each
@pytest.mark.parametrize(
    "iam_client_stub,ec2_client_stub,region",
    [3 * (region,) for region in DEFAULT_AMI],
    indirect=["iam_client_stub", "ec2_client_stub"],
)
def test_fills_out_amis_and_iam(iam_client_stub, ec2_client_stub, region):
    # Set up expected key pair for specific region
    region_key_pair = DEFAULT_KEY_PAIR.copy()
    region_key_pair["KeyName"] = DEFAULT_KEY_PAIR["KeyName"].replace(
        "us-west-2", region
    )

    # Setup stubs to mock out boto3
    stubs.configure_iam_role_default(iam_client_stub)
    stubs.configure_key_pair_default(
        ec2_client_stub, region=region, expected_key_pair=region_key_pair
    )
    stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG)
    stubs.configure_subnet_default(ec2_client_stub)

    config = helpers.load_aws_example_config_file("example-full.yaml")
    head_node_config = config["available_node_types"]["ray.head.default"]["node_config"]
    worker_node_config = config["available_node_types"]["ray.worker.default"][
        "node_config"
    ]

    del head_node_config["ImageId"]
    del worker_node_config["ImageId"]

    # Pass in SG for stub to work
    head_node_config["SecurityGroupIds"] = ["sg-1234abcd"]
    worker_node_config["SecurityGroupIds"] = ["sg-1234abcd"]

    config["provider"]["region"] = region

    defaults_filled = bootstrap_aws(config)

    ami = DEFAULT_AMI.get(defaults_filled.get("provider", {}).get("region"))

    for node_type in defaults_filled["available_node_types"].values():
        node_config = node_type["node_config"]
        assert node_config.get("ImageId") == ami

    # Correctly configured IAM role
    assert defaults_filled["head_node"]["IamInstanceProfile"] == {
        "Arn": DEFAULT_INSTANCE_PROFILE["Arn"]
    }
    # Workers of the head's type do not get the IAM role.
    head_type = config["head_node_type"]
    assert (
        "IamInstanceProfile" not in defaults_filled["available_node_types"][head_type]
    )

    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_iam_already_configured(iam_client_stub, ec2_client_stub):
    """
    Checks that things work as expected when IAM role is supplied by user.
    """
    stubs.configure_key_pair_default(ec2_client_stub)
    stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG)
    stubs.configure_subnet_default(ec2_client_stub)

    config = helpers.load_aws_example_config_file("example-full.yaml")
    head_node_config = config["available_node_types"]["ray.head.default"]["node_config"]
    worker_node_config = config["available_node_types"]["ray.worker.default"][
        "node_config"
    ]

    head_node_config["IamInstanceProfile"] = "mock_profile"

    # Pass in SG for stub to work
    head_node_config["SecurityGroupIds"] = ["sg-1234abcd"]
    worker_node_config["SecurityGroupIds"] = ["sg-1234abcd"]

    defaults_filled = bootstrap_aws(config)
    filled_head = defaults_filled["available_node_types"]["ray.head.default"][
        "node_config"
    ]
    assert filled_head["IamInstanceProfile"] == "mock_profile"
    assert "IamInstanceProfile" not in defaults_filled["head_node"]

    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_create_sg_multinode(iam_client_stub, ec2_client_stub):
    """
    Test AWS Bootstrap logic when config being bootstrapped has the
    following properties:

    (1) auth config does not specify ssh key path
    (2) available_node_types is provided
    (3) security group name and ip permissions set in provider field
    (4) Available node types have SubnetIds field set and this
        field is of form SubnetIds: [subnet-xxxxx].
        Both node types specify the same subnet-xxxxx.

    Tests creation of a security group and key pair under these conditions.
    """

    # Generate a config of the desired form.
    subnet_id = DEFAULT_SUBNET["SubnetId"]

    # security group info to go in provider field
    provider_data = helpers.load_aws_example_config_file("example-security-group.yaml")[
        "provider"
    ]

    # a multi-node-type config -- will add head/worker stuff and security group
    # info to this.
    base_config = helpers.load_aws_example_config_file("example-full.yaml")

    config = copy.deepcopy(base_config)
    # Add security group data
    config["provider"] = provider_data
    # Add head and worker fields.
    head_node_config = config["available_node_types"]["ray.head.default"]["node_config"]
    worker_node_config = config["available_node_types"]["ray.worker.default"][
        "node_config"
    ]
    head_node_config["SubnetIds"] = [subnet_id]
    worker_node_config["SubnetIds"] = [subnet_id]

    # Generate stubs
    stubs.configure_iam_role_default(iam_client_stub)
    stubs.configure_key_pair_default(ec2_client_stub)

    # Only one of these (the one specified in the available_node_types)
    # is in the correct vpc.
    # This list of subnets is generated by the ec2.subnets.all() call
    # and then ignored, since available_node_types already specify
    # subnet_ids.
    stubs.describe_a_thousand_subnets_in_different_vpcs(ec2_client_stub)

    # The rest of the stubbing logic is copied from
    # test_create_sg_with_custom_inbound_rules_and_name.

    # expect to describe the head subnet ID
    stubs.describe_subnets_echo(ec2_client_stub, [DEFAULT_SUBNET])
    # given no existing security groups within the VPC...
    stubs.describe_no_security_groups(ec2_client_stub)
    # expect to create a security group on the head node VPC
    stubs.create_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME)
    # expect new head security group details to be retrieved after creation
    stubs.describe_sgs_on_vpc(
        ec2_client_stub,
        [DEFAULT_SUBNET["VpcId"]],
        [DEFAULT_SG_WITH_NAME],
    )

    # given custom existing default head security group inbound rules...
    # expect to authorize both default and custom inbound rules
    stubs.authorize_sg_ingress(
        ec2_client_stub,
        DEFAULT_SG_WITH_NAME_AND_RULES,
    )

    # given the prior modification to the head security group...
    # expect the next read of a head security group property to reload it
    stubs.describe_sg_echo(ec2_client_stub, DEFAULT_SG_WITH_NAME_AND_RULES)

    _get_subnets_or_die.cache_clear()

    # given our mocks and the config as input...
    # expect the config to be validated and bootstrapped successfully
    bootstrapped_config = helpers.bootstrap_aws_config(config)

    # expect the bootstrapped config to have the custom security group...
    # name and in bound rules
    assert (
        bootstrapped_config["provider"]["security_group"]["GroupName"]
        == DEFAULT_SG_WITH_NAME_AND_RULES["GroupName"]
    )
    assert (
        bootstrapped_config["provider"]["security_group"]["IpPermissions"]
        == CUSTOM_IN_BOUND_RULES
    )

    # Confirming correct security group got filled for head and workers
    sg_id = DEFAULT_SG["GroupId"]
    for node_type in bootstrapped_config["available_node_types"].values():
        node_config = node_type["node_config"]
        assert node_config["SecurityGroupIds"] == [sg_id]

    # Confirming boostrap config updates available node types with
    # default KeyName
    for node_type in bootstrapped_config["available_node_types"].values():
        node_config = node_type["node_config"]
        assert node_config["KeyName"] == DEFAULT_KEY_PAIR["KeyName"]

    # Confirm security group is in the right VPC.
    # (Doesn't really confirm anything except for the structure of this test
    # data.)
    bootstrapped_head_type = bootstrapped_config["head_node_type"]
    bootstrapped_types = bootstrapped_config["available_node_types"]
    bootstrapped_head_config = bootstrapped_types[bootstrapped_head_type]["node_config"]
    assert DEFAULT_SG["VpcId"] == DEFAULT_SUBNET["VpcId"]
    assert DEFAULT_SUBNET["SubnetId"] == bootstrapped_head_config["SubnetIds"][0]

    # ssh private key filled in
    assert "ssh_private_key" in bootstrapped_config["auth"]

    # expect no pending responses left in IAM or EC2 client stub queues
    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_missing_keyname(iam_client_stub, ec2_client_stub):
    config = helpers.load_aws_example_config_file("example-full.yaml")
    config["auth"]["ssh_private_key"] = "/path/to/private/key"
    head_node_config = config["available_node_types"]["ray.head.default"]["node_config"]
    worker_node_config = config["available_node_types"]["ray.worker.default"][
        "node_config"
    ]

    # Setup stubs to mock out boto3. Should fail on assertion after
    # checking KeyName/UserData.
    stubs.configure_iam_role_default(iam_client_stub)

    missing_user_data_config = copy.deepcopy(config)
    with pytest.raises(AssertionError):
        # Config specified ssh_private_key, but missing KeyName/UserData in
        # node configs
        bootstrap_aws(missing_user_data_config)

    # Pass in SG for stub to work
    head_node_config["SecurityGroupIds"] = ["sg-1234abcd"]
    worker_node_config["SecurityGroupIds"] = ["sg-1234abcd"]

    # Set UserData for both node configs
    head_node_config["UserData"] = {"someKey": "someValue"}
    worker_node_config["UserData"] = {"someKey": "someValue"}

    # Stubs to mock out boto3. Should no longer fail on assertion
    # and go on to describe security groups + configure subnet
    stubs.configure_iam_role_default(iam_client_stub)
    stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG)
    stubs.configure_subnet_default(ec2_client_stub)

    # Should work without error now that UserData is set
    bootstrap_aws(config)

    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_log_to_cli(iam_client_stub, ec2_client_stub):
    config = helpers.load_aws_example_config_file("example-full.yaml")

    head_node_config = config["available_node_types"]["ray.head.default"]["node_config"]
    worker_node_config = config["available_node_types"]["ray.worker.default"][
        "node_config"
    ]

    # Pass in SG for stub to work
    head_node_config["SecurityGroupIds"] = ["sg-1234abcd"]
    worker_node_config["SecurityGroupIds"] = ["sg-1234abcd"]

    stubs.configure_iam_role_default(iam_client_stub)
    stubs.configure_key_pair_default(ec2_client_stub)
    stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG)
    stubs.configure_subnet_default(ec2_client_stub)

    config = helpers.bootstrap_aws_config(config)

    # Only side-effect is to print logs to cli, called just to
    # check that it runs without error
    log_to_cli(config)
    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()


def test_network_interfaces(
    ec2_client_stub,
    iam_client_stub,
    ec2_client_stub_fail_fast,
    ec2_client_stub_max_retries,
):

    # use default stubs to skip ahead to subnet configuration
    stubs.configure_iam_role_default(iam_client_stub)
    stubs.configure_key_pair_default(ec2_client_stub)

    # given the security groups associated with our network interfaces...
    sgids = ["sg-00000000", "sg-11111111", "sg-22222222", "sg-33333333"]
    security_groups = []
    suffix = 0
    for sgid in sgids:
        sg = copy.deepcopy(DEFAULT_SG)
        sg["GroupName"] += f"-{suffix}"
        sg["GroupId"] = sgid
        security_groups.append(sg)
        suffix += 1
    # expect to describe all security groups to ensure they share the same VPC
    stubs.describe_sgs_by_id(ec2_client_stub, sgids, security_groups)

    # use a default stub to skip subnet configuration
    stubs.configure_subnet_default(ec2_client_stub)
    stubs.describe_subnets_echo(
        ec2_client_stub,
        [DEFAULT_SUBNET, {**DEFAULT_SUBNET, "SubnetId": "subnet-11111111"}],
    )
    stubs.describe_subnets_echo(
        ec2_client_stub, [{**DEFAULT_SUBNET, "SubnetId": "subnet-22222222"}]
    )
    stubs.describe_subnets_echo(
        ec2_client_stub, [{**DEFAULT_SUBNET, "SubnetId": "subnet-33333333"}]
    )

    # given our mocks and an example config file as input...
    # expect the config to be loaded, validated, and bootstrapped successfully
    config = helpers.bootstrap_aws_example_config_file(
        "example-network-interfaces.yaml"
    )

    # instantiate a new node provider
    new_provider = _get_node_provider(
        config["provider"],
        DEFAULT_CLUSTER_NAME,
        False,
    )

    for name, node_type in config["available_node_types"].items():
        node_cfg = node_type["node_config"]
        tags = helpers.node_provider_tags(config, name)
        # given our bootstrapped node config as input to create a new node...
        # expect to first describe all stopped instances that could be reused
        stubs.describe_instances_with_any_filter_consumer(ec2_client_stub_max_retries)
        # given no stopped EC2 instances to reuse...
        # expect to create new nodes with the given network interface config
        stubs.run_instances_with_network_interfaces_consumer(
            ec2_client_stub_fail_fast,
            node_cfg["NetworkInterfaces"],
        )
        new_provider.create_node(node_cfg, tags, 1)

    iam_client_stub.assert_no_pending_responses()
    ec2_client_stub.assert_no_pending_responses()
    ec2_client_stub_fail_fast.assert_no_pending_responses()
    ec2_client_stub_max_retries.assert_no_pending_responses()


def test_network_interface_conflict_keys():
    # If NetworkInterfaces are defined, SubnetId and SecurityGroupIds
    # can't be specified in the same node type config.
    conflict_kv_pairs = [
        ("SubnetId", "subnet-0000000"),
        ("SubnetIds", ["subnet-0000000", "subnet-1111111"]),
        ("SecurityGroupIds", ["sg-1234abcd", "sg-dcba4321"]),
    ]
    expected_error_msg = (
        "If NetworkInterfaces are defined, subnets and "
        "security groups must ONLY be given in each "
        "NetworkInterface."
    )
    for conflict_kv_pair in conflict_kv_pairs:
        config = helpers.load_aws_example_config_file("example-network-interfaces.yaml")
        head_name = config["head_node_type"]
        head_node_cfg = config["available_node_types"][head_name]["node_config"]
        head_node_cfg[conflict_kv_pair[0]] = conflict_kv_pair[1]
        with pytest.raises(ValueError, match=expected_error_msg):
            helpers.bootstrap_aws_config(config)


def test_network_interface_missing_subnet():
    # If NetworkInterfaces are defined, each must have a subnet ID
    expected_error_msg = (
        "NetworkInterfaces are defined but at least one is "
        "missing a subnet. Please ensure all interfaces "
        "have a subnet assigned."
    )
    config = helpers.load_aws_example_config_file("example-network-interfaces.yaml")
    for name, node_type in config["available_node_types"].items():
        node_cfg = node_type["node_config"]
        for network_interface_cfg in node_cfg["NetworkInterfaces"]:
            network_interface_cfg.pop("SubnetId")
            with pytest.raises(ValueError, match=expected_error_msg):
                helpers.bootstrap_aws_config(config)


def test_network_interface_missing_security_group():
    # If NetworkInterfaces are defined, each must have security groups
    expected_error_msg = (
        "NetworkInterfaces are defined but at least one is "
        "missing a security group. Please ensure all "
        "interfaces have a security group assigned."
    )
    config = helpers.load_aws_example_config_file("example-network-interfaces.yaml")
    for name, node_type in config["available_node_types"].items():
        node_cfg = node_type["node_config"]
        for network_interface_cfg in node_cfg["NetworkInterfaces"]:
            network_interface_cfg.pop("Groups")
            with pytest.raises(ValueError, match=expected_error_msg):
                helpers.bootstrap_aws_config(config)


def test_launch_templates(
    ec2_client_stub, ec2_client_stub_fail_fast, ec2_client_stub_max_retries
):

    # given the launch template associated with our default head node type...
    # expect to first describe the default launch template by ID
    stubs.describe_launch_template_versions_by_id_default(ec2_client_stub, ["$Latest"])
    # given the launch template associated with our default worker node type...
    # expect to next describe the same default launch template by name
    stubs.describe_launch_template_versions_by_name_default(ec2_client_stub, ["2"])
    # use default stubs to skip ahead to subnet configuration
    stubs.configure_key_pair_default(ec2_client_stub)

    # given the security groups associated with our launch template...
    sgids = [DEFAULT_SG["GroupId"]]
    security_groups = [DEFAULT_SG]
    # expect to describe all security groups to ensure they share the same VPC
    stubs.describe_sgs_by_id(ec2_client_stub, sgids, security_groups)

    # use a default stub to skip subnet configuration
    stubs.configure_subnet_default(ec2_client_stub)

    # given our mocks and an example config file as input...
    # expect the config to be loaded, validated, and bootstrapped successfully
    config = helpers.bootstrap_aws_example_config_file("example-launch-templates.yaml")

    # instantiate a new node provider
    new_provider = _get_node_provider(
        config["provider"],
        DEFAULT_CLUSTER_NAME,
        False,
    )

    max_count = 1
    for name, node_type in config["available_node_types"].items():
        # given our bootstrapped node config as input to create a new node...
        # expect to first describe all stopped instances that could be reused
        stubs.describe_instances_with_any_filter_consumer(ec2_client_stub_max_retries)
        # given no stopped EC2 instances to reuse...
        # expect to create new nodes with the given launch template config
        node_cfg = node_type["node_config"]
        stubs.run_instances_with_launch_template_consumer(
            ec2_client_stub_fail_fast,
            config,
            node_cfg,
            name,
            DEFAULT_LT["LaunchTemplateData"],
            max_count,
        )
        tags = helpers.node_provider_tags(config, name)
        new_provider.create_node(node_cfg, tags, max_count)

    ec2_client_stub.assert_no_pending_responses()
    ec2_client_stub_fail_fast.assert_no_pending_responses()
    ec2_client_stub_max_retries.assert_no_pending_responses()


@pytest.mark.parametrize("num_on_demand_nodes", [0, 1001, 9999])
@pytest.mark.parametrize("num_spot_nodes", [0, 1001, 9999])
@pytest.mark.parametrize("stop", [True, False])
def test_terminate_nodes(num_on_demand_nodes, num_spot_nodes, stop):
    # This node makes sure that we stop or terminate all the nodes we're
    # supposed to stop or terminate when we call "terminate_nodes". This test
    # alse makes sure that we don't try to stop or terminate too many nodes in
    # a single EC2 request. By default, only 1000 nodes can be
    # stopped/terminated in one request. To terminate more nodes, we must break
    # them up into multiple smaller requests.
    #
    # "num_on_demand_nodes" is the number of on-demand nodes to stop or
    #   terminate.
    # "num_spot_nodes" is the number of on-demand nodes to terminate.
    # "stop" is True if we want to stop nodes, and False to terminate nodes.
    #   Note that spot instances are always terminated, even if "stop" is True.

    # Generate a list of unique instance ids to terminate
    on_demand_nodes = {"i-{:017d}".format(i) for i in range(num_on_demand_nodes)}
    spot_nodes = {
        "i-{:017d}".format(i + num_on_demand_nodes) for i in range(num_spot_nodes)
    }
    node_ids = list(on_demand_nodes.union(spot_nodes))

    with patch("ray.autoscaler._private.aws.node_provider.make_ec2_resource"):
        provider = AWSNodeProvider(
            provider_config={"region": "nowhere", "cache_stopped_nodes": stop},
            cluster_name="default",
        )

    # "_get_cached_node" is used by the AWSNodeProvider to determine whether a
    # node is a spot instance or an on-demand instance.
    def mock_get_cached_node(node_id):
        result = Mock()
        result.spot_instance_request_id = (
            "sir-08b93456" if node_id in spot_nodes else ""
        )
        return result

    provider._get_cached_node = mock_get_cached_node

    provider.terminate_nodes(node_ids)

    stop_calls = provider.ec2.meta.client.stop_instances.call_args_list
    terminate_calls = provider.ec2.meta.client.terminate_instances.call_args_list

    nodes_to_stop = set()
    nodes_to_terminate = spot_nodes

    if stop:
        nodes_to_stop.update(on_demand_nodes)
    else:
        nodes_to_terminate.update(on_demand_nodes)

    for calls, nodes_to_include_in_call in (stop_calls, nodes_to_stop), (
        terminate_calls,
        nodes_to_terminate,
    ):
        nodes_included_in_call = set()
        for call in calls:
            assert len(call[1]["InstanceIds"]) <= provider.max_terminate_nodes
            nodes_included_in_call.update(call[1]["InstanceIds"])

        assert nodes_to_include_in_call == nodes_included_in_call


def test_use_subnets_ordered_by_az(ec2_client_stub):
    """
    This test validates that when bootstrap_aws populates the SubnetIds field,
    the subnets are ordered the same way as availability zones.

    """
    # Add a response with a twenty subnets round-robined across the 4 AZs in
    # `us-west-2` (a,b,c,d). At the end we should only have 15 subnets, ordered
    # first from `us-west-2c`, then `us-west-2d`, then `us-west-2a`.
    stubs.describe_twenty_subnets_in_different_azs(ec2_client_stub)

    base_config = helpers.load_aws_example_config_file("example-full.yaml")
    base_config["provider"]["availability_zone"] = "us-west-2c,us-west-2d,us-west-2a"
    config = _configure_subnet(base_config)

    # We've filtered down to only subnets in 2c, 2d & 2a
    for node_type in config["available_node_types"].values():
        node_config = node_type["node_config"]
        assert len(node_config["SubnetIds"]) == 15
        offsets = [int(s.split("-")[1]) % 4 for s in node_config["SubnetIds"]]
        assert set(offsets[:5]) == {2}, "First 5 should be in us-west-2c"
        assert set(offsets[5:10]) == {3}, "Next 5 should be in us-west-2d"
        assert set(offsets[10:15]) == {0}, "Last 5 should be in us-west-2a"


if __name__ == "__main__":
    import sys

    sys.exit(pytest.main(["-v", __file__]))
