#!/usr/bin/env python
#

"""
ec2launch: Launch Amazon AWS EC2 instance
"""

import collections
import json
import os
import random
import re
import sys
import time
from optparse import OptionParser

import boto
import ec2common

import gtermapi

ssh_dir = os.path.expanduser("~/.ssh")
gterm_user = "ubuntu"

usage = "usage: %prog [-h ... ] <instance-tag|full-domain-name>"

form_parser = gtermapi.FormParser(usage=usage, title="Create Amazon EC2 instance with the following hostname: ", command="ec2launch -f")
form_parser.add_argument(label="", help="Instance hostname")
form_parser.add_option("type", ("t1.micro", "m1.small", "m1.medium", "m1.large"), help="Instance type")
form_parser.add_option("key_name", "ec2key", help="Instance management SSH key name")
form_parser.add_option("ami", "ami-82fa58eb", help="Instance OS (default: Ubuntu 12.04LTS)")
form_parser.add_option("auth_code", "none", help="User authentication code")
form_parser.add_option("server_secret", "", help="Server secret for host authentication (to connect from anywhere)")
form_parser.add_option("connect_to_gterm", "", help="Full domain name of GraphTerm server to connect to")
form_parser.add_option("pylab", False, help="Install PyLab")
form_parser.add_option("oshell", False, help="Enable OTrace shell")
form_parser.add_option("dry_run", False, help="Dry run")

form_parser.add_option("fullpage", False, short="f", help="Fullpage display", raw=True)
form_parser.add_option("text", False, short="t", help="Text only", raw=True)

(options, args) = form_parser.parse_args()

if not gtermapi.Lterm_cookie or not sys.stdout.isatty():
    options.text = True

if not args:
    if options.text:
        print >> sys.stderr, form_parser.get_usage()
        sys.exit(1)
    gtermapi.write_form(form_parser.create_form(), command="ec2launch -f")
    sys.exit(1)
    
instance_tag = args[0]

if not re.match(r"^[a-z][a-z0-9\.\-]*$", instance_tag):
    raise Exception("Invalid characters in instance name: "+instance_tag)

instance_name, sep, instance_domain = instance_tag.partition(".")

key_file = os.path.join(ssh_dir, options.key_name+".pem")

if not os.path.exists(ssh_dir):
    print >> sys.stderr, "ec2launch: %s directory not found!" % ssh_dir
    sys.exit(1)

GroupRule = collections.namedtuple("GroupRule",
                                   ["ip_protocol", "from_port", "to_port", "cidr_ip", "src_group_name"])

HostGroupName = "gtermhost"
ServerGroupName = "gtermserver"

SecurityGroups = [
    (HostGroupName, "GraphTerm host group", [
        GroupRule("tcp", "22", "22", "0.0.0.0/0", None),
        ]),

    (ServerGroupName, "GraphTerm server group", [
        GroupRule("tcp", "22", "22", "0.0.0.0/0", None),
        GroupRule("tcp", "80", "80", "0.0.0.0/0", None),
        GroupRule("tcp", "8899", "8899", "0.0.0.0/0", None if options.server_secret else HostGroupName),
        ],
     )
    ]

setup_file = "SETUP_OVER"

startup_commands = ["#!/bin/bash",
"set -e -x",
"apt-get update && apt-get upgrade -y",
"apt-get install -y python-setuptools",
"easy_install tornado",
"easy_install graphterm",
"gterm_setup",
]

if options.pylab:
    startup_commands += ["apt-get install -y python-numpy python-scipy python-matplotlib"]

command_format = "sudo -u "+gterm_user+" -i %s --daemon=start"
if options.oshell:
    command_format += " --oshell"
if options.server_secret:
    command_format += ' --server_secret="'+options.server_secret+'"'

if options.connect_to_gterm:
    security_groups = [HostGroupName]
    server_name, sep, server_domain = options.connect_to_gterm.partition(".")
    if not server_domain:
        raise Exception("Must specify fully qualified domain name to connect to GraphTerm server")
    host_name = instance_name if instance_domain == server_domain else instance_tag
    startup_commands += [(command_format % "gtermhost") + " --server_addr="+ options.connect_to_gterm+" "+host_name]
else:
    security_groups = [ServerGroupName]
    host_domain = "`hostname`"
    startup_commands += ["iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to 8900"]
    startup_commands += [(command_format % "gtermserver") + ' --auth_code="'+options.auth_code+'" --host='+host_domain]

startup_commands += ["sudo -u "+gterm_user+" -i touch ~" + gterm_user + "/" + setup_file]

startup_script = "\n".join(startup_commands) + "\n"

# Connect to EC2
ec2 = boto.connect_ec2()

route53conn = None
hosted_zone = None
nameservers = ""
if instance_domain:
    route53conn = ec2common.Route53Connection()
    hosted_zone = ec2common.get_hosted_zone(route53conn, instance_domain)
    if not hosted_zone:
        hosted_zone = ec2common.create_hosted_zone(route53conn, instance_domain)
        nameservers = ec2common.get_nameservers(route53conn, instance_domain)

def get_security_group(ec2conn, group_name):
    groups = [x for x in ec2conn.get_all_security_groups() if x.name == group_name]
    return groups[0] if groups else None

def create_or_update_security_group(ec2conn, group_name, description="", rules=[]):
    """Create (or update) security group"""
    group = get_security_group(ec2conn, group_name)
    new_rules = rules[:]
    if group:
        # Group already exists
        for rule in group.rules:
            # Check each rule
            cidr_ip = rule.grants[0].cidr_ip if rule.grants[0].cidr_ip else "0.0.0.0/0"
            src_group_name = None if rule.grants[0].cidr_ip else rule.grants[0].name
            old_rule = GroupRule(rule.ip_protocol,
                                 rule.from_port,
                                 rule.to_port,
                                 cidr_ip,
                                 src_group_name)
            if old_rule in new_rules:
                # Old rule still valid
                new_rules.remove(old_rule)
            else:
                # Old rule no longer valid
                group.revoke(ip_protocol=rule.ip_protocol,
                             from_port=rule.from_port,
                             to_port=rule.to_port,
                             cidr_ip=cidr_ip,
                             src_group=get_security_group(ec2conn, src_group_name) if src_group_name else None)
    else:
        group = ec2conn.create_security_group(group_name, description or group_name)
                                        

    for rule in new_rules:
        # Create new rules
        if rule.src_group_name:
            src_group = get_security_group(ec2conn, rule.src_group_name)
            if not src_group:
                raise Exception("Source group %s not found" % rule.src_group_name)
        else:
            src_group = None

        group.authorize(ip_protocol=rule.ip_protocol,
                        from_port=rule.from_port,
                        to_port=rule.to_port,
                        cidr_ip=rule.cidr_ip,
                        src_group=src_group)
    return group

# Create key pair, if needed
if not os.path.exists(key_file):
    key_pair = ec2.create_key_pair(options.key_name)
    key_pair.save(ssh_dir)
    os.chmod(key_file, 0600)

if options.dry_run:
    print >> sys.stderr, "run_instances:", dict(image_id=options.ami,
                                                instance_type=options.type,
                                                key_name=options.key_name,
                                                security_groups=security_groups)
    print >> sys.stderr, "startup_script:", startup_script
    sys.exit(1)

# Create security groups as needed
for group_name, description, rules in SecurityGroups:
    create_or_update_security_group(ec2, group_name, description=description, rules=rules)

# Launch instance
reservation = ec2.run_instances(image_id=options.ami,
                                instance_type=options.type,
                                key_name=options.key_name,
                                security_groups=security_groups,
                                user_data=startup_script)
instance = reservation.instances[0]

# Wait for instance to start running
Status_template =  """<em>Creating instance</em> <b>%s</b>: status=<b>%s</b> (waiting %ds)"""

status = instance.update()
start_time = time.time()

if options.text:
    print >> sys.stderr, "Waiting for instance to be created..."
    
while status == "pending":
    timesec = int(time.time() - start_time)
    if not options.text:
        gtermapi.write_html(Status_template % (instance_tag, status, timesec), display="fullpage")
    time.sleep(3)
    status = instance.update()

if status != "running":
    print >> sys.stderr, "ec2launch: ERROR Failed to launch instance: %s" % status
    sys.exit(1)

# Tag instance
instance.add_tag(instance_tag)

instance_id = reservation.id
instance_obj = None
all_instances = ec2.get_all_instances()
for r in all_instances:
    if r.id == instance_id:
        instance_obj = r.instances[0]
        break

if not instance_obj:
    print >> sys.stderr, "ec2launch: ERROR Unable to find launched instance: %s" % status
    sys.exit(1)

public_dns_name = instance_obj.public_dns_name

if hosted_zone:
    # Create new CNAME entry pointing to instance public DNS
    ec2common.cname(route53conn, hosted_zone, instance_tag, public_dns_name)

print "Created EC2 instance %s: id=%s, public=%s" % (instance_tag, instance_id, public_dns_name)

ssh_cmd_args = ["ssh", "-i", os.path.expanduser("~/.ssh/ec2key.pem"),
                "-o", "StrictHostKeyChecking=no", gterm_user+"@"+public_dns_name]
wait_cmd = "while [ ! -f "+setup_file+" ]; do sleep 1; done"


if nameservers:
    print >> sys.stderr, "***NOTE*** Please add the following nameservers for domain "+instance_domain
    print >> sys.stderr, "           ", nameservers
    print >> sys.stderr, ""

print >> sys.stderr, "Waiting for remote setup to complete..."
time.sleep(30)
out, err = gtermapi.command_output(ssh_cmd_args + [wait_cmd], timeout=360)

if out:
    print out
if err:
    print >> sys.stderr, err
else:
    print >> sys.stderr, "Remote setup completed."
    
