"""
Advanced Computer Vision Engine
Image recognition, object detection, face recognition, and image processing.
"""

import asyncio
import base64
import io
import json
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np

# Core image processing
try:
    from PIL import Image, ImageDraw, ImageFont, ImageEnhance, ImageFilter
    PIL_AVAILABLE = True
except ImportError:
    PIL_AVAILABLE = False
    print("PIL/Pillow not available. Image processing will be limited.")

# OpenCV for computer vision
try:
    import cv2
    OPENCV_AVAILABLE = True
except ImportError:
    OPENCV_AVAILABLE = False
    print("OpenCV not available. Advanced computer vision features will be limited.")

# Deep learning frameworks
try:
    import torch
    import torchvision.transforms as transforms
    from torchvision import models
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    print("PyTorch not available. Deep learning features will be limited.")

# Face recognition
try:
    import face_recognition
    FACE_RECOGNITION_AVAILABLE = True
except ImportError:
    FACE_RECOGNITION_AVAILABLE = False
    print("face_recognition library not available. Face recognition will be limited.")

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class DetectionResult:
    """Result from object detection."""
    class_name: str
    confidence: float
    bounding_box: Tuple[int, int, int, int]  # x1, y1, x2, y2
    center: Tuple[int, int]

@dataclass
class FaceRecognitionResult:
    """Result from face recognition."""
    face_locations: List[Tuple[int, int, int, int]]
    face_encodings: List[np.ndarray]
    face_landmarks: List[Dict[str, List[Tuple[int, int]]]]
    known_faces: List[Dict[str, Any]]

@dataclass
class ImageAnalysisResult:
    """Comprehensive image analysis result."""
    image_path: str
    dimensions: Tuple[int, int]
    file_size: int
    format: str
    detected_objects: List[DetectionResult]
    faces: Optional[FaceRecognitionResult] = None
    image_quality: Optional[Dict[str, float]] = None
    dominant_colors: Optional[List[Tuple[int, int, int]]] = None
    text_content: Optional[List[str]] = None

class BasicImageProcessor:
    """Basic image processing without deep learning."""
    
    def __init__(self):
        self.supported_formats = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif'}
    
    def load_image(self, image_path: Union[str, Path]) -> Optional[Image.Image]:
        """Load image from file."""
        if not PIL_AVAILABLE:
            logger.error("PIL not available for image loading")
            return None
        
        try:
            image = Image.open(image_path)
            return image.convert('RGB')
        except Exception as e:
            logger.error(f"Error loading image {image_path}: {str(e)}")
            return None
    
    def resize_image(self, image: Image.Image, size: Tuple[int, int], 
                    maintain_aspect: bool = True) -> Image.Image:
        """Resize image with optional aspect ratio maintenance."""
        if maintain_aspect:
            image.thumbnail(size, Image.Resampling.LANCZOS)
            return image
        else:
            return image.resize(size, Image.Resampling.LANCZOS)
    
    def enhance_image(self, image: Image.Image, 
                     brightness: float = 1.0,
                     contrast: float = 1.0,
                     saturation: float = 1.0,
                     sharpness: float = 1.0) -> Image.Image:
        """Enhance image with brightness, contrast, saturation, and sharpness."""
        if brightness != 1.0:
            enhancer = ImageEnhance.Brightness(image)
            image = enhancer.enhance(brightness)
        
        if contrast != 1.0:
            enhancer = ImageEnhance.Contrast(image)
            image = enhancer.enhance(contrast)
        
        if saturation != 1.0:
            enhancer = ImageEnhance.Color(image)
            image = enhancer.enhance(saturation)
        
        if sharpness != 1.0:
            enhancer = ImageEnhance.Sharpness(image)
            image = enhancer.enhance(sharpness)
        
        return image
    
    def apply_filters(self, image: Image.Image, filter_type: str) -> Image.Image:
        """Apply various filters to image."""
        if filter_type == 'blur':
            return image.filter(ImageFilter.BLUR)
        elif filter_type == 'sharpen':
            return image.filter(ImageFilter.SHARPEN)
        elif filter_type == 'edge_enhance':
            return image.filter(ImageFilter.EDGE_ENHANCE)
        elif filter_type == 'emboss':
            return image.filter(ImageFilter.EMBOSS)
        elif filter_type == 'find_edges':
            return image.filter(ImageFilter.FIND_EDGES)
        else:
            return image
    
    def get_dominant_colors(self, image: Image.Image, num_colors: int = 5) -> List[Tuple[int, int, int]]:
        """Extract dominant colors from image."""
        # Resize image for faster processing
        small_image = image.resize((150, 150))
        
        # Get colors
        colors = small_image.getcolors(maxcolors=256*256*256)
        
        if colors is None:
            return []
        
        # Sort by frequency and get top colors
        colors.sort(key=lambda x: x[0], reverse=True)
        dominant_colors = [color[1] for color in colors[:num_colors]]
        
        return dominant_colors
    
    def calculate_image_quality(self, image: Image.Image) -> Dict[str, float]:
        """Calculate basic image quality metrics."""
        # Convert to numpy array
        img_array = np.array(image)
        
        # Calculate sharpness using Laplacian variance
        gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) if OPENCV_AVAILABLE else np.mean(img_array, axis=2)
        
        if OPENCV_AVAILABLE:
            laplacian_var = cv2.Laplacian(gray.astype(np.uint8), cv2.CV_64F).var()
        else:
            # Simple gradient-based sharpness
            dx = np.gradient(gray, axis=1)
            dy = np.gradient(gray, axis=0)
            laplacian_var = np.mean(dx**2 + dy**2)
        
        # Normalize sharpness to 0-1 range
        sharpness = min(laplacian_var / 1000, 1.0)
        
        # Calculate brightness
        brightness = np.mean(img_array) / 255.0
        
        # Calculate contrast
        contrast = np.std(img_array) / 255.0
        
        return {
            'sharpness': float(sharpness),
            'brightness': float(brightness),
            'contrast': float(contrast)
        }

class OpenCVProcessor:
    """Advanced image processing using OpenCV."""
    
    def __init__(self):
        self.face_cascade = None
        self.eye_cascade = None
        self._load_cascades()
    
    def _load_cascades(self):
        """Load OpenCV cascade classifiers."""
        if not OPENCV_AVAILABLE:
            return
        
        try:
            # Try to load pre-trained cascades
            cascade_dir = cv2.data.haarcascades if hasattr(cv2.data, 'haarcascades') else ''
            
            face_cascade_path = os.path.join(cascade_dir, 'haarcascade_frontalface_default.xml')
            eye_cascade_path = os.path.join(cascade_dir, 'haarcascade_eye.xml')
            
            if os.path.exists(face_cascade_path):
                self.face_cascade = cv2.CascadeClassifier(face_cascade_path)
            
            if os.path.exists(eye_cascade_path):
                self.eye_cascade = cv2.CascadeClassifier(eye_cascade_path)
                
        except Exception as e:
            logger.warning(f"Could not load OpenCV cascades: {str(e)}")
    
    def detect_faces_opencv(self, image: np.ndarray) -> List[Tuple[int, int, int, int]]:
        """Detect faces using OpenCV Haar cascades."""
        if self.face_cascade is None:
            return []
        
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        faces = self.face_cascade.detectMultiScale(gray, 1.1, 4)
        
        return [(x, y, x+w, y+h) for (x, y, w, h) in faces]
    
    def detect_edges(self, image: np.ndarray, low_threshold: int = 50, 
                    high_threshold: int = 150) -> np.ndarray:
        """Detect edges using Canny edge detection."""
        if not OPENCV_AVAILABLE:
            return image
        
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        edges = cv2.Canny(gray, low_threshold, high_threshold)
        
        # Convert back to RGB
        return cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
    
    def detect_contours(self, image: np.ndarray) -> List[np.ndarray]:
        """Detect contours in image."""
        if not OPENCV_AVAILABLE:
            return []
        
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        _, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        return contours
    
    def apply_morphological_operations(self, image: np.ndarray, 
                                     operation: str = 'opening',
                                     kernel_size: int = 5) -> np.ndarray:
        """Apply morphological operations."""
        if not OPENCV_AVAILABLE:
            return image
        
        kernel = np.ones((kernel_size, kernel_size), np.uint8)
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        if operation == 'erosion':
            result = cv2.erode(gray, kernel, iterations=1)
        elif operation == 'dilation':
            result = cv2.dilate(gray, kernel, iterations=1)
        elif operation == 'opening':
            result = cv2.morphologyEx(gray, cv2.MORPH_OPEN, kernel)
        elif operation == 'closing':
            result = cv2.morphologyEx(gray, cv2.MORPH_CLOSE, kernel)
        else:
            result = gray
        
        return cv2.cvtColor(result, cv2.COLOR_GRAY2RGB)

class DeepLearningProcessor:
    """Deep learning-based image processing."""
    
    def __init__(self):
        self.models = {}
        self._load_models()
    
    def _load_models(self):
        """Load pre-trained deep learning models."""
        if not TORCH_AVAILABLE:
            return
        
        try:
            # Load ImageNet pre-trained ResNet
            self.models['resnet'] = models.resnet50(pretrained=True)
            self.models['resnet'].eval()
            
            # Image preprocessing
            self.transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])
            ])
            
            # Load ImageNet class names
            self.imagenet_classes = self._load_imagenet_classes()
            
        except Exception as e:
            logger.error(f"Error loading deep learning models: {str(e)}")
    
    def _load_imagenet_classes(self) -> List[str]:
        """Load ImageNet class names."""
        # Basic subset of ImageNet classes
        return [
            'airplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
            'cat', 'chair', 'cow', 'dining table', 'dog', 'horse', 'motorcycle',
            'person', 'potted plant', 'sheep', 'sofa', 'train', 'tv'
        ]
    
    def classify_image(self, image: Image.Image, top_k: int = 5) -> List[Dict[str, Any]]:
        """Classify image using deep learning model."""
        if 'resnet' not in self.models:
            return []
        
        try:
            # Preprocess image
            input_tensor = self.transform(image).unsqueeze(0)
            
            # Inference
            with torch.no_grad():
                outputs = self.models['resnet'](input_tensor)
                probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
            
            # Get top predictions
            top_predictions = torch.topk(probabilities, top_k)
            
            results = []
            for i in range(top_k):
                class_idx = top_predictions.indices[i].item()
                confidence = top_predictions.values[i].item()
                
                if class_idx < len(self.imagenet_classes):
                    class_name = self.imagenet_classes[class_idx]
                else:
                    class_name = f"class_{class_idx}"
                
                results.append({
                    'class_name': class_name,
                    'confidence': confidence,
                    'class_index': class_idx
                })
            
            return results
            
        except Exception as e:
            logger.error(f"Error in image classification: {str(e)}")
            return []

class FaceRecognitionProcessor:
    """Face recognition and analysis."""
    
    def __init__(self):
        self.known_faces = {}  # Dictionary to store known face encodings
    
    def detect_faces(self, image: np.ndarray) -> FaceRecognitionResult:
        """Detect and analyze faces in image."""
        if not FACE_RECOGNITION_AVAILABLE:
            return FaceRecognitionResult([], [], [], [])
        
        try:
            # Convert PIL Image to RGB array if needed
            if isinstance(image, Image.Image):
                image = np.array(image)
            
            # Detect face locations
            face_locations = face_recognition.face_locations(image)
            
            # Get face encodings
            face_encodings = face_recognition.face_encodings(image, face_locations)
            
            # Get face landmarks
            face_landmarks = face_recognition.face_landmarks(image, face_locations)
            
            # Compare with known faces
            known_faces = []
            for encoding in face_encodings:
                matches = self._compare_with_known_faces(encoding)
                known_faces.append(matches)
            
            return FaceRecognitionResult(
                face_locations=face_locations,
                face_encodings=face_encodings,
                face_landmarks=face_landmarks,
                known_faces=known_faces
            )
            
        except Exception as e:
            logger.error(f"Error in face recognition: {str(e)}")
            return FaceRecognitionResult([], [], [], [])
    
    def _compare_with_known_faces(self, face_encoding: np.ndarray,
                                tolerance: float = 0.6) -> Dict[str, Any]:
        """Compare face encoding with known faces."""
        if not self.known_faces:
            return {'name': 'Unknown', 'confidence': 0.0}
        
        # Compare with all known faces
        known_encodings = list(self.known_faces.values())
        known_names = list(self.known_faces.keys())
        
        matches = face_recognition.compare_faces(known_encodings, face_encoding, tolerance)
        face_distances = face_recognition.face_distance(known_encodings, face_encoding)
        
        if True in matches:
            best_match_index = np.argmin(face_distances)
            name = known_names[best_match_index]
            confidence = 1.0 - face_distances[best_match_index]
            
            return {'name': name, 'confidence': confidence}
        
        return {'name': 'Unknown', 'confidence': 0.0}
    
    def add_known_face(self, image: np.ndarray, name: str) -> bool:
        """Add a known face to the database."""
        try:
            face_encodings = face_recognition.face_encodings(image)
            
            if face_encodings:
                self.known_faces[name] = face_encodings[0]
                logger.info(f"Added known face: {name}")
                return True
            else:
                logger.warning(f"No face found in image for {name}")
                return False
                
        except Exception as e:
            logger.error(f"Error adding known face {name}: {str(e)}")
            return False

class ComputerVisionEngine:
    """
    Comprehensive Computer Vision Engine.
    Supports image recognition, object detection, face recognition, and image processing.
    """
    
    def __init__(self, use_deep_learning: bool = True):
        self.use_deep_learning = use_deep_learning
        self.basic_processor = BasicImageProcessor()
        self.opencv_processor = OpenCVProcessor() if OPENCV_AVAILABLE else None
        self.dl_processor = DeepLearningProcessor() if use_deep_learning else None
        self.face_processor = FaceRecognitionProcessor()
        
        self.processing_stats = {
            'images_processed': 0,
            'total_processing_time': 0.0,
            'average_processing_time': 0.0
        }
    
    async def analyze_image(self, image_path: Union[str, Path],
                          detect_faces: bool = True,
                          detect_objects: bool = True,
                          analyze_quality: bool = True) -> ImageAnalysisResult:
        """Comprehensive image analysis."""
        start_time = datetime.now()
        image_path = Path(image_path)
        
        logger.info(f"Analyzing image: {image_path}")
        
        # Load image
        image = self.basic_processor.load_image(image_path)
        if image is None:
            raise ValueError(f"Could not load image: {image_path}")
        
        # Get basic image info
        dimensions = image.size
        file_size = image_path.stat().st_size
        image_format = image.format or image_path.suffix.lower()
        
        # Convert to numpy array for OpenCV operations
        img_array = np.array(image)
        
        # Object detection/classification
        detected_objects = []
        if detect_objects and self.dl_processor:
            classifications = self.dl_processor.classify_image(image)
            for classification in classifications:
                detected_objects.append(DetectionResult(
                    class_name=classification['class_name'],
                    confidence=classification['confidence'],
                    bounding_box=(0, 0, dimensions[0], dimensions[1]),  # Full image
                    center=(dimensions[0]//2, dimensions[1]//2)
                ))
        
        # Face recognition
        faces = None
        if detect_faces:
            faces = self.face_processor.detect_faces(img_array)
        
        # Image quality analysis
        image_quality = None
        if analyze_quality:
            image_quality = self.basic_processor.calculate_image_quality(image)
        
        # Dominant colors
        dominant_colors = self.basic_processor.get_dominant_colors(image)
        
        # Create result
        result = ImageAnalysisResult(
            image_path=str(image_path),
            dimensions=dimensions,
            file_size=file_size,
            format=image_format,
            detected_objects=detected_objects,
            faces=faces,
            image_quality=image_quality,
            dominant_colors=dominant_colors
        )
        
        # Update stats
        processing_time = (datetime.now() - start_time).total_seconds()
        self.processing_stats['images_processed'] += 1
        self.processing_stats['total_processing_time'] += processing_time
        self.processing_stats['average_processing_time'] = (
            self.processing_stats['total_processing_time'] / 
            self.processing_stats['images_processed']
        )
        
        logger.info(f"Image analysis complete in {processing_time:.2f}s")
        
        return result
    
    async def batch_analyze_images(self, image_paths: List[Union[str, Path]], 
                                 max_workers: int = 4) -> List[ImageAnalysisResult]:
        """Analyze multiple images in parallel."""
        logger.info(f"Starting batch analysis of {len(image_paths)} images")
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            tasks = [
                executor.submit(asyncio.run, self.analyze_image(path))
                for path in image_paths
            ]
            
            results = []
            for task in tasks:
                try:
                    result = task.result()
                    results.append(result)
                except Exception as e:
                    logger.error(f"Error in batch image analysis: {str(e)}")
        
        logger.info(f"Batch analysis complete. Processed {len(results)} images")
        return results
    
    def process_image(self, image: Union[str, Path, Image.Image],
                     operations: List[Dict[str, Any]]) -> Image.Image:
        """Process image with a list of operations."""
        if isinstance(image, (str, Path)):
            image = self.basic_processor.load_image(image)
        
        if image is None:
            raise ValueError("Could not load image")
        
        for operation in operations:
            op_type = operation.get('type')
            params = operation.get('params', {})
            
            if op_type == 'resize':
                size = params.get('size', (224, 224))
                maintain_aspect = params.get('maintain_aspect', True)
                image = self.basic_processor.resize_image(image, size, maintain_aspect)
            
            elif op_type == 'enhance':
                image = self.basic_processor.enhance_image(image, **params)
            
            elif op_type == 'filter':
                filter_type = params.get('filter_type', 'blur')
                image = self.basic_processor.apply_filters(image, filter_type)
            
            elif op_type == 'edges' and self.opencv_processor:
                img_array = np.array(image)
                edges = self.opencv_processor.detect_edges(img_array, **params)
                image = Image.fromarray(edges)
            
            elif op_type == 'morphological' and self.opencv_processor:
                img_array = np.array(image)
                result = self.opencv_processor.apply_morphological_operations(img_array, **params)
                image = Image.fromarray(result)
        
        return image
    
    def create_face_database(self, face_images: Dict[str, Union[str, Path, Image.Image]]) -> Dict[str, bool]:
        """Create a database of known faces."""
        results = {}
        
        for name, image_source in face_images.items():
            if isinstance(image_source, (str, Path)):
                image = self.basic_processor.load_image(image_source)
                if image is None:
                    results[name] = False
                    continue
                img_array = np.array(image)
            elif isinstance(image_source, Image.Image):
                img_array = np.array(image_source)
            else:
                results[name] = False
                continue
            
            success = self.face_processor.add_known_face(img_array, name)
            results[name] = success
        
        return results
    
    def get_processing_stats(self) -> Dict[str, Any]:
        """Get processing statistics."""
        return self.processing_stats.copy()

# Example usage and testing
async def main():
    """Example usage of the Computer Vision Engine."""
    print("=== Computer Vision Engine Demo ===")
    
    # Initialize CV engine
    cv_engine = ComputerVisionEngine(use_deep_learning=True)
    
    # Create a sample image for testing
    if PIL_AVAILABLE:
        # Create a test image
        test_image = Image.new('RGB', (300, 200), color='red')
        draw = ImageDraw.Draw(test_image)
        draw.rectangle([50, 50, 150, 150], fill='blue')
        draw.ellipse([200, 100, 280, 180], fill='green')
        test_image_path = "test_image.jpg"
        test_image.save(test_image_path)
        
        print("\n1. Testing image analysis...")
        
        try:
            # Test image analysis
            result = await cv_engine.analyze_image(
                test_image_path,
                detect_faces=True,
                detect_objects=True,
                analyze_quality=True
            )
            
            print(f"Image dimensions: {result.dimensions}")
            print(f"File size: {result.file_size} bytes")
            print(f"Format: {result.format}")
            print(f"Detected objects: {len(result.detected_objects)}")
            print(f"Image quality: {result.image_quality}")
            print(f"Dominant colors: {result.dominant_colors[:3]}")
            
        except Exception as e:
            print(f"Error in image analysis: {str(e)}")
        
        # Test image processing
        print("\n2. Testing image processing...")
        
        operations = [
            {'type': 'resize', 'params': {'size': (150, 100)}},
            {'type': 'enhance', 'params': {'brightness': 1.2, 'contrast': 1.1}},
            {'type': 'filter', 'params': {'filter_type': 'sharpen'}}
        ]
        
        processed_image = cv_engine.process_image(test_image, operations)
        processed_image.save("processed_test_image.jpg")
        print("Processed image saved as 'processed_test_image.jpg'")
        
        # Clean up test files
        try:
            os.remove(test_image_path)
            os.remove("processed_test_image.jpg")
        except:
            pass
    
    # Test 3: Processing statistics
    print("\n3. Processing statistics:")
    stats = cv_engine.get_processing_stats()
    for key, value in stats.items():
        print(f"{key}: {value}")
    
    # Test 4: Available features
    print("\n4. Available features:")
    print(f"PIL/Pillow: {PIL_AVAILABLE}")
    print(f"OpenCV: {OPENCV_AVAILABLE}")
    print(f"PyTorch: {TORCH_AVAILABLE}")
    print(f"Face Recognition: {FACE_RECOGNITION_AVAILABLE}")
    
    print("\n=== Computer Vision Engine Demo Complete ===")

if __name__ == "__main__":
    asyncio.run(main()) 