#!/usr/bin/env python
#

"""
ec2launch: Launch Amazon AWS EC2 instance
"""

import collections
import json
import os
import random
import re
import sys
import time

import boto
import ec2common

import gterm

ssh_dir = os.path.expanduser("~/.ssh")
gterm_user = "ubuntu"

usage = "usage: %prog [-h ... ] <instance-tag|full-domain-name>"

form_parser = gterm.FormParser(usage=usage, title="Create Amazon EC2 instance with hostname (e.g., tagname.graphterm.net): ", command="ec2launch -f")
form_parser.add_argument(label="", help="Instance hostname")
form_parser.add_option("type", ("m3.medium", "m3.large", "c3.large", "r3.large", "r3.xlarge", "m1.small", "t1.micro"), help="Instance type")
form_parser.add_option("key_name", "ec2key", help="Instance management SSH key name")
form_parser.add_option("ami", "ami-2f8f9246", help="Instance OS (default: Ubuntu 12.04LTS amd64, us-east-1, rel. 20140408)")
form_parser.add_option("https", False, help="Use https for security")
form_parser.add_option("auth_type", ("local", "none", "name", "multiuser"), help="Authentication type (local/none/name/multiuser)")
form_parser.add_option("pylab", False, help="Install PyLab")
form_parser.add_option("netcdf", False, help="Install NetCDF")
form_parser.add_option("R", False, help="Install R")
form_parser.add_option("allow_embed", False, help="Allow cross-domain iframe embedding")
form_parser.add_option("allow_share", False, help="Allow sharing between users")
form_parser.add_option("notebook_server", False, help="Allow passord-proetected access to IPython Notebook server")
form_parser.add_option("logging", False, help="Enable logging")
form_parser.add_option("oshell", False, help="Enable OTrace shell")
form_parser.add_option("screen", False, help="Run server/host using GNU screen")
form_parser.add_option("other_opts", "", help="Other options for gtermserver (space-separated)")
form_parser.add_option("install", "", help="Tar file to install on startup (with optional '; command' suffix)")
form_parser.add_option("copy_files", "", help="List of files to copy (space separated, e.g., .twitter_auth)")
form_parser.add_option("connect_to_gterm", "", help="Full domain name of GraphTerm server to connect to (for hosts)")
form_parser.add_option("dry_run", False, help="Dry run")
form_parser.add_option("verbose", False, help="Verbose")
##form_parser.add_option("wildcard", False, help="Enable wildcard subdomains")

form_parser.add_option("form", False, help="Force form display", raw=True)
form_parser.add_option("fullpage", False, short="f", help="Fullpage display", raw=True)
form_parser.add_option("text", False, short="t", help="Text only", raw=True)

(options, args) = form_parser.parse_args()

if not gterm.Lterm_cookie or not sys.stdout.isatty():
    options.text = True

if not args or options.form:
    if options.text:
        print >> sys.stderr, form_parser.get_usage()
        sys.exit(1)
    gterm.write_form(form_parser.create_form(prefill=(options, args) if options.form else None), command="ec2launch -f")
    sys.exit(1)
    
instance_tag = args[0]

if not re.match(r"^[a-z][a-z0-9\.\-]*$", instance_tag):
    raise Exception("Invalid characters in instance name: "+instance_tag)

instance_name, sep, instance_domain = instance_tag.partition(".")

key_file = os.path.join(ssh_dir, options.key_name+".pem")

if not os.path.exists(ssh_dir):
    print >> sys.stderr, "ec2launch: %s directory not found!" % ssh_dir
    sys.exit(1)

install_file, install_cmd, install_basename, install_baseroot, unarch_cmd = "", "", "", "", ""
copy_files = []
if options.install:
    install_file, sep, install_cmd = options.install.partition(";")
    if not os.path.isfile(install_file):
        raise Exception("Install file "+install_file+" not found")
    install_basename = os.path.basename(install_file)
    install_baseroot, ext = os.path.splitext(install_basename)
    if ext == ".gz":
        install_baseroot, ext = os.path.splitext(install_baseroot)
        if ext != ".tar":
            raise Exception("Invalid extension for install file: "+install_file)
        unarch_cmd = "tar zxf"
    elif ext == ".tgz":
        unarch_cmd = "tar zxf"
    elif ext == ".tar":
        unarch_cmd = "tar xf"
    elif ext == ".zip":
        unarch_cmd = "unzip"
    else:
        raise Exception("Invalid extension for install file: "+install_file)
    copy_files.append(install_file)

if options.copy_files:
    copy_files += options.copy_files.split()

auth_type = options.auth_type.strip()
allow_host_connect = auth_type in ("local", "multiuser")

GroupRule = collections.namedtuple("GroupRule",
                                   ["ip_protocol", "from_port", "to_port", "cidr_ip", "src_group_name"])

host_group = "gtermhost"
server_group = "gtermserver"
nb_server_group = "gtermnbserver"

SecurityGroups = [
    (host_group, "GraphTerm host group", [
        GroupRule("tcp", "22", "22", "0.0.0.0/0", None),
        ]),
    (server_group, "GraphTerm server group", [
        GroupRule("tcp", "22", "22", "0.0.0.0/0", None),
        GroupRule("tcp", "80", "80", "0.0.0.0/0", None),
        GroupRule("tcp", "443", "443", "0.0.0.0/0", None),
        GroupRule("tcp", "8888", "8888", "0.0.0.0/0", None),
        GroupRule("tcp", "8899", "8899", "0.0.0.0/0", None if allow_host_connect else host_group),
        GroupRule("tcp", "8900", "8900", "0.0.0.0/0", None),
        ]),
    (nb_server_group, "GraphTerm notebook server group", [
        GroupRule("tcp", "22", "22", "0.0.0.0/0", None),
        GroupRule("tcp", "80", "80", "0.0.0.0/0", None),
        GroupRule("tcp", "443", "443", "0.0.0.0/0", None),
        GroupRule("tcp", "8899", "8899", "0.0.0.0/0", None if allow_host_connect else host_group),
        GroupRule("tcp", "8900", "8900", "0.0.0.0/0", None),
        GroupRule("tcp", "10000", "12000", "0.0.0.0/0", None),
        ]),
    ]

setup_file = "SETUP_STATUS"
setup_errfile = "SETUP_ERROR"

startup_commands = ["#!/bin/bash",
"AWS_PUBLIC_HOST=`curl http://169.254.169.254/latest/meta-data/public-hostname`",
"AWS_PUBLIC_IP=`curl http://169.254.169.254/latest/meta-data/public-ipv4`",
"AWS_LOCAL_IP=`curl http://169.254.169.254/latest/meta-data/local-ipv4`",
"cat > /root/ec2launch.sh << EOF",
"#!/bin/bash",
"set -e -x",
"apt-get update && apt-get upgrade -qy",
"apt-get install -qy python-setuptools python-pip",
"pip install tornado",
]

install_gterm = not install_basename.startswith("graphterm")

if install_gterm:
    startup_commands += ["pip install graphterm",
                         "gterm_setup"]

if options.pylab:
    startup_commands += ["apt-get install -qy python-numpy python-scipy python-matplotlib python-mpltoolkits.basemap python-scientific python-pandas python-jinja2 ipython"]
    startup_commands += ["pip install pil qrcode"]
    startup_commands += ["pip install --upgrade ipython pyzmq"]

if options.netcdf:
    startup_commands += ["apt-get install -qy libnetcdf-dev netcdf-bin libhdf5-serial-dev"]
    startup_commands += ["pip install netCDF4"]

if options.R:
    startup_commands += ["apt-get install -qy r-base libcurl4-openssl-dev libcairo2-dev libxt-dev"]

command_name = "gtermhost" if options.connect_to_gterm else "gtermserver"
command_line = "sudo -u "+gterm_user+" -i"
if options.screen:
    command_line += " screen -d -m"
    if options.logging:
        command_line += " -L"
    command_line += " -S " + command_name + " " + command_name
else:
    command_line += " " + command_name + " --daemon=start"

command_line += " --widget_port=-1"

if options.https:
    command_line += " --https"
if options.allow_embed:
    command_line += " --allow_embed"
if options.allow_share:
    command_line += " --allow_share"
if options.notebook_server:
    command_line += " --nb_server"
if options.other_opts:
    command_line += " " + options.other_opts.strip()
if options.logging:
    command_line += " --logging"
if options.oshell:
    command_line += " --oshell"
    if options.screen:
        command_line += " --oshell_input"

init_commands = []
fwall_commands = []
if options.connect_to_gterm:
    security_groups = [host_group]
    server_name, sep, server_domain = options.connect_to_gterm.partition(".")
    if not server_domain:
        raise Exception("Must specify fully qualified domain name to connect to GraphTerm server")
    host_name = instance_name if instance_domain == server_domain else instance_tag
    command_line += " --server_addr="+ options.connect_to_gterm+" "+host_name
    init_commands = [ command_line ]
else:
    security_groups = [nb_server_group if options.notebook_server else server_group]

    external_port = 443 if options.https else 80
    public_host = instance_tag if instance_domain else "${AWS_PUBLIC_HOST}"
    local_ip = "${AWS_LOCAL_IP}"
    fwall_commands = ["iptables -t nat -A PREROUTING -p tcp --dport %d -j REDIRECT --to %d" % (external_port, gterm.DEFAULT_HTTP_PORT),
                      "iptables -t nat -A OUTPUT --dst %s -p tcp --dport %d -j DNAT --to-destination %s:%d" % (local_ip, external_port, local_ip, gterm.DEFAULT_HTTP_PORT)]
    command_line += " --auth_type="+options.auth_type
    if auth_type == "multiuser":
        command_line += " --auto_users --super_users=ubuntu"
    command_line += " --host=%s --external_port=%d" % (public_host, external_port)
    ##if options.wildcard:
    ##    command_line += " --blob_host=wildcard"
    if install_gterm:
        init_commands = [ command_line ]

initwall_commands = fwall_commands + init_commands
if initwall_commands:
    startup_commands += initwall_commands
    for cmd in initwall_commands: 
       startup_commands += [ 'echo "'+cmd+'" > /etc/init.d/graphterm' ]
    startup_commands += [ "chmod +x /etc/init.d/graphterm",
                          "update-rc.d graphterm defaults"]

startup_commands += ["echo ''"]
startup_commands += ["echo WORKAROUND: Doing 'pip uninstall dap' to eliminate dap module loading errors for Ubuntu 12.04LTS"]
startup_commands += ["yes | pip uninstall dap"]
startup_commands += ["echo ec2launch SETUP COMPLETED"]

startup_commands += ["EOF",
                     "/bin/bash /root/ec2launch.sh >/root/ec2launch.log 2>&1",
                     "if [ $? -ne 0 ]; then",
                     "  sudo -u "+gterm_user+" -i touch ~" + gterm_user + "/" + setup_errfile,
                     "fi",
                     "sudo -u "+gterm_user+" -i touch ~" + gterm_user + "/" + setup_file,
                     "tail -10 /root/ec2launch.log > ~" + gterm_user + "/" + setup_file,
                     ]
startup_script = "\n".join(startup_commands) + "\n"

if options.verbose:
    print >> sys.stderr, "Startup_commands:\n   "+ "\n   ".join(startup_commands) + "\n\n"

# Connect to EC2
ec2 = boto.connect_ec2()

route53conn = None
hosted_zone = None
nameservers = ""
if instance_domain:
    route53conn = ec2common.Route53Connection()
    hosted_zone = ec2common.get_hosted_zone(route53conn, instance_domain)
    if not hosted_zone:
        hosted_zone = ec2common.create_hosted_zone(route53conn, instance_domain)
        nameservers = ec2common.get_nameservers(route53conn, instance_domain)

def get_security_group(ec2conn, group_name):
    groups = [x for x in ec2conn.get_all_security_groups() if x.name == group_name]
    return groups[0] if groups else None

def create_or_update_security_group(ec2conn, group_name, description="", rules=[]):
    """Create (or update) security group"""
    group = get_security_group(ec2conn, group_name)
    new_rules = rules[:]
    if group:
        # Group already exists
        for rule in group.rules:
            # Check each rule
            cidr_ip = rule.grants[0].cidr_ip if rule.grants[0].cidr_ip else "0.0.0.0/0"
            src_group_name = None if rule.grants[0].cidr_ip else rule.grants[0].name
            old_rule = GroupRule(rule.ip_protocol,
                                 rule.from_port,
                                 rule.to_port,
                                 cidr_ip,
                                 src_group_name)
            if old_rule in new_rules:
                # Old rule still valid
                new_rules.remove(old_rule)
            else:
                # Old rule no longer valid
                group.revoke(ip_protocol=rule.ip_protocol,
                             from_port=rule.from_port,
                             to_port=rule.to_port,
                             cidr_ip=cidr_ip,
                             src_group=get_security_group(ec2conn, src_group_name) if src_group_name else None)
    else:
        group = ec2conn.create_security_group(group_name, description or group_name)
                                        

    for rule in new_rules:
        # Create new rules
        if rule.src_group_name:
            src_group = get_security_group(ec2conn, rule.src_group_name)
            if not src_group:
                raise Exception("Source group %s not found" % rule.src_group_name)
        else:
            src_group = None

        try:
            group.authorize(ip_protocol=rule.ip_protocol,
                            from_port=rule.from_port,
                            to_port=rule.to_port,
                            cidr_ip=rule.cidr_ip,
                src_group=src_group)
        except Exception, excp:
            if options.verbose:
                print >> sys.stderr, "Rule creation error: "+str(excp)+"\n\n"

    return group

# Create key pair, if needed
if not os.path.exists(key_file):
    key_pair = ec2.create_key_pair(options.key_name)
    key_pair.save(ssh_dir)
    os.chmod(key_file, 0600)

wait_cmd = "while [ ! -f "+setup_file+" ]; do sleep 1; done; echo "+setup_file+"; cat "+setup_file
if copy_files and options.install:
    wait_cmd += "; if [ ! -f "+setup_errfile+" ]; then "+unarch_cmd+" "+install_basename+"; cd "+install_baseroot+"; sudo python setup.py install > install.log"
    if install_cmd:
        wait_cmd += "; "+install_cmd
    wait_cmd += "; fi"

if options.dry_run:
    print >> sys.stderr, "run_instances:", dict(image_id=options.ami,
                                                instance_type=options.type,
                                                key_name=options.key_name,
                                                security_groups=security_groups)
    print >> sys.stderr, "STARTUP_SCRIPT:\n", startup_script
    print >> sys.stderr, "WAIT_CMD:\n", wait_cmd
    sys.exit(1)

# Create security groups as needed
for group_name, description, rules in SecurityGroups:
    create_or_update_security_group(ec2, group_name, description=description, rules=rules)

# Launch instance
reservation = ec2.run_instances(image_id=options.ami,
                                instance_type=options.type,
                                key_name=options.key_name,
                                security_groups=security_groups,
                                user_data=startup_script)
instance = reservation.instances[0]

# Wait for instance to start running
Status_template =  """<em>Creating instance</em> <b>%s</b>: status=<b>%s</b> (waiting %ds)"""

status = instance.update()
start_time = time.time()

if options.text:
    print >> sys.stderr, "Waiting for instance to be created..."
    
while status == "pending":
    timesec = int(time.time() - start_time)
    if not options.text:
        gterm.write_pagelet(Status_template % (instance_tag, status, timesec), overwrite=True, display="fullpage")
    time.sleep(3)
    status = instance.update()

if status != "running":
    print >> sys.stderr, "ec2launch: ERROR Failed to launch instance: %s" % status
    sys.exit(1)

# Tag instance
instance.add_tag(instance_tag)

instance_id = reservation.id
instance_obj = None
all_instances = ec2.get_all_instances()
for r in all_instances:
    if r.id == instance_id:
        instance_obj = r.instances[0]
        break

if not instance_obj:
    print >> sys.stderr, "ec2launch: ERROR Unable to find launched instance: %s" % status
    sys.exit(1)

public_dns_name = instance_obj.public_dns_name
public_ip_addr = instance_obj.ip_address

if hosted_zone:
    # Create new CNAME entry pointing to instance public DNS
    ec2common.cname(route53conn, hosted_zone, instance_tag, public_dns_name)
    ##if options.wildcard:
    ##    ec2common.cname(route53conn, hosted_zone, "*."+instance_tag, public_dns_name)

print >> sys.stderr, "Created EC2 instance %s: id=%s, ip=%s, domain=%s" % (instance_tag, instance_id, public_ip_addr, public_dns_name)
print >> sys.stderr, "To check setup progress, type:"
print >> sys.stderr, "   ec2ssh ubuntu@"+(instance_tag if instance_domain else public_ip_addr)+" sudo tail -f /root/ec2launch.log"

ssh_cmd_args = ["ssh", "-i", os.path.expanduser("~/.ssh/ec2key.pem"),
                "-o", "StrictHostKeyChecking=no", gterm_user+"@"+public_dns_name]

if nameservers:
    print >> sys.stderr, "***NOTE*** Please add the following nameservers for domain "+instance_domain
    print >> sys.stderr, "           ", nameservers
    print >> sys.stderr, ""

print >> sys.stderr, "Waiting for remote setup to complete..."
time.sleep(45)

if copy_files:
    scp_cmd_args = ["scp", "-i", os.path.expanduser("~/.ssh/ec2key.pem"), "-o", "StrictHostKeyChecking=no"] + copy_files + [gterm_user+"@"+public_dns_name+":"]
    print >> sys.stderr, " ".join(scp_cmd_args)
    out, err = gterm.command_output(scp_cmd_args, timeout=60)
    if out:
        print out
    if err:
        print err

print >> sys.stderr, " ".join(ssh_cmd_args)
print >> sys.stderr, wait_cmd
out, err = gterm.command_output(ssh_cmd_args + [wait_cmd], timeout=600)

if out:
    print out
if err:
    print >> sys.stderr, err
else:
    print >> sys.stderr, "Remote setup completed."
    
