import unittest
import time
from unittest.mock import patch, MagicMock

from sagemaker_sparkmonitor.cache_utils import ttl_cache

class TestCacheUtils(unittest.TestCase):
    def setUp(self):
        pass

    def tearDown(self):
        pass

    def test_ttl_cache_basic_functionality(self):
        """Test basic caching functionality"""
        call_count = 0
        
        @ttl_cache(ttl_seconds=1)
        def test_function(x):
            nonlocal call_count
            call_count += 1
            return x * 2
        
        # First call should execute function
        result1 = test_function(5)
        self.assertEqual(result1, 10)
        self.assertEqual(call_count, 1)
        
        # Second call with same args should use cache
        result2 = test_function(5)
        self.assertEqual(result2, 10)
        self.assertEqual(call_count, 1)
        
        # Call with different args should execute function
        result3 = test_function(3)
        self.assertEqual(result3, 6)
        self.assertEqual(call_count, 2)

    def test_ttl_cache_expiration(self):
        """Test cache expiration after TTL"""
        call_count = 0
        
        @ttl_cache(ttl_seconds=0.1)  # Very short TTL for testing
        def test_function(x):
            nonlocal call_count
            call_count += 1
            return x * 3
        
        # First call
        result1 = test_function(2)
        self.assertEqual(result1, 6)
        self.assertEqual(call_count, 1)
        
        # Wait for cache to expire
        time.sleep(0.2)
        
        # Second call should execute function again
        result2 = test_function(2)
        self.assertEqual(result2, 6)
        self.assertEqual(call_count, 2)

    def test_ttl_cache_with_kwargs(self):
        """Test caching with keyword arguments"""
        call_count = 0
        
        @ttl_cache(ttl_seconds=1)
        def test_function(x, y=1):
            nonlocal call_count
            call_count += 1
            return x + y
        
        # First call
        result1 = test_function(5, y=2)
        self.assertEqual(result1, 7)
        self.assertEqual(call_count, 1)
        
        # Same call should use cache
        result2 = test_function(5, y=2)
        self.assertEqual(result2, 7)
        self.assertEqual(call_count, 1)
        
        # Different kwargs should execute function
        result3 = test_function(5, y=3)
        self.assertEqual(result3, 8)
        self.assertEqual(call_count, 2)

    def test_ttl_cache_clear_cache(self):
        """Test manual cache clearing"""
        call_count = 0
        
        @ttl_cache(ttl_seconds=10)  # Long TTL
        def test_function(x):
            nonlocal call_count
            call_count += 1
            return x * 4
        
        # First call
        result1 = test_function(3)
        self.assertEqual(result1, 12)
        self.assertEqual(call_count, 1)
        
        # Clear cache
        test_function.clear_cache()
        
        # Next call should execute function again
        result2 = test_function(3)
        self.assertEqual(result2, 12)
        self.assertEqual(call_count, 2)

    def test_ttl_cache_default_ttl(self):
        """Test default TTL value"""
        call_count = 0
        
        @ttl_cache()  # Use default TTL
        def test_function(x):
            nonlocal call_count
            call_count += 1
            return x * 5
        
        # First call
        result1 = test_function(4)
        self.assertEqual(result1, 20)
        self.assertEqual(call_count, 1)
        
        # Second call should use cache (within default 30 seconds)
        result2 = test_function(4)
        self.assertEqual(result2, 20)
        self.assertEqual(call_count, 1)

    @patch('time.time')
    def test_ttl_cache_time_mocking(self, mock_time):
        """Test cache behavior with mocked time"""
        call_count = 0
        mock_time.return_value = 1000.0
        
        @ttl_cache(ttl_seconds=5)
        def test_function(x):
            nonlocal call_count
            call_count += 1
            return x * 6
        
        # First call at time 1000
        result1 = test_function(2)
        self.assertEqual(result1, 12)
        self.assertEqual(call_count, 1)
        
        # Move time forward but within TTL
        mock_time.return_value = 1003.0
        result2 = test_function(2)
        self.assertEqual(result2, 12)
        self.assertEqual(call_count, 1)  # Should use cache
        
        # Move time beyond TTL
        mock_time.return_value = 1006.0
        result3 = test_function(2)
        self.assertEqual(result3, 12)
        self.assertEqual(call_count, 2)  # Should execute function again

    def test_ttl_cache_exception_handling(self):
        """Test that exceptions are not cached"""
        call_count = 0
        
        @ttl_cache(ttl_seconds=1)
        def test_function(x):
            nonlocal call_count
            call_count += 1
            if x < 0:
                raise ValueError("Negative value")
            return x * 2
        
        # First call with valid input
        result1 = test_function(5)
        self.assertEqual(result1, 10)
        self.assertEqual(call_count, 1)
        
        # Call with invalid input should raise exception
        with self.assertRaises(ValueError):
            test_function(-1)
        self.assertEqual(call_count, 2)
        
        # Another call with same invalid input should raise exception again
        with self.assertRaises(ValueError):
            test_function(-1)
        self.assertEqual(call_count, 3)  # Exception not cached

    def test_ttl_cache_cleanup_expired_entries(self):
        """Test that expired entries are cleaned up during cache access"""
        call_count = 0
        
        @ttl_cache(ttl_seconds=0.1)
        def test_function(x):
            nonlocal call_count
            call_count += 1
            return x * 2
        
        # Create multiple cache entries
        test_function(1)
        test_function(2)
        test_function(3)
        self.assertEqual(call_count, 3)
        
        # Wait for entries to expire
        time.sleep(0.2)
        
        # Access cache - should trigger cleanup and execute function
        result = test_function(4)
        self.assertEqual(result, 8)
        self.assertEqual(call_count, 4)
        
        # Verify old entries were cleaned up by accessing them
        test_function(1)  # Should execute again, not use cache
        self.assertEqual(call_count, 5)

    def test_ttl_cache_kwargs_order_independence(self):
        """Test that kwargs order doesn't affect cache key"""
        call_count = 0
        
        @ttl_cache(ttl_seconds=1)
        def test_function(a=1, b=2, c=3):
            nonlocal call_count
            call_count += 1
            return a + b + c
        
        # Call with kwargs in different orders
        result1 = test_function(a=1, b=2, c=3)
        result2 = test_function(c=3, a=1, b=2)
        result3 = test_function(b=2, c=3, a=1)
        
        self.assertEqual(result1, 6)
        self.assertEqual(result2, 6)
        self.assertEqual(result3, 6)
        self.assertEqual(call_count, 1)  # Should use cache for all calls


if __name__ == '__main__':
    unittest.main()