#!/usr/bin/env python3
"""
Tests for Advanced Error Handling and Recovery System
Goal 7.2 Implementation Tests
"""

import unittest
import time
import asyncio
from unittest.mock import patch, MagicMock
import sys
import os

# Add parent directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

from error_handler import (
    ErrorHandler, CircuitBreaker, RetryHandler, ErrorInfo, ErrorSeverity,
    CircuitState, CircuitBreakerConfig, handle_errors, handle_errors_async
)

class TestErrorHandler(unittest.TestCase):
    """Test cases for ErrorHandler"""
    
    def setUp(self):
        """Set up test fixtures"""
        self.error_handler = ErrorHandler()
    
    def test_error_handling(self):
        """Test basic error handling"""
        try:
            raise ValueError("Test error")
        except Exception as e:
            error_info = self.error_handler.handle_error(e, {"test": "context"})
            
            self.assertEqual(error_info.error_type, "ValueError")
            self.assertEqual(error_info.message, "Test error")
            self.assertEqual(error_info.severity, ErrorSeverity.HIGH)
            self.assertIn("test", error_info.context)
            self.assertEqual(error_info.context["test"], "context")
    
    def test_error_severity_determination(self):
        """Test error severity determination"""
        # Test critical error
        try:
            raise Exception("Critical error")
        except Exception as e:
            error_info = self.error_handler.handle_error(e)
            self.assertEqual(error_info.severity, ErrorSeverity.HIGH)
        
        # Test connection error
        try:
            raise ConnectionError("Connection failed")
        except Exception as e:
            error_info = self.error_handler.handle_error(e)
            self.assertEqual(error_info.severity, ErrorSeverity.MEDIUM)
    
    def test_error_statistics(self):
        """Test error statistics collection"""
        # Generate some errors
        for i in range(5):
            try:
                raise ValueError(f"Error {i}")
            except Exception as e:
                self.error_handler.handle_error(e)
        
        try:
            raise ConnectionError("Connection error")
        except Exception as e:
            self.error_handler.handle_error(e)
        
        stats = self.error_handler.get_error_statistics()
        
        self.assertEqual(stats["total_errors"], 6)
        self.assertEqual(stats["error_types"]["ValueError"], 5)
        self.assertEqual(stats["error_types"]["ConnectionError"], 1)
        self.assertEqual(stats["severity_distribution"]["high"], 5)
        self.assertEqual(stats["severity_distribution"]["medium"], 1)
    
    def test_custom_error_handler(self):
        """Test custom error handler registration"""
        handler_called = False
        
        def custom_handler(error_info):
            nonlocal handler_called
            handler_called = True
            self.assertEqual(error_info.error_type, "CustomError")
        
        self.error_handler.register_error_handler("CustomError", custom_handler)
        
        try:
            raise Exception("Custom error")
        except Exception as e:
            # Rename the exception type for testing
            e.__class__.__name__ = "CustomError"
            self.error_handler.handle_error(e)
        
        self.assertTrue(handler_called)
    
    def test_health_checks(self):
        """Test health check functionality"""
        def healthy_check():
            return True
        
        def unhealthy_check():
            return False
        
        def failing_check():
            raise Exception("Health check failed")
        
        self.error_handler.register_health_check("healthy_service", healthy_check)
        self.error_handler.register_health_check("unhealthy_service", unhealthy_check)
        self.error_handler.register_health_check("failing_service", failing_check)
        
        health_status = self.error_handler.check_health()
        
        self.assertEqual(health_status["healthy_service"]["status"], "healthy")
        self.assertEqual(health_status["unhealthy_service"]["status"], "unhealthy")
        self.assertEqual(health_status["failing_service"]["status"], "unhealthy")
        self.assertIn("error", health_status["failing_service"])

class TestCircuitBreaker(unittest.TestCase):
    """Test cases for CircuitBreaker"""
    
    def setUp(self):
        """Set up test fixtures"""
        self.config = CircuitBreakerConfig(
            failure_threshold=3,
            recovery_timeout=1.0
        )
        self.circuit_breaker = CircuitBreaker("test_circuit", self.config)
    
    def test_circuit_breaker_closed_state(self):
        """Test circuit breaker in closed state"""
        def successful_function():
            return "success"
        
        result = self.circuit_breaker.call(successful_function)
        self.assertEqual(result, "success")
        
        status = self.circuit_breaker.get_status()
        self.assertEqual(status["state"], "closed")
        self.assertEqual(status["failure_count"], 0)
    
    def test_circuit_breaker_opening(self):
        """Test circuit breaker opening after failures"""
        def failing_function():
            raise ValueError("Test failure")
        
        # Fail multiple times
        for _ in range(3):
            try:
                self.circuit_breaker.call(failing_function)
            except ValueError:
                pass
        
        # Circuit should be open
        status = self.circuit_breaker.get_status()
        self.assertEqual(status["state"], "open")
        self.assertEqual(status["failure_count"], 3)
        
        # Should reject calls when open
        with self.assertRaises(Exception) as context:
            self.circuit_breaker.call(failing_function)
        self.assertIn("OPEN", str(context.exception))
    
    def test_circuit_breaker_half_open_recovery(self):
        """Test circuit breaker recovery through half-open state"""
        def failing_function():
            raise ValueError("Test failure")
        
        def successful_function():
            return "success"
        
        # Open the circuit
        for _ in range(3):
            try:
                self.circuit_breaker.call(failing_function)
            except ValueError:
                pass
        
        # Wait for recovery timeout
        time.sleep(1.1)
        
        # Should be in half-open state
        status = self.circuit_breaker.get_status()
        self.assertEqual(status["state"], "half_open")
        
        # Successful call should close the circuit
        result = self.circuit_breaker.call(successful_function)
        self.assertEqual(result, "success")
        
        status = self.circuit_breaker.get_status()
        self.assertEqual(status["state"], "closed")
        self.assertEqual(status["failure_count"], 0)
    
    def test_circuit_breaker_async(self):
        """Test async circuit breaker functionality"""
        async def async_successful_function():
            return "async_success"
        
        async def async_failing_function():
            raise ValueError("Async failure")
        
        async def test_async():
            # Test successful async call
            result = await self.circuit_breaker.call_async(async_successful_function)
            self.assertEqual(result, "async_success")
            
            # Test failing async call
            for _ in range(3):
                try:
                    await self.circuit_breaker.call_async(async_failing_function)
                except ValueError:
                    pass
            
            # Circuit should be open
            status = self.circuit_breaker.get_status()
            self.assertEqual(status["state"], "open")
        
        asyncio.run(test_async())

class TestRetryHandler(unittest.TestCase):
    """Test cases for RetryHandler"""
    
    def setUp(self):
        """Set up test fixtures"""
        self.retry_handler = RetryHandler(max_retries=3, base_delay=0.1)
    
    def test_successful_retry(self):
        """Test successful retry after failures"""
        call_count = 0
        
        def failing_then_success():
            nonlocal call_count
            call_count += 1
            if call_count < 3:
                raise ValueError("Temporary failure")
            return "success"
        
        result = self.retry_handler.retry(failing_then_success)
        self.assertEqual(result, "success")
        self.assertEqual(call_count, 3)
    
    def test_max_retries_exceeded(self):
        """Test behavior when max retries exceeded"""
        def always_failing():
            raise ValueError("Persistent failure")
        
        with self.assertRaises(ValueError) as context:
            self.retry_handler.retry(always_failing)
        
        self.assertEqual(str(context.exception), "Persistent failure")
    
    def test_retry_async(self):
        """Test async retry functionality"""
        call_count = 0
        
        async def async_failing_then_success():
            nonlocal call_count
            call_count += 1
            if call_count < 2:
                raise ValueError("Async temporary failure")
            return "async_success"
        
        async def test_async_retry():
            result = await self.retry_handler.retry_async(async_failing_then_success)
            self.assertEqual(result, "async_success")
            self.assertEqual(call_count, 2)
        
        asyncio.run(test_async_retry())
    
    def test_exponential_backoff(self):
        """Test exponential backoff timing"""
        start_time = time.time()
        
        def always_failing():
            raise ValueError("Failure")
        
        try:
            self.retry_handler.retry(always_failing)
        except ValueError:
            pass
        
        end_time = time.time()
        # Should have delays of 0.1, 0.2, 0.4 seconds
        expected_min_delay = 0.1 + 0.2 + 0.4
        self.assertGreater(end_time - start_time, expected_min_delay)

class TestDecorators(unittest.TestCase):
    """Test cases for error handling decorators"""
    
    def test_handle_errors_decorator(self):
        """Test handle_errors decorator"""
        @handle_errors()
        def test_function():
            raise ValueError("Decorated error")
        
        with self.assertRaises(ValueError):
            test_function()
        
        # Check that error was handled
        stats = error_handler.get_error_statistics()
        self.assertGreater(stats["total_errors"], 0)
    
    def test_handle_errors_with_retry(self):
        """Test handle_errors decorator with retry"""
        call_count = 0
        
        @handle_errors(retry=True)
        def test_function():
            nonlocal call_count
            call_count += 1
            if call_count < 3:
                raise ValueError("Temporary error")
            return "success"
        
        result = test_function()
        self.assertEqual(result, "success")
        self.assertEqual(call_count, 3)
    
    def test_handle_errors_with_circuit_breaker(self):
        """Test handle_errors decorator with circuit breaker"""
        @handle_errors(circuit_breaker="test_circuit")
        def test_function():
            raise ValueError("Circuit breaker error")
        
        # First few calls should fail
        for _ in range(3):
            with self.assertRaises(ValueError):
                test_function()
        
        # Circuit should be open, subsequent calls should be rejected
        with self.assertRaises(Exception) as context:
            test_function()
        self.assertIn("OPEN", str(context.exception))
    
    def test_handle_errors_async_decorator(self):
        """Test handle_errors_async decorator"""
        @handle_errors_async()
        async def async_test_function():
            raise ValueError("Async decorated error")
        
        async def run_test():
            with self.assertRaises(ValueError):
                await async_test_function()
        
        asyncio.run(run_test())

if __name__ == "__main__":
    unittest.main() 