#!/usr/bin/env python3
"""cPing concurrently checks if hosts are responding using ICMP echo or TCP SYN."""

# pylint: disable=broad-except  # Don't care about which exception; cleaning up

import argparse
import collections
import queue
import re
import shutil
import signal
import socket
import statistics
import subprocess
import sys
import threading
import time

INTERVAL_MIN = 0.1  # Lower bound on test interval
RESULTS_DEQUE_MIN = 5  # Lower bound on the number of results for a host


class Host:
    """Represents a host and its test results."""
    def __init__(self, host, port, interval):
        self.host = host
        self.port = port
        self.interval = interval
        self.lock = threading.Lock()
        self.results = collections.deque(maxlen=RESULTS_DEQUE_MIN)
        self.status = None

    def add_result(self, result):
        """Add a result to the deque making sure to lock it before hand."""
        self.lock.acquire()
        self.results.append(result)
        self.lock.release()

    def change_deque_size(self, new_size):
        """Change the deque size to be `new_size`."""
        new_size = max(new_size, RESULTS_DEQUE_MIN)

        if self.results.maxlen == new_size:  # Already at new size
            return

        new_results = collections.deque(maxlen=new_size)
        self.lock.acquire()

        for _ in range(new_size):
            try:
                new_results.appendleft(self.results.pop())
            except IndexError:
                break

        self.results = new_results
        self.lock.release()

    def check_icmp(self, shutdown):
        """Run ping and add the reply times to the results buffer."""
        def read_fd(stdout, output):
            for line in iter(stdout.readline, b''):
                output.put(line)
            stdout.close()

        command = ['ping', self.host, '-i{}'.format(self.interval)]
        output = queue.Queue()

        ping = subprocess.Popen(command,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.STDOUT)
        threading.Thread(target=read_fd, args=(ping.stdout, output)).start()

        while not shutdown.wait(self.interval):  # Run until shutdown signaled
            reply = False

            while not output.empty():
                line = output.get(False).decode()

                if line.startswith('ping:') and ping.poll() is not None:
                    raise Exception(line[6:].strip())

                match = re.search('bytes from.+time=([.0-9]+) ms$', line)

                # Reply within interval range (i.e. not retrospective)
                if match and float(match.group(1)) < self.interval * 1000:
                    reply = True
                    self.add_result(float(match.group(1)))
                    break  # Don't process more than one result per loop

            if not reply:  # Didn't find a reply within the interval
                self.add_result(-1)

            if ping.poll() is not None:  # Process is not running
                self.status = 'Ping process died'
                break

        ping.send_signal(signal.SIGINT)

    def check_tcp(self, shutdown):
        """Run TCP SYN tests and add the reply times to the results buffer."""
        timeout = self.interval

        while True:
            try:
                sock_addr = socket.getaddrinfo(self.host, self.port, 0, 0,
                                               socket.IPPROTO_TCP)
                sock = socket.socket(sock_addr[0][0], sock_addr[0][1])
                sock.settimeout(self.interval)

                time_start = time.time()
                sock.connect(sock_addr[0][4])
                self.add_result((time.time() - time_start) * 1000)
                timeout = self.interval  # Suceeded; wait on next iteration
            except socket.timeout:
                self.add_result(-1)
                timeout = 0  # Expired; no need to wait on next iteration
            finally:
                if 'sock' in locals():
                    sock.close()

            if shutdown.wait(timeout):  # Sleep until signaled or exipred
                break  # Shutdown signaled

    def get_status(self, host_padding=64, output_width=None, stats_padding=8):
        """Return a string that represents the status of the host."""
        status = (self.host).ljust(host_padding) + '    '

        if self.status:  # Status already set (maybe an error)
            return status + self.status

        if output_width:  # Update deque len
            new_size = output_width - len(status) - stats_padding * 5
            self.change_deque_size(new_size)

        # Prevent mutations to deque during iteration
        self.lock.acquire()

        # Remove downed pings
        results = list(filter(lambda x: x != -1, self.results))

        # Stats; successful pings only
        if results:
            status += '{:.3f}'.format(min(results)).ljust(stats_padding)

            if len(results) > 1:
                for stat in statistics.mean, max, statistics.stdev:
                    status += '{:.3f}'.format(
                        stat(results)).ljust(stats_padding)

                hits = [x for x in self.results if x != -1]
                success_rate = int(len(hits) * 100 / len(self.results))
                success_rate = '{}%'.format(success_rate).rjust(4)
                status += success_rate.ljust(stats_padding)
            else:
                status += '  -'.ljust(stats_padding) * 4
        else:
            status += '  -'.ljust(stats_padding) * 5

        # History; includes failed pings
        for result in self.results:
            if result == -1:
                status += color('.', False)
            else:
                status += color('!')

        self.lock.release()

        return status

    def test(self, delay, shutdown):
        """Start the continous tests (blocking operation)."""
        time.sleep(delay)

        try:
            if self.port == -1:  # ICMP
                self.check_icmp(shutdown)
            else:  # TCP
                self.check_tcp(shutdown)
        except socket.gaierror:
            self.status = 'Host resolution failed'
        except Exception as exception:
            self.status = str(exception)


def args_init():
    """Initialzes arguments and returns the output of `parse_args`."""
    parser = argparse.ArgumentParser()

    parser.add_argument('host',
                        type=str,
                        nargs='+',
                        help='one or more hosts to ping')
    parser.add_argument('-i',
                        metavar='sec',
                        type=float,
                        help='ping interval (default: %(default)s)',
                        default=1)
    parser.add_argument('-p',
                        metavar='port',
                        type=int,
                        help='test using TCP SYN (default: ICMP echo)',
                        default=-1)

    args = parser.parse_args()
    args.i = max(INTERVAL_MIN, args.i)

    if args.p != -1 and not 0 < args.p < 65536:
        parser.error('Port outside of range 1-65535')

    return args


def color(string, success=True):
    """Colors the string green if its a success, or red if its not."""
    return '\033[' + ('32m' if success else '31m') + string + '\033[0m'


def print_summary(hosts, full=False):
    """Print a summary of the hosts' results. `full` ignores term lines limit."""
    host_padding = max([len(x.host) for x in hosts])
    term_size = shutil.get_terminal_size((64, 20))
    output = []

    for host in hosts:
        output.append(host.get_status(host_padding, term_size.columns))

        if not full and len(output) >= term_size.lines - 1:  # Limit reached
            break

    print('\033[K\n'.join(output) + '\033[J')  # Clear any ghosting

    if not full and len(hosts) > term_size.lines - 1:  # Add overflow
        overflow = len(hosts) - (term_size.lines - 1)
        print('+{} more'.format(overflow), end='', flush=True)


def main():
    """Main entry point."""
    args = args_init()
    hosts = [Host(x, args.p, args.i) for x in args.host]
    shutdown = threading.Event()
    threads = []

    for host, host_id in zip(hosts, range(len(hosts))):
        # Spread test start up across one interval
        delay = args.i / len(hosts) * host_id
        thread = threading.Thread(target=host.test, args=(delay, shutdown))
        thread.start()
        threads.append(thread)

    try:
        try:  # Need to run clean up code before handling exception
            print('\033[?1049h', end='')  # Enable alternate screen buffer

            while [x for x in threads if x.is_alive()]:  # Someone is alive
                print_summary(hosts)
                time.sleep(args.i)
                print('\033[H', end='')  # Move to 1;1
        finally:
            shutdown.set()  # Signal host threads to quit
            print('\033[?1049l', end='')  # Disable alternate screen buffer
            print_summary(hosts, full=True)  # Print the last summary
    except KeyboardInterrupt:
        pass  # Break out of the application


if __name__ == '__main__':
    sys.exit(main())
