from __future__ import annotations

from typing import Any, Dict, Set, Tuple, Union

import dask
from dask.blockwise import BlockwiseDepDict, blockwise
from dask.delayed import Delayed
from dask.highlevelgraph import HighLevelGraph

from .base import Pipeline, PipelineExecutor


def wrap_map_task(function):
    # dependencies are dummy args used to create dependence between stages
    def wrapped(map_arg, config, *dependencies):
        return function(map_arg, config=config)

    return wrapped


def wrap_standalone_task(function):
    def wrapped(config, *dependencies):
        return function(config=config)

    return wrapped


def checkpoint(*args):
    return


def append_token(task_name: str, token: str) -> str:
    return f"{task_name}-{token}"


class DaskPipelineExecutor(PipelineExecutor[Delayed]):
    @staticmethod
    def compile(pipeline: Pipeline):

        token = dask.base.tokenize(pipeline)

        # we are constructing a HighLevelGraph from scratch
        # https://docs.dask.org/en/latest/high-level-graphs.html
        layers = dict()  # type: Dict[str, Dict[Union[str, Tuple[str, int]], Any]]
        dependencies = dict()  # type: Dict[str, Set[str]]

        # start with just the config as a standalone layer
        # create a custom delayed object for the config
        config_key = append_token("config", token)
        layers[config_key] = {config_key: pipeline.config}
        dependencies[config_key] = set()

        prev_key: str = config_key
        for stage in pipeline.stages:
            if stage.mappable is None:
                stage_key = append_token(stage.name, token)
                func = wrap_standalone_task(stage.function)
                layers[stage_key] = {stage_key: (func, config_key, prev_key)}
                dependencies[stage_key] = {config_key, prev_key}
            else:
                func = wrap_map_task(stage.function)
                map_key = append_token(stage.name, token)
                layers[map_key] = map_layer = blockwise(
                    func,
                    map_key,
                    "x",  # <-- dimension name doesn't matter
                    BlockwiseDepDict({(i,): x for i, x in enumerate(stage.mappable)}),
                    # ^ this is extra annoying. `BlockwiseDepList` at least would be nice.
                    "x",
                    config_key,
                    None,
                    prev_key,
                    None,
                    numblocks={},
                    # ^ also annoying; the default of None breaks Blockwise
                )
                dependencies[map_key] = {config_key, prev_key}

                stage_key = f"{stage.name}-checkpoint-{token}"
                layers[stage_key] = {stage_key: (checkpoint, *map_layer.get_output_keys())}
                dependencies[stage_key] = {map_key}
            prev_key = stage_key

        hlg = HighLevelGraph(layers, dependencies)
        delayed = Delayed(prev_key, hlg)
        return delayed

    @staticmethod
    def execute(delayed: Delayed):
        delayed.compute()
