#!python
# tool to automate various AWS commands
import datetime as dt
import os
import shlex
import subprocess
import sys
import time
from typing import Dict

import pytz

import ncluster

from ncluster import aws_util as u
from ncluster import util
from ncluster.aws_backend import INSTANCE_INFO
from boto3_type_annotations.ec2 import Volume
from boto3_type_annotations.ec2 import Image

VERBOSE = False


def _run_shell(user_cmd):
  """Runs shell command, returns list of outputted lines
with newlines stripped"""
  #  print(cmd)
  p = subprocess.Popen(user_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
  (stdout, _) = p.communicate()
  stdout = stdout.decode('ascii') if stdout else ''
  lines = stdout.split('\n')
  stripped_lines = []
  for l in lines:
    stripped_line = l.strip()
    if l:
      stripped_lines.append(stripped_line)
  return stripped_lines


def _check_instance_found(instances, fragment, states=()):
  if not instances:
    if states:
      print(f"Couldn't find instances in state {states} matching '{fragment}' for key {u.get_keypair_name()}")
    else:
      print(f"Couldn't find instances matching '{fragment}' for key {u.get_keypair_name()}")
    return False
  return True


def vprint(*args):
  if VERBOSE:
    print(*args)


def toseconds(dt_):
  """Converts datetime object to seconds."""
  return time.mktime(dt_.utctimetuple())


def ls(fragment=''):
  """List running instances"""
  print(f"https://console.aws.amazon.com/ec2/v2/home?region={u.get_region()}")

  stopped_instances = u.lookup_instances(fragment, states=['stopped'])
  stopped_names = list(u.get_name(i) for i in stopped_instances)
  if stopped_names:
    print("ignored stopped instances: ", ", ".join(stopped_names))

  instances = u.lookup_instances(fragment)
  print('-' * 80)
  print(
    f"{'name':18s} {'hours_live':>10s} {'cost_in_$':>10s} {'instance_type':>15s} {'public_ip':>15s} "
    f"{'key/owner':>15s} {'private ip':>15s}")
  print('-' * 80)
  for instance in instances[::-1]:
    # current time in UTC zone (default AWS)
    now_time = dt.datetime.utcnow().replace(tzinfo=pytz.utc)
    launch_time = instance.launch_time
    elapsed_sec = toseconds(now_time) - toseconds(launch_time)
    elapsed_hours = elapsed_sec / 3600
    instance_type = instance.instance_type
    if instance_type in INSTANCE_INFO:
      cost = INSTANCE_INFO[instance_type]['cost'] * elapsed_hours
    else:
      cost = -1
    key_name = str(instance.key_name)  # could be None
    print(f"{u.get_name(instance):18s} {elapsed_sec / 3600:10.1f} {cost:10.0f} {instance_type[:5]:>15s} "
          f"{instance.public_ip_address:>15s} {key_name[9:]:>15s} {instance.private_ip_address:>15s} "
          f"{instance.placement_group.name} ")

  # list spot requests, ignore active ones since they show up already
  client = u.get_ec2_client()
  spot_requests = []
  for request in client.describe_spot_instance_requests()['SpotInstanceRequests']:
    state = request['State']
    # TODO(y) also ignore state == 'fulfilled'?
    if state == 'cancelled' or state == 'closed' or state == 'active':
      continue

    launch_spec = request['LaunchSpecification']
    spot_requests.append(launch_spec['InstanceType'])
  if spot_requests:
    print(f"Pending spot instances: {','.join(spot_requests)}")
  #   client.cancel_spot_instance_requests(SpotInstanceRequestIds=[request['SpotInstanceRequestId']])


def etchosts(_):
  """Copy/pastable /etc/hosts file"""
  instances = u.lookup_instances()
  instance_tuples = [(u.get_name(i), i.public_ip_address) for i in instances]
  print('-' * 80)
  print("paste following into your /etc/hosts")
  print('-' * 80)
  for name, ip in sorted(instance_tuples):
    print(f"{ip} {name}")

  print("""\n127.0.0.1    localhost
255.255.255.255    broadcasthost
::1             localhost""")


def _user_keypair_check(instance):
  launching_user = instance.key_name[len(u.get_prefix()) + 1:]
  current_user = os.environ['USER']
  assert launching_user == current_user, f"Set USER={launching_user} to connect to this machine, and make sure their " \
                                         f".pem file is in your ~/.ncluster"


def ssh(fragment=''):
  """SSH into the instace with the given prefix."""
  instances = u.lookup_instances(fragment)
  if not _check_instance_found(instances, fragment):
    return
  instance = instances[0]
  if len(instances) > 1:
    print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)} "
          f"launched by {instance.key_name}")
  else:
    print(f"Connecting to  {u.get_name(instance)} "
          f"launched by {instance.key_name}")

  _user_keypair_check(instance)
  user_cmd = f"ssh -t -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no -o ServerAliveCountMax=1 " \
             f"-o ServerAliveInterval=60 " \
             f"{u.get_aws_username(instance)}@{instance.public_ip_address} "
  print(user_cmd)
  os.system(user_cmd)


def reboot(fragment=''):
  """reboots given instance."""
  instances = u.lookup_instances(fragment)
  if not _check_instance_found(instances, fragment):
    return
  instance = instances[0]
  if len(instances) > 1:
    print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)} "
          f"launched by {instance.key_name}")
  else:
    print(f"Rebooting to  {u.get_name(instance)}  ({instance.id})"
          f"launched by {instance.key_name}")

  _user_keypair_check(instance)
  instance.reboot()


def old_ssh(fragment=''):
  """SSH into the instace with the given prefix. Works on dumb terminals."""
  instances = u.lookup_instances(fragment)
  if not _check_instance_found(instances, fragment):
    return
  instance = instances[0]
  if len(instances) > 1:
    print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)} "
          f"launched by {instance.key_name}")
  else:
    print(f"Connecting to  {u.get_name(instance)} "
          f"launched by {instance.key_name}")

  _user_keypair_check(instance)
  user_cmd = f"ssh -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no -o ConnectTimeout=10  " \
             f"-o ServerAliveCountMax=1 " \
             f"-o ServerAliveInterval=60 " \
             f"{u.get_aws_username(instance)}@{instance.public_ip_address}"
  print(user_cmd)
  os.system(user_cmd)


def connect(fragment=''):
  """SSH into the instance using authorized keys mechanism."""
  instances = u.lookup_instances(fragment)
  if not _check_instance_found(instances, fragment):
    return
  instance = instances[0]
  if len(instances) > 1:
    print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)} "
          f"launched by {instance.key_name}")
  else:
    print(f"Connecting to  {u.get_name(instance)} "
          f"launched by {instance.key_name}")

  ssh_cmd = f"ssh -t -o StrictHostKeyChecking=no -o ConnectTimeout=10  " \
            f"-o ServerAliveCountMax=1 " \
            f"-o ServerAliveInterval=60 " \
            f"{u.get_aws_username(instance)}@{instance.public_ip_address} "
  connect_cmd = ssh_cmd
  DO_TMUX = False
  if 'INSIDE_EMACS' in os.environ:
    print("detected Emacs, skipping tmux attach")
  elif os.environ.get('TERM', 'dumb') == 'dumb':
    print("Dumb terminal, doesn't support tmux, skipping tmux attach")
  elif 'NO_TMUX' in os.environ:
    print("detected NO_TMUX, skipping tmux attach")
  else:
    DO_TMUX = True
    connect_cmd += " tmux a"

  print(connect_cmd)
  exit_code = os.system(connect_cmd)

  if exit_code != 0 and DO_TMUX:
    fix_cmd = ssh_cmd + " tmux new"
    print(f"Creating ssh tmux a returned {exit_code}, recreate tmux using '{fix_cmd}'")
    return
    exit_code = os.system(fix_cmd)
    if exit_code != 0:
      print(f"command {fix_cmd} failed with code '{exit_code}'")
  else:
    print(f"cmd {connect_cmd} returned {exit_code}")


def connectm(fragment=''):
  """Like connect, but uses mosh"""
  instances = u.lookup_instances(fragment)
  if not _check_instance_found(instances, fragment):
    return
  instance = instances[0]
  if len(instances) > 1:
    print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)} "
          f"launched by {instance.key_name}")
  else:
    print(f"Connecting to  {u.get_name(instance)} "
          f"launched by {instance.key_name}")

  user_cmd = f"mosh --ssh='ssh -o StrictHostKeyChecking=no -o ConnectTimeout=10  " \
             f"-o ServerAliveCountMax=1 " \
             f"-o ServerAliveInterval=60' " \
             f"{u.get_aws_username(instance)}@{instance.public_ip_address}"
  print(user_cmd)
  os.system(user_cmd)


def kill(fragment: str = '', stop_instead_of_kill: bool = False):
  """

  Args:
    fragment:
    stop_instead_of_kill: use stop_instances instead of terminate_instances
  """

  if stop_instead_of_kill:
    states = ['running']
  else:
    states = ['running', 'stopped']
  instances = u.lookup_instances(fragment, states=states, limit_to_current_user=False)
  instances_to_kill = []
  instances_to_skip = []
  users_to_skip = set()
  instances_to_kill_formatted = []
  for i in instances:
    state = i.state['Name']

    if not LIMIT_TO_CURRENT_USER or i.key_name == u.get_keypair_name():
      instances_to_kill_formatted.append(("  ", u.get_name(i), i.instance_type, i.key_name, state if state == 'stopped' else ''))
      instances_to_kill.append(i)
    else:
      instances_to_skip.append(u.get_name(i))
      users_to_skip.add(i.key_name[9:])

  if stop_instead_of_kill:
    action = 'stopping'
    override_action = 'reallystop'
  else:
    action = 'terminating'
    override_action = 'reallykill'

  if instances_to_skip:
    print(f"Skipping {','.join(instances_to_skip)} launched by ({', '.join(users_to_skip)}), override with {override_action}")
  if not _check_instance_found(instances_to_kill, fragment, states):
    return

  print(f"{action}:")
  for line in instances_to_kill_formatted:
    print(*line)

  ec2_client = u.get_ec2_client()
  # don't ask for confirmation when stopping, erronous stopping has milder consequences
  num_instances = len(instances_to_kill)
  #  if util.is_set("NCLUSTER_SKIP_CONFIRMATION") or stop_instead_of_kill:
  #  print("NCLUSTER_SKIP_CONFIRMATION is set or stop_instead_of_kill, skipping confirmation")
  answer = input(f"{num_instances} instances found, {action} in {u.get_region()}? (y/N) ")

  if answer.lower() == "y":
    instance_ids = [i.id for i in instances_to_kill]

    if stop_instead_of_kill:
      response = ec2_client.stop_instances(InstanceIds=instance_ids)
    else:
      response = ec2_client.terminate_instances(InstanceIds=instance_ids)

    assert u.is_good_response(response), response
    print(f"{action} {num_instances} instances: success")
  else:
    print("Didn't get y, doing nothing")


def stop(fragment=''):
  kill(fragment, stop_instead_of_kill=True)


LIMIT_TO_CURRENT_USER = True


def reallykill(*args, **kwargs):
  """Kill instances, including ones launched by other users."""
  global LIMIT_TO_CURRENT_USER
  LIMIT_TO_CURRENT_USER = False
  kill(*args, **kwargs)
  LIMIT_TO_CURRENT_USER = True


def reallystop(*args, **kwargs):
  """Stop instances, including ones launched by other users."""
  global LIMIT_TO_CURRENT_USER
  LIMIT_TO_CURRENT_USER = False
  stop(*args, **kwargs)
  LIMIT_TO_CURRENT_USER = True


def start(fragment=''):
  instances = u.lookup_instances(fragment, states=['stopped'])
  for i in instances:
    print(u.get_name(i), i.instance_type, i.key_name)

  if not instances:
    print("no stopped instances found, quitting")
    return

  #  answer = input(f"{len(instances)} instances found, start in {u.get_region()}? (y/N) ")
  answer = 'y'

  if answer.lower() == "y":
    for i in instances:
      print(f"starting {u.get_name(i)}")
      i.start()
  else:
    print("Didn't get y, doing nothing")
    return

  print("Warning, need to manually mount efs on instance: ")
  print_efs_mount_command()


def mosh(fragment=''):
  instances = u.lookup_instances(fragment)
  if not _check_instance_found(instances, fragment):
    return
  instance = instances[0]
  print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)}")
  _user_keypair_check(instance)

  user_cmd = f"mosh --ssh='ssh -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no' " \
             f"{u.get_aws_username(instance)}@{instance.public_ip_address} tmux attach"
  print(user_cmd)
  os.system(user_cmd)


def print_efs_mount_command():
  print(u.get_efs_mount_command())


def efs(_):
  print("EFS information. To upload to remote EFS use 'ncluster efs_sync'")
  print_efs_mount_command()
  print()
  print()

  efs_client = u.get_efs_client()
  response = efs_client.describe_file_systems()
  assert u.is_good_response(response), response

  for efs_response in response['FileSystems']:
    #  {'CreationTime': datetime.datetime(2017, 12, 19, 10, 3, 44, tzinfo=tzlocal()),
    # 'CreationToken': '1513706624330134',
    # 'Encrypted': False,
    # 'FileSystemId': 'fs-0f95ab46',
    # 'LifeCycleState': 'available',
    # 'Name': 'nexus01',
    # 'NumberOfMountTargets': 0,
    # 'OwnerId': '316880547378',
    # 'PerformanceMode': 'generalPurpose',
    # 'SizeInBytes': {'Value': 6144}},
    efs_id = efs_response['FileSystemId']
    tags_response = efs_client.describe_tags(FileSystemId=efs_id)
    assert u.is_good_response(tags_response)
    key = u.get_name(tags_response.get('Tags', ''))
    print("%-16s %-16s" % (efs_id, key))
    print('-' * 40)

    # list mount points
    response = efs_client.describe_mount_targets(FileSystemId=efs_id)
    ec2 = u.get_ec2_resource()
    if not response['MountTargets']:
      print("<no mount targets>")
    else:
      for mount_response in response['MountTargets']:
        subnet = ec2.Subnet(mount_response['SubnetId'])
        zone = subnet.availability_zone
        state = mount_response['LifeCycleState']
        id_ = mount_response['MountTargetId']
        ip = mount_response['IpAddress']
        print('%-16s %-16s %-16s %-16s' % (zone, ip, id_, state,))


def terminate_tmux(_):
  """Script to clean-up tmux sessions."""

  for line in _run_shell('tmux ls'):
    session_name = line.split(':', 1)[0]

    if session_name == 'tensorboard' or session_name == 'jupyter' or session_name == 'dropbox':
      print("Skipping " + session_name)
      continue
    print("Killing " + session_name)
    _run_shell('tmux kill-session -t ' + session_name)


def nano(*_unused_args):
  """Bring up t2.nano instance."""
  ncluster.make_task(name='shell',
                     instance_type='t2.nano')


def cmd(user_cmd):
  """Finds most recent instance launched by user, runs commands there, pipes output to stdout"""

  instances = u.lookup_instances(limit_to_current_user=True)
  assert instances, f"{u.get_username()} doesn't have an instances to connect to. Use 'ncluster nano'" \
                    f" to bring up a small instance."
  instance = instances[0]
  user_cmd = f"ssh -t -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no " \
             f"{u.get_aws_username(instance)}@{instance.public_ip_address} {user_cmd}"
  os.system(user_cmd)


def cat(user_cmd): cmd('cat ' + user_cmd)


def ls_(user_cmd): cmd('ls ' + user_cmd)


def cleanup_placement_groups(*_args):
  print("Deleting all placement groups")
  # TODO(y): don't delete groups that have currently stopped instances
  client = u.get_ec2_client()
  for group in client.describe_placement_groups().get('PlacementGroups', []):
    name = group['GroupName']
    sys.stdout.write(f"Deleting {name} ... ")
    sys.stdout.flush()
    try: 
      client.delete_placement_group(GroupName=name)
      print("success")
    except Exception as _:
      print("failed")


# ncluster launch --image_name=dlami23-efa --instance_type=c5.large --name=test
def launch(args_str):
  import argparse
  parser = argparse.ArgumentParser()
  parser.add_argument('--name', type=str, default='ncluster_launch', help="instance name")
  # parser.add_argument('--image_name', type=str, default='')  # default small image
  parser.add_argument('--image_name', type=str, default='Deep Learning AMI (Ubuntu) Version 23.0')
  # can also use --image_name='Deep Learning AMI (Amazon Linux) Version 23.0' # cybertronai01
  parser.add_argument('--instance_type', type=str, default='c5.large', help="type of instance")
  parser.add_argument('--disk_size', type=int, default=0, help="size of disk in GBs. If 0, use default size for the image")
  args = parser.parse_args(shlex.split(args_str))

  return ncluster.make_task(**vars(args))


def fix_default_security_group(_):
  """Allows ncluster and ncluster_nd security groups to exchange traffic with each other."""

  def peer(current, other):
    """allow current group to accept all traffic from other group"""

    groups = u.get_security_group_dict()
    current_group = groups[current]
    other_group = groups[other]
    response = {}
    for protocol in ['icmp']:
      try:
        rule = {'FromPort': -1,
                'IpProtocol': protocol,
                'IpRanges': [],
                'PrefixListIds': [],
                'ToPort': -1,
                'UserIdGroupPairs': [{'GroupId': other_group.id}]}
        response = current_group.authorize_ingress(IpPermissions=[rule])

      except Exception as e:
        if response['Error']['Code'] == 'InvalidPermission.Duplicate':
          print("Warning, got " + str(e))
        else:
          assert False, "Failed while authorizing ingress with " + str(e)

    for protocol in ['tcp', 'udp']:
      try:
        rule = {'FromPort': 0,
                'IpProtocol': protocol,
                'IpRanges': [],
                'PrefixListIds': [],
                'ToPort': 65535,
                'UserIdGroupPairs': [{'GroupId': other_group.id}]}
        response = current_group.authorize_ingress(IpPermissions=[rule])
      except Exception as e:
        if response['Error']['Code'] == 'InvalidPermission.Duplicate':
          print("Warning, got " + str(e))
        else:
          assert False, "Failed while authorizing ingress with " + str(e)

  group1 = u.get_security_group_name()
  group2 = u.get_security_group_nd_name()

  peer(group1, group2)
  peer(group2, group1)


def keys(_):
  """runs ssh-keygen if necessary, prints public key."""
  key = util.get_public_key()
  print("Your public key is below. Append all of your your team-members  public keys to NCLUSTER_AUTHORIZED_KEYS env var separated by ; ie \nNCLUSTER_AUTHORIZED_KEYS=<key1>;<key2>;<key3>\n")
  print(key)


def lookup_image(image_id: str) -> Image:
  """Looks up image from image id like 'ami-0cc96feef8c6bbff3', prints image.name"""
  # could use ec2.images.filter(ImageIds=['ami-0cc96feef8c6bbff3']
  assert image_id.startswith('ami-')
  ec2 = u.get_ec2_resource()
  images = list(ec2.images.filter(ImageIds=[image_id]))
  assert images, f"No images found with id={image_id}"
  assert len(images) == 1, f"Multiple images found with id={image_id}: {','.join(i.name for i in images)}"
  image = [im for im in images if im.id == image_id][0]
  print(image.name)
  return image


def grow_disks(fragment: str, target_size_gb=500):
    """Grows main disk for given machine to 500GB"""

    instance = u.lookup_instance(fragment)
    client = u.get_ec2_client()

    volumes = list(instance.volumes.all())
    for vol in volumes:
      volume: Volume = vol
      if volume.size < target_size_gb:
        print("Growing %s to %s" % (volume.id, target_size_gb))
        response = client.modify_volume(VolumeId=volume.id, Size=target_size_gb)
        assert u.is_good_response(response)
      else:
        print(f"Volume {volume.id} is already {volume.size} GB's, skipping")


def disks(fragment):
  """Print disk information for instance."""
  instances = u.lookup_instances(fragment, states=('running', 'stopped'))
  print(f"{'device':>10s} {'size':>10s} {'type':>10s}({'iops'}) {'id':>30s} ")
  print("-" * 50)
  for instance in instances:
    print()
    print(f"Disks on instance '{u.get_name(instance)}' ({instance.id}, {instance.placement['AvailabilityZone']})")
    for volume in instance.volumes.all():
      device = volume.attachments[0]['Device']
      print(f"{device:>10s} {volume.size:>10d} {volume.volume_type:>10s}({volume.iops}) {volume.id:>30s} {u.get_name(volume)}")
  print("Unattached disks")

  ec2 = u.get_ec2_resource()
  for volume in ec2.volumes.all():
    if not volume.attachments:
      print(f"{u.get_name(volume)} {volume.size:>10d} {volume.volume_type:>10s} {volume.id:>30s}")


def fixkeys(_):
  key_name = u.get_keypair_name()
  pairs = u.get_keypair_dict()
  if key_name not in pairs:
    print(f"Default keypair {key_name} does not exist, returning")
    return
  keypair = pairs[key_name]

  print(f"Deleting current user keypair {key_name}")
  ec2 = u.get_ec2_resource()
  instance_list = []
  for instance in ec2.instances.all():
    if instance.state == 'terminated':
      continue
    instance_list.append(instance)
  if instance_list:
    print("Warning, after deleting keypair, the following instances will be no longer accessible:")
    for i in instance_list:
      print(u.get_name(i), i.id)
    answer = input("Proceed? (y/N) ")
  else:
    answer = "y"
  if answer.lower() == 'y':
    keypair_fn = u.get_keypair_fn()
    if os.path.exists(keypair_fn):
      print(f"Deleting local .pem file '{keypair_fn}'")
      os.system(f'sudo rm -f {keypair_fn}')
    print(f"Deleting AWS keypair '{keypair.name}'")
    keypair.delete()


def efs_sync(_):
  """Starts a daemon to sync local /ncluster/sync with remote ncluster sync."""

  print("Syncing local /ncluster/sync to remote /ncluster/sync")
  if not os.path.exists('/ncluster/sync'):
    print("Local /ncluster/sync doesn't exist, creating")
    if not os.path.exists('/ncluster'):
      os.system('sudo mkdir /ncluster')
      os.system('chown `whoami` /ncluster')
    os.system('mkdir /ncluster/sync')

  instances = u.lookup_instances(limit_to_current_user=True)
  if instances:
    instance = instances[0]
    print(f"Found {len(instances)} instances owned by {u.get_username()}, using {u.get_name(instance)} for syncing")
  else:
    print(f"Found no instances by {u.get_username()}, Launching t2.nano instance to do the sync.")
    task = ncluster.make_task(name='shell', instance_type='t2.nano')
    instance = task.instance

  os.system(f'cd /ncluster/sync && nsync -m {u.get_name(instance)} -d /ncluster/sync')


def spot_prices(instance_type: str) -> Dict[str, float]:
  """

  Print spot instance pricing.
  Args:
    instance_type: AWS instance name. Common names can be shortcuts (p2, p3, c5)

  Returns:
    dictionary of zone->price for given instance
  """
  if instance_type == 'p3':
    instance_type = 'p3.16xlarge'
  elif instance_type == 'p3dn':
    instance_type = 'p3dn.24xlarge'
  elif instance_type == 'p2':
    instance_type = 'p2.16xlarge'
  elif instance_type == 'c5':
    instance_type = 'c5.18xlarge'
  elif instance_type == 'c5n':
    instance_type = 'c5n.18xlarge'
  if not instance_type:
    instance_type = 'p3.16xlarge'
  product = 'Linux/UNIX (Amazon VPC)'
  result = {}
  client = u.get_ec2_client()
  print(f"Prices for instance {instance_type} on Linux")
  for zone in [z['ZoneName'] for z in client.describe_availability_zones()['AvailabilityZones'] if z['State'] == 'available']:
    try:
      price = client.describe_spot_price_history(InstanceTypes=[instance_type],
                                                 MaxResults=1,
                                                 ProductDescriptions=[product],
                                                 AvailabilityZone=zone)['SpotPriceHistory'][0]['SpotPrice']
      price = float(price)
      print(f"{zone}: {price:.2f}")
      result[zone] = price
    except IndexError as _:
      pass
  return result


COMMANDS = {
  'ls': ls,
  'ssh': ssh,
  'ssh_': old_ssh,
  'old_ssh': old_ssh,
  'mosh': mosh,
  'kill': kill,
  'reallykill': reallykill,
  'stop': stop,
  'reallystop': reallystop,
  'start': start,
  'efs': efs,
  'cat': cat,
  'ls_': ls_,
  'nano': nano,
  'cmd': cmd,
  '/etc/hosts': etchosts,
  'hosts': etchosts,
  'terminate_tmux': terminate_tmux,
  'cleanup_placement_groups': cleanup_placement_groups,
  'lookup_image': lookup_image,
  'reboot': reboot,
  'launch': launch,
  'fix_default_security_group': fix_default_security_group,
  'keys': keys,
  'connect': connect,
  'connectm': connectm,
  'grow_disks': grow_disks,
  'disks': disks,
  'fixkeys': fixkeys,
  'efs_sync': efs_sync,
  'spot_prices': spot_prices,
}


def main():
  print(f"Region ({u.get_region()}) $USER ({u.get_username()}) account ({u.get_account_number()}:{u.get_account_name()})")
  if len(sys.argv) < 2:
    mode = 'ls'
  else:
    mode = sys.argv[1]
  if mode == 'help':
    for k, v in COMMANDS.items():
      if v.__doc__:
        print(f'{k}\t{v.__doc__}')
      else:
        print(k)
    return
  if mode not in COMMANDS:
    assert False, f"unknown command '{mode}', available commands are {', '.join(str(a) for a in COMMANDS.keys())}"
  COMMANDS[mode](' '.join([shlex.quote(arg) for arg in sys.argv[2:]]))


if __name__ == '__main__':
  main()
