import abc
from bamboo_lib.logger import logger


class BasePipeline(object):
    __metaclass__ = abc.ABCMeta

    @staticmethod
    @abc.abstractmethod
    def validate_params(params):
        return True

    @staticmethod
    @abc.abstractmethod
    def pipeline_id():
        raise NotImplementedError("Please implement the pipeline_id() method")

    @staticmethod
    @abc.abstractmethod
    def name():
        raise NotImplementedError("Please implement the name() method")

    @staticmethod
    @abc.abstractmethod
    def description():
        raise NotImplementedError("Please implement the name() method")

    @staticmethod
    @abc.abstractmethod
    def params():
        raise NotImplementedError("Please implement the params() method")

    @staticmethod
    @abc.abstractmethod
    def website():
        raise NotImplementedError("Please implement the website() method")

    @staticmethod
    def run(params_dict, **kwargs):
        raise NotImplementedError("Please implement the run() method")


class Parameter(object):
    def __init__(self, name, dtype, options=None, allow_multiple=False, label="", source=None):
        self.name = name
        self.dtype = dtype
        self.options = options
        self.allow_multiple = allow_multiple
        self.label = label
        self.source = source


class LinearPipelineExecutor(object):
    def __init__(self, steps=None, params=None):
        self.steps = steps
        self.params = params

    def set_steps(self, steps):
        self.steps = steps

    # Run step, save output, move next
    def run_pipeline(self):
        prev_result = None
        # have pipeline catch then rethrow error
        logger.info("==== Pipeline run started...")
        for step in self.steps:
            logger.info("Starting step %s ....", str(step.__class__))
            result = step.run_step(prev_result, self.params)
            logger.info("Ending step %s ....", str(step.__class__))
            prev_result = result
        logger.info("==== Pipeline run completed!")


class GraphPipelineExecutor(object):
    def __init__(self, for_each=None, do_steps=None, params=None):
        self.for_each = for_each
        self.do_steps = do_steps
        self.params = params

    # Run step, save output, move next
    def run_pipeline(self):
        logger.info("==== Pipeline run started...")
        count = 1
        for chunk_result in self.for_each.run_step(None, self.params):
            logger.info("Starting pass {}".format(count))
            prev_result = chunk_result
            for step in self.do_steps:
                logger.info("Starting step %s ....", str(step.__class__))
                result = step.run_step(prev_result, self.params)
                logger.info("Ending step %s ....", str(step.__class__))
                prev_result = result
            count += 1
        logger.info("==== Pipeline run completed!")


class PipelineStep(object):
    __metaclass__ = abc.ABCMeta

    def __init__(self, **kwargs):
        for key, val in kwargs.items():
            setattr(self, key, val)

    @abc.abstractmethod
    def run_step(self, prev_result, params):
        """Run operations here and return data required for next step"""


class Node(object):
    def __init__(self, step, is_iterator=False):
        self.step = step
        self.next = None
        self.prev = None
        self.children = None
        self.is_iterator = is_iterator
        self.iterator = None
        self.started = False


class EndNode(Node):
    def __init__(self):
        self.next = None
        self.prev = None
        self.is_iterator = False


class ComplexPipelineExecutor(object):
    def __init__(self, params=None):
        self.execution_plan = []
        self.params = params
        # self.pointer = None
        self.root = None
        self.pointer = None

    def insert(self, head, node_to_add):
        if not self.root and not head:
            self.root = node_to_add
        elif head.next:
            return self.insert(head.next, node_to_add)
        else:
            head.next = node_to_add
            node_to_add.prev = head
        self.pointer = node_to_add
        return node_to_add

    def next(self, step):
        node = Node(step)
        self.insert(self.root, node)
        return self

    def foreach(self, step):
        node = Node(step, is_iterator=True)
        self.insert(self.root, node)
        return self

    def endeach(self):
        node = EndNode()
        self.insert(self.root, node)
        return self

    def __str__(self):
        return str(self.execution_plan)

    def run_pipeline(self):
        return self.run_pipeline_helper(self.root, None, stack=[], history=[])

    def run_pipeline_helper(self, curr_node, prev_result, stack=None, history=None):
        if not curr_node:
            logger.debug("Done.")
        elif curr_node.is_iterator and not curr_node.started:
            iterator = curr_node.step.run_step(prev_result, self.params)
            # print "ITERATOR CREATION", iterator
            stack.insert(0, [iterator, curr_node])
            curr_node.started = True
            return self.run_pipeline_helper(curr_node.next, next(iterator), stack=stack, history=history)
        elif curr_node.is_iterator and curr_node.started:
            try:
                # logger.debug("Refreshing iterator...")
                iterator, target_node = stack[0]
                # refresh children iterators
                hiterator, htarget_node = history.pop(0)
                hiterator = htarget_node.step.run_step(prev_result, self.params)
                stack.insert(0, [hiterator, htarget_node])
                return self.run_pipeline_helper(curr_node.next, next(hiterator), stack=stack, history=history)
            except StopIteration:
                raise Exception("TODO ... not yet implemented!")

        elif isinstance(curr_node, EndNode):

            iterator, target_node = stack[0]
            # logger.debug("Running ITERATOR" + str(iterator))
            try:
                return self.run_pipeline_helper(target_node.next, next(iterator), stack=stack, history=history)
            except StopIteration:
                old_node = stack.pop(0)
                history.insert(0, old_node)

                # logger.debug("ITERATOR EXHAUSTED " + str(iterator))
                if stack:
                    iterator, target_node = stack[0]
                    # if there's a parent iterator, go to that, otherwise
                    # continue on...
                return self.run_pipeline_helper(curr_node.next, prev_result, stack=stack, history=history)
        else:
            # simple movement
            prev_result = curr_node.step.run_step(prev_result, self.params)
            return self.run_pipeline_helper(curr_node.next, prev_result, stack=stack, history=history)


class AdvancedPipelineExecutor(object):
    def __init__(self, params=None):
        self.execution_plan = []
        self.params = params
        self.iterator_idx_stack = []

    def next(self, step):
        self.execution_plan.append(["standard", step])
        return self

    def foreach(self, step):
        self.iterator_idx_stack.insert(0, len(self.execution_plan))
        self.execution_plan.append(["iterator", step])
        return self

    def endeach(self):
        if len(self.iterator_idx_stack) > 0:
            #  help iterators
            last_iterator_idx = self.iterator_idx_stack.pop(0)
            op, my_it_step = self.execution_plan[last_iterator_idx]
            setattr(my_it_step, "endpoint", len(self.execution_plan))
            # raise Exception(my_it_step.endpoint)
            self.execution_plan.append(["enditerator", last_iterator_idx])
        else:
            raise ValueError("Cannot have end iterator without matching start iterator")
        return self

    def next_endstate(self, start):
        for i in range(start, len(self.execution_plan)):
            if self.execution_plan[i][0] == 'enditerator':
                return i
        raise Exception("TODO!")

    def run_pipeline(self):
        pointer = 0
        prev_result = None
        it_stack = []
        # it_stack_positions = []
        while pointer < len(self.execution_plan):
            # print("pos:", pointer)
            instruction, step = self.execution_plan[pointer]

            if len(it_stack) > 0 and instruction != 'enditerator':  # if there is an iterator on the stack, hit next on it
                current_iterator = it_stack[0]
                try:
                    loop_id, prev_result = next(current_iterator)
                    # print("NEXT!", instruction, step, prev_result)
                    # print(loop_id, prev_result)
                except StopIteration:
                    pointer = pointer + 1 if len(it_stack) == 1 else self.next_endstate(pointer) + 1
                    # print("POPPING ITERATOR")
                    it_stack.pop(0)
                    continue
            if instruction == "standard":
                # raise Exception("TODO: jump past end of iterator")
                result = step.run_step(prev_result, self.params)
                prev_result = result
                pointer += 1
            elif instruction == "iterator":
                result = step.run_step(prev_result, self.params)
                it_stack.insert(0, enumerate(result))
                prev_result = None  # TODO what should happen here?
                pointer += 1
            elif instruction == "enditerator":
                prev_result = None
                if it_stack:
                    pointer = step + 1  # in this case the step is the index of the iterator
                else:
                    pointer += 1
            else:
                raise ValueError("Invalid instruction!", instruction)


class ResultWrapper(object):
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)
