#!/usr/bin/env python3
"""
Distributed Processor for TuskLang Python SDK
=============================================
Advanced distributed processing and parallel execution capabilities

This module provides distributed processing capabilities for the TuskLang Python SDK,
enabling parallel processing, load balancing, and distributed task execution across
multiple nodes and clusters.
"""

import asyncio
import threading
import multiprocessing
import queue
import time
import json
import hashlib
import uuid
from typing import Any, Dict, List, Optional, Callable, Union, Tuple
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
from enum import Enum
import logging
import socket
import pickle
import zlib
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
import random


class TaskStatus(Enum):
    """Task status enumeration"""
    PENDING = "pending"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"
    CANCELLED = "cancelled"


class NodeRole(Enum):
    """Node role enumeration"""
    MASTER = "master"
    WORKER = "worker"
    COORDINATOR = "coordinator"


@dataclass
class Task:
    """Task structure for distributed processing"""
    task_id: str
    function_name: str
    args: tuple
    kwargs: dict
    priority: int
    created_at: datetime
    status: TaskStatus
    assigned_node: Optional[str] = None
    result: Any = None
    error: Optional[str] = None
    execution_time: Optional[float] = None
    retry_count: int = 0
    max_retries: int = 3


@dataclass
class Node:
    """Node structure for distributed processing"""
    node_id: str
    host: str
    port: int
    role: NodeRole
    capabilities: List[str]
    load: float
    status: str
    last_heartbeat: datetime
    task_count: int
    max_tasks: int


@dataclass
class ClusterConfig:
    """Cluster configuration structure"""
    cluster_id: str
    master_node: str
    worker_nodes: List[str]
    heartbeat_interval: int
    task_timeout: int
    max_retries: int
    load_balancing_strategy: str
    fault_tolerance: bool


class DistributedProcessor:
    """Distributed processing system for TuskLang"""
    
    def __init__(self, config: Dict[str, Any] = None):
        self.config = config or {}
        self.logger = logging.getLogger('tusklang.distributed')
        
        # Initialize components
        self.cluster_config = self._init_cluster_config()
        self.task_queue = queue.PriorityQueue()
        self.completed_tasks = {}
        self.failed_tasks = {}
        self.nodes = {}
        self.node_loads = {}
        
        # Initialize processing
        self.processing_active = True
        self.master_node = None
        self.worker_nodes = []
        
        # Start background processes
        self._start_background_processes()
    
    def _init_cluster_config(self) -> ClusterConfig:
        """Initialize cluster configuration"""
        return ClusterConfig(
            cluster_id=str(uuid.uuid4()),
            master_node=self.config.get("master_node", "localhost:8000"),
            worker_nodes=self.config.get("worker_nodes", []),
            heartbeat_interval=self.config.get("heartbeat_interval", 30),
            task_timeout=self.config.get("task_timeout", 300),
            max_retries=self.config.get("max_retries", 3),
            load_balancing_strategy=self.config.get("load_balancing_strategy", "round_robin"),
            fault_tolerance=self.config.get("fault_tolerance", True)
        )
    
    def _start_background_processes(self):
        """Start background processing threads"""
        # Task scheduler
        self.scheduler_thread = threading.Thread(target=self._task_scheduler_loop, daemon=True)
        self.scheduler_thread.start()
        
        # Load balancer
        self.load_balancer_thread = threading.Thread(target=self._load_balancer_loop, daemon=True)
        self.load_balancer_thread.start()
        
        # Health monitor
        self.health_monitor_thread = threading.Thread(target=self._health_monitor_loop, daemon=True)
        self.health_monitor_thread.start()
    
    def register_node(self, node_id: str, host: str, port: int, role: NodeRole, 
                     capabilities: List[str], max_tasks: int = 10) -> bool:
        """Register a new node in the cluster"""
        try:
            node = Node(
                node_id=node_id,
                host=host,
                port=port,
                role=role,
                capabilities=capabilities,
                load=0.0,
                status="active",
                last_heartbeat=datetime.now(),
                task_count=0,
                max_tasks=max_tasks
            )
            
            self.nodes[node_id] = node
            self.node_loads[node_id] = 0.0
            
            if role == NodeRole.MASTER:
                self.master_node = node_id
            elif role == NodeRole.WORKER:
                self.worker_nodes.append(node_id)
            
            self.logger.info(f"Registered node: {node_id} ({role.value}) at {host}:{port}")
            return True
            
        except Exception as e:
            self.logger.error(f"Failed to register node {node_id}: {e}")
            return False
    
    def submit_task(self, function_name: str, *args, priority: int = 5, 
                   max_retries: int = 3, **kwargs) -> str:
        """Submit a task for distributed processing"""
        task_id = str(uuid.uuid4())
        
        task = Task(
            task_id=task_id,
            function_name=function_name,
            args=args,
            kwargs=kwargs,
            priority=priority,
            created_at=datetime.now(),
            status=TaskStatus.PENDING,
            max_retries=max_retries
        )
        
        # Add to priority queue (lower priority number = higher priority)
        self.task_queue.put((priority, task))
        
        self.logger.info(f"Submitted task {task_id}: {function_name}")
        return task_id
    
    def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
        """Get task status and result"""
        # Check completed tasks
        if task_id in self.completed_tasks:
            task = self.completed_tasks[task_id]
            return {
                "task_id": task.task_id,
                "status": task.status.value,
                "result": task.result,
                "execution_time": task.execution_time,
                "assigned_node": task.assigned_node
            }
        
        # Check failed tasks
        if task_id in self.failed_tasks:
            task = self.failed_tasks[task_id]
            return {
                "task_id": task.task_id,
                "status": task.status.value,
                "error": task.error,
                "retry_count": task.retry_count,
                "assigned_node": task.assigned_node
            }
        
        return None
    
    def cancel_task(self, task_id: str) -> bool:
        """Cancel a pending task"""
        # This is a simplified implementation
        # In a real system, you'd need to handle task cancellation more carefully
        self.logger.info(f"Attempting to cancel task: {task_id}")
        return True
    
    def get_cluster_status(self) -> Dict[str, Any]:
        """Get cluster status and statistics"""
        active_nodes = [node for node in self.nodes.values() if node.status == "active"]
        
        return {
            "cluster_id": self.cluster_config.cluster_id,
            "total_nodes": len(self.nodes),
            "active_nodes": len(active_nodes),
            "master_node": self.master_node,
            "worker_nodes": self.worker_nodes,
            "pending_tasks": self.task_queue.qsize(),
            "completed_tasks": len(self.completed_tasks),
            "failed_tasks": len(self.failed_tasks),
            "node_loads": self.node_loads,
            "load_balancing_strategy": self.cluster_config.load_balancing_strategy
        }
    
    def _task_scheduler_loop(self):
        """Task scheduler background loop"""
        while self.processing_active:
            try:
                if not self.task_queue.empty():
                    priority, task = self.task_queue.get()
                    
                    # Assign task to available node
                    assigned_node = self._assign_task_to_node(task)
                    
                    if assigned_node:
                        task.assigned_node = assigned_node
                        task.status = TaskStatus.RUNNING
                        self.nodes[assigned_node].task_count += 1
                        
                        # Execute task
                        self._execute_task(task, assigned_node)
                    else:
                        # No available nodes, put task back in queue
                        self.task_queue.put((priority, task))
                        time.sleep(1)
                else:
                    time.sleep(0.1)
                    
            except Exception as e:
                self.logger.error(f"Task scheduler error: {e}")
                time.sleep(1)
    
    def _assign_task_to_node(self, task: Task) -> Optional[str]:
        """Assign task to available node using load balancing strategy"""
        available_nodes = [
            node_id for node_id, node in self.nodes.items()
            if node.status == "active" and node.task_count < node.max_tasks
        ]
        
        if not available_nodes:
            return None
        
        if self.cluster_config.load_balancing_strategy == "round_robin":
            return self._round_robin_assign(available_nodes)
        elif self.cluster_config.load_balancing_strategy == "least_loaded":
            return self._least_loaded_assign(available_nodes)
        elif self.cluster_config.load_balancing_strategy == "capability_based":
            return self._capability_based_assign(task, available_nodes)
        else:
            return random.choice(available_nodes)
    
    def _round_robin_assign(self, available_nodes: List[str]) -> str:
        """Round-robin load balancing"""
        if not hasattr(self, '_round_robin_index'):
            self._round_robin_index = 0
        
        node = available_nodes[self._round_robin_index % len(available_nodes)]
        self._round_robin_index += 1
        return node
    
    def _least_loaded_assign(self, available_nodes: List[str]) -> str:
        """Least-loaded load balancing"""
        return min(available_nodes, key=lambda node_id: self.node_loads.get(node_id, 0.0))
    
    def _capability_based_assign(self, task: Task, available_nodes: List[str]) -> str:
        """Capability-based load balancing"""
        # Find nodes with required capabilities
        capable_nodes = [
            node_id for node_id in available_nodes
            if self._node_has_capability(node_id, task.function_name)
        ]
        
        if capable_nodes:
            return self._least_loaded_assign(capable_nodes)
        else:
            return self._least_loaded_assign(available_nodes)
    
    def _node_has_capability(self, node_id: str, function_name: str) -> bool:
        """Check if node has capability for function"""
        node = self.nodes.get(node_id)
        if not node:
            return False
        
        # Simple capability matching (can be enhanced)
        return any(cap in function_name.lower() for cap in node.capabilities)
    
    def _execute_task(self, task: Task, node_id: str):
        """Execute task on assigned node"""
        start_time = time.time()
        
        try:
            # Simulate task execution
            result = self._execute_function(task.function_name, *task.args, **task.kwargs)
            
            task.result = result
            task.status = TaskStatus.COMPLETED
            task.execution_time = time.time() - start_time
            
            self.completed_tasks[task.task_id] = task
            
            self.logger.info(f"Task {task.task_id} completed on node {node_id}")
            
        except Exception as e:
            task.error = str(e)
            task.execution_time = time.time() - start_time
            
            if task.retry_count < task.max_retries:
                task.retry_count += 1
                task.status = TaskStatus.PENDING
                self.task_queue.put((task.priority, task))
                self.logger.warning(f"Task {task.task_id} failed, retrying ({task.retry_count}/{task.max_retries})")
            else:
                task.status = TaskStatus.FAILED
                self.failed_tasks[task.task_id] = task
                self.logger.error(f"Task {task.task_id} failed permanently after {task.max_retries} retries")
        
        finally:
            # Update node load
            if node_id in self.nodes:
                self.nodes[node_id].task_count -= 1
                self._update_node_load(node_id)
    
    def _execute_function(self, function_name: str, *args, **kwargs) -> Any:
        """Execute function by name"""
        # This is a simplified implementation
        # In a real system, you'd have a registry of available functions
        
        if function_name == "add":
            return sum(args)
        elif function_name == "multiply":
            result = 1
            for arg in args:
                result *= arg
            return result
        elif function_name == "process_data":
            return self._process_data(*args, **kwargs)
        elif function_name == "analyze_text":
            return self._analyze_text(*args, **kwargs)
        elif function_name == "compute_hash":
            return self._compute_hash(*args, **kwargs)
        else:
            raise ValueError(f"Unknown function: {function_name}")
    
    def _process_data(self, data: List[Any], operation: str = "sum") -> Any:
        """Process data with specified operation"""
        if operation == "sum":
            return sum(data)
        elif operation == "average":
            return sum(data) / len(data) if data else 0
        elif operation == "max":
            return max(data) if data else None
        elif operation == "min":
            return min(data) if data else None
        else:
            raise ValueError(f"Unknown operation: {operation}")
    
    def _analyze_text(self, text: str, analysis_type: str = "basic") -> Dict[str, Any]:
        """Analyze text with specified analysis type"""
        result = {
            "length": len(text),
            "word_count": len(text.split()),
            "character_count": len(text.replace(" ", "")),
            "analysis_type": analysis_type
        }
        
        if analysis_type == "advanced":
            result["unique_words"] = len(set(text.lower().split()))
            result["average_word_length"] = sum(len(word) for word in text.split()) / len(text.split()) if text.split() else 0
        
        return result
    
    def _compute_hash(self, data: str, algorithm: str = "md5") -> str:
        """Compute hash of data"""
        if algorithm == "md5":
            return hashlib.md5(data.encode()).hexdigest()
        elif algorithm == "sha1":
            return hashlib.sha1(data.encode()).hexdigest()
        elif algorithm == "sha256":
            return hashlib.sha256(data.encode()).hexdigest()
        else:
            raise ValueError(f"Unknown hash algorithm: {algorithm}")
    
    def _load_balancer_loop(self):
        """Load balancer background loop"""
        while self.processing_active:
            try:
                # Update node loads
                for node_id in self.nodes:
                    self._update_node_load(node_id)
                
                time.sleep(5)  # Update every 5 seconds
                
            except Exception as e:
                self.logger.error(f"Load balancer error: {e}")
                time.sleep(10)
    
    def _update_node_load(self, node_id: str):
        """Update node load calculation"""
        node = self.nodes.get(node_id)
        if node:
            # Calculate load based on task count and other factors
            load = (node.task_count / node.max_tasks) * 100.0
            self.node_loads[node_id] = load
    
    def _health_monitor_loop(self):
        """Health monitor background loop"""
        while self.processing_active:
            try:
                current_time = datetime.now()
                
                for node_id, node in self.nodes.items():
                    # Check if node is responsive
                    time_since_heartbeat = (current_time - node.last_heartbeat).total_seconds()
                    
                    if time_since_heartbeat > self.cluster_config.heartbeat_interval * 2:
                        node.status = "inactive"
                        self.logger.warning(f"Node {node_id} marked as inactive")
                        
                        # Handle fault tolerance
                        if self.cluster_config.fault_tolerance:
                            self._handle_node_failure(node_id)
                
                time.sleep(self.cluster_config.heartbeat_interval)
                
            except Exception as e:
                self.logger.error(f"Health monitor error: {e}")
                time.sleep(30)
    
    def _handle_node_failure(self, node_id: str):
        """Handle node failure with fault tolerance"""
        # Reassign tasks from failed node
        failed_tasks = [
            task for task in self.completed_tasks.values()
            if task.assigned_node == node_id and task.status == TaskStatus.RUNNING
        ]
        
        for task in failed_tasks:
            task.status = TaskStatus.PENDING
            task.assigned_node = None
            self.task_queue.put((task.priority, task))
        
        self.logger.info(f"Reassigned {len(failed_tasks)} tasks from failed node {node_id}")
    
    def shutdown(self):
        """Shutdown distributed processor"""
        self.processing_active = False
        
        # Wait for background threads to finish
        if hasattr(self, 'scheduler_thread'):
            self.scheduler_thread.join(timeout=5)
        if hasattr(self, 'load_balancer_thread'):
            self.load_balancer_thread.join(timeout=5)
        if hasattr(self, 'health_monitor_thread'):
            self.health_monitor_thread.join(timeout=5)
        
        self.logger.info("Distributed processor shutdown complete")


class ParallelProcessor:
    """Parallel processing utilities"""
    
    def __init__(self, max_workers: int = None):
        self.max_workers = max_workers or multiprocessing.cpu_count()
        self.thread_pool = ThreadPoolExecutor(max_workers=self.max_workers)
        self.process_pool = ProcessPoolExecutor(max_workers=self.max_workers)
    
    def parallel_map(self, func: Callable, items: List[Any], 
                    use_processes: bool = False, chunk_size: int = 1) -> List[Any]:
        """Execute function in parallel on items"""
        executor = self.process_pool if use_processes else self.thread_pool
        
        if use_processes:
            # Process pool executor
            results = list(executor.map(func, items, chunksize=chunk_size))
        else:
            # Thread pool executor
            results = list(executor.map(func, items))
        
        return results
    
    def parallel_execute(self, tasks: List[Tuple[Callable, tuple, dict]]) -> List[Any]:
        """Execute multiple tasks in parallel"""
        futures = []
        
        for func, args, kwargs in tasks:
            future = self.thread_pool.submit(func, *args, **kwargs)
            futures.append(future)
        
        results = []
        for future in as_completed(futures):
            try:
                result = future.result()
                results.append(result)
            except Exception as e:
                results.append(None)
        
        return results
    
    def shutdown(self):
        """Shutdown parallel processor"""
        self.thread_pool.shutdown(wait=True)
        self.process_pool.shutdown(wait=True)


# Global distributed processor instance
distributed_processor = DistributedProcessor()


def submit_distributed_task(function_name: str, *args, priority: int = 5, **kwargs) -> str:
    """Submit task for distributed processing"""
    return distributed_processor.submit_task(function_name, *args, priority=priority, **kwargs)


def get_task_status(task_id: str) -> Optional[Dict[str, Any]]:
    """Get task status"""
    return distributed_processor.get_task_status(task_id)


def get_cluster_status() -> Dict[str, Any]:
    """Get cluster status"""
    return distributed_processor.get_cluster_status()


def register_node(node_id: str, host: str, port: int, role: str, 
                 capabilities: List[str], max_tasks: int = 10) -> bool:
    """Register node in cluster"""
    node_role = NodeRole(role.lower())
    return distributed_processor.register_node(node_id, host, port, node_role, capabilities, max_tasks)


def parallel_map(func: Callable, items: List[Any], use_processes: bool = False) -> List[Any]:
    """Execute function in parallel"""
    processor = ParallelProcessor()
    return processor.parallel_map(func, items, use_processes)


if __name__ == "__main__":
    print("Distributed Processor for TuskLang Python SDK")
    print("=" * 50)
    
    # Test distributed processing
    print("\n1. Testing Distributed Processing:")
    
    # Register nodes
    register_node("master", "localhost", 8000, "master", ["all"], 20)
    register_node("worker1", "localhost", 8001, "worker", ["processing", "analysis"], 10)
    register_node("worker2", "localhost", 8002, "worker", ["computation", "hashing"], 10)
    
    # Submit tasks
    task1 = submit_distributed_task("add", 1, 2, 3, 4, 5, priority=1)
    task2 = submit_distributed_task("multiply", 2, 3, 4, priority=2)
    task3 = submit_distributed_task("process_data", [1, 2, 3, 4, 5], "average", priority=3)
    task4 = submit_distributed_task("analyze_text", "Hello world! This is a test.", "advanced", priority=4)
    task5 = submit_distributed_task("compute_hash", "test data", "sha256", priority=5)
    
    # Wait for tasks to complete
    time.sleep(2)
    
    # Check task status
    print("\n2. Task Status:")
    for task_id in [task1, task2, task3, task4, task5]:
        status = get_task_status(task_id)
        if status:
            print(f"  Task {task_id}: {status['status']} - {status.get('result', status.get('error', 'N/A'))}")
    
    # Get cluster status
    print("\n3. Cluster Status:")
    cluster_status = get_cluster_status()
    print(f"  Active nodes: {cluster_status['active_nodes']}")
    print(f"  Pending tasks: {cluster_status['pending_tasks']}")
    print(f"  Completed tasks: {cluster_status['completed_tasks']}")
    print(f"  Failed tasks: {cluster_status['failed_tasks']}")
    
    # Test parallel processing
    print("\n4. Testing Parallel Processing:")
    
    def square(x):
        return x * x
    
    numbers = list(range(1, 11))
    squared = parallel_map(square, numbers)
    print(f"  Squared numbers: {squared}")
    
    print("\nDistributed processing testing completed!") 