from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.backends import default_backend
import os
import click

PRIVATE_KEY_FILE = ".pdg/private_key.pem"
PUBLIC_KEY_FILE = ".pdg/public_key.pem"

def generate_key_pair(passphrase):
    """Generate and save a key pair with the given passphrase."""
    os.makedirs(os.path.dirname(PRIVATE_KEY_FILE), exist_ok=True) # Ensure directory exists
    private_key = rsa.generate_private_key(
        public_exponent=65537,
        key_size=2048,
        backend=default_backend()
    )
    encrypted_pem = private_key.private_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PrivateFormat.TraditionalOpenSSL,
        encryption_algorithm=serialization.BestAvailableEncryption(passphrase.encode())
    )
    with open(PRIVATE_KEY_FILE, "wb") as f:
        f.write(encrypted_pem)
    public_key = private_key.public_key()
    pem = public_key.public_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PublicFormat.SubjectPublicKeyInfo
    )
    with open(PUBLIC_KEY_FILE, "wb") as f:
        f.write(pem)
    click.echo("Predator Drone: Key pair generated.")

def load_private_key(passphrase):
    """Load the private key with the given passphrase."""
    with open(PRIVATE_KEY_FILE, "rb") as f:
        private_key = serialization.load_pem_private_key(
            f.read(),
            password=passphrase.encode(),
            backend=default_backend()
        )
    return private_key

def load_public_key():
    """Load the public key."""
    with open(PUBLIC_KEY_FILE, "rb") as f:
        public_key = serialization.load_pem_public_key(
            f.read(),
            backend=default_backend()
        )
    return public_key

def sign_data(data, private_key):
    """Sign the given data with the private key."""
    return private_key.sign(
        data.encode(),
        padding.PSS(
            mgf=padding.MGF1(hashes.SHA256()),
            salt_length=padding.PSS.MAX_LENGTH
        ),
        hashes.SHA256()
    )

def verify_data(data, signature, public_key):
    """Verify the signature of the given data with the public key."""
    try:
        public_key.verify(
            signature,
            data.encode(),
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA256()),
                salt_length=padding.PSS.MAX_LENGTH
            ),
            hashes.SHA256()
        )
        return True
    except Exception:
        return False