"""
Zipkin Operator - Distributed Tracing Integration
Production-ready Zipkin integration with span creation, correlation, and observability integration.
"""

import asyncio
import json
import logging
import socket
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Union, Callable
from urllib.parse import urljoin
import threading
from contextlib import contextmanager

# Zipkin Support
try:
    from py_zipkin.zipkin import zipkin_span, create_http_headers_for_new_span
    from py_zipkin.transport import BaseTransportHandler
    from py_zipkin.request_helpers import create_zipkin_attr
    from py_zipkin.encoding import Encoding
    ZIPKIN_AVAILABLE = True
except ImportError:
    ZIPKIN_AVAILABLE = False
    print("py-zipkin library not available. @zipkin operator will be limited.")

# HTTP Support for transport
try:
    import aiohttp
    import requests
    HTTP_AVAILABLE = True
except ImportError:
    HTTP_AVAILABLE = False
    print("HTTP libraries not available for Zipkin transport.")

# OpenTracing Support (optional)
try:
    import opentracing
    from opentracing.ext import tags
    from opentracing.propagation import Format
    OPENTRACING_AVAILABLE = True
except ImportError:
    OPENTRACING_AVAILABLE = False
    print("OpenTracing library not available. Advanced tracing features limited.")

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class ZipkinConfig:
    """Zipkin configuration."""
    zipkin_address: str = "http://localhost:9411"
    service_name: str = "tsk-service"
    sample_rate: float = 1.0
    encoding: str = "JSON"  # JSON or THRIFT
    max_span_batch_size: int = 100
    batch_timeout: float = 5.0
    include_logging_integration: bool = True
    include_metrics_integration: bool = True
    transport_handler: Optional[str] = None  # http, kafka, scribe

@dataclass
class ZipkinSpan:
    """Zipkin span data structure."""
    trace_id: str
    span_id: str
    parent_span_id: Optional[str] = None
    operation_name: str = ""
    service_name: str = ""
    start_time: int = 0
    duration: int = 0
    tags: Dict[str, Any] = field(default_factory=dict)
    annotations: List[Dict[str, Any]] = field(default_factory=list)
    kind: Optional[str] = None  # CLIENT, SERVER, PRODUCER, CONSUMER
    remote_endpoint: Optional[Dict[str, Any]] = None
    local_endpoint: Optional[Dict[str, Any]] = None

@dataclass
class ZipkinAnnotation:
    """Zipkin annotation."""
    timestamp: int
    value: str
    endpoint: Optional[Dict[str, str]] = None

@dataclass
class ZipkinEndpoint:
    """Zipkin endpoint information."""
    service_name: str
    ipv4: Optional[str] = None
    ipv6: Optional[str] = None
    port: Optional[int] = None

class BaseTransportHandler:
    """Base transport handler for Zipkin spans."""
    
    def get_max_payload_bytes(self):
        return 1024 * 1024  # 1MB max payload
    
    def send(self, encoded_spans):
        """Send spans to transport."""
        raise NotImplementedError

class ZipkinTransportHandler(BaseTransportHandler):
    """Custom transport handler for Zipkin spans."""
    
    def __init__(self, config: ZipkinConfig):
        self.config = config
        self.spans_buffer = []
        self.last_flush = time.time()
        self._lock = threading.Lock()
    
    def get_max_payload_bytes(self):
        return 1024 * 1024  # 1MB max payload
    
    def send(self, encoded_spans):
        """Send spans to Zipkin collector."""
        with self._lock:
            if isinstance(encoded_spans, bytes):
                encoded_spans = encoded_spans.decode('utf-8')
            
            try:
                spans_data = json.loads(encoded_spans) if isinstance(encoded_spans, str) else encoded_spans
                self.spans_buffer.extend(spans_data if isinstance(spans_data, list) else [spans_data])
                
                # Check if we should flush
                current_time = time.time()
                should_flush = (
                    len(self.spans_buffer) >= self.config.max_span_batch_size or
                    current_time - self.last_flush >= self.config.batch_timeout
                )
                
                if should_flush:
                    self._flush_spans()
                    
            except Exception as e:
                logger.error(f"Error processing spans: {str(e)}")
    
    def _flush_spans(self):
        """Flush buffered spans to Zipkin."""
        if not self.spans_buffer:
            return
        
        try:
            endpoint = urljoin(self.config.zipkin_address, '/api/v2/spans')
            
            if HTTP_AVAILABLE:
                response = requests.post(
                    endpoint,
                    json=self.spans_buffer,
                    headers={'Content-Type': 'application/json'},
                    timeout=5
                )
                
                if response.status_code == 202:
                    logger.debug(f"Sent {len(self.spans_buffer)} spans to Zipkin")
                else:
                    logger.warning(f"Failed to send spans: {response.status_code}")
            else:
                logger.warning("HTTP library not available, spans not sent")
            
            self.spans_buffer.clear()
            self.last_flush = time.time()
            
        except Exception as e:
            logger.error(f"Error flushing spans to Zipkin: {str(e)}")

class ZipkinSpanManager:
    """Manages Zipkin spans and context."""
    
    def __init__(self, config: ZipkinConfig):
        self.config = config
        self.active_spans = {}
        self.span_stack = threading.local()
        self.transport_handler = ZipkinTransportHandler(config)
        self._local_endpoint = self._create_local_endpoint()
    
    def _create_local_endpoint(self) -> Dict[str, Any]:
        """Create local endpoint information."""
        try:
            hostname = socket.gethostname()
            ip_address = socket.gethostbyname(hostname)
        except:
            ip_address = "127.0.0.1"
        
        return {
            'serviceName': self.config.service_name,
            'ipv4': ip_address
        }
    
    def _get_span_stack(self):
        """Get thread-local span stack."""
        if not hasattr(self.span_stack, 'spans'):
            self.span_stack.spans = []
        return self.span_stack.spans
    
    def _generate_trace_id(self) -> str:
        """Generate new trace ID."""
        return f"{uuid.uuid4().hex[:16]}"
    
    def _generate_span_id(self) -> str:
        """Generate new span ID."""
        return f"{uuid.uuid4().hex[:8]}"
    
    @contextmanager
    def create_span(self, operation_name: str, parent_span: Optional[ZipkinSpan] = None,
                   tags: Optional[Dict[str, Any]] = None, kind: Optional[str] = None):
        """Create and manage a new span."""
        span_stack = self._get_span_stack()
        
        # Determine parent
        if parent_span:
            trace_id = parent_span.trace_id
            parent_span_id = parent_span.span_id
        elif span_stack:
            parent = span_stack[-1]
            trace_id = parent.trace_id
            parent_span_id = parent.span_id
        else:
            trace_id = self._generate_trace_id()
            parent_span_id = None
        
        # Create span
        span = ZipkinSpan(
            trace_id=trace_id,
            span_id=self._generate_span_id(),
            parent_span_id=parent_span_id,
            operation_name=operation_name,
            service_name=self.config.service_name,
            start_time=int(time.time() * 1000000),  # microseconds
            tags=tags or {},
            kind=kind,
            local_endpoint=self._local_endpoint
        )
        
        # Push to stack
        span_stack.append(span)
        self.active_spans[span.span_id] = span
        
        try:
            yield span
        finally:
            # Calculate duration
            span.duration = int(time.time() * 1000000) - span.start_time
            
            # Pop from stack
            if span_stack and span_stack[-1].span_id == span.span_id:
                span_stack.pop()
            
            # Send span
            self._send_span(span)
            
            # Remove from active spans
            if span.span_id in self.active_spans:
                del self.active_spans[span.span_id]
    
    def _send_span(self, span: ZipkinSpan):
        """Send span to transport handler."""
        try:
            span_data = {
                'traceId': span.trace_id,
                'id': span.span_id,
                'name': span.operation_name,
                'timestamp': span.start_time,
                'duration': span.duration,
                'localEndpoint': span.local_endpoint
            }
            
            if span.parent_span_id:
                span_data['parentId'] = span.parent_span_id
            
            if span.kind:
                span_data['kind'] = span.kind
            
            if span.tags:
                span_data['tags'] = span.tags
            
            if span.annotations:
                span_data['annotations'] = span.annotations
            
            if span.remote_endpoint:
                span_data['remoteEndpoint'] = span.remote_endpoint
            
            # Send via transport handler
            self.transport_handler.send(json.dumps([span_data]))
            
        except Exception as e:
            logger.error(f"Error sending span: {str(e)}")
    
    def get_current_span(self) -> Optional[ZipkinSpan]:
        """Get current active span."""
        span_stack = self._get_span_stack()
        return span_stack[-1] if span_stack else None
    
    def add_annotation(self, value: str, timestamp: Optional[int] = None):
        """Add annotation to current span."""
        current_span = self.get_current_span()
        if current_span:
            annotation = {
                'timestamp': timestamp or int(time.time() * 1000000),
                'value': value
            }
            current_span.annotations.append(annotation)
    
    def add_tag(self, key: str, value: Any):
        """Add tag to current span."""
        current_span = self.get_current_span()
        if current_span:
            current_span.tags[key] = str(value)

class ZipkinSamplingStrategy:
    """Zipkin sampling strategy implementation."""
    
    def __init__(self, sample_rate: float = 1.0):
        self.sample_rate = max(0.0, min(1.0, sample_rate))
    
    def should_sample(self, trace_id: str, operation_name: str = "") -> bool:
        """Determine if trace should be sampled."""
        if self.sample_rate >= 1.0:
            return True
        elif self.sample_rate <= 0.0:
            return False
        
        # Use trace ID for consistent sampling
        trace_hash = hash(trace_id) % 1000000
        threshold = int(self.sample_rate * 1000000)
        
        return trace_hash < threshold

class ZipkinMetricsIntegration:
    """Integration with metrics systems."""
    
    def __init__(self):
        self.span_metrics = {
            'spans_created': 0,
            'spans_finished': 0,
            'spans_errored': 0,
            'operations': {}
        }
        self._lock = threading.Lock()
    
    def record_span_start(self, operation_name: str):
        """Record span start."""
        with self._lock:
            self.span_metrics['spans_created'] += 1
            if operation_name not in self.span_metrics['operations']:
                self.span_metrics['operations'][operation_name] = {
                    'count': 0,
                    'total_duration': 0,
                    'errors': 0
                }
            self.span_metrics['operations'][operation_name]['count'] += 1
    
    def record_span_finish(self, operation_name: str, duration: int, error: bool = False):
        """Record span finish."""
        with self._lock:
            self.span_metrics['spans_finished'] += 1
            if error:
                self.span_metrics['spans_errored'] += 1
                self.span_metrics['operations'][operation_name]['errors'] += 1
            
            self.span_metrics['operations'][operation_name]['total_duration'] += duration
    
    def get_metrics(self) -> Dict[str, Any]:
        """Get current metrics."""
        with self._lock:
            return {
                'summary': self.span_metrics.copy(),
                'operation_stats': {
                    op: {
                        'count': stats['count'],
                        'avg_duration_ms': stats['total_duration'] / (stats['count'] * 1000) if stats['count'] > 0 else 0,
                        'error_rate': stats['errors'] / stats['count'] if stats['count'] > 0 else 0
                    }
                    for op, stats in self.span_metrics['operations'].items()
                }
            }

class ZipkinOperator:
    """@zipkin operator implementation with full production features."""
    
    def __init__(self):
        self.config: Optional[ZipkinConfig] = None
        self.span_manager: Optional[ZipkinSpanManager] = None
        self.sampling_strategy: Optional[ZipkinSamplingStrategy] = None
        self.metrics_integration: Optional[ZipkinMetricsIntegration] = None
        self.operation_stats = {
            'spans_created': 0,
            'spans_sent': 0,
            'annotations_added': 0,
            'tags_added': 0
        }
        self._executor = ThreadPoolExecutor(max_workers=5)
        self._initialized = False
    
    async def connect(self, config: Optional[ZipkinConfig] = None) -> bool:
        """Initialize Zipkin tracing."""
        if not ZIPKIN_AVAILABLE:
            logger.error("py-zipkin library not available")
            return False
        
        if config is None:
            config = ZipkinConfig()
        
        self.config = config
        
        try:
            # Initialize components
            self.span_manager = ZipkinSpanManager(config)
            self.sampling_strategy = ZipkinSamplingStrategy(config.sample_rate)
            
            if config.include_metrics_integration:
                self.metrics_integration = ZipkinMetricsIntegration()
            
            # Test connection to Zipkin
            if HTTP_AVAILABLE:
                test_endpoint = urljoin(config.zipkin_address, '/api/v2/services')
                try:
                    response = requests.get(test_endpoint, timeout=5)
                    if response.status_code == 200:
                        logger.info(f"Successfully connected to Zipkin at {config.zipkin_address}")
                    else:
                        logger.warning(f"Zipkin responded with status {response.status_code}")
                except requests.exceptions.RequestException:
                    logger.warning("Could not reach Zipkin, but tracing will continue")
            
            self._initialized = True
            return True
            
        except Exception as e:
            logger.error(f"Error initializing Zipkin tracing: {str(e)}")
            return False
    
    @contextmanager
    def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None, 
             kind: Optional[str] = None):
        """Create a new trace span."""
        if not self._initialized or not self.span_manager:
            # Provide no-op context manager if not initialized
            yield None
            return
        
        # Check sampling
        trace_id = self.span_manager._generate_trace_id()
        if not self.sampling_strategy.should_sample(trace_id, operation_name):
            yield None
            return
        
        # Record metrics
        if self.metrics_integration:
            self.metrics_integration.record_span_start(operation_name)
        
        start_time = time.time()
        error_occurred = False
        
        try:
            with self.span_manager.create_span(operation_name, tags=tags, kind=kind) as span:
                self.operation_stats['spans_created'] += 1
                yield span
        except Exception as e:
            error_occurred = True
            self.add_tag('error', True)
            self.add_tag('error.message', str(e))
            raise
        finally:
            # Record metrics
            if self.metrics_integration:
                duration = int((time.time() - start_time) * 1000000)  # microseconds
                self.metrics_integration.record_span_finish(operation_name, duration, error_occurred)
    
    @contextmanager  
    def child_span(self, operation_name: str, parent_span: Optional[ZipkinSpan] = None,
                  tags: Optional[Dict[str, Any]] = None, kind: Optional[str] = None):
        """Create a child span."""
        if not self._initialized or not self.span_manager:
            yield None
            return
        
        # Record metrics
        if self.metrics_integration:
            self.metrics_integration.record_span_start(operation_name)
        
        start_time = time.time()
        error_occurred = False
        
        try:
            with self.span_manager.create_span(operation_name, parent_span, tags, kind) as span:
                self.operation_stats['spans_created'] += 1
                yield span
        except Exception as e:
            error_occurred = True
            self.add_tag('error', True)
            self.add_tag('error.message', str(e))
            raise
        finally:
            # Record metrics
            if self.metrics_integration:
                duration = int((time.time() - start_time) * 1000000)
                self.metrics_integration.record_span_finish(operation_name, duration, error_occurred)
    
    def add_annotation(self, value: str, timestamp: Optional[int] = None):
        """Add annotation to current span."""
        if self._initialized and self.span_manager:
            self.span_manager.add_annotation(value, timestamp)
            self.operation_stats['annotations_added'] += 1
    
    def add_tag(self, key: str, value: Any):
        """Add tag to current span."""
        if self._initialized and self.span_manager:
            self.span_manager.add_tag(key, value)
            self.operation_stats['tags_added'] += 1
    
    def get_current_span(self) -> Optional[ZipkinSpan]:
        """Get current active span."""
        if self._initialized and self.span_manager:
            return self.span_manager.get_current_span()
        return None
    
    def create_trace_headers(self) -> Dict[str, str]:
        """Create HTTP headers for trace propagation."""
        if not self._initialized or not self.span_manager:
            return {}
        
        current_span = self.span_manager.get_current_span()
        if not current_span:
            return {}
        
        try:
            if ZIPKIN_AVAILABLE:
                headers = create_http_headers_for_new_span()
                return dict(headers)
        except:
            pass
        
        # Fallback manual headers
        return {
            'X-Trace-Id': current_span.trace_id,
            'X-Span-Id': current_span.span_id,
            'X-Parent-Span-Id': current_span.parent_span_id or '',
        }
    
    def extract_trace_context(self, headers: Dict[str, str]) -> Optional[Dict[str, str]]:
        """Extract trace context from HTTP headers."""
        trace_context = {}
        
        # Standard Zipkin headers
        if 'X-Trace-Id' in headers:
            trace_context['trace_id'] = headers['X-Trace-Id']
        if 'X-Span-Id' in headers:
            trace_context['span_id'] = headers['X-Span-Id']
        if 'X-Parent-Span-Id' in headers:
            trace_context['parent_span_id'] = headers['X-Parent-Span-Id']
        
        # B3 headers (Zipkin standard)
        if 'X-B3-TraceId' in headers:
            trace_context['trace_id'] = headers['X-B3-TraceId']
        if 'X-B3-SpanId' in headers:
            trace_context['span_id'] = headers['X-B3-SpanId']
        if 'X-B3-ParentSpanId' in headers:
            trace_context['parent_span_id'] = headers['X-B3-ParentSpanId']
        
        return trace_context if trace_context else None
    
    async def instrument_http_request(self, method: str, url: str, **request_kwargs):
        """Instrument HTTP request with tracing."""
        if not HTTP_AVAILABLE:
            logger.warning("HTTP library not available for instrumentation")
            return None
        
        with self.trace(f"HTTP {method.upper()}", kind="CLIENT") as span:
            if span:
                self.add_tag('http.method', method.upper())
                self.add_tag('http.url', url)
                self.add_tag('component', 'http-client')
            
            # Add trace headers
            headers = request_kwargs.get('headers', {})
            headers.update(self.create_trace_headers())
            request_kwargs['headers'] = headers
            
            try:
                response = requests.request(method, url, **request_kwargs)
                
                if span:
                    self.add_tag('http.status_code', response.status_code)
                    if response.status_code >= 400:
                        self.add_tag('error', True)
                
                return response
                
            except Exception as e:
                if span:
                    self.add_tag('error', True)
                    self.add_tag('error.message', str(e))
                raise
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get tracing statistics."""
        stats = {
            'operations': self.operation_stats.copy(),
            'initialized': self._initialized,
            'sampling_rate': self.config.sample_rate if self.config else 0.0
        }
        
        if self.metrics_integration:
            stats['detailed_metrics'] = self.metrics_integration.get_metrics()
        
        if self.span_manager:
            stats['active_spans'] = len(self.span_manager.active_spans)
        
        return stats
    
    async def flush_spans(self):
        """Flush any pending spans."""
        if self.span_manager and self.span_manager.transport_handler:
            self.span_manager.transport_handler._flush_spans()
    
    async def close(self):
        """Close tracing and cleanup."""
        # Flush any remaining spans
        await self.flush_spans()
        
        # Shutdown executor
        self._executor.shutdown(wait=True)
        
        self._initialized = False
        logger.info("Zipkin tracing closed")

# Export the operator
__all__ = [
    'ZipkinOperator', 'ZipkinConfig', 'ZipkinSpan', 'ZipkinAnnotation', 
    'ZipkinEndpoint', 'ZipkinSamplingStrategy', 'ZipkinMetricsIntegration'
] 