import logging
import pstats
import typing as t
from io import StringIO
from cProfile import Profile
from seagrass.base import ProtoHook

# Type alias to represent the input types that are allowed as "restrictions"
R = t.Union[str, int, float]


class ProfilerHook(ProtoHook[bool]):
    """A Seagrass hook that uses the built-in cProfile module to collect and log performance
    statistics on events."""

    profile: Profile
    sort_keys: t.Union[t.Tuple[int], t.Tuple[str, ...]]
    restrictions: t.Tuple[R, ...]
    is_active: bool = True

    # Set a high prehook_priority and posthook_priority to ensure
    # that the profiler only gets called directly before and after
    # the event.
    prehook_priority: int = 10
    posthook_priority: int = 10

    def __init__(
        self,
        sort_keys: t.Union[t.Union[str, int], t.Tuple[str, ...]] = tuple(),
        restrictions: t.Union[R, t.Tuple[R, ...]] = tuple(),
    ) -> None:
        """Create a new ProfilerHook.

        :param Union[K,Tuple[K,...]] sort_keys: a key or list of keys to use to sort the output
            generated by :py:meth:`log_results`. The available keys and their meanings are the same
            as those of `pstats.Stats.sort_stats`_.
        :param Union[R,Tuple[R,...]] restrictions: a restriction or list of restrictions to use
            for the output generated by :py:meth:`log_results`. The available restrictions and
            their meanings are the same as those of `pstats.Stats.print_stats`_.

        .. _pstats.Stats.sort_stats: https://docs.python.org/3/library/profile.html#pstats.Stats.sort_stats
        .. _pstats.Stats.print_stats: https://docs.python.org/3/library/profile.html#pstats.Stats.print_stats
        """

        if isinstance(sort_keys, str):
            self.sort_keys = (sort_keys,)
        elif isinstance(sort_keys, int):
            self.sort_keys = (sort_keys,)
        else:
            self.sort_keys = sort_keys

        if isinstance(restrictions, (str, int, float)):
            self.restrictions = (restrictions,)
        else:
            self.restrictions = restrictions

        self.profile = Profile()

    def prehook(
        self, event_name: str, args: t.Tuple[t.Any, ...], kwargs: t.Dict[str, t.Any]
    ) -> bool:
        # Start profiling
        was_active = self.is_active
        self.is_active = True
        self.profile.enable()
        return was_active

    def cleanup(self, event_name: str, context: bool, exc: t.Optional[Exception]) -> None:
        # Stop profiling
        self.is_active = context
        if not self.is_active:
            self.profile.disable()

    def get_stats(self, **kwargs) -> t.Optional[pstats.Stats]:
        """Return the profiling statistics as a pstats.Stats class.

        :param kwargs: Keyword arguments to pass to the constructor of `pstats.Stats`_.
        :return: an instance of `pstats.Stats`_. If no profiling information was collected,
            this function will return ``None`` instead.
        :rtype: Optional[pstats.Stats]

        .. _pstats.Stats: https://docs.python.org/3/library/profile.html#pstats.Stats
        """
        if self.profile.getstats() == []:  # type: ignore
            return None
        else:
            return pstats.Stats(self.profile, **kwargs)

    def reset(self) -> None:
        """Reset the internal profiler."""
        self.profile.clear()  # type: ignore

    def log_results(self, logger: logging.Logger) -> None:
        """Log the results captured by ProfilerHook."""
        logger.info("Results from %s:", self.__class__.__name__)

        # Dump results to an in-memory stream
        output = StringIO()
        if (stats := self.get_stats(stream=output)) is None:
            logger.info("   (no samples were collected)")
            return

        stats.sort_stats(*self.sort_keys).print_stats(*self.restrictions)

        # Now take results from the in-memory stream and log them using the provided logger.
        logger.info("")
        output.seek(0)
        for line in output.readlines():
            logger.info("    " + line.rstrip())
