from __future__ import annotations

import argparse
import json
import typing

import rich_argparse

if typing.TYPE_CHECKING:
    Subparsers = argparse._SubParsersAction[argparse.ArgumentParser]


def open_interactive_session(*, variables: dict[str, typing.Any]) -> None:
    header = 'data sorted in these variables:'
    for key, value in variables.items():
        header += (
            '\n- \033[1m\033[97m' + key + '\033[0m: ' + type(value).__name__
        )
    try:
        from IPython.terminal.embed import InteractiveShellEmbed

        ipshell = InteractiveShellEmbed(colors='Linux')  # type: ignore
        ipshell(header=header, local_ns=variables)
    except ImportError:
        import code
        import sys

        class ExitInteract:
            def __call__(self) -> None:
                raise SystemExit

            def __repr__(self) -> str:
                raise SystemExit

        try:
            sys.ps1 = '>>> '
            code.interact(
                banner='\n' + header + '\n',
                local=dict(variables, exit=ExitInteract()),
            )
        except SystemExit:
            pass


def _enter_debugger() -> None:
    """open debugger to most recent exception

    - adapted from https://stackoverflow.com/a/242514
    """
    import sys
    import traceback

    # print stacktrace
    extype, value, tb = sys.exc_info()
    print('[ENTERING DEBUGGER]')
    traceback.print_exc()
    print()

    try:
        import ipdb  # type: ignore
        import types

        tb = typing.cast(types.TracebackType, tb)
        ipdb.post_mortem(tb)

    except ImportError:
        import pdb

        pdb.post_mortem(tb)


class HelpFormatter(rich_argparse.RichHelpFormatter):
    usage_markup = True

    styles = {
        'argparse.prog': 'bold white',
        'argparse.groups': 'bold green',
        'argparse.args': 'bold white',
        'argparse.metavar': 'grey62',
        'argparse.help': 'grey62',
        'argparse.text': 'blue',
        'argparse.syntax': 'blue',
        'argparse.default': 'blue',
    }

    def __init__(self, prog: str) -> None:
        super().__init__(prog, max_help_position=32)

    def _format_args(self, action, default_metavar):  # type: ignore
        get_metavar = self._metavar_formatter(action, default_metavar)
        if action.nargs == argparse.ZERO_OR_MORE:
            return '[%s [%s ...]]' % get_metavar(2)
        elif action.nargs == argparse.ONE_OR_MORE:
            return '%s [...]' % get_metavar(1)
        return super()._format_args(action, default_metavar)

    def format_help(self) -> str:
        lines = [
            line
            for line in super().format_help().split('\n')
            if not line.startswith('  \x1b[1;37m{')
        ]
        return '\n'.join(lines)


class SpecJsonEncoder(json.JSONEncoder):
    def default(self, obj: typing.Any) -> str:
        import polars as pl

        if isinstance(obj, (pl.DataTypeClass, pl.DataType)):
            return str(obj)
        else:
            return super().default(obj)
