#!/usr/bin/env python
#  Runs in either python2 or python3, but is faster in python2.

# Copyright (c) 2011-2016 Timothy Savannah under GPLv3, All Rights Reserved. See LICENSE for more information
"""
Disttask is a utility which provides the ability to distribute a task across a fixed number of processes, for better utilization of multiprocessing.

Use it with existing single-threaded/process tools and scripts to take full advantage of your computer's resources.

"""

import os
import re
import sys
import select
import signal
import subprocess
import threading
import time

from collections import deque

__version__ = '2.2.1'

__version_tuple__ = (2, 2, 1)

try:
    bytes
except:
    bytes = str # Python < 2.6
    
if bytes == str:
    # Python 2, no additional decoding necessary.
    tostr = str
else:
    # Python 3, additional decoding necessary
    try:
        defaultEncoding = sys.getdefaultencoding()
    except:
        defaultEncoding = 'utf-8'
    
    def tostr(x):
        if isinstance(x, str) is True:
            return x
        if isinstance(x, bytes) is False:
            return str(x)
        return x.decode(defaultEncoding)


# Regex to find %s and %d, and takes in leading character.
#  to match on a string, make sure to prepend with a throw-away character ' '
#  to be able to match on first char.
replaceReg = re.compile('(.[%][sd])')

def replaceItem(cmd, item, pipeNum):
    '''
        replaceItem - Replace %s and %d with the item value and pipe num in the given command.

            This method is complete, in that it supports NOT replacing %%s, but is slower than
            replaceItemQuick. If '%%s' or '%%d' occurs in command, use this method, otherwise use replaceItemQuick

        @param cmd <str> - Command given to disttask to fill-in
        @param item <str> - Current argset item
        @param pipeNum <int/str> - Current pipe number

        @return <str> - substituted command
    '''

    pipeNum = str(pipeNum)

    # Add space so regex works, and start with idx=1 (excluding leading space)
    cmd = ' ' + cmd
    ret = []
    curI = 1

    # For each match, append from command from last iter to match,
    #   then replace item/pipe, and reset iter to after matched region.
    for matchObj in replaceReg.finditer(cmd):

        (matchStart, matchEnd) = matchObj.span()

        # If the leading char was %, then we have matched %%s or %%d,
        #  so replace %s or %d.
        if cmd[matchStart] == '%':
            ret.append( cmd[curI:matchStart] )
            ret.append('%' + cmd[matchEnd-1])
        else:
            # Otherwise, if %s add item, else add pipenum.
            #  And put back the matched first char (matchStart+1)
            ret.append( cmd[curI:matchStart+1] )
            if cmd[matchEnd-1] == 's':
                ret.append(item)
            else:
                ret.append(pipeNum)
        
        curI = matchEnd

    # Get remainder
    ret.append(cmd[curI:])

    return ''.join(ret)

def replaceItemQuick(cmd, item, pipeNum):
    '''
        replaceItemQuick - Replace %s and %d with the item value and pipe num in the given command quickly

            This does NOT support double-percent (%%), and if command contains it, should NOT be used.

        @param cmd <str> - Command given to disttask to fill-in
        @param item <str> - Current argset item
        @param pipeNum <int/str> - Current pipe number

        @return <str> - substituted command
    '''

    return cmd.replace('%s', item).replace('%d', str(pipeNum))

def getReplaceItemMethod(cmd):
    '''
        getReplaceItemMethod - Picks the best "replace item" method for the given command.

        @param cmd <str> - Command with %s/%d for subs

        @return <function> - The replaceItem method to use to fill in cmd
    '''
    if '%%s' in cmd or '%%d' in cmd:
        return replaceItem
    else:
        return replaceItemQuick



class StdoutWriter(threading.Thread):
    '''
        StdoutWriter - The thread which writes data to stdout from the several subprocesses.
    '''

    # FLUSH_EVERY - Explicitly flush after this many items.
    FLUSH_EVERY = 1

    def __init__(self, *args, **kwargs):
        threading.Thread.__init__(self, *args, **kwargs)

        self.stdoutData = deque()

        self.keepGoing = True

    def addData(self, data):
        '''
            addData - Use this method to add an item to print
        '''
        self.stdoutData.append(data)

    def setFlushEvery(self, nWrites):
        self.FLUSH_EVERY = nWrites

    def stopRunning(self):
        '''
            stopRunning - Stop executing "run" at next possible moment.
        '''
        self.keepGoing = False

    def run(self):
        time.sleep(.001) # Block immediatly whilst setup happens
        stdoutData = self.stdoutData

        flushEvery = self.FLUSH_EVERY

        # In python3, we write bytes.
        try:
            writeOutput = sys.stdout.buffer.write
        except:
            writeOutput = sys.stdout.write

        while self.keepGoing is True or len(stdoutData) > 0:
            i = 0
            while len(stdoutData) > 0:
                nextItem = stdoutData.popleft()
                writeOutput(nextItem)
                i += 1
                if i >= flushEvery:
                    i = 0
                    sys.stdout.flush()

            sys.stdout.flush()
            time.sleep(.0005)

class Runner(threading.Thread):
    '''
        Runner - Thread running a subprocess
    '''

    def __init__(self, cmd, stdoutWriter, thisItem, collateOutput=True):
        threading.Thread.__init__(self)
        self.cmd = cmd
        self.stdoutWriter = stdoutWriter

        self.thisItem = thisItem
        self.collateOutput = collateOutput

        self.keepGoing = True

    def run(self):
        pipe = subprocess.Popen(self.cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        if self.collateOutput is True:
            output = []
            def handleLine(line):
                output.append(line)
        else:
            thisItem = self.thisItem

            if sys.version_info.major >= 3:
                def handleLine(line):
                    prefix = ('[%s] ' %(thisItem,)).encode(defaultEncoding)
                    self.stdoutWriter.addData(prefix + line)
            else:
                def handleLine(line):
                    self.stdoutWriter.addData('[%s] %s' %(thisItem, line))

        pipeStdout = pipe.stdout
        i = 0
        while self.keepGoing is True and (not pipeStdout.closed or pipe.poll() is not None):
            try:
                (rlist, wlist, errors) = select.select([pipeStdout], [], [pipeStdout], .004)
                if errors:
                    try:
                        pipeStdout.close()
                    except:
                        pass
                    break

                if not rlist:
                    time.sleep(.002)
                    continue

                line = pipeStdout.readline()
                    
                if line == b'':
                    break

                handleLine(line)
            except Exception as e:
                keepGoing = False
                pipe.terminate()
                sys.stderr.write('Got exception: %s\n' %(str(e),))
                break
        pipe.wait()
        if self.collateOutput is True:
            try:
                self.stdoutWriter.addData(''.join(output))
            except:
                for item in output:
                    self.stdoutWriter.addData(item)

class DistTask(object):
    '''
        DistTask - Main class that manages running the tasks.

            Can be run as a thread, or as a standalone.

            As a thread, and when "endWhenDone=False", you can keep feeding in items (addToArgset, addItemToArgset)
    '''

    def __init__(self, cmd, concurrent_tasks, argset, stdoutWriter, endWhenDone=True, collateOutput=True):
        '''
            Create object

            @param cmd <str> - Command with "%s" for argset items and "%d" for pipe number
            @param concurrent_tasks <int> - Number of tasks to run at once, or 0 to mean all items in argset at once
            @param stdoutWriter <StdoutWriter> - Class to pass stdout output
            @param endWhenDone <bool> Default True - If we should stop running when we have completed the argset.
                If False, we will wait for more items.
            @param collateOutput <bool> Default True - If True, output is only printed when an item completes its task.
                Otherwise, every line is printed prefixed with the argset item name.
        '''

        self.cmd = cmd

        # Use the quicker method of replace if we don't need to support double-percent
        self.replaceItem = getReplaceItemMethod(cmd)

        self.concurrent_tasks = concurrent_tasks or len(argset)
        self.argset = deque(argset)
        self.stdoutWriter = stdoutWriter
        self.endWhenDone = endWhenDone

        if self.endWhenDone is False:
            self.keepGoing = True

        self.collateOutput = collateOutput

        # keepGoing is an attribute when end

    def addToArgset(self, items):
        '''
            addToArgset - Add several items to argset (pending items)

            @param items list<str> -  Several items
        '''
        self.argset += items

    def addItemToArgset(self, item):
        '''
            addItemToArgset - Add a single item to argset (pending items)

            @param item <str> - An item (%s sub)
        '''
        self.argset.append(item)

    def stopRunning(self):
        '''
            stopRunning - Stop executing "run" at next possible moment.
        '''
        self.keepGoing = False

    def run(self):
        '''
            run - Start executing. Will continue to read off "argset" until complete.

                If endWhenDone=True on this object, after "argset" is empty, this will return.

                Otherwise, it will run until "stopRunning" is called.

                The StdoutWriter will be terminated at the end of this method.
        '''

        argset = self.argset
        for i in range(self.concurrent_tasks):
            pipes.append(None)

        pipesRunning = -1

        stdoutWriter = self.stdoutWriter

        if self.endWhenDone is True:
            shouldKeepGoing = lambda : bool(pipesRunning != 0)
        else:
            shouldKeepGoing = lambda : bool(self.keepGoing is True or (len(self.argset) > 0 or pipesRunning > 0))

        collateOutput = self.collateOutput

        while shouldKeepGoing():
            pipesRunning = 0
            for i in range(self.concurrent_tasks):
                if pipes[i] is None:
                    if len(argset) > 0:
                        nextItem = argset.popleft()
                        cmd = self.replaceItem(self.cmd, nextItem, i)
                        pipes[i] = Runner(cmd, stdoutWriter, nextItem, collateOutput)
                        pipes[i].start()
                        pipesRunning += 1
                else:
                    if pipes[i].isAlive() is False:
                        if len(argset) > 0:
                            nextItem = argset.popleft()
                            cmd = self.replaceItem(self.cmd, nextItem, i)
                            pipes[i].join() # cleanup
                            pipes[i] = Runner(cmd, stdoutWriter, nextItem, collateOutput)
                            pipes[i].start()
                            pipesRunning += 1
                    else:
                        pipesRunning += 1

            time.sleep(.0002)

        stdoutWriter.stopRunning()

if (__name__ == "__main__"):
    args = sys.argv[1:]

    collateOutput = True
    if '-nc' in args:
        args.remove('-nc')
        collateOutput = False
    if '--no-collate' in args:
        args.remove('--no-collate')
        collateOutput = False

    if '--version' in args:
        sys.stderr.write('disttask version %s by Tim Savannah\n' %(__version__,))
        sys.exit(0)

    if len(args) < 3 or '--help' in args:
        sys.stderr.write("Usage: " + os.path.basename(sys.argv[0]) + " [cmd] [concurrent tasks] [argset]\n\n")
        sys.stderr.write("Use a %s in [cmd] where you want the args to go. use %d for the pipe number.\nTo run a list of commands (job server), have '%s' be your full command.\nIf you need a literal %s or %d, use %%s or %%d.\n\n")
        sys.stderr.write('''
    Options:

       -nc or --no-collate          By default, the output will be held until the task is completed, so output is not intermixed.
                                       By providing "-nc" or "--no-collate", instead each line that comes in from any running task
                                       is printed, prefixed with the argset in square-brackets (e.x.  "[arg1] Some message"


   Argsets from stdin

      If argset is '--', then the argset items will be read from stdin instead of being provided on the commandline.
      Execution begins immediately, so you can use disttask as a job manager with another process feeding in items
      as they become available.


   Max Concurrency

      You may use "0" or "MAX" as the "concurrent tasks" parameter to execute all items in the argset simultaneously


   Example Usage

      disttask "ssh root@%s hostname" 3 host1 host2 host3 host4 host5 host6 # Connect and get hostname on 6 hosts, 3 at a time.
''')

        sys.stderr.write("\ndisttask version " + __version__ + "\n")
        sys.exit(1)


    pipes = []


    cmd = args.pop(0)

    concurrent_tasks = args.pop(0)
    if concurrent_tasks.lower() == 'max':
        concurrent_tasks = 0

    elif concurrent_tasks.isdigit() is False:
        sys.stderr.write('Number of concurrent tasks must be an integer, not "%s"\n' %(concurrent_tasks, ))
        sys.exit(127)

    concurrent_tasks = int(concurrent_tasks)
    argset = args
 
    if len(argset) == 1 and argset[0] == '--' and concurrent_tasks == 0:
        sys.stderr.write('concurrent tasks = 0 (MAX) is not supported with input from stdin.\n')
        sys.exit(127)

    stdoutWriter = StdoutWriter()
    if collateOutput is False:
        stdoutWriter.setFlushEvery(10)
    stdoutWriter.start()

    if len(argset) == 1 and argset[0] == '--':

        runner = DistTask(cmd, concurrent_tasks, [], stdoutWriter, endWhenDone=False, collateOutput=collateOutput)
        runnerThread = threading.Thread(target=runner.run)
        runnerThread.start()

        nextItem = None
        while not sys.stdin.closed:
            try:
                nextItem = sys.stdin.readline()
                if nextItem == '':
                    break
            except:
                break
            runner.addItemToArgset(nextItem[:-1])

        runner.stopRunning()
        runnerThread.join()
    else:
        runner = DistTask(cmd, concurrent_tasks, argset, stdoutWriter, endWhenDone=True, collateOutput=collateOutput)
        runner.run()

# vim: set ts=4 sw=4 expandtab
