"""
DuckDB query engine with proper random sampling and async streaming support.

This module provides efficient querying of remote Parquet files with:
- True random sampling using DuckDB SAMPLE functions
- Async streaming support with no disk caching
- Memory-only operations
"""

import asyncio
from typing import Any, Dict, List, Optional, Union, Iterator, AsyncIterator
import warnings
from contextlib import asynccontextmanager

import duckdb
import pandas as pd
from datasets import Dataset


class DuckDBQueryEngine:
    """
    Query engine for executing SQL queries on remote Parquet files using DuckDB.

    Features:
    - True random sampling (not just LIMIT)
    - Async streaming support
    - Memory-only operations (no disk caching)
    - Efficient remote file access
    """

    def __init__(self, memory_limit: str = "1GB", threads: Optional[int] = None):
        """
        Initialize DuckDB query engine.

        Args:
            memory_limit: Memory limit for DuckDB (e.g., "1GB", "500MB")
            threads: Number of threads to use. If None, uses DuckDB default.
        """
        # Use in-memory database (no disk persistence)
        self.connection = duckdb.connect(":memory:")

        # Configure DuckDB settings for optimal performance
        self.connection.execute(f"SET memory_limit='{memory_limit}'")

        if threads:
            self.connection.execute(f"SET threads={threads}")

        # Enable HTTP file system for remote access
        self.connection.execute("INSTALL httpfs")
        self.connection.execute("LOAD httpfs")

        # Disable disk-based operations
        self.connection.execute(
            "SET temp_directory='/tmp'"
        )  # Use system temp only if needed
        self.connection.execute("SET enable_object_cache=true")
        self.connection.execute("SET enable_http_metadata_cache=true")

    def query(
        self,
        parquet_files: List[str],
        query: Optional[str] = None,
        columns: Optional[List[str]] = None,
        where: Optional[str] = None,
        limit: Optional[int] = None,
        sample_size: Optional[int] = None,
        sample_method: str = "reservoir",
    ) -> pd.DataFrame:
        """
        Execute a query on remote Parquet files.

        Args:
            parquet_files: List of Parquet file URLs to query
            query: Custom SQL query. If provided, other parameters are ignored.
            columns: List of columns to select. If None, selects all.
            where: WHERE clause for filtering
            limit: Maximum number of rows to return (applied after sampling)
            sample_size: Number of rows to sample. Enables TRUE random sampling.
            sample_method: Sampling method ("reservoir", "bernoulli", "system")

        Returns:
            pandas DataFrame with query results
        """
        if not parquet_files:
            raise ValueError("No parquet files provided")

        # If custom query provided, execute it directly
        if query:
            return self._execute_custom_query(query, parquet_files)

        # Build query components
        select_clause = ", ".join(f'"{col}"' for col in columns) if columns else "*"
        from_clause = self._build_from_clause(parquet_files)

        # Build WHERE clause
        where_clause = f"WHERE {where}" if where else ""

        # Build proper sampling - this is TRUE random sampling
        if sample_size:
            if sample_method == "reservoir":
                # Reservoir sampling - true random sample
                sample_clause = f"USING SAMPLE reservoir({sample_size})"
            elif sample_method == "bernoulli":
                # Bernoulli sampling - each row has probability of being selected
                # Calculate approximate percentage (this is imprecise without knowing total size)
                sample_clause = f"USING SAMPLE bernoulli(1%)"  # Start with 1% and let LIMIT handle the rest
            elif sample_method == "system":
                # System sampling - samples entire pages
                sample_clause = f"USING SAMPLE system(1%)"
            else:
                raise ValueError(f"Unknown sample method: {sample_method}")
        else:
            sample_clause = ""

        # Build LIMIT clause (applied after sampling)
        limit_clause = f"LIMIT {limit}" if limit else ""

        # Construct final query
        sql_query = f"""
        SELECT {select_clause}
        FROM {from_clause}
        {sample_clause}
        {where_clause}
        {limit_clause}
        """.strip()

        return self._execute_query(sql_query)

    def stream_query(
        self, parquet_files: List[str], batch_size: int = 1000, **kwargs
    ) -> Iterator[pd.DataFrame]:
        """
        Stream query results in batches.

        Args:
            parquet_files: List of Parquet file URLs to query
            batch_size: Number of rows per batch
            **kwargs: Other query parameters

        Yields:
            pandas DataFrame batches
        """
        # Remove limit from kwargs to stream all data
        kwargs.pop("limit", None)

        # Execute query and stream results
        # Note: DuckDB doesn't have built-in streaming, so we simulate it
        # by processing in chunks using OFFSET/LIMIT
        offset = 0

        while True:
            chunk_query = self._build_streaming_query(
                parquet_files, offset, batch_size, **kwargs
            )
            chunk_df = self._execute_query(chunk_query)

            if chunk_df.empty:
                break

            yield chunk_df

            if len(chunk_df) < batch_size:
                # Last chunk
                break

            offset += batch_size

    async def astream_query(
        self, parquet_files: List[str], batch_size: int = 1000, **kwargs
    ) -> AsyncIterator[pd.DataFrame]:
        """
        Async stream query results in batches.

        Args:
            parquet_files: List of Parquet file URLs to query
            batch_size: Number of rows per batch
            **kwargs: Other query parameters

        Yields:
            pandas DataFrame batches
        """
        loop = asyncio.get_event_loop()

        # Run streaming query in thread pool to avoid blocking
        for batch in self.stream_query(parquet_files, batch_size, **kwargs):
            # Yield control to event loop
            await asyncio.sleep(0)
            yield batch

    def sample_random(
        self,
        parquet_files: List[str],
        n: int,
        method: str = "reservoir",
        columns: Optional[List[str]] = None,
        where: Optional[str] = None,
    ) -> pd.DataFrame:
        """
        True random sampling from Parquet files.

        Args:
            parquet_files: List of Parquet file URLs
            n: Number of rows to sample
            method: Sampling method ("reservoir", "bernoulli", "system")
            columns: List of columns to select
            where: Optional WHERE clause for filtering

        Returns:
            pandas DataFrame with randomly sampled rows
        """
        return self.query(
            parquet_files=parquet_files,
            columns=columns,
            where=where,
            sample_size=n,
            sample_method=method,
        )

    def get_schema(self, parquet_files: List[str]) -> Dict[str, str]:
        """
        Get schema information from Parquet files.

        Args:
            parquet_files: List of Parquet file URLs

        Returns:
            Dictionary mapping column names to data types
        """
        if not parquet_files:
            return {}

        # Use DESCRIBE to get schema from first file
        first_file = parquet_files[0]
        query = f"DESCRIBE SELECT * FROM read_parquet('{first_file}') LIMIT 0"

        try:
            result = self._execute_query(query)
            return dict(zip(result["column_name"], result["column_type"]))
        except Exception as e:
            warnings.warn(f"Could not get schema: {e}")
            return {}

    def count_rows(self, parquet_files: List[str], where: Optional[str] = None) -> int:
        """
        Count total rows in Parquet files.

        Args:
            parquet_files: List of Parquet file URLs
            where: Optional WHERE clause for filtering

        Returns:
            Total number of rows
        """
        if not parquet_files:
            return 0

        from_clause = self._build_from_clause(parquet_files)
        where_clause = f"WHERE {where}" if where else ""

        query = f"SELECT COUNT(*) as count FROM {from_clause} {where_clause}"

        try:
            result = self._execute_query(query)
            return int(result["count"].iloc[0])
        except Exception:
            return 0

    def _build_from_clause(self, parquet_files: List[str]) -> str:
        """Build FROM clause for multiple parquet files."""
        if len(parquet_files) == 1:
            return f"read_parquet('{parquet_files[0]}')"
        else:
            # Use array syntax for multiple files
            files_str = "', '".join(parquet_files)
            return f"read_parquet(['{files_str}'])"

    def _build_streaming_query(
        self, parquet_files: List[str], offset: int, limit: int, **kwargs
    ) -> str:
        """Build query for streaming with OFFSET/LIMIT."""
        columns = kwargs.get("columns")
        where = kwargs.get("where")

        select_clause = ", ".join(f'"{col}"' for col in columns) if columns else "*"
        from_clause = self._build_from_clause(parquet_files)
        where_clause = f"WHERE {where}" if where else ""

        return f"""
        SELECT {select_clause}
        FROM {from_clause}
        {where_clause}
        ORDER BY RANDOM()  -- Ensure consistent streaming order
        LIMIT {limit} OFFSET {offset}
        """.strip()

    def _execute_query(self, query: str) -> pd.DataFrame:
        """Execute SQL query and return results as DataFrame."""
        try:
            result = self.connection.execute(query).df()
            return result
        except Exception as e:
            raise RuntimeError(f"Query execution failed: {e}\nQuery: {query}")

    def _execute_custom_query(
        self, query: str, parquet_files: List[str]
    ) -> pd.DataFrame:
        """
        Execute a custom SQL query, replacing placeholders with actual file paths.

        The query can use '{parquet_files}' as a placeholder for the parquet files.
        """
        # Replace placeholder with actual file list
        from_clause = self._build_from_clause(parquet_files)
        formatted_query = query.replace("{parquet_files}", from_clause)

        return self._execute_query(formatted_query)

    def close(self):
        """Close the DuckDB connection."""
        if self.connection:
            self.connection.close()

    def __enter__(self):
        """Context manager entry."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit."""
        self.close()


class AsyncDuckDBQueryEngine(DuckDBQueryEngine):
    """Async wrapper for DuckDB operations."""

    async def aquery(self, *args, **kwargs) -> pd.DataFrame:
        """Async version of query method."""
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, self.query, *args, **kwargs)

    async def asample_random(self, *args, **kwargs) -> pd.DataFrame:
        """Async version of random sampling."""
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, self.sample_random, *args, **kwargs)

    async def acount_rows(self, *args, **kwargs) -> int:
        """Async version of row counting."""
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(None, self.count_rows, *args, **kwargs)

    async def __aenter__(self):
        """Async context manager entry."""
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit."""
        self.close()
