from revcore_micro.ddb.exceptions import InstanceNotFound
from jwcrypto.jwk import JWK
from revcore_micro.flask import exceptions
from revcore_micro.flask.users import User
import boto3
import jwt
import requests
import json


class BaseVerifier:
    def __init__(self, client_class, region_name='ap-northeast-1'):
        self.client_class = client_class
        self.region_name = region_name

    def verify(self, *args, **kwargs):
        raise NotImplementedError('verify')


class ClientVerifier(BaseVerifier):
    def verify(self, client_id, **kwargs):
        try:
            return self.client_class.get(id=client_id), None
        except InstanceNotFound:
            raise exceptions.PermissionDenied(detail=f'{client_id} not found')


class ClientSecretVerifier(BaseVerifier):
    region_name = 'ap-northeast-1'

    def verify(self, client_secret, **kwargs):
        try:
            client = boto3.client('apigateway', region_name=self.region_name)
            client_id = client.get_api_key(apiKey=client_secret)['name']
            return self.client_class.get(id=client_id), None
        except Exception as e:
            raise exceptions.PermissionDenied(detail=str(e))


class JWTVerifier(BaseVerifier):
    token_type = 'access'
    permission_classes = []
    user_instance_class = User

    def get_user_instance(self, user):
        user = self.user_instance_class(client_class=self.client_class, user=user)
        return user
        
    def check_user_permission(self, user):
        for permission_class in self.permission_classes:
            permission = permission_class(user=user)
            permission.check_user_permission()

    def check_token_type(self, typ):
        if typ != self.token_type:
            detail = f'invalid token type: {typ}'
            raise exceptions.PermissionDenied(detail=detail)
            
    def get_public_key(self, kid, unverified):
        host, version = unverified['iss'].split('/')
        if version == 'v3':
            url = f'https://{host}/{version}/certs'
            resp = requests.get(url).json()['keys']
            key = [key for key in resp if key['kid'] == kid][0]
        else:
            resp = requests.get('https://keys.revtel-api.com/pub.json').json()
            key = [key for key in resp if key['kid'] == kid][0]
        key = JWK.from_json(json.dumps(key))
        return key.export_to_pem()

    def verify(self, token, **kwargs):
        try:
            header = jwt.get_unverified_header(token)
            unverified = jwt.decode(token, verify=False)
            client = unverified['aud']
            pub = self.get_public_key(header['kid'], unverified)
            print(pub)
            user = jwt.decode(token, algorithms='RS256', key=pub, verify=True, audience=client)
            self.check_token_type(user['typ'])
            user = self.get_user_instance(user)
            self.check_user_permission(user)
            client = self.client_class.get(id=client)
            return client, user
        except Exception as e:
            raise exceptions.PermissionDenied(detail=str(e))


class RefreshVerifier(JWTVerifier):
    token_type = 'refresh'

    def get_user_instance(self, user):
        return user
