"""Generated by Sideko (sideko.dev)"""

import abc
import datetime
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union, cast

import httpx
import jsonpointer

from .request import RequestConfig


class AuthProvider(abc.ABC):
    """
    Abstract base class defining the interface for authentication providers.

    Each concrete implementation handles a specific authentication method
    and modifies the request configuration accordingly.
    """

    @abc.abstractmethod
    def add_to_request(self, cfg: RequestConfig) -> RequestConfig:
        """
        Adds authentication details to the request configuration.

        Args:
            cfg: The request configuration to modify

        Returns:
            The modified request configuration with authentication details added
        """

    @abc.abstractmethod
    def set_value(self, val: Optional[str]) -> None:
        """
        Generic method to set an auth value.

        Args:
            val: Authentication value to set
        """


class AuthBasic(AuthProvider):
    """
    Implements HTTP Basic Authentication.

    Adds username and password credentials to the request using the standard
    HTTP Basic Authentication scheme.
    """

    username: Optional[str]
    password: Optional[str]

    def __init__(
        self, *, username: Optional[str] = None, password: Optional[str] = None
    ):
        super().__init__()
        self.username = username
        self.password = password

    def add_to_request(self, cfg: RequestConfig) -> RequestConfig:
        """
        Adds Basic Authentication credentials to the request configuration.

        Only modifies the configuration if both username and password are provided.
        """
        if self.username is not None and self.password is not None:
            cfg["auth"] = (self.username, self.password)
        return cfg

    def set_value(self, val: Optional[str]) -> None:
        """
        Sets value as the username
        """
        self.username = val


class AuthBearer(AuthProvider):
    """
    Implements Bearer token authentication.

    Adds a Bearer token to the request's Authorization header following
    a 'Bearer ' prefix
    """

    token: Optional[str]

    def __init__(self, *, token: Optional[str] = None):
        super().__init__()
        self.token = token

    def add_to_request(self, cfg: RequestConfig) -> RequestConfig:
        """
        Adds Bearer token to the Authorization header.

        Only modifies the configuration if a token value is provided.
        """
        if self.token is not None:
            headers = cfg.get("headers", {})
            headers["Authorization"] = f"Bearer {self.token}"
            cfg["headers"] = headers
        return cfg

    def set_value(self, val: Optional[str]) -> None:
        """
        Sets value as the bearer token
        """
        self.token = val


class AuthKey(AuthProvider):
    """
    Implements query, header, or cookie based authentication.

    Adds an authentication token to the request in the configured location
    """

    name: str
    location: Literal["query", "header", "cookie"]
    val: Optional[str]

    def __init__(
        self,
        *,
        name: str,
        location: Literal["query", "header", "cookie"],
        val: Optional[str] = None,
    ):
        super().__init__()
        self.name = name
        self.location = location
        self.val = val

    def add_to_request(self, cfg: RequestConfig) -> RequestConfig:
        """
        Adds authentication value as a query/header/cookie parameter
        """
        if self.val is None:
            return cfg
        elif self.location == "query":
            params = cfg.get("params", {})
            params[self.name] = self.val
            cfg["params"] = params
        elif self.location == "header":
            headers = cfg.get("headers", {})
            headers[self.name] = self.val
            cfg["headers"] = headers
        else:
            cookies = cfg.get("cookies", {})
            cookies[self.name] = self.val
            cfg["cookies"] = cookies

        return cfg

    def set_value(self, val: Optional[str]) -> None:
        """
        Sets value as the key
        """
        self.val = val


GrantType = Literal["password", "client_credentials"]
CredentialsLocation = Literal["request_body", "basic_authorization_header"]
BodyContent = Literal["form", "json"]


class OAuth2Password(TypedDict, total=True):
    """
    OAuth2 authentication form for a password flow

    Details:
    https://datatracker.ietf.org/doc/html/rfc6749#section-4.3
    """

    username: str
    password: str
    client_id: Optional[str]
    client_secret: Optional[str]
    grant_type: Optional[Union[GrantType, str]]
    scope: Optional[List[str]]

    token_url: Optional[str]
    """
    Overrides the default token url
    """


class OAuth2ClientCredentials(TypedDict, total=True):
    """
    OAuth2 authentication form for a client credentials flow

    Details:
    https://datatracker.ietf.org/doc/html/rfc6749#section-4.4
    """

    client_id: str
    client_secret: str
    grant_type: Optional[Union[GrantType, str]]
    scope: Optional[List[str]]

    token_url: Optional[str]
    """
    Overrides the default token url
    """


class OAuth2(AuthProvider):
    """
    Implements OAuth2 token retrieval and refreshing.
    Currently supports `password` and `client_credentials`
    grant types.
    """

    # OAuth2 provider configuration
    base_url: str
    token_url: str
    access_token_pointer: str
    expires_in_pointer: str
    credentials_location: CredentialsLocation
    body_content: BodyContent
    request_mutator: AuthProvider

    # OAuth2 access token request values
    grant_type: Union[GrantType, str]
    username: Optional[str]
    password: Optional[str]
    client_id: Optional[str]
    client_secret: Optional[str]
    scope: Optional[List[str]]

    # access_token storage
    access_token: Optional[str]
    expires_at: Optional[datetime.datetime]

    def __init__(
        self,
        *,
        base_url: str,
        default_token_url: str,
        access_token_pointer: str,
        expires_in_pointer: str,
        credentials_location: CredentialsLocation,
        body_content: BodyContent,
        request_mutator: AuthProvider,
        form: Optional[Union[OAuth2Password, OAuth2ClientCredentials]] = None,
    ):
        super().__init__()

        form_data: Union[OAuth2Password, OAuth2ClientCredentials] = form or cast(
            OAuth2ClientCredentials, {}
        )

        self.base_url = base_url
        self.token_url = form_data.get("token_url") or default_token_url
        self.access_token_pointer = access_token_pointer
        self.expires_in_pointer = expires_in_pointer
        self.credentials_location = credentials_location
        self.body_content = body_content
        self.request_mutator = request_mutator

        default_grant_type: GrantType = (
            "password"
            if form_data.get("username") is not None
            else "client_credentials"
        )
        self.grant_type = form_data.get("grant_type") or default_grant_type
        self.username = cast(Optional[str], form_data.get("username"))
        self.password = cast(Optional[str], form_data.get("password"))
        self.client_id = form_data.get("client_id")
        self.client_secret = form_data.get("client_secret")
        self.scope = form_data.get("scope")

        self.access_token = None
        self.expires_at = None

    def _refresh(self) -> Tuple[str, datetime.datetime]:
        # build token url using base_url if relative
        url = self.token_url
        if url.startswith("/"):
            url = f"{self.base_url.strip('/')}/{self.token_url.strip('/')}"

        req_cfg: Dict[str, Any] = {"url": url}
        req_data: Dict[str, str] = {"grant_type": self.grant_type}

        # add client credentials
        if self.client_id is not None or self.client_secret is not None:
            if self.credentials_location == "basic_authorization_header":
                req_cfg["auth"] = (self.client_id or "", self.client_secret or "")
            else:
                if self.client_id is not None:
                    req_data["client_id"] = self.client_id
                if self.client_secret is not None:
                    req_data["client_secret"] = self.client_secret

        # construct request data
        if self.username is not None:
            req_data["username"] = self.username
        if self.password is not None:
            req_data["password"] = self.password
        if self.scope is not None:
            req_data["scope"] = " ".join(self.scope)

        if self.body_content == "json":
            req_cfg["json"] = req_data
            req_cfg["headers"] = {"content-type": "application/json"}
        else:
            req_cfg["data"] = req_data
            req_cfg["headers"] = {"content-type": "application/x-www-form-urlencoded"}

        # make access token request
        token_res = httpx.post(**req_cfg)
        token_res.raise_for_status()

        # retrieve access token & optional expiry seconds
        token_res_json: Dict[str, Any] = token_res.json()
        access_token = str(
            jsonpointer.resolve_pointer(token_res_json, self.access_token_pointer)
        )

        expires_in_secs = jsonpointer.resolve_pointer(
            token_res_json, self.expires_in_pointer
        )
        if not isinstance(expires_in_secs, int):
            expires_in_secs = 600
        expires_at = datetime.datetime.now() + datetime.timedelta(
            seconds=(
                expires_in_secs - 60
            )  # subtract a minute from the expiry as a buffer
        )

        return (access_token, expires_at)

    def add_to_request(self, cfg: RequestConfig) -> RequestConfig:
        if (
            self.username is None
            and self.password is None
            and self.client_id is None
            and self.client_secret is None
        ):
            # provider is not configured to make an oauth token request
            return cfg

        token_expired = (
            self.expires_at is not None and self.expires_at <= datetime.datetime.now()
        )
        if self.access_token is None or token_expired:
            access_token, expires_at = self._refresh()
            self.expires_at = expires_at
            self.access_token = access_token

        self.request_mutator.set_value(self.access_token)
        return self.request_mutator.add_to_request(cfg)

    def set_value(self, _val: Optional[str]) -> None:
        raise NotImplementedError("an OAuth2 auth provider cannot be a request_mutator")
