#!/usr/bin/env python3
"""
Tests for Advanced Monitoring and Observability Framework
Goal 7.3 Implementation Tests
"""

import unittest
import time
import asyncio
from unittest.mock import patch, MagicMock
import sys
import os

# Add parent directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

from monitoring_framework import (
    MonitoringFramework, MetricsCollector, StructuredLogger, TraceCollector,
    AlertManager, HealthChecker, Metric, LogEntry, TraceSpan, Alert,
    monitor_operation, monitor_async_operation
)

class TestMetricsCollector(unittest.TestCase):
    """Test cases for MetricsCollector"""
    
    def setUp(self):
        """Set up test fixtures"""
        self.collector = MetricsCollector()
    
    def test_record_metric(self):
        """Test metric recording"""
        self.collector.record_metric("test_metric", 42.5, {"tag1": "value1"})
        
        metrics = self.collector.get_metric("test_metric")
        self.assertEqual(len(metrics), 1)
        self.assertEqual(metrics[0].name, "test_metric")
        self.assertEqual(metrics[0].value, 42.5)
        self.assertEqual(metrics[0].tags["tag1"], "value1")
    
    def test_metric_filtering(self):
        """Test metric filtering by tags and time range"""
        # Record metrics with different tags
        self.collector.record_metric("test_metric", 1.0, {"env": "prod"})
        self.collector.record_metric("test_metric", 2.0, {"env": "dev"})
        
        # Filter by tag
        prod_metrics = self.collector.get_metric("test_metric", {"env": "prod"})
        self.assertEqual(len(prod_metrics), 1)
        self.assertEqual(prod_metrics[0].value, 1.0)
        
        # Filter by time range
        time.sleep(0.1)
        self.collector.record_metric("test_metric", 3.0, {"env": "prod"})
        
        recent_metrics = self.collector.get_metric("test_metric", time_range=0.05)
        self.assertEqual(len(recent_metrics), 1)
        self.assertEqual(recent_metrics[0].value, 3.0)
    
    def test_metric_aggregation(self):
        """Test metric aggregation"""
        # Record multiple metrics
        for i in range(10):
            self.collector.record_metric("test_metric", float(i))
        
        # Wait for aggregation
        time.sleep(1.1)
        
        # Check aggregations
        avg_value = self.collector.get_aggregation("test_metric", "avg")
        self.assertIsNotNone(avg_value)
        self.assertAlmostEqual(avg_value, 4.5, places=1)
        
        max_value = self.collector.get_aggregation("test_metric", "max")
        self.assertEqual(max_value, 9.0)
        
        min_value = self.collector.get_aggregation("test_metric", "min")
        self.assertEqual(min_value, 0.0)

class TestStructuredLogger(unittest.TestCase):
    """Test cases for StructuredLogger"""
    
    def setUp(self):
        """Set up test fixtures"""
        self.logger = StructuredLogger("test-service")
    
    def test_logging(self):
        """Test basic logging functionality"""
        self.logger.log("info", "Test message", "test_component", {"tag1": "value1"})
        
        logs = self.logger.get_logs()
        self.assertEqual(len(logs), 1)
        
        log_entry = logs[0]
        self.assertEqual(log_entry.level, "info")
        self.assertEqual(log_entry.message, "Test message")
        self.assertEqual(log_entry.component, "test_component")
        self.assertEqual(log_entry.tags["tag1"], "value1")
        self.assertEqual(log_entry.service, "test-service")
    
    def test_correlation_id(self):
        """Test correlation ID functionality"""
        # Set correlation ID
        correlation_id = self.logger.set_correlation_id("test-correlation")
        self.assertEqual(correlation_id, "test-correlation")
        
        # Log message
        self.logger.log("info", "Test message")
        
        # Check correlation ID
        logs = self.logger.get_logs()
        self.assertEqual(logs[0].correlation_id, "test-correlation")
        
        # Get correlation ID
        retrieved_id = self.logger.get_correlation_id()
        self.assertEqual(retrieved_id, "test-correlation")
    
    def test_log_filtering(self):
        """Test log filtering"""
        # Log messages with different levels and components
        self.logger.log("info", "Info message", "component1")
        self.logger.log("error", "Error message", "component1")
        self.logger.log("info", "Info message", "component2")
        
        # Filter by level
        error_logs = self.logger.get_logs(level="error")
        self.assertEqual(len(error_logs), 1)
        self.assertEqual(error_logs[0].message, "Error message")
        
        # Filter by component
        component1_logs = self.logger.get_logs(component="component1")
        self.assertEqual(len(component1_logs), 2)
        
        # Filter by correlation ID
        correlation_id = self.logger.set_correlation_id("filter-test")
        self.logger.log("info", "Filtered message")
        
        filtered_logs = self.logger.get_logs(correlation_id="filter-test")
        self.assertEqual(len(filtered_logs), 1)
        self.assertEqual(filtered_logs[0].message, "Filtered message")

class TestTraceCollector(unittest.TestCase):
    """Test cases for TraceCollector"""
    
    def setUp(self):
        """Set up test fixtures"""
        self.tracer = TraceCollector()
    
    def test_span_creation(self):
        """Test span creation and management"""
        span_id = self.tracer.start_span("test_span", tags={"tag1": "value1"})
        
        # Check active spans
        self.assertIn(span_id, self.tracer.active_spans)
        
        span = self.tracer.active_spans[span_id]
        self.assertEqual(span.name, "test_span")
        self.assertEqual(span.tags["tag1"], "value1")
        self.assertIsNone(span.end_time)
    
    def test_span_completion(self):
        """Test span completion"""
        span_id = self.tracer.start_span("test_span")
        
        # End span
        self.tracer.end_span(span_id, "success", {"result": "ok"})
        
        # Check span is no longer active
        self.assertNotIn(span_id, self.tracer.active_spans)
        
        # Check span in traces
        trace_id = None
        for tid, spans in self.tracer.traces.items():
            for span in spans:
                if span.span_id == span_id:
                    trace_id = tid
                    break
            if trace_id:
                break
        
        self.assertIsNotNone(trace_id)
        spans = self.tracer.traces[trace_id]
        completed_span = next(s for s in spans if s.span_id == span_id)
        self.assertEqual(completed_span.status, "success")
        self.assertEqual(completed_span.tags["result"], "ok")
        self.assertIsNotNone(completed_span.end_time)
    
    def test_span_events(self):
        """Test adding events to spans"""
        span_id = self.tracer.start_span("test_span")
        
        # Add events
        self.tracer.add_span_event(span_id, "test_event", {"attr1": "value1"})
        self.tracer.add_span_event(span_id, "another_event")
        
        span = self.tracer.active_spans[span_id]
        self.assertEqual(len(span.events), 2)
        
        event1 = span.events[0]
        self.assertEqual(event1["name"], "test_event")
        self.assertEqual(event1["attributes"]["attr1"], "value1")
        
        event2 = span.events[1]
        self.assertEqual(event2["name"], "another_event")
    
    def test_trace_retrieval(self):
        """Test trace retrieval"""
        # Create spans in same trace
        trace_id = "test-trace"
        span1_id = self.tracer.start_span("span1", trace_id)
        span2_id = self.tracer.start_span("span2", trace_id, parent_span_id=span1_id)
        
        # End spans
        self.tracer.end_span(span2_id)
        self.tracer.end_span(span1_id)
        
        # Get trace
        trace = self.tracer.get_trace(trace_id)
        self.assertEqual(len(trace), 2)
        
        span1 = next(s for s in trace if s.span_id == span1_id)
        span2 = next(s for s in trace if s.span_id == span2_id)
        
        self.assertEqual(span2.parent_span_id, span1_id)

class TestAlertManager(unittest.TestCase):
    """Test cases for AlertManager"""
    
    def setUp(self):
        """Set up test fixtures"""
        self.alert_manager = AlertManager()
        self.metrics_collector = MetricsCollector()
    
    def test_alert_creation(self):
        """Test alert creation"""
        alert = Alert(
            name="test_alert",
            condition="gt",
            threshold=100.0,
            severity="warning",
            message="Test alert"
        )
        
        self.alert_manager.add_alert(alert)
        self.assertIn("test_alert", self.alert_manager.alerts)
    
    def test_alert_triggering(self):
        """Test alert triggering"""
        # Create alert
        alert = Alert(
            name="test_metric",
            condition="gt",
            threshold=50.0,
            severity="warning",
            message="Metric too high"
        )
        self.alert_manager.add_alert(alert)
        
        # Record metric that triggers alert
        self.metrics_collector.record_metric("test_metric", 75.0)
        
        # Wait for aggregation
        time.sleep(1.1)
        
        # Check alerts
        triggered_alerts = self.alert_manager.check_alerts(self.metrics_collector)
        self.assertEqual(len(triggered_alerts), 1)
        
        alert_info = triggered_alerts[0]
        self.assertEqual(alert_info["name"], "test_metric")
        self.assertEqual(alert_info["current_value"], 75.0)
        self.assertEqual(alert_info["threshold"], 50.0)
    
    def test_alert_handler(self):
        """Test alert handler registration and execution"""
        handler_called = False
        alert_info_received = None
        
        def test_handler(alert_info):
            nonlocal handler_called, alert_info_received
            handler_called = True
            alert_info_received = alert_info
        
        self.alert_manager.register_alert_handler(test_handler)
        
        # Create and trigger alert
        alert = Alert(
            name="test_metric",
            condition="gt",
            threshold=10.0,
            severity="warning",
            message="Test alert"
        )
        self.alert_manager.add_alert(alert)
        
        self.metrics_collector.record_metric("test_metric", 20.0)
        time.sleep(1.1)
        
        self.alert_manager.check_alerts(self.metrics_collector)
        
        self.assertTrue(handler_called)
        self.assertIsNotNone(alert_info_received)
        self.assertEqual(alert_info_received["name"], "test_metric")
    
    def test_alert_cooldown(self):
        """Test alert cooldown functionality"""
        alert = Alert(
            name="test_metric",
            condition="gt",
            threshold=10.0,
            severity="warning",
            message="Test alert",
            cooldown_period=1.0
        )
        self.alert_manager.add_alert(alert)
        
        # Trigger alert
        self.metrics_collector.record_metric("test_metric", 20.0)
        time.sleep(1.1)
        
        triggered1 = self.alert_manager.check_alerts(self.metrics_collector)
        self.assertEqual(len(triggered1), 1)
        
        # Try to trigger again immediately (should be in cooldown)
        triggered2 = self.alert_manager.check_alerts(self.metrics_collector)
        self.assertEqual(len(triggered2), 0)
        
        # Wait for cooldown to expire
        time.sleep(1.1)
        triggered3 = self.alert_manager.check_alerts(self.metrics_collector)
        self.assertEqual(len(triggered3), 1)

class TestHealthChecker(unittest.TestCase):
    """Test cases for HealthChecker"""
    
    def setUp(self):
        """Set up test fixtures"""
        self.health_checker = HealthChecker()
    
    def test_health_check_registration(self):
        """Test health check registration"""
        def test_check():
            return True
        
        self.health_checker.register_health_check("test_service", test_check)
        self.assertIn("test_service", self.health_checker.health_checks)
    
    def test_health_check_execution(self):
        """Test health check execution"""
        def healthy_check():
            return True
        
        def unhealthy_check():
            return False
        
        def failing_check():
            raise Exception("Health check failed")
        
        self.health_checker.register_health_check("healthy", healthy_check)
        self.health_checker.register_health_check("unhealthy", unhealthy_check)
        self.health_checker.register_health_check("failing", failing_check)
        
        results = self.health_checker.run_health_checks()
        
        self.assertEqual(results["healthy"]["status"], "healthy")
        self.assertEqual(results["unhealthy"]["status"], "unhealthy")
        self.assertEqual(results["failing"]["status"], "unhealthy")
        self.assertIn("error", results["failing"])
    
    def test_health_status_retrieval(self):
        """Test health status retrieval"""
        def test_check():
            return True
        
        self.health_checker.register_health_check("test_service", test_check)
        self.health_checker.run_health_checks()
        
        # Get all status
        all_status = self.health_checker.get_health_status()
        self.assertIn("test_service", all_status)
        
        # Get specific status
        service_status = self.health_checker.get_health_status("test_service")
        self.assertEqual(service_status["status"], "healthy")

class TestMonitoringFramework(unittest.TestCase):
    """Test cases for MonitoringFramework"""
    
    def setUp(self):
        """Set up test fixtures"""
        self.framework = MonitoringFramework("test-service")
    
    def test_trace_operation_context_manager(self):
        """Test trace operation context manager"""
        with self.framework.trace_operation("test_op", tags={"tag1": "value1"}) as span_id:
            time.sleep(0.1)  # Simulate work
        
        # Check that span was created and completed
        self.assertNotIn(span_id, self.framework.tracer.active_spans)
        
        # Find the completed span
        trace_id = None
        for tid, spans in self.framework.tracer.traces.items():
            for span in spans:
                if span.span_id == span_id:
                    trace_id = tid
                    break
            if trace_id:
                break
        
        self.assertIsNotNone(trace_id)
        spans = self.framework.tracer.traces[trace_id]
        completed_span = next(s for s in spans if s.span_id == span_id)
        self.assertEqual(completed_span.status, "success")
        self.assertEqual(completed_span.tags["tag1"], "value1")
    
    def test_metric_recording(self):
        """Test metric recording through framework"""
        self.framework.record_metric("test_metric", 42.5, {"tag1": "value1"})
        
        metrics = self.framework.metrics.get_metric("test_metric")
        self.assertEqual(len(metrics), 1)
        self.assertEqual(metrics[0].value, 42.5)
        self.assertEqual(metrics[0].tags["tag1"], "value1")
    
    def test_logging(self):
        """Test logging through framework"""
        correlation_id = self.framework.set_correlation_id("test-correlation")
        self.framework.log("info", "Test message", "test_component")
        
        logs = self.framework.logger.get_logs()
        self.assertEqual(len(logs), 1)
        self.assertEqual(logs[0].correlation_id, "test-correlation")
        self.assertEqual(logs[0].message, "Test message")
    
    def test_alert_management(self):
        """Test alert management through framework"""
        self.framework.add_alert(
            "test_metric",
            "gt",
            100.0,
            "warning",
            "Test alert"
        )
        
        self.assertIn("test_metric", self.framework.alerts.alerts)
        
        # Test alert handler registration
        handler_called = False
        
        def test_handler(alert_info):
            nonlocal handler_called
            handler_called = True
        
        self.framework.register_alert_handler(test_handler)
        
        # Trigger alert
        self.framework.record_metric("test_metric", 150.0)
        time.sleep(1.1)
        
        self.framework.check_alerts()
        self.assertTrue(handler_called)
    
    def test_health_checks(self):
        """Test health checks through framework"""
        results = self.framework.run_health_checks()
        
        # Should have default health checks
        self.assertIn("memory_usage", results)
        self.assertIn("disk_space", results)
        self.assertIn("service_uptime", results)
    
    def test_status_report(self):
        """Test status report generation"""
        # Generate some activity
        self.framework.record_metric("test_metric", 42.5)
        self.framework.log("info", "Test message")
        
        with self.framework.trace_operation("test_op"):
            time.sleep(0.1)
        
        report = self.framework.get_status_report()
        
        self.assertEqual(report["service"], "test-service")
        self.assertIn("metrics", report)
        self.assertIn("health", report)
        self.assertIn("alerts", report)
        self.assertGreater(report["metrics"]["total_metrics"], 0)

class TestDecorators(unittest.TestCase):
    """Test cases for monitoring decorators"""
    
    def setUp(self):
        """Set up test fixtures"""
        self.framework = MonitoringFramework("test-service")
    
    def test_monitor_operation_decorator(self):
        """Test monitor_operation decorator"""
        @monitor_operation("test_function")
        def test_function():
            time.sleep(0.1)
            return "success"
        
        result = test_function()
        self.assertEqual(result, "success")
        
        # Check that metrics were recorded
        metrics = self.framework.metrics.get_metric("test_function_duration")
        self.assertGreater(len(metrics), 0)
        
        metrics = self.framework.metrics.get_metric("test_function_success")
        self.assertGreater(len(metrics), 0)
    
    def test_monitor_operation_with_error(self):
        """Test monitor_operation decorator with error"""
        @monitor_operation("error_function")
        def error_function():
            raise ValueError("Test error")
        
        with self.assertRaises(ValueError):
            error_function()
        
        # Check that error metrics were recorded
        metrics = self.framework.metrics.get_metric("error_function_error")
        self.assertGreater(len(metrics), 0)
    
    def test_monitor_async_operation_decorator(self):
        """Test monitor_async_operation decorator"""
        @monitor_async_operation("async_function")
        async def async_function():
            await asyncio.sleep(0.1)
            return "async_success"
        
        async def run_test():
            result = await async_function()
            self.assertEqual(result, "async_success")
            
            # Check that metrics were recorded
            metrics = self.framework.metrics.get_metric("async_function_duration")
            self.assertGreater(len(metrics), 0)
        
        asyncio.run(run_test())

if __name__ == "__main__":
    unittest.main() 