#!/usr/bin/env python
#

"""
ec2launch: Launch Amazon AWS EC2 instance
"""

import collections
import json
import os
import random
import re
import sys
import time
from optparse import OptionParser

import boto
import ec2common

import gtermapi

ssh_dir = os.path.expanduser("~/.ssh")
gterm_user = "ubuntu"

usage = "usage: %prog [-h ... ] <instance-tag|full-domain-name>"

form_parser = gtermapi.FormParser(usage=usage, title="Create Amazon EC2 instance with hostname (e.g., tagname.graphterm.net): ", command="ec2launch -f")
form_parser.add_argument(label="", help="Instance hostname")
form_parser.add_option("type", ("t1.micro", "m1.small", "m1.medium", "m1.large"), help="Instance type")
form_parser.add_option("key_name", "ec2key", help="Instance management SSH key name")
form_parser.add_option("ami", "ami-82fa58eb", help="Instance OS (default: Ubuntu 12.04LTS)")
form_parser.add_option("auth_code", "none", help="User authentication code (single space to suppress graphterm startup)")
form_parser.add_option("server_secret", "", help="Server secret for host authentication (to connect from anywhere)")
form_parser.add_option("connect_to_gterm", "", help="Full domain name of GraphTerm server to connect to")
form_parser.add_option("install", "", help="Tar file to install on startup (with optional '; command' suffix)")
form_parser.add_option("copy_files", "", help="List of files to copy (space separated, e.g., .twitter_auth)")
form_parser.add_option("pylab", False, help="Install PyLab")
form_parser.add_option("R", False, help="Install R")
form_parser.add_option("oshell", False, help="Enable OTrace shell")
form_parser.add_option("dry_run", False, help="Dry run")
form_parser.add_option("verbose", False, help="Verbose")
form_parser.add_option("wildcard", False, help="Enable wildcard subdomains")

form_parser.add_option("fullpage", False, short="f", help="Fullpage display", raw=True)
form_parser.add_option("text", False, short="t", help="Text only", raw=True)

(options, args) = form_parser.parse_args()

if not gtermapi.Lterm_cookie or not sys.stdout.isatty():
    options.text = True

if not args:
    if options.text:
        print >> sys.stderr, form_parser.get_usage()
        sys.exit(1)
    gtermapi.write_form(form_parser.create_form(), command="ec2launch -f")
    sys.exit(1)
    
instance_tag = args[0]

if not re.match(r"^[a-z][a-z0-9\.\-]*$", instance_tag):
    raise Exception("Invalid characters in instance name: "+instance_tag)

instance_name, sep, instance_domain = instance_tag.partition(".")

key_file = os.path.join(ssh_dir, options.key_name+".pem")

if not os.path.exists(ssh_dir):
    print >> sys.stderr, "ec2launch: %s directory not found!" % ssh_dir
    sys.exit(1)

install_file, install_cmd, install_basename, install_baseroot, unarch_cmd = "", "", "", "", ""
copy_files = []
if options.install:
    install_file, sep, install_cmd = options.install.partition(";")
    if not os.path.isfile(install_file):
        raise Exception("Install file "+install_file+" not found")
    install_basename = os.path.basename(install_file)
    install_baseroot, ext = os.path.splitext(install_basename)
    if ext == ".gz":
        install_baseroot, ext = os.path.splitext(install_baseroot)
        if ext != ".tar":
            raise Exception("Invalid extension for install file: "+install_file)
        unarch_cmd = "tar zxf"
    elif ext == ".tgz":
        unarch_cmd = "tar zxf"
    elif ext == ".tar":
        unarch_cmd = "tar xf"
    elif ext == ".zip":
        unarch_cmd = "unzip"
    else:
        raise Exception("Invalid extension for install file: "+install_file)
    copy_files.append(install_file)

if options.copy_files:
    copy_files += options.copy_files.split()

GroupRule = collections.namedtuple("GroupRule",
                                   ["ip_protocol", "from_port", "to_port", "cidr_ip", "src_group_name"])

HostGroupName = "gtermhost"
ServerGroupName = "gtermserver"

SecurityGroups = [
    (HostGroupName, "GraphTerm host group", [
        GroupRule("tcp", "22", "22", "0.0.0.0/0", None),
        ]),

    (ServerGroupName, "GraphTerm server group", [
        GroupRule("tcp", "22", "22", "0.0.0.0/0", None),
        GroupRule("tcp", "80", "80", "0.0.0.0/0", None),
        GroupRule("tcp", "8888", "8888", "0.0.0.0/0", None),
        GroupRule("tcp", "8899", "8899", "0.0.0.0/0", None if options.server_secret else HostGroupName),
        GroupRule("tcp", "8900", "8900", "0.0.0.0/0", None),
        ],
     )
    ]

setup_file = "SETUP_OVER"

startup_commands = ["#!/bin/bash",
"set -e -x",
"apt-get update && apt-get upgrade -y",
"apt-get install -y python-setuptools",
"easy_install tornado",
]

install_gterm = not install_basename.startswith("graphterm")

if install_gterm:
    startup_commands += ["easy_install graphterm",
                         "gterm_setup"]

if options.pylab:
    startup_commands += ["apt-get install -y python-numpy python-scipy python-matplotlib python-scientific libnetcdf-dev netcdf-bin ipython"]
    startup_commands += ["easy_install pil qrcode"]

if options.R:
    startup_commands += ["apt-get install -y r-base libcurl4-openssl-dev libcairo2-dev libxt-dev"]

command_format = "sudo -u "+gterm_user+" -i %s --daemon=start"
if options.oshell:
    command_format += " --oshell"
if options.server_secret:
    command_format += ' --server_secret="'+options.server_secret+'"'

init_command = ""
fwall_command = ""
if options.connect_to_gterm:
    security_groups = [HostGroupName]
    server_name, sep, server_domain = options.connect_to_gterm.partition(".")
    if not server_domain:
        raise Exception("Must specify fully qualified domain name to connect to GraphTerm server")
    host_name = instance_name if instance_domain == server_domain else instance_tag
    init_command = (command_format % "gtermhost") + " --server_addr="+ options.connect_to_gterm+" "+host_name
    startup_commands += [ init_command ]
else:
    security_groups = [ServerGroupName]

    host_domain = instance_tag if instance_domain else "`curl http://169.254.169.254/latest/meta-data/public-hostname`"
    fwall_command = "iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to 8900"
    startup_commands += [ fwall_command ]
    if install_gterm and options.auth_code.strip():
        init_command = (command_format % "gtermserver") + ' --auth_code="'+options.auth_code+'" --host='+host_domain
        if options.wildcard:
            init_command += " --blob_host=wildcard"
        startup_commands += [ init_command ]

if init_command or fwall_command:
    startup_commands += [ "echo '"+init_command+"' > /etc/init.d/graphterm",
                          "echo '"+fwall_command+"' > /etc/init.d/graphterm",
                          "chmod +x /etc/init.d/graphterm",
                          "update-rc.d graphterm defaults"]

startup_commands += ["sudo -u "+gterm_user+" -i touch ~" + gterm_user + "/" + setup_file]

startup_script = "\n".join(startup_commands) + "\n"

if options.verbose:
    print >> sys.stderr, "Startup_commands:\n   "+ "\n   ".join(startup_commands) + "\n\n"

# Connect to EC2
ec2 = boto.connect_ec2()

route53conn = None
hosted_zone = None
nameservers = ""
if instance_domain:
    route53conn = ec2common.Route53Connection()
    hosted_zone = ec2common.get_hosted_zone(route53conn, instance_domain)
    if not hosted_zone:
        hosted_zone = ec2common.create_hosted_zone(route53conn, instance_domain)
        nameservers = ec2common.get_nameservers(route53conn, instance_domain)

def get_security_group(ec2conn, group_name):
    groups = [x for x in ec2conn.get_all_security_groups() if x.name == group_name]
    return groups[0] if groups else None

def create_or_update_security_group(ec2conn, group_name, description="", rules=[]):
    """Create (or update) security group"""
    group = get_security_group(ec2conn, group_name)
    new_rules = rules[:]
    if group:
        # Group already exists
        for rule in group.rules:
            # Check each rule
            cidr_ip = rule.grants[0].cidr_ip if rule.grants[0].cidr_ip else "0.0.0.0/0"
            src_group_name = None if rule.grants[0].cidr_ip else rule.grants[0].name
            old_rule = GroupRule(rule.ip_protocol,
                                 rule.from_port,
                                 rule.to_port,
                                 cidr_ip,
                                 src_group_name)
            if old_rule in new_rules:
                # Old rule still valid
                new_rules.remove(old_rule)
            else:
                # Old rule no longer valid
                group.revoke(ip_protocol=rule.ip_protocol,
                             from_port=rule.from_port,
                             to_port=rule.to_port,
                             cidr_ip=cidr_ip,
                             src_group=get_security_group(ec2conn, src_group_name) if src_group_name else None)
    else:
        group = ec2conn.create_security_group(group_name, description or group_name)
                                        

    for rule in new_rules:
        # Create new rules
        if rule.src_group_name:
            src_group = get_security_group(ec2conn, rule.src_group_name)
            if not src_group:
                raise Exception("Source group %s not found" % rule.src_group_name)
        else:
            src_group = None

        group.authorize(ip_protocol=rule.ip_protocol,
                        from_port=rule.from_port,
                        to_port=rule.to_port,
                        cidr_ip=rule.cidr_ip,
                        src_group=src_group)
    return group

# Create key pair, if needed
if not os.path.exists(key_file):
    key_pair = ec2.create_key_pair(options.key_name)
    key_pair.save(ssh_dir)
    os.chmod(key_file, 0600)

if options.dry_run:
    print >> sys.stderr, "run_instances:", dict(image_id=options.ami,
                                                instance_type=options.type,
                                                key_name=options.key_name,
                                                security_groups=security_groups)
    print >> sys.stderr, "startup_script:", startup_script
    sys.exit(1)

# Create security groups as needed
for group_name, description, rules in SecurityGroups:
    create_or_update_security_group(ec2, group_name, description=description, rules=rules)

# Launch instance
reservation = ec2.run_instances(image_id=options.ami,
                                instance_type=options.type,
                                key_name=options.key_name,
                                security_groups=security_groups,
                                user_data=startup_script)
instance = reservation.instances[0]

# Wait for instance to start running
Status_template =  """<em>Creating instance</em> <b>%s</b>: status=<b>%s</b> (waiting %ds)"""

status = instance.update()
start_time = time.time()

if options.text:
    print >> sys.stderr, "Waiting for instance to be created..."
    
while status == "pending":
    timesec = int(time.time() - start_time)
    if not options.text:
        gtermapi.write_pagelet(Status_template % (instance_tag, status, timesec), overwrite=True, display="fullpage")
    time.sleep(3)
    status = instance.update()

if status != "running":
    print >> sys.stderr, "ec2launch: ERROR Failed to launch instance: %s" % status
    sys.exit(1)

# Tag instance
instance.add_tag(instance_tag)

instance_id = reservation.id
instance_obj = None
all_instances = ec2.get_all_instances()
for r in all_instances:
    if r.id == instance_id:
        instance_obj = r.instances[0]
        break

if not instance_obj:
    print >> sys.stderr, "ec2launch: ERROR Unable to find launched instance: %s" % status
    sys.exit(1)

public_dns_name = instance_obj.public_dns_name

if hosted_zone:
    # Create new CNAME entry pointing to instance public DNS
    ec2common.cname(route53conn, hosted_zone, instance_tag, public_dns_name)
    if options.wildcard:
        ec2common.cname(route53conn, hosted_zone, "*."+instance_tag, public_dns_name)

print >> sys.stderr, "Created EC2 instance %s: id=%s, public=%s" % (instance_tag, instance_id, public_dns_name)
if options.verbose:
    print >> sys.stderr, "To check status, type:"
    print >> sys.stderr, "   ec2ssh ubuntu@"+(instance_tag if instance_domain else public_dns_name)
    print >> sys.stderr, "   sudo su -"
    print >> sys.stderr, "   tail -f /var/log/dpkg.log"

ssh_cmd_args = ["ssh", "-i", os.path.expanduser("~/.ssh/ec2key.pem"),
                "-o", "StrictHostKeyChecking=no", gterm_user+"@"+public_dns_name]
wait_cmd = "while [ ! -f "+setup_file+" ]; do sleep 1; done"

if nameservers:
    print >> sys.stderr, "***NOTE*** Please add the following nameservers for domain "+instance_domain
    print >> sys.stderr, "           ", nameservers
    print >> sys.stderr, ""

print >> sys.stderr, "Waiting for remote setup to complete..."
time.sleep(30)

if copy_files:
    scp_cmd_args = ["scp", "-i", os.path.expanduser("~/.ssh/ec2key.pem"), "-o", "StrictHostKeyChecking=no"] + copy_files + [gterm_user+"@"+public_dns_name+":"]
    print >> sys.stderr, " ".join(scp_cmd_args)
    out, err = gtermapi.command_output(scp_cmd_args, timeout=60)
    if out:
        print out
    if err:
        print err
    if options.install:
        wait_cmd += "; "+unarch_cmd+" "+install_basename+"; cd "+install_baseroot+"; sudo python setup.py install"
        if install_cmd:
            wait_cmd += "; "+install_cmd

print >> sys.stderr, " ".join(ssh_cmd_args)
print >> sys.stderr, wait_cmd
out, err = gtermapi.command_output(ssh_cmd_args + [wait_cmd], timeout=600)

if out:
    print out
if err:
    print >> sys.stderr, err
else:
    print >> sys.stderr, "Remote setup completed."
    
