#!/usr/bin/env python3
"""
Advanced Security for TuskLang Python SDK
=========================================
Comprehensive security features and quantum-safe cryptography

This module provides advanced security capabilities for the TuskLang Python SDK,
including quantum-safe encryption, advanced authentication, security monitoring,
and comprehensive threat detection and prevention.
"""

import hashlib
import hmac
import secrets
import base64
import json
import time
import threading
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
from enum import Enum
import logging
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.serialization import load_pem_private_key
import jwt
import bcrypt
import argon2


class SecurityLevel(Enum):
    """Security level enumeration"""
    BASIC = "basic"
    STANDARD = "standard"
    HIGH = "high"
    QUANTUM_SAFE = "quantum_safe"


class ThreatLevel(Enum):
    """Threat level enumeration"""
    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"
    CRITICAL = "critical"


@dataclass
class SecurityEvent:
    """Security event structure"""
    event_id: str
    timestamp: datetime
    event_type: str
    threat_level: ThreatLevel
    source_ip: str
    user_id: Optional[str]
    details: Dict[str, Any]
    action_taken: str


@dataclass
class SecurityPolicy:
    """Security policy structure"""
    policy_id: str
    name: str
    description: str
    rules: List[Dict[str, Any]]
    enabled: bool
    priority: int


class AdvancedSecurity:
    """Advanced security system for TuskLang"""
    
    def __init__(self, config: Dict[str, Any] = None):
        self.config = config or {}
        self.logger = logging.getLogger('tusklang.security')
        
        # Initialize components
        self.security_events = []
        self.security_policies = {}
        self.threat_detection = ThreatDetection()
        self.encryption_manager = EncryptionManager()
        self.authentication_manager = AuthenticationManager()
        self.security_monitor = SecurityMonitor()
        
        # Initialize security policies
        self._init_security_policies()
        
        # Start security monitoring
        self.security_active = True
        self.security_thread = threading.Thread(target=self._security_monitor_loop, daemon=True)
        self.security_thread.start()
    
    def _init_security_policies(self):
        """Initialize default security policies"""
        policies = [
            SecurityPolicy(
                policy_id="access_control",
                name="Access Control Policy",
                description="Controls access to sensitive resources",
                rules=[
                    {"type": "rate_limit", "max_attempts": 5, "window_minutes": 15},
                    {"type": "session_timeout", "timeout_minutes": 30},
                    {"type": "password_policy", "min_length": 12, "require_special": True}
                ],
                enabled=True,
                priority=1
            ),
            SecurityPolicy(
                policy_id="data_protection",
                name="Data Protection Policy",
                description="Protects sensitive data with encryption",
                rules=[
                    {"type": "encryption_at_rest", "algorithm": "AES-256"},
                    {"type": "encryption_in_transit", "protocol": "TLS-1.3"},
                    {"type": "data_classification", "sensitive_patterns": ["password", "token", "key"]}
                ],
                enabled=True,
                priority=2
            ),
            SecurityPolicy(
                policy_id="threat_prevention",
                name="Threat Prevention Policy",
                description="Prevents and detects security threats",
                rules=[
                    {"type": "sql_injection", "enabled": True},
                    {"type": "xss_prevention", "enabled": True},
                    {"type": "csrf_protection", "enabled": True}
                ],
                enabled=True,
                priority=3
            )
        ]
        
        for policy in policies:
            self.security_policies[policy.policy_id] = policy
    
    def encrypt_data(self, data: str, key: Optional[str] = None, 
                    algorithm: str = "AES-256") -> Dict[str, Any]:
        """Encrypt data using specified algorithm"""
        return self.encryption_manager.encrypt(data, key, algorithm)
    
    def decrypt_data(self, encrypted_data: Dict[str, Any], key: Optional[str] = None) -> str:
        """Decrypt data using specified algorithm"""
        return self.encryption_manager.decrypt(encrypted_data, key)
    
    def generate_quantum_safe_key(self, key_size: int = 256) -> str:
        """Generate quantum-safe encryption key"""
        return self.encryption_manager.generate_quantum_safe_key(key_size)
    
    def hash_password(self, password: str, algorithm: str = "argon2") -> str:
        """Hash password using specified algorithm"""
        return self.authentication_manager.hash_password(password, algorithm)
    
    def verify_password(self, password: str, hashed_password: str) -> bool:
        """Verify password against hash"""
        return self.authentication_manager.verify_password(password, hashed_password)
    
    def generate_jwt_token(self, payload: Dict[str, Any], secret: str, 
                          algorithm: str = "HS256") -> str:
        """Generate JWT token"""
        return self.authentication_manager.generate_jwt_token(payload, secret, algorithm)
    
    def verify_jwt_token(self, token: str, secret: str) -> Dict[str, Any]:
        """Verify JWT token"""
        return self.authentication_manager.verify_jwt_token(token, secret)
    
    def detect_threat(self, event_data: Dict[str, Any]) -> Optional[ThreatLevel]:
        """Detect security threats"""
        return self.threat_detection.detect_threat(event_data)
    
    def log_security_event(self, event_type: str, threat_level: ThreatLevel,
                          source_ip: str, details: Dict[str, Any], 
                          user_id: Optional[str] = None):
        """Log security event"""
        event = SecurityEvent(
            event_id=str(secrets.token_hex(16)),
            timestamp=datetime.now(),
            event_type=event_type,
            threat_level=threat_level,
            source_ip=source_ip,
            user_id=user_id,
            details=details,
            action_taken="logged"
        )
        
        self.security_events.append(event)
        self.logger.warning(f"Security event: {event_type} - {threat_level.value}")
    
    def get_security_events(self, hours: int = 24) -> List[Dict[str, Any]]:
        """Get security events from specified time period"""
        cutoff_time = datetime.now() - timedelta(hours=hours)
        recent_events = [
            event for event in self.security_events
            if event.timestamp >= cutoff_time
        ]
        
        return [asdict(event) for event in recent_events]
    
    def get_security_stats(self) -> Dict[str, Any]:
        """Get security statistics"""
        total_events = len(self.security_events)
        recent_events = self.get_security_events(24)
        
        threat_counts = {}
        for event in recent_events:
            threat_level = event["threat_level"]
            threat_counts[threat_level] = threat_counts.get(threat_level, 0) + 1
        
        return {
            "total_events": total_events,
            "events_last_24h": len(recent_events),
            "threat_distribution": threat_counts,
            "active_policies": len([p for p in self.security_policies.values() if p.enabled])
        }
    
    def _security_monitor_loop(self):
        """Security monitoring background loop"""
        while self.security_active:
            try:
                # Monitor for security threats
                self.security_monitor.monitor()
                
                # Clean up old events
                self._cleanup_old_events()
                
                time.sleep(60)  # Check every minute
                
            except Exception as e:
                self.logger.error(f"Security monitor error: {e}")
                time.sleep(120)
    
    def _cleanup_old_events(self):
        """Clean up old security events"""
        cutoff_time = datetime.now() - timedelta(days=30)
        self.security_events = [
            event for event in self.security_events
            if event.timestamp >= cutoff_time
        ]


class EncryptionManager:
    """Encryption management system"""
    
    def __init__(self):
        self.logger = logging.getLogger('tusklang.security.encryption')
    
    def encrypt(self, data: str, key: Optional[str] = None, 
                algorithm: str = "AES-256") -> Dict[str, Any]:
        """Encrypt data using specified algorithm"""
        try:
            if not key:
                key = self.generate_key(32)
            
            if algorithm == "AES-256":
                return self._encrypt_aes256(data, key)
            elif algorithm == "RSA":
                return self._encrypt_rsa(data, key)
            elif algorithm == "quantum_safe":
                return self._encrypt_quantum_safe(data, key)
            else:
                raise ValueError(f"Unsupported encryption algorithm: {algorithm}")
                
        except Exception as e:
            self.logger.error(f"Encryption error: {e}")
            raise
    
    def decrypt(self, encrypted_data: Dict[str, Any], key: Optional[str] = None) -> str:
        """Decrypt data using specified algorithm"""
        try:
            algorithm = encrypted_data.get("algorithm", "AES-256")
            
            if algorithm == "AES-256":
                return self._decrypt_aes256(encrypted_data, key)
            elif algorithm == "RSA":
                return self._decrypt_rsa(encrypted_data, key)
            elif algorithm == "quantum_safe":
                return self._decrypt_quantum_safe(encrypted_data, key)
            else:
                raise ValueError(f"Unsupported decryption algorithm: {algorithm}")
                
        except Exception as e:
            self.logger.error(f"Decryption error: {e}")
            raise
    
    def generate_key(self, length: int = 32) -> str:
        """Generate random encryption key"""
        return base64.b64encode(secrets.token_bytes(length)).decode()
    
    def generate_quantum_safe_key(self, key_size: int = 256) -> str:
        """Generate quantum-safe encryption key"""
        # This is a simplified implementation
        # In production, use proper quantum-safe algorithms like LWE or NTRU
        return base64.b64encode(secrets.token_bytes(key_size // 8)).decode()
    
    def _encrypt_aes256(self, data: str, key: str) -> Dict[str, Any]:
        """Encrypt data using AES-256"""
        key_bytes = base64.b64decode(key)
        iv = secrets.token_bytes(16)
        
        cipher = Cipher(algorithms.AES(key_bytes), modes.GCM(iv))
        encryptor = cipher.encryptor()
        
        ciphertext = encryptor.update(data.encode()) + encryptor.finalize()
        
        return {
            "algorithm": "AES-256",
            "ciphertext": base64.b64encode(ciphertext).decode(),
            "iv": base64.b64encode(iv).decode(),
            "tag": base64.b64encode(encryptor.tag).decode()
        }
    
    def _decrypt_aes256(self, encrypted_data: Dict[str, Any], key: str) -> str:
        """Decrypt data using AES-256"""
        key_bytes = base64.b64decode(key)
        iv = base64.b64decode(encrypted_data["iv"])
        ciphertext = base64.b64decode(encrypted_data["ciphertext"])
        tag = base64.b64decode(encrypted_data["tag"])
        
        cipher = Cipher(algorithms.AES(key_bytes), modes.GCM(iv, tag))
        decryptor = cipher.decryptor()
        
        plaintext = decryptor.update(ciphertext) + decryptor.finalize()
        return plaintext.decode()
    
    def _encrypt_rsa(self, data: str, public_key: str) -> Dict[str, Any]:
        """Encrypt data using RSA"""
        # Simplified RSA encryption
        # In production, use proper RSA implementation
        return {
            "algorithm": "RSA",
            "ciphertext": base64.b64encode(data.encode()).decode(),
            "public_key": public_key
        }
    
    def _decrypt_rsa(self, encrypted_data: Dict[str, Any], private_key: str) -> str:
        """Decrypt data using RSA"""
        # Simplified RSA decryption
        ciphertext = base64.b64decode(encrypted_data["ciphertext"])
        return ciphertext.decode()
    
    def _encrypt_quantum_safe(self, data: str, key: str) -> Dict[str, Any]:
        """Encrypt data using quantum-safe algorithm"""
        # Simplified quantum-safe encryption
        # In production, use proper quantum-safe algorithms
        return {
            "algorithm": "quantum_safe",
            "ciphertext": base64.b64encode(data.encode()).decode(),
            "key": key
        }
    
    def _decrypt_quantum_safe(self, encrypted_data: Dict[str, Any], key: str) -> str:
        """Decrypt data using quantum-safe algorithm"""
        ciphertext = base64.b64decode(encrypted_data["ciphertext"])
        return ciphertext.decode()


class AuthenticationManager:
    """Authentication management system"""
    
    def __init__(self):
        self.logger = logging.getLogger('tusklang.security.authentication')
    
    def hash_password(self, password: str, algorithm: str = "argon2") -> str:
        """Hash password using specified algorithm"""
        try:
            if algorithm == "bcrypt":
                salt = bcrypt.gensalt()
                return bcrypt.hashpw(password.encode(), salt).decode()
            elif algorithm == "argon2":
                ph = argon2.PasswordHasher()
                return ph.hash(password)
            elif algorithm == "pbkdf2":
                salt = secrets.token_bytes(16)
                kdf = PBKDF2HMAC(
                    algorithm=hashes.SHA256(),
                    length=32,
                    salt=salt,
                    iterations=100000,
                )
                key = kdf.derive(password.encode())
                return base64.b64encode(salt + key).decode()
            else:
                raise ValueError(f"Unsupported hashing algorithm: {algorithm}")
                
        except Exception as e:
            self.logger.error(f"Password hashing error: {e}")
            raise
    
    def verify_password(self, password: str, hashed_password: str) -> bool:
        """Verify password against hash"""
        try:
            # Try to determine algorithm from hash format
            if hashed_password.startswith("$2b$"):
                return bcrypt.checkpw(password.encode(), hashed_password.encode())
            elif hashed_password.startswith("$argon2"):
                ph = argon2.PasswordHasher()
                ph.verify(hashed_password, password)
                return True
            else:
                # Assume PBKDF2
                hash_bytes = base64.b64decode(hashed_password)
                salt = hash_bytes[:16]
                stored_key = hash_bytes[16:]
                
                kdf = PBKDF2HMAC(
                    algorithm=hashes.SHA256(),
                    length=32,
                    salt=salt,
                    iterations=100000,
                )
                key = kdf.derive(password.encode())
                return hmac.compare_digest(stored_key, key)
                
        except Exception as e:
            self.logger.error(f"Password verification error: {e}")
            return False
    
    def generate_jwt_token(self, payload: Dict[str, Any], secret: str, 
                          algorithm: str = "HS256") -> str:
        """Generate JWT token"""
        try:
            # Add standard claims
            payload.update({
                "iat": int(time.time()),
                "exp": int(time.time()) + 3600,  # 1 hour expiration
                "iss": "tusklang-sdk"
            })
            
            return jwt.encode(payload, secret, algorithm=algorithm)
            
        except Exception as e:
            self.logger.error(f"JWT generation error: {e}")
            raise
    
    def verify_jwt_token(self, token: str, secret: str) -> Dict[str, Any]:
        """Verify JWT token"""
        try:
            payload = jwt.decode(token, secret, algorithms=["HS256"])
            return payload
            
        except jwt.ExpiredSignatureError:
            raise ValueError("Token has expired")
        except jwt.InvalidTokenError as e:
            raise ValueError(f"Invalid token: {e}")


class ThreatDetection:
    """Threat detection system"""
    
    def __init__(self):
        self.logger = logging.getLogger('tusklang.security.threat_detection')
        self.threat_patterns = self._init_threat_patterns()
    
    def _init_threat_patterns(self) -> Dict[str, List[str]]:
        """Initialize threat detection patterns"""
        return {
            "sql_injection": [
                "'; DROP TABLE",
                "UNION SELECT",
                "OR 1=1",
                "'; INSERT INTO",
                "EXEC xp_"
            ],
            "xss": [
                "<script>",
                "javascript:",
                "onload=",
                "onerror=",
                "eval("
            ],
            "path_traversal": [
                "../",
                "..\\",
                "/etc/passwd",
                "C:\\Windows\\System32"
            ],
            "command_injection": [
                "; ls",
                "| cat",
                "&& rm",
                "`whoami`"
            ]
        }
    
    def detect_threat(self, event_data: Dict[str, Any]) -> Optional[ThreatLevel]:
        """Detect security threats in event data"""
        try:
            # Check for various threat types
            threat_level = ThreatLevel.LOW
            
            # Check for SQL injection
            if self._check_sql_injection(event_data):
                threat_level = ThreatLevel.HIGH
            
            # Check for XSS
            if self._check_xss(event_data):
                threat_level = ThreatLevel.MEDIUM
            
            # Check for path traversal
            if self._check_path_traversal(event_data):
                threat_level = ThreatLevel.HIGH
            
            # Check for command injection
            if self._check_command_injection(event_data):
                threat_level = ThreatLevel.CRITICAL
            
            # Check for brute force attempts
            if self._check_brute_force(event_data):
                threat_level = ThreatLevel.MEDIUM
            
            return threat_level if threat_level != ThreatLevel.LOW else None
            
        except Exception as e:
            self.logger.error(f"Threat detection error: {e}")
            return None
    
    def _check_sql_injection(self, event_data: Dict[str, Any]) -> bool:
        """Check for SQL injection patterns"""
        data_str = str(event_data).lower()
        patterns = self.threat_patterns["sql_injection"]
        
        return any(pattern.lower() in data_str for pattern in patterns)
    
    def _check_xss(self, event_data: Dict[str, Any]) -> bool:
        """Check for XSS patterns"""
        data_str = str(event_data).lower()
        patterns = self.threat_patterns["xss"]
        
        return any(pattern.lower() in data_str for pattern in patterns)
    
    def _check_path_traversal(self, event_data: Dict[str, Any]) -> bool:
        """Check for path traversal patterns"""
        data_str = str(event_data).lower()
        patterns = self.threat_patterns["path_traversal"]
        
        return any(pattern.lower() in data_str for pattern in patterns)
    
    def _check_command_injection(self, event_data: Dict[str, Any]) -> bool:
        """Check for command injection patterns"""
        data_str = str(event_data).lower()
        patterns = self.threat_patterns["command_injection"]
        
        return any(pattern.lower() in data_str for pattern in patterns)
    
    def _check_brute_force(self, event_data: Dict[str, Any]) -> bool:
        """Check for brute force attempts"""
        # Check for multiple failed login attempts
        if "failed_attempts" in event_data:
            return event_data["failed_attempts"] > 5
        
        return False


class SecurityMonitor:
    """Security monitoring system"""
    
    def __init__(self):
        self.logger = logging.getLogger('tusklang.security.monitor')
        self.monitoring_active = True
    
    def monitor(self):
        """Perform security monitoring"""
        try:
            # Monitor system resources
            self._monitor_system_resources()
            
            # Monitor network activity
            self._monitor_network_activity()
            
            # Monitor file system changes
            self._monitor_file_system()
            
        except Exception as e:
            self.logger.error(f"Security monitoring error: {e}")
    
    def _monitor_system_resources(self):
        """Monitor system resource usage"""
        # This would typically check CPU, memory, disk usage
        # and alert if thresholds are exceeded
        pass
    
    def _monitor_network_activity(self):
        """Monitor network activity"""
        # This would typically check for unusual network connections
        # and traffic patterns
        pass
    
    def _monitor_file_system(self):
        """Monitor file system changes"""
        # This would typically check for unauthorized file changes
        # and new files in sensitive directories
        pass


# Global security instance
advanced_security = AdvancedSecurity()


def encrypt_data(data: str, key: Optional[str] = None, algorithm: str = "AES-256") -> Dict[str, Any]:
    """Encrypt data using specified algorithm"""
    return advanced_security.encrypt_data(data, key, algorithm)


def decrypt_data(encrypted_data: Dict[str, Any], key: Optional[str] = None) -> str:
    """Decrypt data using specified algorithm"""
    return advanced_security.decrypt_data(encrypted_data, key)


def generate_quantum_safe_key(key_size: int = 256) -> str:
    """Generate quantum-safe encryption key"""
    return advanced_security.generate_quantum_safe_key(key_size)


def hash_password(password: str, algorithm: str = "argon2") -> str:
    """Hash password using specified algorithm"""
    return advanced_security.hash_password(password, algorithm)


def verify_password(password: str, hashed_password: str) -> bool:
    """Verify password against hash"""
    return advanced_security.verify_password(password, hashed_password)


def generate_jwt_token(payload: Dict[str, Any], secret: str, algorithm: str = "HS256") -> str:
    """Generate JWT token"""
    return advanced_security.generate_jwt_token(payload, secret, algorithm)


def verify_jwt_token(token: str, secret: str) -> Dict[str, Any]:
    """Verify JWT token"""
    return advanced_security.verify_jwt_token(token, secret)


def detect_threat(event_data: Dict[str, Any]) -> Optional[str]:
    """Detect security threats"""
    threat_level = advanced_security.detect_threat(event_data)
    return threat_level.value if threat_level else None


def log_security_event(event_type: str, threat_level: str, source_ip: str, 
                      details: Dict[str, Any], user_id: Optional[str] = None):
    """Log security event"""
    threat_enum = ThreatLevel(threat_level.lower())
    advanced_security.log_security_event(event_type, threat_enum, source_ip, details, user_id)


def get_security_events(hours: int = 24) -> List[Dict[str, Any]]:
    """Get security events from specified time period"""
    return advanced_security.get_security_events(hours)


def get_security_stats() -> Dict[str, Any]:
    """Get security statistics"""
    return advanced_security.get_security_stats()


if __name__ == "__main__":
    print("Advanced Security for TuskLang Python SDK")
    print("=" * 50)
    
    # Test encryption
    print("\n1. Testing Encryption:")
    test_data = "Hello, World! This is a test message."
    encrypted = encrypt_data(test_data)
    decrypted = decrypt_data(encrypted)
    print(f"  Original: {test_data}")
    print(f"  Encrypted: {encrypted['ciphertext'][:50]}...")
    print(f"  Decrypted: {decrypted}")
    
    # Test password hashing
    print("\n2. Testing Password Hashing:")
    password = "MySecurePassword123!"
    hashed = hash_password(password)
    verified = verify_password(password, hashed)
    print(f"  Password: {password}")
    print(f"  Hashed: {hashed[:50]}...")
    print(f"  Verified: {verified}")
    
    # Test JWT tokens
    print("\n3. Testing JWT Tokens:")
    payload = {"user_id": "12345", "role": "admin"}
    secret = "my-secret-key"
    token = generate_jwt_token(payload, secret)
    decoded = verify_jwt_token(token, secret)
    print(f"  Payload: {payload}")
    print(f"  Token: {token[:50]}...")
    print(f"  Decoded: {decoded}")
    
    # Test threat detection
    print("\n4. Testing Threat Detection:")
    safe_data = {"query": "SELECT * FROM users WHERE id = 1"}
    malicious_data = {"query": "SELECT * FROM users; DROP TABLE users;"}
    
    safe_threat = detect_threat(safe_data)
    malicious_threat = detect_threat(malicious_data)
    
    print(f"  Safe data threat: {safe_threat}")
    print(f"  Malicious data threat: {malicious_threat}")
    
    # Test security logging
    print("\n5. Testing Security Logging:")
    log_security_event("login_attempt", "medium", "192.168.1.100", 
                      {"username": "testuser", "success": False})
    
    events = get_security_events(1)
    print(f"  Security events (last hour): {len(events)}")
    
    # Get security stats
    print("\n6. Security Statistics:")
    stats = get_security_stats()
    print(f"  Total events: {stats['total_events']}")
    print(f"  Events last 24h: {stats['events_last_24h']}")
    print(f"  Active policies: {stats['active_policies']}")
    
    print("\nAdvanced security testing completed!") 