"""Basic scaffolding for handling OAuth"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/api/08_oauth.ipynb.

# %% auto 0
__all__ = ['http_patterns', 'GoogleAppClient', 'GitHubAppClient', 'HuggingFaceClient', 'DiscordAppClient', 'Auth0AppClient',
           'get_host', 'redir_url', 'url_match', 'OAuth', 'load_creds']

# %% ../nbs/api/08_oauth.ipynb
from .common import *
from oauthlib.oauth2 import WebApplicationClient
from urllib.parse import urlparse, urlencode, parse_qs, quote, unquote
import secrets, httpx

# %% ../nbs/api/08_oauth.ipynb
class _AppClient(WebApplicationClient):
    id_key = 'sub'
    def __init__(self, client_id, client_secret, code=None, scope=None, **kwargs):
        super().__init__(client_id, code=code, scope=scope, **kwargs)
        self.client_secret = client_secret

# %% ../nbs/api/08_oauth.ipynb
class GoogleAppClient(_AppClient):
    "A `WebApplicationClient` for Google oauth2"
    base_url = "https://accounts.google.com/o/oauth2/v2/auth"
    token_url = "https://oauth2.googleapis.com/token"
    info_url = "https://openidconnect.googleapis.com/v1/userinfo"
    
    def __init__(self, client_id, client_secret, code=None, scope=None, project_id=None, **kwargs):
        scope_pre = "https://www.googleapis.com/auth/userinfo"
        if not scope: scope=["openid", f"{scope_pre}.email", f"{scope_pre}.profile"]
        super().__init__(client_id, client_secret, code=code, scope=scope, **kwargs)
        self.project_id = project_id
    
    @classmethod
    def from_file(cls, fname, code=None, scope=None, **kwargs):
        cred = Path(fname).read_json()['web']
        return cls(cred['client_id'], client_secret=cred['client_secret'], project_id=cred['project_id'],
                  code=code, scope=scope, **kwargs)

# %% ../nbs/api/08_oauth.ipynb
class GitHubAppClient(_AppClient):
    "A `WebApplicationClient` for GitHub oauth2"
    prefix = "https://github.com/login/oauth/"
    base_url = f"{prefix}authorize"
    token_url = f"{prefix}access_token"
    info_url = "https://api.github.com/user"
    id_key = 'id'

    def __init__(self, client_id, client_secret, code=None, scope=None, **kwargs):
        super().__init__(client_id, client_secret, code=code, scope=scope, **kwargs)

# %% ../nbs/api/08_oauth.ipynb
class HuggingFaceClient(_AppClient):
    "A `WebApplicationClient` for HuggingFace oauth2"
    prefix = "https://huggingface.co/oauth/"
    base_url = f"{prefix}authorize"
    token_url = f"{prefix}token"
    info_url = f"{prefix}userinfo"
    
    def __init__(self, client_id, client_secret, code=None, scope=None, state=None, **kwargs):
        if not scope: scope=["openid","profile"]
        if not state: state=secrets.token_urlsafe(16)
        super().__init__(client_id, client_secret, code=code, scope=scope, state=state, **kwargs)

# %% ../nbs/api/08_oauth.ipynb
class DiscordAppClient(_AppClient):
    "A `WebApplicationClient` for Discord oauth2"
    base_url = "https://discord.com/oauth2/authorize"
    token_url = "https://discord.com/api/oauth2/token"
    revoke_url = "https://discord.com/api/oauth2/token/revoke"
    info_url = "https://discord.com/api/users/@me"
    id_key = 'id'

    def __init__(self, client_id, client_secret, is_user=False, perms=0, scope=None, **kwargs):
        if not scope: scope="applications.commands applications.commands.permissions.update identify"
        self.integration_type = 1 if is_user else 0
        self.perms = perms
        super().__init__(client_id, client_secret, scope=scope, **kwargs)

    def login_link(self, redirect_uri=None, scope=None, state=None):
        use_scope = scope or self.scope
        d = dict(response_type='code', client_id=self.client_id,
                 integration_type=self.integration_type, scope=use_scope)
        if state: d['state'] = state
        if redirect_uri: d['redirect_uri'] = redirect_uri
        return f'{self.base_url}?' + urlencode(d)

    def parse_response(self, code, redirect_uri=None):
        headers = {'Content-Type': 'application/x-www-form-urlencoded'}
        data = dict(grant_type='authorization_code', code=code)
        if redirect_uri: data['redirect_uri'] = redirect_uri
        r = httpx.post(self.token_url, data=data, headers=headers, auth=(self.client_id, self.client_secret))
        r.raise_for_status()
        self.parse_request_body_response(r.text)

# %% ../nbs/api/08_oauth.ipynb
class Auth0AppClient(_AppClient):
    "A `WebApplicationClient` for Auth0 OAuth2"
    def __init__(self, domain, client_id, client_secret, code=None, scope=None, redirect_uri="", **kwargs):
        self.redirect_uri,self.domain = redirect_uri,domain
        config = self._fetch_openid_config()
        self.base_url,self.token_url,self.info_url = config["authorization_endpoint"],config["token_endpoint"],config["userinfo_endpoint"]
        super().__init__(client_id, client_secret, code=code, scope=scope, redirect_uri=redirect_uri, **kwargs)

    def _fetch_openid_config(self):
        r = httpx.get(f"https://{self.domain}/.well-known/openid-configuration")
        r.raise_for_status()
        return r.json()

    def login_link(self, req):
        d = dict(response_type="code", client_id=self.client_id, scope=self.scope, redirect_uri=redir_url(req, self.redirect_uri))
        return f"{self.base_url}?{urlencode(d)}"

# %% ../nbs/api/08_oauth.ipynb
@patch
def login_link(self:WebApplicationClient, redirect_uri, scope=None, state=None, **kwargs):
    "Get a login link for this client"
    if not scope: scope=self.scope
    if not state: state=getattr(self, 'state', None)
    return self.prepare_request_uri(self.base_url, redirect_uri, scope, state=state, **kwargs)

# %% ../nbs/api/08_oauth.ipynb
def get_host(request):
    """Get the host, preferring X-Forwarded-Host if available"""
    forwarded_host = request.headers.get('x-forwarded-host')
    return forwarded_host if forwarded_host else request.url.netloc

# %% ../nbs/api/08_oauth.ipynb
def redir_url(req, redir_path, scheme=None):
    "Get the redir url for the host in `request`"
    host = get_host(req)
    scheme = 'http' if host.split(':')[0] in ("localhost", "127.0.0.1") else 'https'
    return f"{scheme}://{host}{redir_path}"

# %% ../nbs/api/08_oauth.ipynb
@patch
def parse_response(self:_AppClient, code, redirect_uri):
    "Get the token from the oauth2 server response"
    payload = dict(code=code, redirect_uri=redirect_uri, client_id=self.client_id,
                   client_secret=self.client_secret, grant_type='authorization_code')
    r = httpx.post(self.token_url, data=payload)
    r.raise_for_status()
    self.parse_request_body_response(r.text)

# %% ../nbs/api/08_oauth.ipynb
@patch
def get_info(self:_AppClient, token=None):
    "Get the info for authenticated user"
    if not token: token = self.token["access_token"]
    headers = {'Authorization': f'Bearer {token}'}
    return httpx.get(self.info_url, headers=headers).json()

# %% ../nbs/api/08_oauth.ipynb
@patch
def retr_info(self:_AppClient, code, redirect_uri):
    "Combines `parse_response` and `get_info`"
    self.parse_response(code, redirect_uri)
    return self.get_info()

# %% ../nbs/api/08_oauth.ipynb
@patch
def retr_id(self:_AppClient, code, redirect_uri):
    "Call `retr_info` and then return id/subscriber value"
    return self.retr_info(code, redirect_uri)[self.id_key]

# %% ../nbs/api/08_oauth.ipynb
http_patterns = (r'^(localhost|127\.0\.0\.1)(:\d+)?$',)
def url_match(request, patterns=http_patterns):
    return any(re.match(pattern, get_host(request).split(':')[0]) for pattern in patterns)

# %% ../nbs/api/08_oauth.ipynb
class OAuth:
    def __init__(self, app, cli, skip=None, redir_path='/redirect', error_path='/error', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns):
        if not skip: skip = [redir_path,error_path,login_path]
        store_attr()
        def before(req, session):
            if 'auth' not in req.scope: req.scope['auth'] = session.get('auth')
            auth = req.scope['auth']
            if not auth: return self.redir_login(session)
            res = self.check_invalid(req, session, auth)
            if res: return res
        app.before.append(Beforeware(before, skip=skip))

        @app.get(redir_path)
        def redirect(req, session, code:str=None, error:str=None, state:str=None):
            if not code: session['oauth_error']=error; return RedirectResponse(self.error_path, status_code=303)
            scheme = 'http' if url_match(req,self.http_patterns) or not self.https else 'https'
            base_url = f"{scheme}://{get_host(req)}"
            info = AttrDictDefault(cli.retr_info(code, base_url+redir_path))
            ident = info.get(self.cli.id_key)
            if not ident: return self.redir_login(session)
            res = self.get_auth(info, ident, session, state)
            if not res:   return self.redir_login(session)
            req.scope['auth'] = session['auth'] = ident
            return res

        @app.get(logout_path)
        def logout(session):
            session.pop('auth', None)
            return self.logout(session)

    def redir_login(self, session): return RedirectResponse(self.login_path, status_code=303)
    def redir_url(self, req):
        scheme = 'http' if url_match(req,self.http_patterns) or not self.https else 'https'
        return redir_url(req, self.redir_path, scheme)

    def login_link(self, req, scope=None, state=None): return self.cli.login_link(self.redir_url(req), scope=scope, state=state)
    def check_invalid(self, req, session, auth): return False
    def logout(self, session): return self.redir_login(session)
    def get_auth(self, info, ident, session, state): raise NotImplementedError()

# %% ../nbs/api/08_oauth.ipynb
try:
    from google.oauth2.credentials import Credentials
    from google.auth.transport.requests import Request
except ImportError:
    Request=None
    class Credentials: pass

# %% ../nbs/api/08_oauth.ipynb
@patch()
def consent_url(self:GoogleAppClient, proj=None):
    "Get Google OAuth consent screen URL"
    loc = "https://console.cloud.google.com/auth/clients"
    if proj is None: proj=self.project_id
    return f"{loc}/{self.client_id}?project={proj}"

# %% ../nbs/api/08_oauth.ipynb
@patch
def update(self:Credentials):
    "Refresh the credentials if they are expired, and return them"
    if self.expired: self.refresh(Request())
    return self

# %% ../nbs/api/08_oauth.ipynb
@patch
def save(self:Credentials, fname):
    "Save credentials to `fname`"
    save_pickle(fname, self)

# %% ../nbs/api/08_oauth.ipynb
def load_creds(fname):
    "Load credentials from `fname`"
    return load_pickle(fname).update()

# %% ../nbs/api/08_oauth.ipynb
@patch
def creds(self:GoogleAppClient):
    "Create `Credentials` from the client, refreshing if needed"
    return Credentials(token=self.access_token, refresh_token=self.refresh_token, 
        token_uri=self.token_url, client_id=self.client_id,
        client_secret=self.client_secret, scopes=self.scope).update()
