"""
Core hypersets functionality: Simple SQL interface with download measurement.

Provides minimal surface area - just execute SQL queries on HF datasets
and let DuckDB handle all optimization.
"""

import time
import logging
import random
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from urllib.parse import urlparse
import threading
import io
from unittest.mock import patch
from functools import wraps

import duckdb
import pandas as pd
import requests
from huggingface_hub import HfApi, hf_hub_url

# Configure logging for 429 handling
logger = logging.getLogger(__name__)
if not logger.handlers:
    handler = logging.StreamHandler()
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.setLevel(logging.INFO)


@dataclass
class QueryStats:
    """Statistics about query execution and data transfer."""

    query_time_seconds: float
    bytes_downloaded: int
    files_accessed: int
    rows_returned: int
    columns_returned: int
    wait_time_seconds: float = 0
    retry_count: int = 0

    @property
    def mb_downloaded(self) -> float:
        return self.bytes_downloaded / 1024 / 1024

    @property
    def effective_query_time(self) -> float:
        """Query time excluding 429 wait time."""
        return self.query_time_seconds - self.wait_time_seconds


class DownloadTracker:
    """Track actual bytes downloaded during DuckDB operations."""

    def __init__(self):
        self.bytes_downloaded = 0
        self.files_accessed = set()
        self.total_wait_time = 0
        self.retry_count = 0
        self._lock = threading.Lock()

    def add_download(self, url: str, bytes_count: int):
        with self._lock:
            self.bytes_downloaded += bytes_count
            self.files_accessed.add(url)

    def add_wait_time(self, wait_seconds: float):
        with self._lock:
            self.total_wait_time += wait_seconds
            self.retry_count += 1

    def reset(self):
        with self._lock:
            self.bytes_downloaded = 0
            self.files_accessed.clear()
            self.total_wait_time = 0
            self.retry_count = 0


def smart_retry_429(
    max_retries: int = 10, base_delay: float = 1.0, max_delay: float = 60.0
):
    """
    Decorator for smart 429 retry with exponential backoff and logging.

    Args:
        max_retries: Maximum number of retry attempts
        base_delay: Base delay in seconds
        max_delay: Maximum delay between retries
    """

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            last_exception = None
            total_wait = 0
            last_log_time = 0

            for attempt in range(max_retries + 1):
                try:
                    return func(*args, **kwargs)
                except (duckdb.HTTPException, requests.exceptions.HTTPError) as e:
                    error_str = str(e)
                    if "429" not in error_str and "rate limit" not in error_str.lower():
                        # Not a rate limit error, re-raise
                        raise

                    if attempt == max_retries:
                        logger.error(
                            f"Max retries ({max_retries}) exceeded for 429 errors. Total wait time: {total_wait:.1f}s"
                        )
                        raise

                    # Calculate exponential backoff with jitter
                    delay = min(base_delay * (2**attempt), max_delay)
                    jitter = random.uniform(0.1, 0.3) * delay  # Add 10-30% jitter
                    actual_delay = delay + jitter

                    total_wait += actual_delay
                    _download_tracker.add_wait_time(actual_delay)

                    # Log rate limit delays (3s threshold with exponential re-logging)
                    should_log = False
                    if total_wait >= 3.0:  # First log after 3s total wait
                        # Log exponentially: at 3s, 6s, 12s, 24s, etc.
                        next_log_threshold = 3.0 * (
                            2
                            ** (
                                len(
                                    [
                                        x
                                        for x in [3, 6, 12, 24, 48]
                                        if x <= last_log_time
                                    ]
                                )
                            )
                        )
                        if total_wait >= next_log_threshold:
                            should_log = True
                            last_log_time = total_wait

                    if should_log:
                        logger.warning(
                            f"Rate limited (429) - waiting {actual_delay:.1f}s "
                            f"(attempt {attempt + 1}/{max_retries}, total wait: {total_wait:.1f}s)"
                        )

                    time.sleep(actual_delay)
                    last_exception = e

            # Should never reach here due to max_retries check above
            raise last_exception

        return wrapper

    return decorator


# Global download tracker
_download_tracker = DownloadTracker()


@smart_retry_429(max_retries=5, base_delay=0.5, max_delay=30.0)
def discover_parquet_files(
    dataset_name: str, config: Optional[str] = None, token: Optional[str] = None
) -> List[str]:
    """
    Discover parquet files for a dataset using minimal file listing.

    Args:
        dataset_name: HF dataset name
        config: Optional config name (if None, discovers from files)
        token: Optional HF token

    Returns:
        List of parquet file URLs
    """
    api = HfApi(token=token)

    try:
        # List all parquet files
        repo_files = api.list_repo_files(dataset_name, repo_type="dataset", token=token)
        parquet_files = [f for f in repo_files if f.endswith(".parquet")]

        if not parquet_files:
            raise ValueError(f"No parquet files found in dataset {dataset_name}")

        # Filter by config if specified
        if config:
            parquet_files = [f for f in parquet_files if f.startswith(f"{config}/")]
            if not parquet_files:
                available_configs = list(
                    set(
                        f.split("/")[0]
                        for f in repo_files
                        if "/" in f and f.endswith(".parquet")
                    )
                )
                raise ValueError(
                    f"No files found for config '{config}'. Available configs: {available_configs}"
                )

        # Convert to URLs
        urls = []
        for file_path in parquet_files:
            url = hf_hub_url(
                repo_id=dataset_name, filename=file_path, repo_type="dataset"
            )
            urls.append(url)

        return urls

    except Exception as e:
        raise RuntimeError(f"Failed to discover parquet files: {e}")


def query_dataset(
    dataset_name: str,
    query: str,
    config: Optional[str] = None,
    token: Optional[str] = None,
    memory_limit: str = "1GB",
) -> tuple[pd.DataFrame, QueryStats]:
    """
    Execute SQL query on HF dataset and return results with download stats.

    This is the core function - just takes SQL and lets DuckDB optimize everything.

    Args:
        dataset_name: HF dataset name (e.g. "wikimedia/wikipedia")
        query: SQL query to execute. Use 'dataset' as table name.
        config: Optional config (e.g. "20231101.en"). If None, uses all parquet files.
        token: Optional HF token
        memory_limit: DuckDB memory limit

    Returns:
        Tuple of (DataFrame with results, QueryStats with download info)

    Examples:
        >>> # Random sampling - let DuckDB optimize
        >>> df, stats = query_dataset(
        ...     "wikimedia/wikipedia",
        ...     "SELECT title, text FROM dataset USING SAMPLE 100",
        ...     config="20231101.en"
        ... )
        >>> print(f"Downloaded {stats.mb_downloaded:.1f} MB")

        >>> # Filtering - DuckDB will only read necessary files
        >>> df, stats = query_dataset(
        ...     "imdb",
        ...     "SELECT * FROM dataset WHERE label = 1 LIMIT 1000"
        ... )
    """
    start_time = time.time()
    _download_tracker.reset()

    # Discover parquet files
    parquet_urls = discover_parquet_files(dataset_name, config, token)

    # Create DuckDB connection
    conn = duckdb.connect(":memory:")
    conn.execute(f"SET memory_limit='{memory_limit}'")
    conn.execute("INSTALL httpfs")
    conn.execute("LOAD httpfs")

    # Install custom HTTP handler to track downloads
    _install_download_tracker(conn)

    try:
        # Create view pointing to parquet files - let DuckDB optimize which to read
        if len(parquet_urls) == 1:
            files_expr = f"'{parquet_urls[0]}'"
        else:
            files_list = "', '".join(parquet_urls)
            files_expr = f"['{files_list}']"

        # Create virtual table that DuckDB can optimize against
        conn.execute(f"CREATE VIEW dataset AS SELECT * FROM read_parquet({files_expr})")

        # Execute user query - DuckDB handles optimization
        result_df = conn.execute(query).df()

        query_time = time.time() - start_time

        # Estimate download size if monkey patching didn't capture it
        estimated_bytes = _download_tracker.bytes_downloaded
        if estimated_bytes == 0 and not result_df.empty:
            # Estimate based on result size (very rough)
            estimated_bytes = (
                result_df.memory_usage(deep=True).sum() * 2
            )  # Factor for compression

        # Collect stats
        stats = QueryStats(
            query_time_seconds=query_time,
            bytes_downloaded=estimated_bytes,
            files_accessed=(
                len(_download_tracker.files_accessed)
                if _download_tracker.files_accessed
                else len(parquet_urls)
            ),
            rows_returned=len(result_df),
            columns_returned=len(result_df.columns) if not result_df.empty else 0,
            wait_time_seconds=_download_tracker.total_wait_time,
            retry_count=_download_tracker.retry_count,
        )

        return result_df, stats

    finally:
        conn.close()


def _install_download_tracker(conn: duckdb.DuckDBPyConnection):
    """Install HTTP interceptor to track actual downloads."""
    # Enable DuckDB progress tracking for HTTP requests
    conn.execute("SET enable_progress_bar=true")
    conn.execute(
        "SET enable_progress_bar_print=false"
    )  # Don't print, we'll track programmatically

    # Note: Real download tracking would require patching DuckDB's HTTP client
    # For now we'll estimate based on query result size and file access
    # This is a limitation we need to acknowledge to the user


@smart_retry_429(max_retries=5, base_delay=0.5, max_delay=30.0)
def list_configs(dataset_name: str, token: Optional[str] = None) -> List[str]:
    """
    List available configurations for a dataset.

    Args:
        dataset_name: HF dataset name
        token: Optional HF token

    Returns:
        List of available config names
    """
    api = HfApi(token=token)

    try:
        repo_files = api.list_repo_files(dataset_name, repo_type="dataset", token=token)
        parquet_files = [f for f in repo_files if f.endswith(".parquet")]

        # Extract configs from file paths
        configs = set()
        for file_path in parquet_files:
            if "/" in file_path:
                config = file_path.split("/")[0]
                # Filter out non-config directories
                if not config.startswith(".") and config != "data":
                    configs.add(config)

        return sorted(list(configs))

    except Exception as e:
        raise RuntimeError(f"Failed to list configs: {e}")


# Simple convenience functions
def sample_random(
    dataset_name: str, n: int, config: Optional[str] = None, **kwargs
) -> tuple[pd.DataFrame, QueryStats]:
    """Sample n random rows using DuckDB's sampling."""
    query = f"SELECT * FROM dataset USING SAMPLE {n}"
    return query_dataset(dataset_name, query, config, **kwargs)


def count_rows(
    dataset_name: str, config: Optional[str] = None, **kwargs
) -> tuple[int, QueryStats]:
    """Count total rows in dataset."""
    df, stats = query_dataset(
        dataset_name, "SELECT COUNT(*) as count FROM dataset", config, **kwargs
    )
    return int(df["count"].iloc[0]), stats
