"""
Query Interface for Hypersets

Provides clean API for querying datasets with multiple output formats.
"""

import time
import logging
from typing import Optional, Any, Dict, Union, List
from dataclasses import dataclass

import pandas as pd
from .duckdb_mount import DuckDBMount, mount_dataset
from .dataset_info import get_dataset_info, DatasetInfo

logger = logging.getLogger(__name__)

# Optional imports for HuggingFace datasets
try:
    from datasets import Dataset

    HF_DATASETS_AVAILABLE = True
except ImportError:
    Dataset = None
    HF_DATASETS_AVAILABLE = False
    logger.warning("datasets library not available. to_dataset() will not work.")


@dataclass
class QueryResult:
    """
    Result of a hypersets query with multiple output format options.
    """

    _df: pd.DataFrame
    query: str
    dataset_name: str
    config: Optional[str]
    split: Optional[str]
    query_time_seconds: float
    dataset_info: DatasetInfo
    track_downloads: bool = False
    _download_stats: Optional[Any] = None

    def __post_init__(self):
        """Initialize download tracking if requested."""
        if self.track_downloads:
            from .download_tracking import create_download_stats

            self._download_stats = create_download_stats(
                dataset_name=self.dataset_name,
                query=self.query,
                query_time=self.query_time_seconds,
                result_rows=len(self._df),
                result_columns=len(self._df.columns) if not self._df.empty else 0,
                config=self.config,
                split=self.split,
            )

    def to_pandas(self) -> pd.DataFrame:
        """Return results as pandas DataFrame."""
        return self._df.copy()

    def to_dataset(self) -> "Dataset":
        """Return results as HuggingFace Dataset."""
        if not HF_DATASETS_AVAILABLE:
            raise RuntimeError(
                "datasets library not installed. " "Install with: pip install datasets"
            )

        return Dataset.from_pandas(self._df)

    def __len__(self) -> int:
        """Return number of rows."""
        return len(self._df)

    def __repr__(self) -> str:
        base = f"QueryResult(rows={len(self._df)}, time={self.query_time_seconds:.2f}s)"
        if self.track_downloads and self._download_stats:
            savings = self._download_stats.data_savings
            base += f"\n{savings}"
        return base

    @property
    def shape(self) -> tuple:
        """Return (rows, columns) shape."""
        return self._df.shape

    @property
    def columns(self) -> list:
        """Return column names."""
        return self._df.columns.tolist()

    @property
    def data_savings(self):
        """Get data savings information if tracking is enabled."""
        if not self.track_downloads or not self._download_stats:
            raise RuntimeError(
                "Download tracking not enabled. Use track_downloads=True"
            )
        return self._download_stats.data_savings

    def head(self, n: int = 5) -> pd.DataFrame:
        """Return first n rows."""
        return self._df.head(n)

    def sample(self, n: int = 5) -> pd.DataFrame:
        """Return random sample of n rows."""
        return self._df.sample(n=min(n, len(self._df)))


def query(
    sql: str,
    dataset: str,
    config: Optional[str] = None,
    split: Optional[str] = None,
    token: Optional[str] = None,
    memory_limit: str = "1GB",
    threads: int = 4,
    track_downloads: bool = False,
) -> QueryResult:
    """
    Execute SQL query against a HuggingFace dataset.

    Args:
        sql: SQL query to execute (use 'dataset' as table name)
        dataset: HuggingFace dataset name (e.g. "wikimedia/wikipedia")
        config: Optional config filter (e.g. "20231101.en")
        split: Optional split filter (e.g. "train")
        token: Optional HF token
        memory_limit: DuckDB memory limit
        threads: Number of threads for DuckDB
        track_downloads: Whether to track and estimate data savings

    Returns:
        QueryResult object with to_pandas() and to_dataset() methods

    Examples:
        # Basic query
        result = hs.query(
            "SELECT title FROM dataset LIMIT 100",
            dataset="wikimedia/wikipedia",
            config="20231101.en"
        )
        df = result.to_pandas()

        # With download tracking
        result = hs.query(
            "SELECT * FROM dataset USING SAMPLE 1000",
            dataset="imdb",
            track_downloads=True
        )
        print(result.data_savings)  # Shows estimated savings
    """
    start_time = time.time()

    # Get dataset info for metadata
    dataset_info = get_dataset_info(dataset, token=token)

    # Execute query using mount system
    with mount_dataset(
        dataset_name=dataset,
        config=config,
        split=split,
        memory_limit=memory_limit,
        threads=threads,
        token=token,
    ) as mount:
        df = mount.query(sql)

    query_time = time.time() - start_time

    return QueryResult(
        _df=df,
        query=sql,
        dataset_name=dataset,
        config=config,
        split=split,
        query_time_seconds=query_time,
        dataset_info=dataset_info,
        track_downloads=track_downloads,
    )


def sample(
    n: int,
    dataset: str,
    config: Optional[str] = None,
    split: Optional[str] = None,
    columns: Optional[list] = None,
    track_downloads: bool = False,
    **kwargs,
) -> QueryResult:
    """
    Sample random rows from a dataset.

    Args:
        n: Number of rows to sample
        dataset: HuggingFace dataset name
        config: Optional config filter
        split: Optional split filter
        columns: Optional list of columns to select
        track_downloads: Whether to track and estimate data savings
        **kwargs: Additional arguments for query()

    Returns:
        QueryResult object
    """
    if columns:
        columns_str = ", ".join(columns)
        sql = f"SELECT {columns_str} FROM dataset USING SAMPLE {n}"
    else:
        sql = f"SELECT * FROM dataset USING SAMPLE {n}"

    return query(
        sql,
        dataset=dataset,
        config=config,
        split=split,
        track_downloads=track_downloads,
        **kwargs,
    )


def head(
    n: int = 5,
    dataset: str = None,
    config: Optional[str] = None,
    split: Optional[str] = None,
    columns: Optional[List[str]] = None,
    track_downloads: bool = False,
    **kwargs,
) -> QueryResult:
    """
    Get first n rows from a dataset.

    Args:
        n: Number of rows to return
        dataset: HuggingFace dataset name
        config: Optional config filter
        split: Optional split filter
        columns: Optional list of columns to select
        track_downloads: Whether to track and estimate data savings
        **kwargs: Additional arguments for query()

    Returns:
        QueryResult object
    """
    # Build column selection
    if columns:
        column_str = ", ".join(columns)
    else:
        column_str = "*"

    sql = f"SELECT {column_str} FROM dataset LIMIT {n}"
    return query(
        sql,
        dataset=dataset,
        config=config,
        split=split,
        track_downloads=track_downloads,
        **kwargs,
    )


def count(
    dataset: str, config: Optional[str] = None, split: Optional[str] = None, **kwargs
) -> int:
    """
    Count total rows in a dataset.

    Args:
        dataset: HuggingFace dataset name
        config: Optional config filter
        split: Optional split filter
        **kwargs: Additional arguments for query()

    Returns:
        Total number of rows
    """
    result = query(
        "SELECT COUNT(*) as count FROM dataset",
        dataset=dataset,
        config=config,
        split=split,
        **kwargs,
    )
    return int(result.to_pandas()["count"].iloc[0])


def schema(
    dataset: str, config: Optional[str] = None, split: Optional[str] = None, **kwargs
) -> pd.DataFrame:
    """
    Get schema information for a dataset.

    Args:
        dataset: HuggingFace dataset name
        config: Optional config filter
        split: Optional split filter
        **kwargs: Additional arguments for query()

    Returns:
        DataFrame with schema information
    """
    result = query(
        "DESCRIBE dataset", dataset=dataset, config=config, split=split, **kwargs
    )
    return result.to_pandas()
