"""
Dataset Info Retriever for Hypersets

Discovers parquet files, configs, splits, and schema information
from HuggingFace datasets efficiently by reading README.md YAML frontmatter.
"""

import time
import json
import hashlib
import re
import yaml
from typing import List, Dict, Optional, Set, Tuple
from dataclasses import dataclass, field
from pathlib import Path
import tempfile
import logging

import requests
from huggingface_hub import HfApi, hf_hub_url
from huggingface_hub.errors import RepositoryNotFoundError

from .exceptions import DatasetNotFoundError, ConfigNotFoundError, SplitNotFoundError

logger = logging.getLogger(__name__)


@dataclass
class SplitInfo:
    """Information about a dataset split."""

    name: str
    file_count: int
    file_paths: List[str]
    estimated_size_bytes: Optional[int] = None


@dataclass
class ConfigInfo:
    """Information about a dataset configuration."""

    name: str
    splits: Dict[str, SplitInfo] = field(default_factory=dict)

    @property
    def split_names(self) -> List[str]:
        return list(self.splits.keys())

    @property
    def total_files(self) -> int:
        return sum(split.file_count for split in self.splits.values())


@dataclass
class DatasetInfo:
    """Complete dataset information."""

    name: str
    configs: Dict[str, ConfigInfo] = field(default_factory=dict)
    total_parquet_files: int = 0
    cache_timestamp: float = field(default_factory=time.time)

    @property
    def config_names(self) -> List[str]:
        return list(self.configs.keys())

    @property
    def estimated_total_size_gb(self) -> float:
        """Rough estimate of total dataset size."""
        # Very rough estimate: assume average parquet file is ~50MB
        return (self.total_parquet_files * 50) / 1024

    def get_parquet_urls(
        self, config: Optional[str] = None, split: Optional[str] = None
    ) -> List[str]:
        """Get parquet file URLs for given config/split."""
        urls = []

        if config:
            if config not in self.configs:
                raise ConfigNotFoundError(
                    f"Config '{config}' not found. Available: {self.config_names}"
                )

            config_info = self.configs[config]
            if split:
                if split not in config_info.splits:
                    raise SplitNotFoundError(
                        f"Split '{split}' not found in config '{config}'. Available: {config_info.split_names}"
                    )
                split_info = config_info.splits[split]
                for file_path in split_info.file_paths:
                    url = hf_hub_url(
                        repo_id=self.name, filename=file_path, repo_type="dataset"
                    )
                    urls.append(url)
            else:
                # All splits in this config
                for split_info in config_info.splits.values():
                    for file_path in split_info.file_paths:
                        url = hf_hub_url(
                            repo_id=self.name, filename=file_path, repo_type="dataset"
                        )
                        urls.append(url)
        else:
            # All configs and splits
            for config_info in self.configs.values():
                for split_info in config_info.splits.values():
                    for file_path in split_info.file_paths:
                        url = hf_hub_url(
                            repo_id=self.name, filename=file_path, repo_type="dataset"
                        )
                        urls.append(url)

        return urls


class DatasetInfoCache:
    """Simple file-based cache for dataset info."""

    def __init__(self, cache_dir: Optional[str] = None, ttl_hours: int = 24):
        self.cache_dir = Path(cache_dir or tempfile.gettempdir()) / "hypersets_cache"
        self.cache_dir.mkdir(exist_ok=True)
        self.ttl_seconds = ttl_hours * 3600

    def _cache_key(self, dataset_name: str) -> str:
        """Generate cache key for dataset."""
        return hashlib.md5(dataset_name.encode()).hexdigest()

    def get(self, dataset_name: str) -> Optional[DatasetInfo]:
        """Get cached dataset info if valid."""
        cache_file = self.cache_dir / f"{self._cache_key(dataset_name)}.json"

        if not cache_file.exists():
            return None

        try:
            data = json.loads(cache_file.read_text())

            # Check if cache is still valid
            if time.time() - data.get("cache_timestamp", 0) > self.ttl_seconds:
                logger.debug(f"Cache expired for {dataset_name}")
                cache_file.unlink()  # Delete expired cache
                return None

            # Reconstruct DatasetInfo from JSON
            dataset_info = DatasetInfo(name=data["name"])
            dataset_info.total_parquet_files = data["total_parquet_files"]
            dataset_info.cache_timestamp = data["cache_timestamp"]

            for config_name, config_data in data["configs"].items():
                config_info = ConfigInfo(name=config_name)
                for split_name, split_data in config_data["splits"].items():
                    split_info = SplitInfo(
                        name=split_name,
                        file_count=split_data["file_count"],
                        file_paths=split_data["file_paths"],
                        estimated_size_bytes=split_data.get("estimated_size_bytes"),
                    )
                    config_info.splits[split_name] = split_info
                dataset_info.configs[config_name] = config_info

            logger.info(f"Using cached dataset info for {dataset_name}")
            return dataset_info

        except Exception as e:
            logger.warning(f"Failed to read cache for {dataset_name}: {e}")
            return None

    def set(self, dataset_info: DatasetInfo):
        """Cache dataset info."""
        cache_file = self.cache_dir / f"{self._cache_key(dataset_info.name)}.json"

        try:
            # Convert to JSON-serializable format
            data = {
                "name": dataset_info.name,
                "total_parquet_files": dataset_info.total_parquet_files,
                "cache_timestamp": dataset_info.cache_timestamp,
                "configs": {},
            }

            for config_name, config_info in dataset_info.configs.items():
                config_data = {"splits": {}}
                for split_name, split_info in config_info.splits.items():
                    split_data = {
                        "file_count": split_info.file_count,
                        "file_paths": split_info.file_paths,
                        "estimated_size_bytes": split_info.estimated_size_bytes,
                    }
                    config_data["splits"][split_name] = split_data
                data["configs"][config_name] = config_data

            cache_file.write_text(json.dumps(data, indent=2))
            logger.debug(f"Cached dataset info for {dataset_info.name}")

        except Exception as e:
            logger.warning(f"Failed to cache dataset info for {dataset_info.name}: {e}")

    def clear(self):
        """Clear all cached data."""
        for cache_file in self.cache_dir.glob("*.json"):
            try:
                cache_file.unlink()
            except Exception as e:
                logger.warning(f"Failed to delete cache file {cache_file}: {e}")


# Global cache instance
_cache = DatasetInfoCache()


def _handle_rate_limit(func, max_retries: int = 5):
    """Simple rate limit handler with exponential backoff."""
    for attempt in range(max_retries):
        try:
            return func()
        except Exception as e:
            if "429" in str(e) or "rate limit" in str(e).lower():
                if attempt < max_retries - 1:
                    delay = 2**attempt
                    logger.warning(
                        f"Rate limited, waiting {delay}s (attempt {attempt + 1})"
                    )
                    time.sleep(delay)
                    continue
            raise e


def _parse_readme_yaml(
    dataset_name: str, token: Optional[str] = None
) -> Optional[Dict]:
    """Parse YAML frontmatter from README.md to get config information."""
    try:
        # Get README.md content
        readme_url = hf_hub_url(
            repo_id=dataset_name, filename="README.md", repo_type="dataset"
        )

        # Add token to headers if provided
        headers = {}
        if token:
            headers["Authorization"] = f"Bearer {token}"

        response = requests.get(readme_url, headers=headers)
        if response.status_code != 200:
            logger.debug(f"Could not fetch README.md for {dataset_name}")
            return None

        readme_content = response.text

        # Extract YAML frontmatter
        yaml_match = re.match(r"^---\n(.*?)\n---", readme_content, re.DOTALL)
        if not yaml_match:
            logger.debug(f"No YAML frontmatter found in README.md for {dataset_name}")
            return None

        yaml_content = yaml.safe_load(yaml_match.group(1))
        return yaml_content

    except Exception as e:
        logger.debug(f"Failed to parse README.md for {dataset_name}: {e}")
        return None


def _expand_glob_pattern(pattern: str, all_files: List[str]) -> List[str]:
    """Expand glob pattern to actual file paths."""
    import fnmatch

    if "*" not in pattern:
        # Not a glob pattern
        return [pattern] if pattern in all_files else []

    # Match files using glob pattern
    matched_files = []
    for file_path in all_files:
        if fnmatch.fnmatch(file_path, pattern):
            matched_files.append(file_path)

    return matched_files


def discover_dataset_info(
    dataset_name: str, token: Optional[str] = None, use_cache: bool = True
) -> DatasetInfo:
    """
    Discover complete dataset information by reading README.md YAML frontmatter.

    This is the CORRECT way to get dataset configs - from the YAML metadata,
    not by parsing folder structure.

    Args:
        dataset_name: HuggingFace dataset name
        token: Optional HF token
        use_cache: Whether to use cached results

    Returns:
        DatasetInfo object with complete dataset information

    Raises:
        DatasetNotFoundError: If dataset doesn't exist or has no parquet files
    """
    # Check cache first
    if use_cache:
        cached = _cache.get(dataset_name)
        if cached:
            return cached

    logger.info(f"Discovering dataset info for {dataset_name}")

    # Get all files in repository with rate limit handling
    def fetch_files():
        api = HfApi(token=token)
        return api.list_repo_files(dataset_name, repo_type="dataset", token=token)

    try:
        repo_files = _handle_rate_limit(fetch_files)
    except RepositoryNotFoundError:
        raise DatasetNotFoundError(f"Dataset '{dataset_name}' not found")
    except Exception as e:
        raise DatasetNotFoundError(f"Failed to access dataset '{dataset_name}': {e}")

    parquet_files = [f for f in repo_files if f.endswith(".parquet")]

    if not parquet_files:
        raise DatasetNotFoundError(
            f"No parquet files found in dataset '{dataset_name}'"
        )

    # Parse README.md YAML frontmatter for config definitions
    yaml_data = _parse_readme_yaml(dataset_name, token)

    dataset_info = DatasetInfo(
        name=dataset_name, total_parquet_files=len(parquet_files)
    )

    if yaml_data and "configs" in yaml_data:
        # Use YAML config definitions (ONLY FRONTMATTER - NO FALLBACKS)
        logger.debug(f"Found {len(yaml_data['configs'])} configs in README.md YAML")

        for config_data in yaml_data["configs"]:
            config_name = config_data.get("config_name")
            if not config_name:
                continue

            data_files = config_data.get("data_files", [])
            config_info = ConfigInfo(name=config_name)

            for split_data in data_files:
                split_name = split_data.get("split", "train")
                path_patterns = split_data.get("path", [])

                # Ensure path_patterns is a list
                if isinstance(path_patterns, str):
                    path_patterns = [path_patterns]

                # Expand glob patterns to actual files - ONLY use YAML patterns
                split_files = []
                for pattern in path_patterns:
                    matched_files = _expand_glob_pattern(pattern, parquet_files)
                    split_files.extend(matched_files)

                if split_files:
                    config_info.splits[split_name] = SplitInfo(
                        name=split_name,
                        file_count=len(split_files),
                        file_paths=split_files,
                    )

            if config_info.splits:  # Only add configs with actual files
                dataset_info.configs[config_name] = config_info

    # Fallback: if no YAML configs found, parse file structure
    if not dataset_info.configs:
        logger.debug(f"No configs in YAML, falling back to file structure parsing")

        # Use original folder structure parsing as fallback
        config_info = ConfigInfo(name="default")

        # Group files by potential splits
        train_files = [f for f in parquet_files if "train" in f.lower()]
        test_files = [f for f in parquet_files if "test" in f.lower()]
        validation_files = [
            f
            for f in parquet_files
            if any(val in f.lower() for val in ["validation", "valid", "val", "dev"])
        ]

        # Assign files to splits
        if train_files:
            config_info.splits["train"] = SplitInfo(
                name="train", file_count=len(train_files), file_paths=train_files
            )
        if test_files:
            config_info.splits["test"] = SplitInfo(
                name="test", file_count=len(test_files), file_paths=test_files
            )
        if validation_files:
            config_info.splits["validation"] = SplitInfo(
                name="validation",
                file_count=len(validation_files),
                file_paths=validation_files,
            )

        # If no clear splits found, put everything in train
        if not config_info.splits:
            config_info.splits["train"] = SplitInfo(
                name="train", file_count=len(parquet_files), file_paths=parquet_files
            )

        dataset_info.configs["default"] = config_info

    # Cache the result
    if use_cache:
        _cache.set(dataset_info)

    logger.info(
        f"Discovered {len(dataset_info.configs)} configs, {len(parquet_files)} files total"
    )

    return dataset_info


def info(dataset_name: str, **kwargs) -> DatasetInfo:
    """Get dataset info (main public API)."""
    return discover_dataset_info(dataset_name, **kwargs)


def list_configs(dataset_name: str, **kwargs) -> List[str]:
    """List available configurations for a dataset."""
    info = discover_dataset_info(dataset_name, **kwargs)
    return info.config_names


def clear_cache():
    """Clear all cached dataset information."""
    _cache.clear()
