#!/usr/bin/env python3
"""
Security Commands for TuskLang Python CLI
=========================================
Implements comprehensive security operations including authentication,
encryption, scanning, and audit functionality
"""

import os
import sys
import time
import json
import hashlib
import secrets
import base64
import socket
import subprocess
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from datetime import datetime, timedelta
import logging

# Import security components
try:
    from ...advanced_features.advanced_security import (
        AdvancedSecurity, encrypt_data, decrypt_data, hash_password, 
        verify_password, generate_jwt_token, verify_jwt_token,
        detect_threat, log_security_event, get_security_events, get_security_stats
    )
    from ...core.security_manager import SecurityManager
    from ...core.enterprise_security_systems import EnterpriseSecuritySystems
    SECURITY_AVAILABLE = True
except ImportError:
    SECURITY_AVAILABLE = False
    # Create dummy classes for when security components are not available
    class AdvancedSecurity:
        def __init__(self): pass
    class SecurityManager:
        def __init__(self): pass
    class EnterpriseSecuritySystems:
        def __init__(self): pass

from ..utils.output_formatter import OutputFormatter
from ..utils.error_handler import ErrorHandler
from ..utils.config_loader import ConfigLoader


class SecuritySessionManager:
    """Manages security sessions and authentication state"""
    
    def __init__(self):
        self.session_file = Path.home() / '.tsk' / 'security_session.json'
        self.session_file.parent.mkdir(exist_ok=True)
        self.session_data = self._load_session()
    
    def _load_session(self) -> Dict[str, Any]:
        """Load session data from file"""
        if self.session_file.exists():
            try:
                with open(self.session_file, 'r') as f:
                    return json.load(f)
            except Exception:
                return {}
        return {}
    
    def _save_session(self):
        """Save session data to file"""
        with open(self.session_file, 'w') as f:
            json.dump(self.session_data, f, indent=2)
    
    def create_session(self, user_id: str, token: str, expires_at: datetime):
        """Create new security session"""
        self.session_data = {
            'user_id': user_id,
            'token': token,
            'expires_at': expires_at.isoformat(),
            'created_at': datetime.now().isoformat(),
            'active': True
        }
        self._save_session()
    
    def get_session(self) -> Optional[Dict[str, Any]]:
        """Get current session data"""
        if not self.session_data.get('active', False):
            return None
        
        # Check if session is expired
        expires_at = datetime.fromisoformat(self.session_data['expires_at'])
        if datetime.now() > expires_at:
            self.clear_session()
            return None
        
        return self.session_data
    
    def refresh_session(self, new_token: str, expires_at: datetime):
        """Refresh current session"""
        if self.session_data.get('active', False):
            self.session_data['token'] = new_token
            self.session_data['expires_at'] = expires_at.isoformat()
            self._save_session()
    
    def clear_session(self):
        """Clear current session"""
        self.session_data = {}
        self._save_session()


def handle_security_command(args: Any, cli: Any) -> int:
    """Handle security commands"""
    formatter = OutputFormatter(cli.json_output, cli.quiet, cli.verbose)
    error_handler = ErrorHandler(cli.json_output, cli.verbose)
    
    try:
        if args.security_command == 'auth':
            return _handle_auth_command(args, formatter, error_handler, cli)
        elif args.security_command == 'scan':
            return _handle_scan_command(args, formatter, error_handler)
        elif args.security_command == 'encrypt':
            return _handle_encrypt_command(args, formatter, error_handler)
        elif args.security_command == 'decrypt':
            return _handle_decrypt_command(args, formatter, error_handler)
        elif args.security_command == 'audit':
            return _handle_audit_command(args, formatter, error_handler)
        elif args.security_command == 'hash':
            return _handle_hash_command(args, formatter, error_handler, cli)
        else:
            formatter.error("Unknown security command")
            return ErrorHandler.INVALID_ARGS
            
    except Exception as e:
        return error_handler.handle_error(e)


def _handle_auth_command(args: Any, formatter: OutputFormatter, error_handler: ErrorHandler, cli: Any = None) -> int:
    """Handle authentication subcommands"""
    try:
        if args.auth_command == 'login':
            return _handle_auth_login(args, formatter, error_handler)
        elif args.auth_command == 'logout':
            return _handle_auth_logout(args, formatter, error_handler)
        elif args.auth_command == 'status':
            return _handle_auth_status(args, formatter, error_handler, cli)
        elif args.auth_command == 'refresh':
            return _handle_auth_refresh(args, formatter, error_handler)
        else:
            formatter.error("Unknown auth command")
            return ErrorHandler.INVALID_ARGS
            
    except Exception as e:
        return error_handler.handle_error(e)


def _handle_auth_login(args: Any, formatter: OutputFormatter, error_handler: ErrorHandler) -> int:
    """Handle authentication login"""
    formatter.loading("Authenticating user...")
    
    try:
        if not SECURITY_AVAILABLE:
            formatter.error("Security components not available")
            return ErrorHandler.GENERAL_ERROR
        
        # Get credentials
        username = args.username
        password = args.password
        
        if not username or not password:
            formatter.error("Username and password are required")
            return ErrorHandler.INVALID_ARGS
        
        # Initialize security systems
        security = AdvancedSecurity()
        enterprise_security = EnterpriseSecuritySystems()
        
        # Authenticate user
        auth_result = enterprise_security.authenticate_user(username, password)
        
        if auth_result.get('success', False):
            # Create session
            session_mgr = SecuritySessionManager()
            user_id = auth_result.get('user_id', username)
            token = auth_result.get('token', secrets.token_hex(32))
            expires_at = datetime.now() + timedelta(hours=24)
            
            session_mgr.create_session(user_id, token, expires_at)
            
            formatter.success("Authentication successful")
            formatter.info(f"User: {username}")
            formatter.info(f"Session expires: {expires_at.strftime('%Y-%m-%d %H:%M:%S')}")
            
            return ErrorHandler.SUCCESS
        else:
            formatter.error("Authentication failed")
            formatter.error(auth_result.get('message', 'Invalid credentials'))
            return ErrorHandler.AUTHENTICATION_ERROR
            
    except Exception as e:
        return error_handler.handle_error(e)


def _handle_auth_logout(args: Any, formatter: OutputFormatter, error_handler: ErrorHandler) -> int:
    """Handle authentication logout"""
    formatter.loading("Logging out...")
    
    try:
        session_mgr = SecuritySessionManager()
        session = session_mgr.get_session()
        
        if session:
            # Log security event
            if SECURITY_AVAILABLE:
                log_security_event(
                    "logout", "low", "127.0.0.1",
                    {"user_id": session.get('user_id'), "method": "cli"},
                    session.get('user_id')
                )
            
            session_mgr.clear_session()
            formatter.success("Logged out successfully")
        else:
            formatter.warning("No active session found")
        
        return ErrorHandler.SUCCESS
        
    except Exception as e:
        return error_handler.handle_error(e)


def _handle_auth_status(args: Any, formatter: OutputFormatter, error_handler: ErrorHandler, cli: Any = None) -> int:
    """Handle authentication status check"""
    formatter.loading("Checking authentication status...")
    
    try:
        session_mgr = SecuritySessionManager()
        session = session_mgr.get_session()
        
        if session:
            expires_at = datetime.fromisoformat(session['expires_at'])
            time_remaining = expires_at - datetime.now()
            
            formatter.success("Authentication Status: ACTIVE")
            formatter.info(f"User: {session.get('user_id', 'unknown')}")
            formatter.info(f"Session expires: {expires_at.strftime('%Y-%m-%d %H:%M:%S')}")
            formatter.info(f"Time remaining: {str(time_remaining).split('.')[0]}")
            
            if cli and cli.json_output:
                formatter.print_json({
                    'status': 'active',
                    'user_id': session.get('user_id'),
                    'expires_at': session['expires_at'],
                    'time_remaining_seconds': time_remaining.total_seconds()
                })
        else:
            formatter.warning("Authentication Status: INACTIVE")
            formatter.info("No active session found")
            
            if cli and cli.json_output:
                formatter.print_json({
                    'status': 'inactive',
                    'message': 'No active session found'
                })
        
        return ErrorHandler.SUCCESS
        
    except Exception as e:
        return error_handler.handle_error(e)


def _handle_auth_refresh(args: Any, formatter: OutputFormatter, error_handler: ErrorHandler) -> int:
    """Handle authentication token refresh"""
    formatter.loading("Refreshing authentication token...")
    
    try:
        session_mgr = SecuritySessionManager()
        session = session_mgr.get_session()
        
        if not session:
            formatter.error("No active session to refresh")
            return ErrorHandler.AUTHENTICATION_ERROR
        
        if not SECURITY_AVAILABLE:
            formatter.error("Security components not available")
            return ErrorHandler.GENERAL_ERROR
        
        # Generate new token
        new_token = secrets.token_hex(32)
        expires_at = datetime.now() + timedelta(hours=24)
        
        session_mgr.refresh_session(new_token, expires_at)
        
        formatter.success("Token refreshed successfully")
        formatter.info(f"New session expires: {expires_at.strftime('%Y-%m-%d %H:%M:%S')}")
        
        return ErrorHandler.SUCCESS
        
    except Exception as e:
        return error_handler.handle_error(e)


def _handle_scan_command(args: Any, formatter: OutputFormatter, error_handler: ErrorHandler) -> int:
    """Handle security scanning"""
    formatter.loading("Performing security scan...")
    
    try:
        scan_path = args.path or '.'
        scan_results = []
        
        # Scan for common security issues
        scan_results.extend(_scan_file_permissions(scan_path))
        scan_results.extend(_scan_sensitive_files(scan_path))
        scan_results.extend(_scan_network_ports())
        scan_results.extend(_scan_dependencies())
        
        # Display results
        if scan_results:
            formatter.warning(f"Found {len(scan_results)} security issues")
            formatter.table(
                ['Severity', 'Issue', 'Location', 'Recommendation'],
                scan_results,
                'Security Scan Results'
            )
        else:
            formatter.success("No security issues found")
        
        return ErrorHandler.SUCCESS
        
    except Exception as e:
        return error_handler.handle_error(e)


def _scan_file_permissions(path: str) -> List[List[str]]:
    """Scan for file permission issues"""
    issues = []
    
    try:
        for root, dirs, files in os.walk(path):
            for file in files:
                file_path = Path(root) / file
                try:
                    stat = file_path.stat()
                    mode = stat.st_mode
                    
                    # Check for world-writable files
                    if mode & 0o002:
                        issues.append([
                            'HIGH',
                            'World-writable file',
                            str(file_path),
                            'Remove world write permissions'
                        ])
                    
                    # Check for executable files with sensitive names
                    if mode & 0o111 and any(sensitive in file.lower() for sensitive in ['key', 'secret', 'password', 'token']):
                        issues.append([
                            'MEDIUM',
                            'Executable sensitive file',
                            str(file_path),
                            'Review file permissions'
                        ])
                        
                except (OSError, PermissionError):
                    continue
    except Exception:
        pass
    
    return issues


def _scan_sensitive_files(path: str) -> List[List[str]]:
    """Scan for sensitive files"""
    issues = []
    sensitive_patterns = [
        '*.key', '*.pem', '*.p12', '*.pfx',
        '*secret*', '*password*', '*token*',
        '.env', '.secret', 'config.json'
    ]
    
    try:
        for pattern in sensitive_patterns:
            for file_path in Path(path).rglob(pattern):
                if file_path.is_file():
                    # Check if file is in version control
                    if _is_in_git(file_path):
                        issues.append([
                            'CRITICAL',
                            'Sensitive file in version control',
                            str(file_path),
                            'Remove from version control and add to .gitignore'
                        ])
                    else:
                        issues.append([
                            'MEDIUM',
                            'Sensitive file found',
                            str(file_path),
                            'Review file security'
                        ])
    except Exception:
        pass
    
    return issues


def _scan_network_ports() -> List[List[str]]:
    """Scan for open network ports"""
    issues = []
    common_ports = [22, 23, 80, 443, 3306, 5432, 6379, 8080, 9000]
    
    try:
        for port in common_ports:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.settimeout(1)
            result = sock.connect_ex(('localhost', port))
            sock.close()
            
            if result == 0:
                service_name = _get_service_name(port)
                issues.append([
                    'LOW',
                    f'Open port {port} ({service_name})',
                    f'localhost:{port}',
                    'Review if port is necessary'
                ])
    except Exception:
        pass
    
    return issues


def _scan_dependencies() -> List[List[str]]:
    """Scan for vulnerable dependencies"""
    issues = []
    
    try:
        # Check for requirements.txt
        req_file = Path('requirements.txt')
        if req_file.exists():
            # This is a simplified check - in production, use proper vulnerability scanning
            issues.append([
                'INFO',
                'Dependencies found',
                'requirements.txt',
                'Run security audit on dependencies'
            ])
    except Exception:
        pass
    
    return issues


def _is_in_git(file_path: Path) -> bool:
    """Check if file is tracked by git"""
    try:
        result = subprocess.run(
            ['git', 'ls-files', str(file_path)],
            capture_output=True, text=True
        )
        return result.returncode == 0 and result.stdout.strip()
    except Exception:
        return False


def _get_service_name(port: int) -> str:
    """Get service name for port"""
    service_names = {
        22: 'SSH', 23: 'Telnet', 80: 'HTTP', 443: 'HTTPS',
        3306: 'MySQL', 5432: 'PostgreSQL', 6379: 'Redis',
        8080: 'HTTP-Alt', 9000: 'Jenkins'
    }
    return service_names.get(port, 'Unknown')


def _handle_encrypt_command(args: Any, formatter: OutputFormatter, error_handler: ErrorHandler) -> int:
    """Handle file encryption"""
    formatter.loading("Encrypting file...")
    
    try:
        if not SECURITY_AVAILABLE:
            formatter.error("Security components not available")
            return ErrorHandler.GENERAL_ERROR
        
        file_path = Path(args.file)
        if not file_path.exists():
            formatter.error(f"File not found: {file_path}")
            return ErrorHandler.FILE_NOT_FOUND
        
        # Read file content
        with open(file_path, 'rb') as f:
            content = f.read().decode('utf-8')
        
        # Encrypt content
        encrypted_data = encrypt_data(content)
        
        # Save encrypted file
        encrypted_path = file_path.with_suffix(file_path.suffix + '.encrypted')
        with open(encrypted_path, 'w') as f:
            json.dump(encrypted_data, f, indent=2)
        
        formatter.success(f"File encrypted successfully")
        formatter.info(f"Original: {file_path}")
        formatter.info(f"Encrypted: {encrypted_path}")
        
        return ErrorHandler.SUCCESS
        
    except Exception as e:
        return error_handler.handle_error(e)


def _handle_decrypt_command(args: Any, formatter: OutputFormatter, error_handler: ErrorHandler) -> int:
    """Handle file decryption"""
    formatter.loading("Decrypting file...")
    
    try:
        if not SECURITY_AVAILABLE:
            formatter.error("Security components not available")
            return ErrorHandler.GENERAL_ERROR
        
        file_path = Path(args.file)
        if not file_path.exists():
            formatter.error(f"File not found: {file_path}")
            return ErrorHandler.FILE_NOT_FOUND
        
        # Read encrypted data
        with open(file_path, 'r') as f:
            encrypted_data = json.load(f)
        
        # Decrypt content
        decrypted_content = decrypt_data(encrypted_data)
        
        # Save decrypted file
        decrypted_path = file_path.with_suffix('').with_suffix(''.join(file_path.suffixes[:-1]))
        with open(decrypted_path, 'w') as f:
            f.write(decrypted_content)
        
        formatter.success(f"File decrypted successfully")
        formatter.info(f"Encrypted: {file_path}")
        formatter.info(f"Decrypted: {decrypted_path}")
        
        return ErrorHandler.SUCCESS
        
    except Exception as e:
        return error_handler.handle_error(e)


def _handle_audit_command(args: Any, formatter: OutputFormatter, error_handler: ErrorHandler) -> int:
    """Handle security audit"""
    formatter.loading("Performing security audit...")
    
    try:
        if not SECURITY_AVAILABLE:
            formatter.error("Security components not available")
            return ErrorHandler.GENERAL_ERROR
        
        # Get security statistics
        security_stats = get_security_stats()
        
        # Get recent security events
        hours = getattr(args, 'hours', 24)
        security_events = get_security_events(hours)
        
        # Display audit results
        formatter.success("Security Audit Complete")
        formatter.info(f"Total events: {security_stats['total_events']}")
        formatter.info(f"Events last {hours}h: {security_stats['events_last_24h']}")
        formatter.info(f"Active policies: {security_stats['active_policies']}")
        
        # Display threat distribution
        if security_stats.get('threat_distribution'):
            formatter.info("Threat Distribution:")
            for threat_level, count in security_stats['threat_distribution'].items():
                formatter.info(f"  {threat_level}: {count}")
        
        # Display recent events
        if security_events:
            formatter.info(f"Recent Security Events (last {hours}h):")
            for event in security_events[:10]:  # Show last 10 events
                timestamp = event['timestamp']
                event_type = event['event_type']
                threat_level = event['threat_level']
                formatter.info(f"  [{timestamp}] {event_type} ({threat_level})")
        
        return ErrorHandler.SUCCESS
        
    except Exception as e:
        return error_handler.handle_error(e)


def _handle_hash_command(args: Any, formatter: OutputFormatter, error_handler: ErrorHandler, cli: Any = None) -> int:
    """Handle hash generation"""
    formatter.loading("Generating hash...")
    
    try:
        if args.file:
            # Hash file
            file_path = Path(args.file)
            if not file_path.exists():
                formatter.error(f"File not found: {file_path}")
                return ErrorHandler.FILE_NOT_FOUND
            
            with open(file_path, 'rb') as f:
                content = f.read()
        else:
            # Hash string
            content = args.string.encode('utf-8')
        
        # Generate hash
        algorithm = args.algorithm.lower()
        if algorithm == 'md5':
            hash_value = hashlib.md5(content).hexdigest()
        elif algorithm == 'sha1':
            hash_value = hashlib.sha1(content).hexdigest()
        elif algorithm == 'sha256':
            hash_value = hashlib.sha256(content).hexdigest()
        elif algorithm == 'sha512':
            hash_value = hashlib.sha512(content).hexdigest()
        else:
            formatter.error(f"Unsupported algorithm: {algorithm}")
            return ErrorHandler.INVALID_ARGS
        
        formatter.success(f"{algorithm.upper()} hash generated")
        formatter.info(f"Hash: {hash_value}")
        
        if cli and cli.json_output:
            formatter.print_json({
                'algorithm': algorithm,
                'hash': hash_value,
                'input_type': 'file' if args.file else 'string'
            })
        
        return ErrorHandler.SUCCESS
        
    except Exception as e:
        return error_handler.handle_error(e)


# Global security session manager
security_session_manager = SecuritySessionManager()


def check_authentication() -> bool:
    """Check if user is authenticated"""
    session = security_session_manager.get_session()
    return session is not None


def require_authentication():
    """Decorator to require authentication for commands"""
    def decorator(func):
        def wrapper(*args, **kwargs):
            if not check_authentication():
                print("❌ Authentication required")
                print("Run 'tsk security auth login <username> <password>' to authenticate")
                return 1
            return func(*args, **kwargs)
        return wrapper
    return decorator 