"""Hook infrastructure for Bitfount."""
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from enum import Enum
from functools import wraps
import logging
from types import FunctionType, MappingProxyType
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    List,
    Literal,
    Mapping,
    Protocol,
    Tuple,
    Type,
    Union,
    cast,
    overload,
    runtime_checkable,
)

from bitfount.exceptions import HookError

if TYPE_CHECKING:
    from bitfount.federated.algorithms.base import _BaseAlgorithm
    from bitfount.federated.pod import Pod

__all__: List[str] = [
    "BaseAlgorithmHook",
    "BasePodHook",
    "HookType",
    "get_hooks",
]

logger = logging.getLogger(__name__)

_HOOK_DECORATED_ATTRIBUTE = "_decorate"


class HookType(Enum):
    """Enum for hook types."""

    POD = "POD"
    ALGORITHM = "ALGORITHM"


@runtime_checkable
class HookProtocol(Protocol):
    """Base Protocol for hooks used just for type annotation."""

    hook_name: str

    @property
    def type(self) -> HookType:
        """Return the hook type."""
        ...

    @property
    def registered(self) -> bool:
        """Return whether the hook is registered."""
        ...

    def register(self) -> None:
        """Register the hook.

        Adds hook to the registry against the hook type.
        """
        ...


@runtime_checkable
class PodHookProtocol(HookProtocol, Protocol):
    """Protocol for Pod hooks."""

    def on_pod_startup_start(self, pod: Pod, *args: Any, **kwargs: Any) -> None:
        """Run the hook at the very start of pod startup."""
        ...

    def on_pod_startup_end(self, pod: Pod, *args: Any, **kwargs: Any) -> None:
        """Run the hook at the end of pod startup."""
        ...

    def on_task_start(self, pod: Pod, *args: Any, **kwargs: Any) -> None:
        """Run the hook when a new task is received at the start."""
        ...

    def on_task_end(self, pod: Pod, *args: Any, **kwargs: Any) -> None:
        """Run the hook when a new task is received at the end."""
        ...

    def on_pod_shutdown_start(self, pod: Pod, *args: Any, **kwargs: Any) -> None:
        """Run the hook at the very start of pod shutdown."""
        ...

    def on_pod_shutdown_end(self, pod: Pod, *args: Any, **kwargs: Any) -> None:
        """Run the hook at the very end of pod shutdown."""
        ...


@runtime_checkable
class AlgorithmHookProtocol(HookProtocol, Protocol):
    """Protocol for Algorithm hooks."""

    def on_init_start(
        self, algorithm: _BaseAlgorithm, *args: Any, **kwargs: Any
    ) -> None:
        """Run the hook at the very start of algorithm initialisation."""
        ...

    def on_init_end(self, algorithm: _BaseAlgorithm, *args: Any, **kwargs: Any) -> None:
        """Run the hook at the very end of algorithm initialisation."""
        ...

    def on_run_start(
        self, algorithm: _BaseAlgorithm, *args: Any, **kwargs: Any
    ) -> None:
        """Run the hook at the very start of algorithm run."""
        ...

    def on_run_end(self, algorithm: _BaseAlgorithm, *args: Any, **kwargs: Any) -> None:
        """Run the hook at the very end of algorithm run."""
        ...


HOOK_TYPE_TO_PROTOCOL_MAPPING: Dict[HookType, Type[HookProtocol]] = {
    HookType.POD: PodHookProtocol,
    HookType.ALGORITHM: AlgorithmHookProtocol,
}

# The mutable underlying dict that holds the registry information
_registry: Dict[HookType, List[HookProtocol]] = {}
# The read-only version of the registry that is allowed to be imported
registry: Mapping[HookType, List[HookProtocol]] = MappingProxyType(_registry)


@overload
def get_hooks(type: Literal[HookType.POD]) -> List[PodHookProtocol]:
    ...


@overload
def get_hooks(type: Literal[HookType.ALGORITHM]) -> List[AlgorithmHookProtocol]:
    ...


def get_hooks(
    type: HookType,
) -> Union[List[AlgorithmHookProtocol], List[PodHookProtocol]]:
    """Get all registered hooks of a particular type."""
    hooks = registry.get(type, [])
    if type is HookType.POD:
        return cast(List[PodHookProtocol], hooks)
    elif type is HookType.ALGORITHM:
        return cast(List[AlgorithmHookProtocol], hooks)


def ignore_decorator(f: Callable) -> Callable:
    """Decorator to exclude methods from autodecoration."""
    setattr(f, _HOOK_DECORATED_ATTRIBUTE, False)
    return f


def hook_decorator(f: Callable) -> Callable:
    """Hook decorator which logs before and after the hook it decorates."""

    @wraps(f)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        """Wraps provided function and prints before and after."""
        logger.debug(f"Calling hook {f.__name__}")
        try:
            return_val = f(*args, **kwargs)
        except Exception as e:
            logger.error(f"Exception in hook {f.__name__}")
            logger.exception(e)
        else:
            logger.debug(f"Called hook {f.__name__}")
            return return_val

    return wrapper


def get_hook_decorator_metaclass(decorator: Callable) -> type:
    """Decorate all instance methods (unless excluded) with the same decorator."""

    class HookDecoratorMetaClass(type):
        """Decorate all instance methods (unless excluded) with the same decorator."""

        @classmethod
        def do_decorate(cls, attr: str, value: Any) -> bool:
            """Checks if an object should be decorated."""
            return (
                "__" not in attr
                and not attr.startswith("_")
                and isinstance(value, FunctionType)
                and getattr(value, _HOOK_DECORATED_ATTRIBUTE, True)
            )

        def __new__(
            cls, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]
        ) -> type:
            for attr, value in dct.items():
                if cls.do_decorate(attr, value):
                    setattr(value, _HOOK_DECORATED_ATTRIBUTE, True)
                    dct[attr] = decorator(value)
            return super().__new__(cls, name, bases, dct)

        def __setattr__(self, attr: str, value: Any) -> None:
            if self.do_decorate(attr, value):
                value = decorator(value)
            super().__setattr__(attr, value)

    return HookDecoratorMetaClass


AbstractHookMetaClass = type(
    "AbstractHookMetaClass",
    (ABCMeta, get_hook_decorator_metaclass(decorator=hook_decorator)),
    {},
)


# Mypy explicitly does not support dynamically computed metaclasses yet.
class BaseHook(metaclass=AbstractHookMetaClass):  # type: ignore[misc] # Reason: See above # noqa: B950
    """Base hook class."""

    def __init__(self) -> None:
        """Initialise the hook."""
        self.hook_name = type(self).__name__

    @property
    @abstractmethod
    def type(self) -> HookType:
        """Return the hook type."""
        raise NotImplementedError

    @property
    def registered(self) -> bool:
        """Return whether the hook is registered."""
        return self.hook_name in [h.hook_name for h in _registry.get(self.type, [])]

    @ignore_decorator
    def register(self) -> None:
        """Register the hook.

        Adds hook to the registry against the hook type.
        """
        if not isinstance(self, HOOK_TYPE_TO_PROTOCOL_MAPPING[self.type]):
            raise HookError("Hook does not implement the specified protocol")

        if self.registered:
            logger.info("Hook already registered")
            return

        logger.debug(f"Adding {self.hook_name} to Hooks registry")
        existing_hooks = _registry.get(self.type, [])
        existing_hooks.append(self)
        _registry[self.type] = existing_hooks
        logger.info(f"Added {self.hook_name} to Hooks registry")


class BasePodHook(BaseHook):
    """Base pod hook class."""

    @property
    def type(self) -> HookType:
        """Return the hook type."""
        return HookType.POD

    def on_pod_startup_start(self, pod: Pod, *args: Any, **kwargs: Any) -> None:
        """Run the hook at the very start of pod startup."""
        pass

    def on_pod_startup_end(self, pod: Pod, *args: Any, **kwargs: Any) -> None:
        """Run the hook at the end of pod startup."""
        pass

    def on_task_start(self, pod: Pod, *args: Any, **kwargs: Any) -> None:
        """Run the hook when a new task is received at the start."""
        pass

    def on_task_end(self, pod: Pod, *args: Any, **kwargs: Any) -> None:
        """Run the hook when a new task is received at the end."""
        pass

    def on_pod_shutdown_start(self, pod: Pod, *args: Any, **kwargs: Any) -> None:
        """Run the hook at the very start of pod shutdown."""
        pass

    def on_pod_shutdown_end(self, pod: Pod, *args: Any, **kwargs: Any) -> None:
        """Run the hook at the very end of pod shutdown."""
        pass


class BaseAlgorithmHook(BaseHook):
    """Base algorithm hook class."""

    @property
    def type(self) -> HookType:
        """Return the hook type."""
        return HookType.ALGORITHM

    def on_init_start(
        self, algorithm: _BaseAlgorithm, *args: Any, **kwargs: Any
    ) -> None:
        """Run the hook at the very start of algorithm initialisation."""
        pass

    def on_init_end(self, algorithm: _BaseAlgorithm, *args: Any, **kwargs: Any) -> None:
        """Run the hook at the very end of algorithm initialisation."""
        pass

    def on_run_start(
        self, algorithm: _BaseAlgorithm, *args: Any, **kwargs: Any
    ) -> None:
        """Run the hook at the very start of algorithm run."""
        pass

    def on_run_end(self, algorithm: _BaseAlgorithm, *args: Any, **kwargs: Any) -> None:
        """Run the hook at the very end of algorithm run."""
        pass
