import typing as tp
import logging

from flask import Response, Request
import flask
from satella.time import measure

logger = logging.getLogger(__name__)
__version__ = '0.5'


def FlaskRequestsLogging(app, default_level_mapping: tp.Optional[tp.Dict[int, int]] = None,
                         log_template: tp.Callable[[Request, Response], str] = \
                                 lambda req,
                                        resp: f'Request {req.method} {req.path} '
                                              f'finished with {resp.status_code} '
                                              f'took {req.time_elapsed()} seconds',
                         extra_args_gen: tp.Callable[[Request, Response], dict] = \
                                 lambda req, resp: {'url': req.path,
                                                    'method': req.method,
                                                    'status_code': resp.status_code,
                                                    'elapsed': req.time_elapsed()},
                         unhandled_exception_template: tp.Callable[[Request, Exception], str] = \
                                 lambda req, exc: f'Got exception while processing {req.method} '
                                                  f'{req.path}',
                         extra_except_args_gen: tp.Callable[[Request, Exception], dict] = \
                                 lambda req, exc: {'url': req.path,
                                                   'method': req.method,
                                                   'elapsed': req.time_elapsed()},
                         log_unhandled_exceptions_as: int = logging.ERROR,
                         pass_as_extras: bool = True):
    """
    Instantiate Flask-Requests-Logging

    Your flask's requests will gain a new attribute :code:`time_elapsed`, calling
    which will return the amount of time that it took to execute given request.

    Exceptions will emit two log entries, one for log_template and one for
    unhandled_exception_template, only if given exception is unhandled. A exception log
    entry WILL NOT be generated for a handled exception. Dump their traces there.

    :param app: app to use
    :param default_level_mapping: a mapping of either leftmost digit to error code, or entire
        error code to level mapping. Default is log 2xx and 3xx with INFO, 4xx with WARN
        and 5xx with ERROR. If not given, request will be logged with INFO.
    :param log_template: a function that called with two arguments, flask request and Reponse,
        will return the logging message. Defaults to 'Request {method} {url rule} returned with
        {status_code} took {elapsed} seconds'
    :param extra_args_gen: generator of a dictionary that will be attached as extras to the logging
        entry. By default returns a dict of ('method'=> Method, 'url' => URL rule, 'status_code' =>
        status_code, 'elapsed' => seconds it took)
    :param unhandled_exception_template: a callable that takes two arguments - Flask request and
        an exception instance and is supposed to return a logging message
    :param extra_except_args_gen: generator or a dictionary that will be attached as extras to the
        logging entry if an unhandled exception occurs. exc_info is generated implicitly.
    :param log_unhandled_exceptions_as: logging level to log unhandled exceptions as
    :param pass_as_extras: if True, extra dictionary generated by extra_args_gen will be passed as
        value of the extra keyword, if False it will be passed as kwargs to logging handler.
    """
    default_level_mapping = default_level_mapping or {2: logging.INFO,
                                                      3: logging.INFO,
                                                      4: logging.WARN,
                                                      5: logging.ERROR}

    @app.before_request
    def before_request():
        flask.request.time_elapsed = measure()

    app.before_request(before_request)

    @app.after_request
    def after_request(r: Response):
        flask.request.time_elapsed.stop()
        level = logging.INFO
        if r.status_code in default_level_mapping:
            level = default_level_mapping[r.status_code]
        else:
            p = r.status_code // 100
            if p in default_level_mapping:
                level = default_level_mapping[p]

        msg = log_template(flask.request, r)
        extras = extra_args_gen(flask.request, r)
        if pass_as_extras:
            logger.log(level, msg, extra=extras)
        else:
            logger.log(level, msg, **extras)

        return r

    @app.teardown_request
    def teardown_request(e: tp.Optional[Exception] = None):
        if e is None:
            return

        msg = unhandled_exception_template(flask.request, e)
        extras = extra_except_args_gen(flask.request, e)

        if pass_as_extras:
            logger.log(log_unhandled_exceptions_as, msg, exc_info=e, extra=extras)
        else:
            logger.log(log_unhandled_exceptions_as, msg, exc_info=e, **extras)
