"""
ETCD Operator - Distributed Key-Value Store Integration
Production-ready ETCD integration with cluster connectivity, transactions, and real-time watching.
"""

import asyncio
import json
import logging
import ssl
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
from urllib.parse import urlparse

# ETCD Support
try:
    import etcd3
    from etcd3.client import EtcdTokenCallCredentials
    from etcd3.exceptions import ConnectionFailedError, Etcd3Exception
    ETCD_AVAILABLE = True
except ImportError:
    ETCD_AVAILABLE = False
    print("etcd3 library not available. @etcd operator will be limited.")

# Additional asyncio support
try:
    import aioetcd3
    ASYNC_ETCD_AVAILABLE = True
except ImportError:
    ASYNC_ETCD_AVAILABLE = False
    print("aioetcd3 not available. Async operations limited.")

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class EtcdConfig:
    """ETCD connection configuration."""
    host: str = "localhost"
    port: int = 2379
    ca_cert: Optional[str] = None
    cert_key: Optional[str] = None
    cert_cert: Optional[str] = None
    timeout: int = 30
    user: Optional[str] = None
    password: Optional[str] = None
    grpc_options: Dict[str, Any] = field(default_factory=dict)
    enable_ssl: bool = False

@dataclass
class EtcdKeyValue:
    """ETCD key-value pair with metadata."""
    key: str
    value: Any
    create_revision: int = 0
    mod_revision: int = 0
    version: int = 0
    lease_id: int = 0

@dataclass
class EtcdWatchEvent:
    """ETCD watch event structure."""
    event_type: str  # PUT, DELETE
    key: str
    value: Any
    prev_value: Any = None
    create_revision: int = 0
    mod_revision: int = 0
    version: int = 0

@dataclass
class EtcdTransaction:
    """ETCD transaction structure."""
    conditions: List[Dict[str, Any]] = field(default_factory=list)
    success_ops: List[Dict[str, Any]] = field(default_factory=list)
    failure_ops: List[Dict[str, Any]] = field(default_factory=list)

@dataclass
class EtcdLease:
    """ETCD lease information."""
    id: int
    ttl: int
    granted_ttl: int
    keys: List[str] = field(default_factory=list)

class EtcdConnectionManager:
    """Manages ETCD connections with failover and load balancing."""
    
    def __init__(self, endpoints: List[str], config: EtcdConfig):
        self.endpoints = endpoints
        self.config = config
        self.clients = []
        self.current_client_index = 0
        self.connection_pool_size = min(len(endpoints), 5)
        self.health_check_interval = 30
        self._lock = asyncio.Lock()
        
    async def initialize_connections(self) -> bool:
        """Initialize connections to all endpoints."""
        if not ETCD_AVAILABLE:
            logger.error("ETCD client library not available")
            return False
            
        try:
            for endpoint in self.endpoints:
                parsed = urlparse(f"http://{endpoint}")
                host = parsed.hostname or endpoint.split(':')[0]
                port = parsed.port or int(endpoint.split(':')[1]) if ':' in endpoint else 2379
                
                client_config = EtcdConfig(
                    host=host,
                    port=port,
                    ca_cert=self.config.ca_cert,
                    cert_key=self.config.cert_key,
                    cert_cert=self.config.cert_cert,
                    timeout=self.config.timeout,
                    user=self.config.user,
                    password=self.config.password,
                    enable_ssl=self.config.enable_ssl
                )
                
                client = self._create_client(client_config)
                if client:
                    self.clients.append(client)
                    
            if not self.clients:
                logger.error("No ETCD clients could be initialized")
                return False
                
            logger.info(f"Initialized {len(self.clients)} ETCD connections")
            return True
            
        except Exception as e:
            logger.error(f"Error initializing ETCD connections: {str(e)}")
            return False
    
    def _create_client(self, config: EtcdConfig) -> Optional[Any]:
        """Create ETCD client with configuration."""
        try:
            client_args = {
                'host': config.host,
                'port': config.port,
                'timeout': config.timeout
            }
            
            # SSL Configuration
            if config.enable_ssl:
                client_args['ca_cert'] = config.ca_cert
                client_args['cert_key'] = config.cert_key
                client_args['cert_cert'] = config.cert_cert
            
            # Authentication
            if config.user and config.password:
                client_args['user'] = config.user
                client_args['password'] = config.password
            
            # GRPC Options
            if config.grpc_options:
                client_args['grpc_options'] = config.grpc_options
                
            client = etcd3.client(**client_args)
            
            # Test connection
            client.status()
            logger.info(f"Connected to ETCD: {config.host}:{config.port}")
            return client
            
        except Exception as e:
            logger.warning(f"Failed to connect to ETCD {config.host}:{config.port}: {str(e)}")
            return None
    
    async def get_healthy_client(self) -> Optional[Any]:
        """Get a healthy client with automatic failover."""
        async with self._lock:
            attempts = 0
            max_attempts = len(self.clients) * 2
            
            while attempts < max_attempts:
                try:
                    client = self.clients[self.current_client_index]
                    
                    # Health check
                    client.status()
                    return client
                    
                except Exception as e:
                    logger.warning(f"Client {self.current_client_index} unhealthy: {str(e)}")
                    self.current_client_index = (self.current_client_index + 1) % len(self.clients)
                    attempts += 1
                    await asyncio.sleep(0.1)
            
            logger.error("No healthy ETCD clients available")
            return None

class EtcdOperator:
    """@etcd operator implementation with full production features."""
    
    def __init__(self):
        self.connection_manager: Optional[EtcdConnectionManager] = None
        self.watch_tasks = {}
        self.lease_manager = {}
        self.operation_stats = {
            'get_operations': 0,
            'put_operations': 0,
            'delete_operations': 0,
            'watch_operations': 0,
            'transaction_operations': 0,
            'lease_operations': 0
        }
        self._executor = ThreadPoolExecutor(max_workers=10)
    
    async def connect(self, endpoints: Union[str, List[str]], config: Optional[EtcdConfig] = None) -> bool:
        """Connect to ETCD cluster."""
        if isinstance(endpoints, str):
            endpoints = [endpoints]
        
        if config is None:
            config = EtcdConfig()
        
        self.connection_manager = EtcdConnectionManager(endpoints, config)
        success = await self.connection_manager.initialize_connections()
        
        if success:
            # Start background tasks
            asyncio.create_task(self._lease_keepalive_task())
            asyncio.create_task(self._health_monitor_task())
        
        return success
    
    async def get(self, key: str, serializable: bool = False) -> Optional[EtcdKeyValue]:
        """Get value for key."""
        if not self.connection_manager:
            raise RuntimeError("Not connected to ETCD cluster")
        
        client = await self.connection_manager.get_healthy_client()
        if not client:
            raise ConnectionError("No healthy ETCD clients available")
        
        try:
            loop = asyncio.get_event_loop()
            result = await loop.run_in_executor(
                self._executor,
                lambda: client.get(key, serializable=serializable)
            )
            
            self.operation_stats['get_operations'] += 1
            
            if result[0] is not None:
                metadata = result[1]
                return EtcdKeyValue(
                    key=key,
                    value=result[0].decode('utf-8') if isinstance(result[0], bytes) else result[0],
                    create_revision=metadata.create_revision,
                    mod_revision=metadata.mod_revision,
                    version=metadata.version,
                    lease_id=metadata.lease_id
                )
            return None
            
        except Exception as e:
            logger.error(f"Error getting key {key}: {str(e)}")
            raise
    
    async def get_prefix(self, prefix: str, limit: Optional[int] = None) -> List[EtcdKeyValue]:
        """Get all keys with prefix."""
        if not self.connection_manager:
            raise RuntimeError("Not connected to ETCD cluster")
        
        client = await self.connection_manager.get_healthy_client()
        if not client:
            raise ConnectionError("No healthy ETCD clients available")
        
        try:
            loop = asyncio.get_event_loop()
            results = await loop.run_in_executor(
                self._executor,
                lambda: client.get_prefix(prefix, limit=limit)
            )
            
            kvs = []
            for value, metadata in results:
                kvs.append(EtcdKeyValue(
                    key=metadata.key.decode('utf-8'),
                    value=value.decode('utf-8') if isinstance(value, bytes) else value,
                    create_revision=metadata.create_revision,
                    mod_revision=metadata.mod_revision,
                    version=metadata.version,
                    lease_id=metadata.lease_id
                ))
            
            self.operation_stats['get_operations'] += 1
            return kvs
            
        except Exception as e:
            logger.error(f"Error getting prefix {prefix}: {str(e)}")
            raise
    
    async def put(self, key: str, value: Any, lease_id: Optional[int] = None, prev_kv: bool = False) -> bool:
        """Put key-value pair."""
        if not self.connection_manager:
            raise RuntimeError("Not connected to ETCD cluster")
        
        client = await self.connection_manager.get_healthy_client()
        if not client:
            raise ConnectionError("No healthy ETCD clients available")
        
        try:
            # Serialize value if needed
            if not isinstance(value, (str, bytes)):
                value = json.dumps(value)
            
            loop = asyncio.get_event_loop()
            result = await loop.run_in_executor(
                self._executor,
                lambda: client.put(key, value, lease=lease_id, prev_kv=prev_kv)
            )
            
            self.operation_stats['put_operations'] += 1
            return True
            
        except Exception as e:
            logger.error(f"Error putting key {key}: {str(e)}")
            raise
    
    async def delete(self, key: str, prev_kv: bool = False) -> bool:
        """Delete key."""
        if not self.connection_manager:
            raise RuntimeError("Not connected to ETCD cluster")
        
        client = await self.connection_manager.get_healthy_client()
        if not client:
            raise ConnectionError("No healthy ETCD clients available")
        
        try:
            loop = asyncio.get_event_loop()
            result = await loop.run_in_executor(
                self._executor,
                lambda: client.delete(key, prev_kv=prev_kv)
            )
            
            self.operation_stats['delete_operations'] += 1
            return result
            
        except Exception as e:
            logger.error(f"Error deleting key {key}: {str(e)}")
            raise
    
    async def delete_prefix(self, prefix: str) -> int:
        """Delete all keys with prefix."""
        if not self.connection_manager:
            raise RuntimeError("Not connected to ETCD cluster")
        
        client = await self.connection_manager.get_healthy_client()
        if not client:
            raise ConnectionError("No healthy ETCD clients available")
        
        try:
            loop = asyncio.get_event_loop()
            result = await loop.run_in_executor(
                self._executor,
                lambda: client.delete_prefix(prefix)
            )
            
            self.operation_stats['delete_operations'] += 1
            return result
            
        except Exception as e:
            logger.error(f"Error deleting prefix {prefix}: {str(e)}")
            raise
    
    async def watch(self, key: str, callback: Callable[[EtcdWatchEvent], None]) -> str:
        """Watch key for changes."""
        if not self.connection_manager:
            raise RuntimeError("Not connected to ETCD cluster")
        
        client = await self.connection_manager.get_healthy_client()
        if not client:
            raise ConnectionError("No healthy ETCD clients available")
        
        watch_id = f"watch_{key}_{int(time.time())}"
        
        async def watch_task():
            try:
                loop = asyncio.get_event_loop()
                
                def create_watch():
                    return client.watch(key)
                
                watch_gen = await loop.run_in_executor(self._executor, create_watch)
                
                for event in watch_gen:
                    try:
                        watch_event = EtcdWatchEvent(
                            event_type="PUT" if event.type == "PUT" else "DELETE",
                            key=event.key.decode('utf-8'),
                            value=event.value.decode('utf-8') if event.value else None,
                            prev_value=event.prev_value.decode('utf-8') if event.prev_value else None,
                            create_revision=event.create_revision,
                            mod_revision=event.mod_revision,
                            version=event.version
                        )
                        
                        await asyncio.get_event_loop().run_in_executor(
                            None, callback, watch_event
                        )
                        
                    except Exception as e:
                        logger.error(f"Error in watch callback: {str(e)}")
                        
            except Exception as e:
                logger.error(f"Watch task error: {str(e)}")
            finally:
                if watch_id in self.watch_tasks:
                    del self.watch_tasks[watch_id]
        
        task = asyncio.create_task(watch_task())
        self.watch_tasks[watch_id] = task
        
        self.operation_stats['watch_operations'] += 1
        return watch_id
    
    async def watch_prefix(self, prefix: str, callback: Callable[[EtcdWatchEvent], None]) -> str:
        """Watch prefix for changes."""
        if not self.connection_manager:
            raise RuntimeError("Not connected to ETCD cluster")
        
        client = await self.connection_manager.get_healthy_client()
        if not client:
            raise ConnectionError("No healthy ETCD clients available")
        
        watch_id = f"watch_prefix_{prefix}_{int(time.time())}"
        
        async def watch_task():
            try:
                loop = asyncio.get_event_loop()
                
                def create_watch():
                    return client.watch_prefix(prefix)
                
                watch_gen = await loop.run_in_executor(self._executor, create_watch)
                
                for event in watch_gen:
                    try:
                        watch_event = EtcdWatchEvent(
                            event_type="PUT" if event.type == "PUT" else "DELETE",
                            key=event.key.decode('utf-8'),
                            value=event.value.decode('utf-8') if event.value else None,
                            prev_value=event.prev_value.decode('utf-8') if event.prev_value else None,
                            create_revision=event.create_revision,
                            mod_revision=event.mod_revision,
                            version=event.version
                        )
                        
                        await asyncio.get_event_loop().run_in_executor(
                            None, callback, watch_event
                        )
                        
                    except Exception as e:
                        logger.error(f"Error in watch callback: {str(e)}")
                        
            except Exception as e:
                logger.error(f"Watch prefix task error: {str(e)}")
            finally:
                if watch_id in self.watch_tasks:
                    del self.watch_tasks[watch_id]
        
        task = asyncio.create_task(watch_task())
        self.watch_tasks[watch_id] = task
        
        self.operation_stats['watch_operations'] += 1
        return watch_id
    
    async def cancel_watch(self, watch_id: str) -> bool:
        """Cancel watch operation."""
        if watch_id in self.watch_tasks:
            task = self.watch_tasks[watch_id]
            task.cancel()
            del self.watch_tasks[watch_id]
            return True
        return False
    
    async def transaction(self, transaction: EtcdTransaction) -> Dict[str, Any]:
        """Execute transaction."""
        if not self.connection_manager:
            raise RuntimeError("Not connected to ETCD cluster")
        
        client = await self.connection_manager.get_healthy_client()
        if not client:
            raise ConnectionError("No healthy ETCD clients available")
        
        try:
            txn = client.txn()
            
            # Add conditions
            for condition in transaction.conditions:
                if condition['type'] == 'compare_value':
                    txn = txn.compare_value(condition['key'], condition['operator'], condition['value'])
                elif condition['type'] == 'compare_version':
                    txn = txn.compare_version(condition['key'], condition['operator'], condition['version'])
                elif condition['type'] == 'compare_create_revision':
                    txn = txn.compare_create_revision(condition['key'], condition['operator'], condition['revision'])
                elif condition['type'] == 'compare_mod_revision':
                    txn = txn.compare_mod_revision(condition['key'], condition['operator'], condition['revision'])
            
            # Add success operations
            success_ops = []
            for op in transaction.success_ops:
                if op['type'] == 'put':
                    success_ops.append(txn.put(op['key'], op['value']))
                elif op['type'] == 'delete':
                    success_ops.append(txn.delete(op['key']))
            
            # Add failure operations
            failure_ops = []
            for op in transaction.failure_ops:
                if op['type'] == 'put':
                    failure_ops.append(txn.put(op['key'], op['value']))
                elif op['type'] == 'delete':
                    failure_ops.append(txn.delete(op['key']))
            
            loop = asyncio.get_event_loop()
            result = await loop.run_in_executor(
                self._executor,
                lambda: txn.and_then(*success_ops).or_else(*failure_ops).commit()
            )
            
            self.operation_stats['transaction_operations'] += 1
            
            return {
                'succeeded': result[0],
                'responses': result[1] if result[1] else []
            }
            
        except Exception as e:
            logger.error(f"Error executing transaction: {str(e)}")
            raise
    
    async def lease_grant(self, ttl: int) -> EtcdLease:
        """Grant lease."""
        if not self.connection_manager:
            raise RuntimeError("Not connected to ETCD cluster")
        
        client = await self.connection_manager.get_healthy_client()
        if not client:
            raise ConnectionError("No healthy ETCD clients available")
        
        try:
            loop = asyncio.get_event_loop()
            lease = await loop.run_in_executor(
                self._executor,
                lambda: client.lease(ttl)
            )
            
            lease_obj = EtcdLease(
                id=lease.id,
                ttl=ttl,
                granted_ttl=lease.granted_ttl
            )
            
            self.lease_manager[lease.id] = lease
            self.operation_stats['lease_operations'] += 1
            
            return lease_obj
            
        except Exception as e:
            logger.error(f"Error granting lease: {str(e)}")
            raise
    
    async def lease_revoke(self, lease_id: int) -> bool:
        """Revoke lease."""
        if not self.connection_manager:
            raise RuntimeError("Not connected to ETCD cluster")
        
        client = await self.connection_manager.get_healthy_client()
        if not client:
            raise ConnectionError("No healthy ETCD clients available")
        
        try:
            loop = asyncio.get_event_loop()
            await loop.run_in_executor(
                self._executor,
                lambda: client.revoke_lease(lease_id)
            )
            
            if lease_id in self.lease_manager:
                del self.lease_manager[lease_id]
            
            return True
            
        except Exception as e:
            logger.error(f"Error revoking lease {lease_id}: {str(e)}")
            raise
    
    async def lease_keepalive(self, lease_id: int) -> bool:
        """Keep lease alive."""
        if lease_id in self.lease_manager:
            lease = self.lease_manager[lease_id]
            try:
                loop = asyncio.get_event_loop()
                await loop.run_in_executor(
                    self._executor,
                    lambda: lease.refresh()
                )
                return True
            except Exception as e:
                logger.error(f"Error keeping lease {lease_id} alive: {str(e)}")
                return False
        return False
    
    async def get_cluster_status(self) -> Dict[str, Any]:
        """Get cluster status."""
        if not self.connection_manager:
            raise RuntimeError("Not connected to ETCD cluster")
        
        client = await self.connection_manager.get_healthy_client()
        if not client:
            raise ConnectionError("No healthy ETCD clients available")
        
        try:
            loop = asyncio.get_event_loop()
            status = await loop.run_in_executor(
                self._executor,
                lambda: client.status()
            )
            
            return {
                'cluster_id': status.cluster_id,
                'member_id': status.member_id,
                'leader': status.leader,
                'raft_term': status.raft_term,
                'raft_index': status.raft_index,
                'raft_applied_index': status.raft_applied_index,
                'db_size': status.db_size,
                'db_size_in_use': status.db_size_in_use,
                'version': status.version
            }
            
        except Exception as e:
            logger.error(f"Error getting cluster status: {str(e)}")
            raise
    
    async def _lease_keepalive_task(self):
        """Background task to keep leases alive."""
        while True:
            try:
                for lease_id in list(self.lease_manager.keys()):
                    await self.lease_keepalive(lease_id)
                await asyncio.sleep(10)  # Keep alive every 10 seconds
            except Exception as e:
                logger.error(f"Lease keepalive task error: {str(e)}")
                await asyncio.sleep(30)
    
    async def _health_monitor_task(self):
        """Background task to monitor cluster health."""
        while True:
            try:
                if self.connection_manager:
                    client = await self.connection_manager.get_healthy_client()
                    if client:
                        status = await self.get_cluster_status()
                        logger.info(f"ETCD cluster healthy - Leader: {status.get('leader')}")
                    else:
                        logger.warning("No healthy ETCD clients available")
                await asyncio.sleep(30)  # Health check every 30 seconds
            except Exception as e:
                logger.error(f"Health monitor task error: {str(e)}")
                await asyncio.sleep(60)
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get operation statistics."""
        return {
            'operations': self.operation_stats.copy(),
            'active_watches': len(self.watch_tasks),
            'active_leases': len(self.lease_manager),
            'connected': self.connection_manager is not None
        }
    
    async def close(self):
        """Close connections and cleanup."""
        # Cancel all watch tasks
        for watch_id, task in list(self.watch_tasks.items()):
            task.cancel()
            del self.watch_tasks[watch_id]
        
        # Revoke all leases
        for lease_id in list(self.lease_manager.keys()):
            await self.lease_revoke(lease_id)
        
        # Close executor
        self._executor.shutdown(wait=True)
        
        logger.info("ETCD operator closed")

# Export the operator
__all__ = ['EtcdOperator', 'EtcdConfig', 'EtcdKeyValue', 'EtcdWatchEvent', 'EtcdTransaction', 'EtcdLease'] 