#!/usr/bin/env python

import json
import sys
import time
import traceback

from argparse import ArgumentParser
from collections import namedtuple
from urlparse import urljoin

import requests

discotimeout = 4 # In seconds.
tokenkey = 'server-token'
# NB: Version is duplicated in setup.py.
useragent = 'otpl-service-check/1.0.5'

Result = namedtuple('Result', ('code', 'message'))

# Health check endpoint response parsers.  We have different parsers for
# different content types.  "Entry point" is Parser.parse.

class Parser(object):
  parsers = {} # lower-case content type prefix -> parser class
  @classmethod
  def parse(cls, contenttype, text):
    contenttype = contenttype.lower()
    for prefix, parsercls in cls.parsers.iteritems():
      if contenttype.startswith(prefix):
        return parsercls().parse(text)
    return DefaultParser().parse(text)

class LimitedParser(Parser):
  def __init__(self, limit=128): self.limit = limit
  def parse(self, text):
    if len(text) > self.limit:
      text = text[:self.limit] + '...'
    return text
Parser.parsers['text/plain'] = LimitedParser

class HtmlParser(Parser):
  def parse(self, text): return 'html response elided'
Parser.parsers['text/html'] = HtmlParser

class JsonParser(Parser):
  limit = 1024
  def parse(self, text):
    try: data = json.loads(text)
    except ValueError: return DefaultParser().parse(text)
    pretty = json.dumps(data, indent=2)
    return LimitedParser(self.limit).parse(pretty)
Parser.parsers['application/json'] = JsonParser

DefaultParser = LimitedParser

class Main(object):
  # Parse arguments.
  def __init__(self):
    self.parser = ArgumentParser(
      description='Check Discovery service for health.')

    # These first two are actually required.  See below.
    self.parser.add_argument('-d', '--discovery', default=None,
      help='discovery server URL')
    self.parser.add_argument('-s', '--service', default=None,
      help='service name to check')

    self.parser.add_argument('-e', '--endpoint', default='health',
      help='endpoint to check; default %(default)r')
    self.parser.add_argument('-t', '--timeout', type=float, default=5,
      help='endpoint check timeout in seconds; default %(default)s')
    self.parser.add_argument('-c', '--critical-fewer', type=int, default=1,
      help='minimum instances before critical; default %(default)s; '
      'set to 0 to disable')
    self.parser.add_argument('-w', '--warn-fewer', type=int, default=1,
      help='minimum instances before warning; default %(default)s; '
      'set to 0 to disable')
    args = self.parser.parse_args()

    # We do this manually here since the argparse default is to exit
    # with code 2.  See parser_error.
    if args.discovery is None:
      self.parser_error('argument -d/--discovery is required')
    if args.service is None:
      self.parser_error('argument -s/--service is required')

    if args.timeout <= 0:
      self.parser_error('timeout must be positive')
    if args.critical_fewer < 0:
      self.parser_error('critical-fewer must be non-negative')
    if args.warn_fewer < 0:
      self.parser_error('warn-fewer must be non-negative')
    if args.warn_fewer < args.critical_fewer:
      self.parser_error(
        'warn-fewer must be at least as large as critical-fewer')

    self.args = args
    
    # track output we've already seen and remove dupes
    self.response_data_seen = set()

  def parser_error(self, message):
    # Code 3 is "UNKNOWN".  (argparse default is 2, which would be
    # "CRITICAL"--inappropriate.)
    self.parser.print_usage()
    self.parser.exit(3, '%s: error: %s\n' % (self.parser.prog, message))

  def requestsget(self, url, timeout):
      return requests.get(url, timeout=timeout,
          headers={'User-Agent':useragent})

  def get_announcements(self):
    url = urljoin(self.args.discovery, 'state')
    state = self.requestsget(url, discotimeout).json()
    return [a for a in state if a['serviceType'] == self.args.service]

  @staticmethod
  def count_announcements(announcements):
    seen = set()
    count = 0
    for ann in announcements:
      metadata = ann.get('metadata', {})
      if tokenkey not in metadata:
        # No token; this is ok.
        count += 1
        continue
      token = metadata[tokenkey]
      if token not in seen:
        seen.add(token)
        count += 1
    return count

  @staticmethod
  def make_result(code, topic, message):
    state = {2: 'critical', 1: 'warning', 0: 'ok'}[code]
    return Result(code, '%s %s: %s' % (topic, state, message))

  @classmethod
  def make_response_with_uri(cls, code, topic, uri, message):
    msg = '%s\ncheck URI %s' % (message, uri)
    return cls.make_result(code, topic, msg)

  def make_announcement_result(self, code, count):
    msg = '%s\ncrit./warn thresh.: %s/%s' % (
      count, self.args.critical_fewer, self.args.warn_fewer)
    return self.make_result(code, 'announcements', msg)

  def make_response_result(
      self, code, uri, status_code, duration, contenttype, text):
    msg = '%s from endpoint' % status_code
    if code != 0:
      if text in self.response_data_seen:
        # extra leading space on next line is important so it sorts after real results
        return self.make_result(code, 'health', " <duplicate '%s'>" % uri)
      msg += '\n' + Parser.parse(contenttype, text)
      self.response_data_seen.add(text)
    msg += '\nduration %.3fs' % duration
    return self.make_response_with_uri(code, 'health', uri, msg)

  def make_timeout_result(self, uri, type):
    return self.make_response_with_uri(
      2, '%s timeout' % type, uri, 'thresh. %.3f' % self.args.timeout)

  def check_endpoint(self, ann):
    serviceuri = ann['serviceUri']
    uri = urljoin(serviceuri, self.args.endpoint)
    start = time.time()
    try:
      resp = self.requestsget(uri, self.args.timeout)
      duration = time.time() - start
      statuscode = resp.status_code
      code = statuscode // 100
      contenttype = resp.headers['content-type']
      text = resp.text
      if code == 2:
        return self.make_response_result(
          0, uri, statuscode, duration, contenttype, text)
      if code == 4:
        return self.make_response_result(
          1, uri, statuscode, duration, contenttype, text)
      if code == 5:
        return self.make_response_result(
          2, uri, statuscode, duration, contenttype, text)
      raise Exception('unexpected status code', statuscode)
    except requests.exceptions.ConnectTimeout:
      return self.make_timeout_result(uri, 'connect')
    except requests.exceptions.ReadTimeout:
      return self.make_timeout_result(uri, 'read')
    except requests.exceptions.RequestException:
      return self.make_response_with_uri(2, 'health',
        uri, 'unhandled exception\n' + traceback.format_exc())

  def run(self):
    try:
      announcements = self.get_announcements()
    except Exception:
      print 'failed to get announcements'
      print traceback.format_exc()
      return 3

    # Will contain Result instances.
    results = []

    count = self.count_announcements(announcements)
    if count < self.args.critical_fewer:
      results.append(self.make_announcement_result(2, count))
    elif count < self.args.warn_fewer:
      results.append(self.make_announcement_result(1, count))
    else:
      results.append(self.make_announcement_result(0, count))

    for ann in announcements:
      res = self.check_endpoint(ann)
      if res is not None: results.append(res)

    # Print worst results first and return with worst code.
    results.sort(reverse=True)
    for res in results:
      print res.message
      print '---'
    return results[0].code

try:
  sys.exit(Main().run())
except Exception:
  print 'unhandled exception'
  print traceback.format_exc()
  sys.exit(3)
