#!/usr/bin/env python3
"""
API Gateway for TuskLang Python SDK
===================================
Request routing, authentication, rate limiting, and API management

This module provides a comprehensive API gateway for the TuskLang Python SDK,
enabling request routing, authentication, rate limiting, monitoring, and
advanced API management capabilities.
"""

import asyncio
import json
import time
import threading
from typing import Any, Dict, List, Optional, Callable, Union, Tuple
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
from enum import Enum
import logging
import aiohttp
from aiohttp import web
import jwt
import hashlib
from collections import defaultdict, deque
import redis
import yaml


class RouteMethod(Enum):
    """HTTP method enumeration"""
    GET = "GET"
    POST = "POST"
    PUT = "PUT"
    DELETE = "DELETE"
    PATCH = "PATCH"
    OPTIONS = "OPTIONS"
    HEAD = "HEAD"


class AuthType(Enum):
    """Authentication type enumeration"""
    NONE = "none"
    API_KEY = "api_key"
    JWT = "jwt"
    OAUTH2 = "oauth2"
    BASIC = "basic"


@dataclass
class RouteConfig:
    """Route configuration structure"""
    path: str
    method: RouteMethod
    target_service: str
    target_path: str
    auth_type: AuthType
    rate_limit: int
    timeout: int
    cache_ttl: int
    enabled: bool = True


@dataclass
class APIRequest:
    """API request structure"""
    request_id: str
    timestamp: datetime
    method: str
    path: str
    headers: Dict[str, str]
    query_params: Dict[str, str]
    body: Any
    client_ip: str
    user_id: Optional[str] = None


@dataclass
class APIResponse:
    """API response structure"""
    request_id: str
    status_code: int
    headers: Dict[str, str]
    body: Any
    response_time: float
    timestamp: datetime


class APIGateway:
    """API Gateway for TuskLang"""
    
    def __init__(self, config: Dict[str, Any] = None):
        self.config = config or {}
        self.logger = logging.getLogger('tusklang.apigateway')
        
        # Initialize components
        self.routes = {}
        self.middleware = []
        self.rate_limiters = {}
        self.cache = {}
        self.request_log = deque(maxlen=10000)
        
        # Initialize gateway components
        self.router = Router()
        self.authenticator = Authenticator()
        self.rate_limiter = RateLimiter()
        self.cache_manager = CacheManager()
        self.monitor = APIMonitor()
        
        # Initialize gateway
        self.gateway_active = True
        self.app = web.Application()
        self.runner = None
        
        # Start background processes
        self._start_background_processes()
    
    def _start_background_processes(self):
        """Start background API gateway processes"""
        # Cache cleanup
        self.cache_cleanup_thread = threading.Thread(target=self._cache_cleanup_loop, daemon=True)
        self.cache_cleanup_thread.start()
        
        # Metrics collection
        self.metrics_thread = threading.Thread(target=self._metrics_collection_loop, daemon=True)
        self.metrics_thread.start()
    
    def add_route(self, route_config: RouteConfig) -> bool:
        """Add a new route to the gateway"""
        try:
            route_key = f"{route_config.method.value}:{route_config.path}"
            self.routes[route_key] = route_config
            
            # Add route to router
            self.router.add_route(route_config)
            
            # Initialize rate limiter for route
            self.rate_limiter.add_route(route_config.path, route_config.rate_limit)
            
            self.logger.info(f"Added route: {route_config.method.value} {route_config.path}")
            return True
            
        except Exception as e:
            self.logger.error(f"Failed to add route: {e}")
            return False
    
    def add_middleware(self, middleware: Callable):
        """Add middleware to the gateway"""
        self.middleware.append(middleware)
        self.logger.info(f"Added middleware: {middleware.__name__}")
    
    async def start_gateway(self, host: str = "0.0.0.0", port: int = 8080) -> bool:
        """Start the API gateway"""
        try:
            # Set up routes
            self.app.router.add_get('/{path:.*}', self._handle_request)
            self.app.router.add_post('/{path:.*}', self._handle_request)
            self.app.router.add_put('/{path:.*}', self._handle_request)
            self.app.router.add_delete('/{path:.*}', self._handle_request)
            self.app.router.add_patch('/{path:.*}', self._handle_request)
            
            # Add middleware
            for middleware in self.middleware:
                self.app.middlewares.append(middleware)
            
            # Start server
            self.runner = web.AppRunner(self.app)
            await self.runner.setup()
            
            site = web.TCPSite(self.runner, host, port)
            await site.start()
            
            self.logger.info(f"API Gateway started on {host}:{port}")
            return True
            
        except Exception as e:
            self.logger.error(f"Failed to start API gateway: {e}")
            return False
    
    async def stop_gateway(self):
        """Stop the API gateway"""
        if self.runner:
            await self.runner.cleanup()
            self.gateway_active = False
            self.logger.info("API Gateway stopped")
    
    async def _handle_request(self, request: web.Request) -> web.Response:
        """Handle incoming API request"""
        start_time = time.time()
        request_id = str(int(time.time() * 1000))
        
        try:
            # Create API request object
            api_request = APIRequest(
                request_id=request_id,
                timestamp=datetime.now(),
                method=request.method,
                path=str(request.path),
                headers=dict(request.headers),
                query_params=dict(request.query),
                body=await request.read() if request.body_exists else None,
                client_ip=request.remote
            )
            
            # Log request
            self.request_log.append(api_request)
            
            # Apply middleware
            for middleware in self.middleware:
                api_request = await middleware(api_request)
                if not api_request:
                    return web.Response(status=400, text="Request blocked by middleware")
            
            # Find route
            route = self.router.find_route(api_request.method, api_request.path)
            if not route:
                return web.Response(status=404, text="Route not found")
            
            # Check rate limit
            if not self.rate_limiter.check_rate_limit(api_request.client_ip, route.path):
                return web.Response(status=429, text="Rate limit exceeded")
            
            # Authenticate request
            if route.auth_type != AuthType.NONE:
                auth_result = await self.authenticator.authenticate(api_request, route.auth_type)
                if not auth_result["success"]:
                    return web.Response(status=401, text=auth_result["message"])
                api_request.user_id = auth_result.get("user_id")
            
            # Check cache
            cache_key = self._generate_cache_key(api_request)
            if route.cache_ttl > 0:
                cached_response = self.cache_manager.get(cache_key)
                if cached_response:
                    return web.Response(
                        status=cached_response["status_code"],
                        headers=cached_response["headers"],
                        text=cached_response["body"]
                    )
            
            # Forward request to target service
            response = await self._forward_request(api_request, route)
            
            # Cache response
            if route.cache_ttl > 0 and response.status == 200:
                self.cache_manager.set(cache_key, {
                    "status_code": response.status,
                    "headers": dict(response.headers),
                    "body": await response.text(),
                    "ttl": route.cache_ttl
                })
            
            # Create API response object
            response_time = time.time() - start_time
            api_response = APIResponse(
                request_id=request_id,
                status_code=response.status,
                headers=dict(response.headers),
                body=await response.text(),
                response_time=response_time,
                timestamp=datetime.now()
            )
            
            # Update metrics
            self.monitor.record_request(api_request, api_response)
            
            return response
            
        except Exception as e:
            self.logger.error(f"Request handling error: {e}")
            return web.Response(status=500, text="Internal server error")
    
    async def _forward_request(self, api_request: APIRequest, route: RouteConfig) -> web.Response:
        """Forward request to target service"""
        try:
            # Build target URL
            target_url = f"http://{route.target_service}{route.target_path}"
            
            # Prepare headers
            headers = api_request.headers.copy()
            headers.pop('Host', None)  # Remove host header
            
            # Make request to target service
            timeout = aiohttp.ClientTimeout(total=route.timeout)
            async with aiohttp.ClientSession(timeout=timeout) as session:
                if api_request.method == "GET":
                    async with session.get(target_url, headers=headers, params=api_request.query_params) as response:
                        return web.Response(
                            status=response.status,
                            headers=dict(response.headers),
                            text=await response.text()
                        )
                elif api_request.method == "POST":
                    async with session.post(target_url, headers=headers, json=api_request.body) as response:
                        return web.Response(
                            status=response.status,
                            headers=dict(response.headers),
                            text=await response.text()
                        )
                elif api_request.method == "PUT":
                    async with session.put(target_url, headers=headers, json=api_request.body) as response:
                        return web.Response(
                            status=response.status,
                            headers=dict(response.headers),
                            text=await response.text()
                        )
                elif api_request.method == "DELETE":
                    async with session.delete(target_url, headers=headers) as response:
                        return web.Response(
                            status=response.status,
                            headers=dict(response.headers),
                            text=await response.text()
                        )
                else:
                    return web.Response(status=405, text="Method not allowed")
                    
        except Exception as e:
            self.logger.error(f"Request forwarding error: {e}")
            return web.Response(status=502, text="Bad gateway")
    
    def _generate_cache_key(self, api_request: APIRequest) -> str:
        """Generate cache key for request"""
        key_data = f"{api_request.method}:{api_request.path}:{api_request.query_params}"
        return hashlib.md5(key_data.encode()).hexdigest()
    
    def get_gateway_stats(self) -> Dict[str, Any]:
        """Get gateway statistics"""
        return self.monitor.get_stats()
    
    def get_route_info(self, path: str, method: str) -> Optional[Dict[str, Any]]:
        """Get route information"""
        route_key = f"{method}:{path}"
        route = self.routes.get(route_key)
        
        if route:
            return {
                "path": route.path,
                "method": route.method.value,
                "target_service": route.target_service,
                "target_path": route.target_path,
                "auth_type": route.auth_type.value,
                "rate_limit": route.rate_limit,
                "timeout": route.timeout,
                "cache_ttl": route.cache_ttl,
                "enabled": route.enabled
            }
        return None
    
    def list_routes(self) -> List[Dict[str, Any]]:
        """List all routes"""
        return [
            self.get_route_info(route.path, route.method.value)
            for route in self.routes.values()
        ]
    
    def _cache_cleanup_loop(self):
        """Cache cleanup background loop"""
        while self.gateway_active:
            try:
                self.cache_manager.cleanup()
                time.sleep(60)  # Cleanup every minute
                
            except Exception as e:
                self.logger.error(f"Cache cleanup error: {e}")
                time.sleep(120)
    
    def _metrics_collection_loop(self):
        """Metrics collection background loop"""
        while self.gateway_active:
            try:
                self.monitor.collect_metrics()
                time.sleep(30)  # Collect metrics every 30 seconds
                
            except Exception as e:
                self.logger.error(f"Metrics collection error: {e}")
                time.sleep(60)


class Router:
    """Request router for API gateway"""
    
    def __init__(self):
        self.logger = logging.getLogger('tusklang.apigateway.router')
        self.routes = {}
    
    def add_route(self, route_config: RouteConfig):
        """Add route to router"""
        route_key = f"{route_config.method.value}:{route_config.path}"
        self.routes[route_key] = route_config
    
    def find_route(self, method: str, path: str) -> Optional[RouteConfig]:
        """Find route for request"""
        # Exact match
        route_key = f"{method}:{path}"
        if route_key in self.routes:
            return self.routes[route_key]
        
        # Pattern matching (simplified)
        for route_key, route in self.routes.items():
            if route.method.value == method and self._match_path(route.path, path):
                return route
        
        return None
    
    def _match_path(self, route_path: str, request_path: str) -> bool:
        """Match route path pattern"""
        # Simplified path matching
        # In production, use proper path matching with parameters
        return route_path == request_path


class Authenticator:
    """Authentication handler for API gateway"""
    
    def __init__(self):
        self.logger = logging.getLogger('tusklang.apigateway.authenticator')
        self.api_keys = {}
        self.jwt_secret = "your-secret-key"
    
    async def authenticate(self, api_request: APIRequest, auth_type: AuthType) -> Dict[str, Any]:
        """Authenticate API request"""
        try:
            if auth_type == AuthType.API_KEY:
                return await self._authenticate_api_key(api_request)
            elif auth_type == AuthType.JWT:
                return await self._authenticate_jwt(api_request)
            elif auth_type == AuthType.OAUTH2:
                return await self._authenticate_oauth2(api_request)
            elif auth_type == AuthType.BASIC:
                return await self._authenticate_basic(api_request)
            else:
                return {"success": False, "message": "Unsupported authentication type"}
                
        except Exception as e:
            self.logger.error(f"Authentication error: {e}")
            return {"success": False, "message": "Authentication failed"}
    
    async def _authenticate_api_key(self, api_request: APIRequest) -> Dict[str, Any]:
        """Authenticate using API key"""
        api_key = api_request.headers.get('X-API-Key')
        if not api_key:
            return {"success": False, "message": "API key required"}
        
        if api_key in self.api_keys:
            return {"success": True, "user_id": self.api_keys[api_key]}
        else:
            return {"success": False, "message": "Invalid API key"}
    
    async def _authenticate_jwt(self, api_request: APIRequest) -> Dict[str, Any]:
        """Authenticate using JWT token"""
        auth_header = api_request.headers.get('Authorization')
        if not auth_header or not auth_header.startswith('Bearer '):
            return {"success": False, "message": "Bearer token required"}
        
        token = auth_header.split(' ')[1]
        try:
            payload = jwt.decode(token, self.jwt_secret, algorithms=['HS256'])
            return {"success": True, "user_id": payload.get('user_id')}
        except jwt.InvalidTokenError:
            return {"success": False, "message": "Invalid token"}
    
    async def _authenticate_oauth2(self, api_request: APIRequest) -> Dict[str, Any]:
        """Authenticate using OAuth2"""
        # Simplified OAuth2 authentication
        # In production, implement proper OAuth2 flow
        return {"success": True, "user_id": "oauth2_user"}
    
    async def _authenticate_basic(self, api_request: APIRequest) -> Dict[str, Any]:
        """Authenticate using Basic authentication"""
        auth_header = api_request.headers.get('Authorization')
        if not auth_header or not auth_header.startswith('Basic '):
            return {"success": False, "message": "Basic authentication required"}
        
        # Simplified basic auth
        # In production, implement proper basic authentication
        return {"success": True, "user_id": "basic_user"}


class RateLimiter:
    """Rate limiter for API gateway"""
    
    def __init__(self):
        self.logger = logging.getLogger('tusklang.apigateway.ratelimiter')
        self.rate_limits = {}
        self.request_counts = defaultdict(lambda: deque(maxlen=1000))
    
    def add_route(self, path: str, rate_limit: int):
        """Add rate limit for route"""
        self.rate_limits[path] = rate_limit
    
    def check_rate_limit(self, client_ip: str, path: str) -> bool:
        """Check if request is within rate limit"""
        if path not in self.rate_limits:
            return True
        
        rate_limit = self.rate_limits[path]
        key = f"{client_ip}:{path}"
        
        current_time = time.time()
        requests = self.request_counts[key]
        
        # Remove old requests (older than 1 minute)
        while requests and current_time - requests[0] > 60:
            requests.popleft()
        
        # Check if limit exceeded
        if len(requests) >= rate_limit:
            return False
        
        # Add current request
        requests.append(current_time)
        return True


class CacheManager:
    """Cache manager for API gateway"""
    
    def __init__(self):
        self.logger = logging.getLogger('tusklang.apigateway.cache')
        self.cache = {}
    
    def get(self, key: str) -> Optional[Dict[str, Any]]:
        """Get cached response"""
        if key in self.cache:
            cached_item = self.cache[key]
            if time.time() < cached_item["expires_at"]:
                return cached_item["data"]
            else:
                del self.cache[key]
        return None
    
    def set(self, key: str, data: Dict[str, Any]):
        """Set cached response"""
        ttl = data.get("ttl", 300)  # Default 5 minutes
        self.cache[key] = {
            "data": data,
            "expires_at": time.time() + ttl
        }
    
    def cleanup(self):
        """Clean up expired cache entries"""
        current_time = time.time()
        expired_keys = [
            key for key, item in self.cache.items()
            if current_time >= item["expires_at"]
        ]
        
        for key in expired_keys:
            del self.cache[key]


class APIMonitor:
    """API monitoring for gateway"""
    
    def __init__(self):
        self.logger = logging.getLogger('tusklang.apigateway.monitor')
        self.stats = {
            "total_requests": 0,
            "successful_requests": 0,
            "failed_requests": 0,
            "average_response_time": 0.0,
            "requests_per_minute": 0,
            "error_rate": 0.0
        }
        self.request_times = deque(maxlen=1000)
    
    def record_request(self, api_request: APIRequest, api_response: APIResponse):
        """Record API request and response"""
        self.stats["total_requests"] += 1
        
        if api_response.status_code < 400:
            self.stats["successful_requests"] += 1
        else:
            self.stats["failed_requests"] += 1
        
        self.request_times.append(api_response.response_time)
    
    def collect_metrics(self):
        """Collect and update metrics"""
        if self.request_times:
            self.stats["average_response_time"] = sum(self.request_times) / len(self.request_times)
        
        if self.stats["total_requests"] > 0:
            self.stats["error_rate"] = self.stats["failed_requests"] / self.stats["total_requests"]
    
    def get_stats(self) -> Dict[str, Any]:
        """Get current statistics"""
        return self.stats.copy()


# Global API gateway instance
api_gateway = APIGateway()


def add_gateway_route(path: str, method: str, target_service: str, target_path: str,
                     auth_type: str = "none", rate_limit: int = 100, timeout: int = 30,
                     cache_ttl: int = 300) -> bool:
    """Add route to API gateway"""
    method_enum = RouteMethod(method.upper())
    auth_enum = AuthType(auth_type.lower())
    
    route_config = RouteConfig(
        path=path,
        method=method_enum,
        target_service=target_service,
        target_path=target_path,
        auth_type=auth_enum,
        rate_limit=rate_limit,
        timeout=timeout,
        cache_ttl=cache_ttl
    )
    
    return api_gateway.add_route(route_config)


def add_gateway_middleware(middleware: Callable):
    """Add middleware to API gateway"""
    api_gateway.add_middleware(middleware)


async def start_gateway(host: str = "0.0.0.0", port: int = 8080) -> bool:
    """Start API gateway"""
    return await api_gateway.start_gateway(host, port)


async def stop_gateway():
    """Stop API gateway"""
    await api_gateway.stop_gateway()


def get_gateway_stats() -> Dict[str, Any]:
    """Get gateway statistics"""
    return api_gateway.get_gateway_stats()


def get_route_info(path: str, method: str) -> Optional[Dict[str, Any]]:
    """Get route information"""
    return api_gateway.get_route_info(path, method)


def list_gateway_routes() -> List[Dict[str, Any]]:
    """List all gateway routes"""
    return api_gateway.list_routes()


if __name__ == "__main__":
    print("API Gateway for TuskLang Python SDK")
    print("=" * 50)
    
    # Test API gateway
    print("\n1. Testing API Gateway Setup:")
    
    # Add routes
    add_gateway_route("/api/users", "GET", "user-service:8001", "/users", "api_key", 100)
    add_gateway_route("/api/auth", "POST", "auth-service:8002", "/auth", "jwt", 50)
    add_gateway_route("/api/payments", "POST", "payment-service:8003", "/payments", "oauth2", 20)
    
    # List routes
    routes = list_gateway_routes()
    print(f"  Added routes: {len(routes)}")
    for route in routes:
        print(f"    - {route['method']} {route['path']} -> {route['target_service']}")
    
    # Test gateway startup
    print("\n2. Testing Gateway Startup:")
    
    async def test_gateway():
        started = await start_gateway("localhost", 8080)
        print(f"  Gateway started: {started}")
        
        if started:
            # Get stats
            stats = get_gateway_stats()
            print(f"  Gateway stats: {stats}")
            
            # Stop gateway
            await stop_gateway()
            print("  Gateway stopped")
    
    # Run async test
    asyncio.run(test_gateway())
    
    print("\nAPI Gateway testing completed!") 