"""
TuskLang Python SDK - Real-Time Stream Processor (g9.3)
Apache Kafka integration with windowing operations and stream processing
"""

import asyncio
import json
import logging
import time
import threading
from collections import defaultdict, deque
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Dict, List, Optional, Any, Callable, Union, Tuple
import statistics
import hashlib

try:
    from kafka import KafkaProducer, KafkaConsumer, KafkaAdminClient
    from kafka.admin import ConfigResource, ConfigResourceType, NewTopic
    from kafka.errors import KafkaError, TopicAlreadyExistsError
    KAFKA_AVAILABLE = True
except ImportError:
    KAFKA_AVAILABLE = False


class WindowType(Enum):
    TUMBLING = "tumbling"
    SLIDING = "sliding"  
    SESSION = "session"
    GLOBAL = "global"


class AggregationType(Enum):
    COUNT = "count"
    SUM = "sum"
    AVG = "average"
    MIN = "min"
    MAX = "max"
    DISTINCT_COUNT = "distinct_count"
    FIRST = "first"
    LAST = "last"


@dataclass
class StreamRecord:
    """Stream record with metadata"""
    key: str
    value: Dict[str, Any]
    timestamp: datetime = field(default_factory=datetime.now)
    partition: int = 0
    offset: int = 0
    headers: Dict[str, str] = field(default_factory=dict)
    
    def __hash__(self):
        return hash(f"{self.key}:{self.timestamp.timestamp()}")


@dataclass
class WindowSpec:
    """Window specification"""
    window_type: WindowType
    size_ms: int
    advance_ms: Optional[int] = None  # For sliding windows
    grace_period_ms: int = 1000
    retention_ms: int = 86400000  # 24 hours
    
    def __post_init__(self):
        if self.window_type == WindowType.SLIDING and self.advance_ms is None:
            self.advance_ms = self.size_ms // 2  # Default 50% overlap


@dataclass
class WindowResult:
    """Window computation result"""
    window_id: str
    start_time: datetime
    end_time: datetime
    record_count: int
    aggregated_value: Any
    aggregation_type: AggregationType
    records: List[StreamRecord] = field(default_factory=list)


class StreamProcessor:
    """Core stream processing engine"""
    
    def __init__(self):
        self.processors: Dict[str, Callable[[StreamRecord], StreamRecord]] = {}
        self.aggregators: Dict[str, Callable[[List[StreamRecord]], Any]] = {}
        self.windows: Dict[str, WindowSpec] = {}
        self.window_states: Dict[str, Dict] = defaultdict(dict)
        self.logger = logging.getLogger(__name__)
    
    def add_processor(self, name: str, processor_func: Callable[[StreamRecord], StreamRecord]):
        """Add stream processor function"""
        self.processors[name] = processor_func
        self.logger.info(f"Added processor: {name}")
    
    def add_aggregator(self, name: str, aggregator_func: Callable[[List[StreamRecord]], Any]):
        """Add aggregation function"""
        self.aggregators[name] = aggregator_func
        self.logger.info(f"Added aggregator: {name}")
    
    def add_window(self, name: str, window_spec: WindowSpec):
        """Add window specification"""
        self.windows[name] = window_spec
        self.window_states[name] = {
            'active_windows': {},
            'completed_windows': deque(maxlen=1000)
        }
        self.logger.info(f"Added window: {name} ({window_spec.window_type.value})")
    
    def process_record(self, record: StreamRecord) -> List[StreamRecord]:
        """Process single record through all processors"""
        results = [record]
        
        for name, processor in self.processors.items():
            try:
                processed_results = []
                for rec in results:
                    processed_rec = processor(rec)
                    if processed_rec:
                        processed_results.append(processed_rec)
                results = processed_results
            except Exception as e:
                self.logger.error(f"Processor {name} failed: {e}")
        
        return results
    
    def update_windows(self, record: StreamRecord) -> List[WindowResult]:
        """Update all windows with new record"""
        completed_windows = []
        
        for window_name, window_spec in self.windows.items():
            window_results = self._update_window(window_name, window_spec, record)
            completed_windows.extend(window_results)
        
        return completed_windows
    
    def _update_window(self, window_name: str, window_spec: WindowSpec, 
                      record: StreamRecord) -> List[WindowResult]:
        """Update specific window with record"""
        state = self.window_states[window_name]
        completed = []
        
        if window_spec.window_type == WindowType.TUMBLING:
            completed = self._update_tumbling_window(window_name, window_spec, record, state)
        elif window_spec.window_type == WindowType.SLIDING:
            completed = self._update_sliding_window(window_name, window_spec, record, state)
        elif window_spec.window_type == WindowType.SESSION:
            completed = self._update_session_window(window_name, window_spec, record, state)
        elif window_spec.window_type == WindowType.GLOBAL:
            completed = self._update_global_window(window_name, window_spec, record, state)
        
        return completed
    
    def _update_tumbling_window(self, window_name: str, window_spec: WindowSpec,
                               record: StreamRecord, state: Dict) -> List[WindowResult]:
        """Update tumbling window"""
        window_start = self._get_window_start(record.timestamp, window_spec.size_ms)
        window_end = window_start + timedelta(milliseconds=window_spec.size_ms)
        window_id = f"{window_name}_{int(window_start.timestamp() * 1000)}"
        
        # Add record to window
        if window_id not in state['active_windows']:
            state['active_windows'][window_id] = {
                'start': window_start,
                'end': window_end,
                'records': [],
                'last_update': datetime.now()
            }
        
        state['active_windows'][window_id]['records'].append(record)
        state['active_windows'][window_id]['last_update'] = datetime.now()
        
        # Check for completed windows
        completed = []
        current_time = datetime.now()
        grace_period = timedelta(milliseconds=window_spec.grace_period_ms)
        
        to_remove = []
        for wid, window_data in state['active_windows'].items():
            if current_time > window_data['end'] + grace_period:
                # Window is complete
                window_result = self._compute_window_result(
                    wid, window_data, AggregationType.COUNT
                )
                completed.append(window_result)
                state['completed_windows'].append(window_result)
                to_remove.append(wid)
        
        for wid in to_remove:
            del state['active_windows'][wid]
        
        return completed
    
    def _update_sliding_window(self, window_name: str, window_spec: WindowSpec,
                              record: StreamRecord, state: Dict) -> List[WindowResult]:
        """Update sliding window"""
        completed = []
        
        # Create multiple overlapping windows
        advance_ms = window_spec.advance_ms or window_spec.size_ms
        window_start = self._get_window_start(record.timestamp, advance_ms)
        
        # Generate all windows this record belongs to
        record_time_ms = int(record.timestamp.timestamp() * 1000)
        window_size_ms = window_spec.size_ms
        
        # Find all windows that should contain this record
        for i in range(window_size_ms // advance_ms + 1):
            potential_start_ms = record_time_ms - (i * advance_ms)
            potential_start = datetime.fromtimestamp(potential_start_ms / 1000)
            potential_end = potential_start + timedelta(milliseconds=window_size_ms)
            
            if potential_start <= record.timestamp < potential_end:
                window_id = f"{window_name}_sliding_{int(potential_start.timestamp() * 1000)}"
                
                if window_id not in state['active_windows']:
                    state['active_windows'][window_id] = {
                        'start': potential_start,
                        'end': potential_end,
                        'records': [],
                        'last_update': datetime.now()
                    }
                
                state['active_windows'][window_id]['records'].append(record)
                state['active_windows'][window_id]['last_update'] = datetime.now()
        
        # Check for completed windows
        current_time = datetime.now()
        grace_period = timedelta(milliseconds=window_spec.grace_period_ms)
        
        to_remove = []
        for wid, window_data in state['active_windows'].items():
            if current_time > window_data['end'] + grace_period:
                window_result = self._compute_window_result(
                    wid, window_data, AggregationType.COUNT
                )
                completed.append(window_result)
                state['completed_windows'].append(window_result)
                to_remove.append(wid)
        
        for wid in to_remove:
            del state['active_windows'][wid]
        
        return completed
    
    def _update_session_window(self, window_name: str, window_spec: WindowSpec,
                              record: StreamRecord, state: Dict) -> List[WindowResult]:
        """Update session window (gap-based)"""
        completed = []
        session_gap = timedelta(milliseconds=window_spec.size_ms)
        
        # Find existing session for this key or create new one
        key_sessions = state.get('key_sessions', {})
        if record.key not in key_sessions:
            key_sessions[record.key] = {}
        
        sessions = key_sessions[record.key]
        
        # Find session this record belongs to
        target_session = None
        for session_id, session_data in sessions.items():
            last_record_time = max(r.timestamp for r in session_data['records'])
            if record.timestamp - last_record_time <= session_gap:
                target_session = session_id
                break
        
        if target_session is None:
            # Create new session
            target_session = f"{window_name}_session_{record.key}_{int(record.timestamp.timestamp() * 1000)}"
            sessions[target_session] = {
                'start': record.timestamp,
                'end': record.timestamp,
                'records': [],
                'last_update': datetime.now()
            }
        
        # Add record to session
        sessions[target_session]['records'].append(record)
        sessions[target_session]['end'] = max(sessions[target_session]['end'], record.timestamp)
        sessions[target_session]['last_update'] = datetime.now()
        
        # Check for completed sessions
        current_time = datetime.now()
        grace_period = timedelta(milliseconds=window_spec.grace_period_ms)
        
        to_remove = []
        for session_id, session_data in sessions.items():
            if current_time - session_data['end'] > session_gap + grace_period:
                window_result = self._compute_window_result(
                    session_id, session_data, AggregationType.COUNT
                )
                completed.append(window_result)
                state['completed_windows'].append(window_result)
                to_remove.append(session_id)
        
        for session_id in to_remove:
            del sessions[session_id]
        
        state['key_sessions'] = key_sessions
        return completed
    
    def _update_global_window(self, window_name: str, window_spec: WindowSpec,
                             record: StreamRecord, state: Dict) -> List[WindowResult]:
        """Update global window (no time bounds)"""
        window_id = f"{window_name}_global"
        
        if window_id not in state['active_windows']:
            state['active_windows'][window_id] = {
                'start': datetime.min,
                'end': datetime.max,
                'records': [],
                'last_update': datetime.now()
            }
        
        state['active_windows'][window_id]['records'].append(record)
        state['active_windows'][window_id]['last_update'] = datetime.now()
        
        # Global windows don't complete automatically
        return []
    
    def _get_window_start(self, timestamp: datetime, window_size_ms: int) -> datetime:
        """Get window start time for timestamp"""
        timestamp_ms = int(timestamp.timestamp() * 1000)
        window_start_ms = (timestamp_ms // window_size_ms) * window_size_ms
        return datetime.fromtimestamp(window_start_ms / 1000)
    
    def _compute_window_result(self, window_id: str, window_data: Dict,
                              aggregation_type: AggregationType) -> WindowResult:
        """Compute aggregation result for window"""
        records = window_data['records']
        
        if aggregation_type == AggregationType.COUNT:
            aggregated_value = len(records)
        elif aggregation_type == AggregationType.SUM:
            aggregated_value = sum(r.value.get('value', 0) for r in records)
        elif aggregation_type == AggregationType.AVG:
            values = [r.value.get('value', 0) for r in records]
            aggregated_value = statistics.mean(values) if values else 0
        elif aggregation_type == AggregationType.MIN:
            values = [r.value.get('value', 0) for r in records]
            aggregated_value = min(values) if values else None
        elif aggregation_type == AggregationType.MAX:
            values = [r.value.get('value', 0) for r in records]
            aggregated_value = max(values) if values else None
        elif aggregation_type == AggregationType.DISTINCT_COUNT:
            unique_values = set(str(r.value) for r in records)
            aggregated_value = len(unique_values)
        elif aggregation_type == AggregationType.FIRST:
            aggregated_value = records[0].value if records else None
        elif aggregation_type == AggregationType.LAST:
            aggregated_value = records[-1].value if records else None
        else:
            aggregated_value = len(records)  # Default to count
        
        return WindowResult(
            window_id=window_id,
            start_time=window_data['start'],
            end_time=window_data['end'],
            record_count=len(records),
            aggregated_value=aggregated_value,
            aggregation_type=aggregation_type,
            records=records.copy()
        )


class KafkaStreamProcessor:
    """Kafka-based stream processor"""
    
    def __init__(self, bootstrap_servers: str = "localhost:9092",
                 group_id: str = "tusklang-stream-processor"):
        if not KAFKA_AVAILABLE:
            raise ImportError("Kafka not available - install kafka-python")
        
        self.bootstrap_servers = bootstrap_servers
        self.group_id = group_id
        self.stream_processor = StreamProcessor()
        self.consumers: Dict[str, KafkaConsumer] = {}
        self.producers: Dict[str, KafkaProducer] = {}
        self.admin_client = KafkaAdminClient(
            bootstrap_servers=bootstrap_servers,
            client_id='stream-processor-admin'
        )
        self.logger = logging.getLogger(__name__)
        self.is_running = False
        self.processing_tasks = []
    
    def create_topic(self, topic_name: str, num_partitions: int = 3, 
                    replication_factor: int = 1) -> bool:
        """Create Kafka topic"""
        try:
            topic = NewTopic(
                name=topic_name,
                num_partitions=num_partitions,
                replication_factor=replication_factor
            )
            self.admin_client.create_topics([topic])
            self.logger.info(f"Created topic: {topic_name}")
            return True
        except TopicAlreadyExistsError:
            self.logger.info(f"Topic already exists: {topic_name}")
            return True
        except Exception as e:
            self.logger.error(f"Failed to create topic {topic_name}: {e}")
            return False
    
    def add_source(self, topic_name: str, auto_offset_reset: str = 'latest') -> bool:
        """Add Kafka source topic"""
        try:
            consumer = KafkaConsumer(
                topic_name,
                bootstrap_servers=self.bootstrap_servers,
                group_id=f"{self.group_id}-{topic_name}",
                auto_offset_reset=auto_offset_reset,
                value_deserializer=lambda m: json.loads(m.decode('utf-8')),
                key_deserializer=lambda m: m.decode('utf-8') if m else None,
                enable_auto_commit=True
            )
            self.consumers[topic_name] = consumer
            self.logger.info(f"Added source topic: {topic_name}")
            return True
        except Exception as e:
            self.logger.error(f"Failed to add source {topic_name}: {e}")
            return False
    
    def add_sink(self, topic_name: str) -> bool:
        """Add Kafka sink topic"""
        try:
            producer = KafkaProducer(
                bootstrap_servers=self.bootstrap_servers,
                value_serializer=lambda v: json.dumps(v).encode('utf-8'),
                key_serializer=lambda k: k.encode('utf-8') if k else None
            )
            self.producers[topic_name] = producer
            self.logger.info(f"Added sink topic: {topic_name}")
            return True
        except Exception as e:
            self.logger.error(f"Failed to add sink {topic_name}: {e}")
            return False
    
    def add_processor(self, name: str, processor_func: Callable[[StreamRecord], StreamRecord]):
        """Add stream processor"""
        self.stream_processor.add_processor(name, processor_func)
    
    def add_window(self, name: str, window_type: WindowType, size_ms: int, 
                   advance_ms: Optional[int] = None, grace_period_ms: int = 1000):
        """Add windowing operation"""
        window_spec = WindowSpec(
            window_type=window_type,
            size_ms=size_ms,
            advance_ms=advance_ms,
            grace_period_ms=grace_period_ms
        )
        self.stream_processor.add_window(name, window_spec)
    
    async def start(self):
        """Start stream processing"""
        self.is_running = True
        self.logger.info("Starting Kafka stream processing")
        
        # Start processing task for each source topic
        for topic_name, consumer in self.consumers.items():
            task = asyncio.create_task(self._process_topic(topic_name, consumer))
            self.processing_tasks.append(task)
        
        # Start window management task
        window_task = asyncio.create_task(self._manage_windows())
        self.processing_tasks.append(window_task)
        
        await asyncio.gather(*self.processing_tasks)
    
    async def stop(self):
        """Stop stream processing"""
        self.is_running = False
        self.logger.info("Stopping Kafka stream processing")
        
        # Cancel processing tasks
        for task in self.processing_tasks:
            task.cancel()
        
        # Close consumers and producers
        for consumer in self.consumers.values():
            consumer.close()
        
        for producer in self.producers.values():
            producer.close()
        
        self.admin_client.close()
    
    async def _process_topic(self, topic_name: str, consumer: KafkaConsumer):
        """Process messages from specific topic"""
        self.logger.info(f"Started processing topic: {topic_name}")
        
        while self.is_running:
            try:
                message_batch = consumer.poll(timeout_ms=1000)
                
                for topic_partition, messages in message_batch.items():
                    for message in messages:
                        # Convert Kafka message to StreamRecord
                        record = StreamRecord(
                            key=message.key or "",
                            value=message.value,
                            timestamp=datetime.fromtimestamp(message.timestamp / 1000),
                            partition=message.partition,
                            offset=message.offset
                        )
                        
                        # Process through stream processor
                        processed_records = self.stream_processor.process_record(record)
                        
                        # Update windows
                        for processed_record in processed_records:
                            completed_windows = self.stream_processor.update_windows(processed_record)
                            
                            # Send window results to output topics
                            for window_result in completed_windows:
                                await self._output_window_result(window_result)
                        
                        # Send processed records to sink topics
                        for processed_record in processed_records:
                            await self._output_record(processed_record)
                
                await asyncio.sleep(0.01)  # Small delay
                
            except Exception as e:
                self.logger.error(f"Error processing topic {topic_name}: {e}")
                await asyncio.sleep(1)
    
    async def _manage_windows(self):
        """Manage window lifecycle and cleanup"""
        while self.is_running:
            try:
                # Cleanup expired windows
                current_time = datetime.now()
                
                for window_name, state in self.stream_processor.window_states.items():
                    window_spec = self.stream_processor.windows[window_name]
                    retention_period = timedelta(milliseconds=window_spec.retention_ms)
                    
                    # Remove old completed windows
                    cutoff_time = current_time - retention_period
                    state['completed_windows'] = deque([
                        w for w in state['completed_windows'] 
                        if w.end_time > cutoff_time
                    ], maxlen=1000)
                
                await asyncio.sleep(60)  # Check every minute
                
            except Exception as e:
                self.logger.error(f"Error managing windows: {e}")
                await asyncio.sleep(60)
    
    async def _output_window_result(self, window_result: WindowResult):
        """Output window result to appropriate sink"""
        output_topic = "window-results"
        
        if output_topic in self.producers:
            try:
                result_data = {
                    'window_id': window_result.window_id,
                    'start_time': window_result.start_time.isoformat(),
                    'end_time': window_result.end_time.isoformat(),
                    'record_count': window_result.record_count,
                    'aggregated_value': window_result.aggregated_value,
                    'aggregation_type': window_result.aggregation_type.value
                }
                
                producer = self.producers[output_topic]
                await asyncio.get_event_loop().run_in_executor(
                    None, 
                    lambda: producer.send(output_topic, value=result_data, key=window_result.window_id)
                )
                
            except Exception as e:
                self.logger.error(f"Error outputting window result: {e}")
    
    async def _output_record(self, record: StreamRecord):
        """Output processed record to sink topics"""
        # Determine output topic based on record or configuration
        output_topic = "processed-records"
        
        if output_topic in self.producers:
            try:
                record_data = {
                    'key': record.key,
                    'value': record.value,
                    'timestamp': record.timestamp.isoformat(),
                    'headers': record.headers
                }
                
                producer = self.producers[output_topic]
                await asyncio.get_event_loop().run_in_executor(
                    None,
                    lambda: producer.send(output_topic, value=record_data, key=record.key)
                )
                
            except Exception as e:
                self.logger.error(f"Error outputting record: {e}")
    
    def get_metrics(self) -> Dict[str, Any]:
        """Get stream processing metrics"""
        metrics = {
            'processors': len(self.stream_processor.processors),
            'windows': len(self.stream_processor.windows),
            'sources': len(self.consumers),
            'sinks': len(self.producers),
            'is_running': self.is_running
        }
        
        # Add window statistics
        window_stats = {}
        for window_name, state in self.stream_processor.window_states.items():
            window_stats[window_name] = {
                'active_windows': len(state.get('active_windows', {})),
                'completed_windows': len(state.get('completed_windows', []))
            }
        
        metrics['window_stats'] = window_stats
        return metrics


# Built-in processors and aggregators
def filter_processor(condition: Callable[[StreamRecord], bool]) -> Callable[[StreamRecord], StreamRecord]:
    """Create filter processor"""
    def processor(record: StreamRecord) -> Optional[StreamRecord]:
        return record if condition(record) else None
    return processor


def map_processor(transform_func: Callable[[Dict[str, Any]], Dict[str, Any]]) -> Callable[[StreamRecord], StreamRecord]:
    """Create map processor"""
    def processor(record: StreamRecord) -> StreamRecord:
        transformed_value = transform_func(record.value)
        return StreamRecord(
            key=record.key,
            value=transformed_value,
            timestamp=record.timestamp,
            partition=record.partition,
            offset=record.offset,
            headers=record.headers
        )
    return processor


def enrich_processor(enrichment_func: Callable[[StreamRecord], Dict[str, Any]]) -> Callable[[StreamRecord], StreamRecord]:
    """Create enrichment processor"""
    def processor(record: StreamRecord) -> StreamRecord:
        enriched_data = enrichment_func(record)
        enriched_value = {**record.value, **enriched_data}
        return StreamRecord(
            key=record.key,
            value=enriched_value,
            timestamp=record.timestamp,
            partition=record.partition,
            offset=record.offset,
            headers=record.headers
        )
    return processor


if __name__ == "__main__":
    async def main():
        # Example stream processing pipeline
        if not KAFKA_AVAILABLE:
            print("Kafka not available - using in-memory demonstration")
            
            # Demonstrate in-memory stream processing
            processor = StreamProcessor()
            
            # Add simple processors
            processor.add_processor("filter_positive", 
                filter_processor(lambda r: r.value.get('value', 0) > 0))
            
            processor.add_processor("double_value",
                map_processor(lambda v: {'value': v.get('value', 0) * 2, 'original': v}))
            
            # Add windows
            processor.add_window("tumbling_5s", WindowSpec(WindowType.TUMBLING, 5000))
            processor.add_window("sliding_10s", WindowSpec(WindowType.SLIDING, 10000, 2000))
            
            # Process some test records
            test_records = [
                StreamRecord("key1", {"value": 10}, datetime.now()),
                StreamRecord("key2", {"value": -5}, datetime.now()),
                StreamRecord("key1", {"value": 20}, datetime.now()),
                StreamRecord("key3", {"value": 30}, datetime.now()),
            ]
            
            for record in test_records:
                processed = processor.process_record(record)
                windows = processor.update_windows(record)
                
                print(f"Processed: {[r.value for r in processed]}")
                print(f"Completed windows: {len(windows)}")
            
            print("g9.3: Stream Processor (in-memory demo) - COMPLETED ✅")
        
        else:
            # Real Kafka example
            stream_processor = KafkaStreamProcessor()
            
            # Create topics
            stream_processor.create_topic("input-stream")
            stream_processor.create_topic("processed-records") 
            stream_processor.create_topic("window-results")
            
            # Add source and sinks
            stream_processor.add_source("input-stream")
            stream_processor.add_sink("processed-records")
            stream_processor.add_sink("window-results")
            
            # Add processors
            stream_processor.add_processor("filter", 
                filter_processor(lambda r: r.value.get('amount', 0) > 100))
            
            stream_processor.add_processor("enrich",
                enrich_processor(lambda r: {'processed_at': datetime.now().isoformat()}))
            
            # Add windows
            stream_processor.add_window("sales_5min", WindowType.TUMBLING, 300000)  # 5 minutes
            stream_processor.add_window("activity_sliding", WindowType.SLIDING, 600000, 60000)  # 10min/1min
            
            print("Kafka stream processor configured")
            print("g9.3: Real-Time Stream Processor - COMPLETED ✅")
            
            # In production, you would call:
            # await stream_processor.start()
    
    asyncio.run(main()) 