"""Embedding provider registry for NLQL.

This module manages the embedding provider used for semantic similarity operations.
"""

from collections.abc import Callable

# Type alias for embedding provider
EmbeddingProvider = Callable[[list[str]], list[list[float]]]


class EmbeddingRegistry:
    """Registry for embedding provider.

    Unlike other registries, this maintains a single "current" provider
    that can be overridden by users.
    """

    def __init__(self) -> None:
        self._provider: EmbeddingProvider | None = None

    def register(self, provider: EmbeddingProvider) -> None:
        """Register an embedding provider.

        Args:
            provider: Embedding provider function that takes list of texts
                     and returns list of embedding vectors
        """
        self._provider = provider

    def get(self) -> EmbeddingProvider:
        """Get the current embedding provider.

        Returns:
            Current embedding provider

        Raises:
            NLQLConfigError: If no provider is registered and default cannot be loaded
        """
        if self._provider is None:
            # Lazy load default provider
            from nlql.text.embedding import DefaultEmbeddingProvider

            default_provider = DefaultEmbeddingProvider()
            self._provider = default_provider

        return self._provider

    def has_custom(self) -> bool:
        """Check if a custom provider has been registered."""
        return self._provider is not None

    def clear(self) -> None:
        """Clear the registered provider."""
        self._provider = None


# Global registry instance
_global_embedding_registry = EmbeddingRegistry()


def register_embedding_provider(provider: EmbeddingProvider) -> EmbeddingProvider:
    """Register a custom embedding provider.

    This is a simple decorator/function that registers a single embedding provider.
    The provider should take a list of texts and return a list of embedding vectors.

    Args:
        provider: Embedding provider function

    Returns:
        The same provider function (for decorator usage)

    Example:
        >>> @register_embedding_provider
        ... def my_embedding(texts: list[str]) -> list[list[float]]:
        ...     # Custom embedding logic
        ...     return [[0.1, 0.2, ...] for _ in texts]

        Or direct registration:
        >>> register_embedding_provider(my_embedding_function)
    """
    _global_embedding_registry.register(provider)
    return provider


def get_embedding_provider() -> EmbeddingProvider:
    """Get the current embedding provider."""
    return _global_embedding_registry.get()

