from typing import Any, Protocol

from numpy.typing import NDArray
from pydantic import BaseModel

from lego.lego_types import JSONDict


class MilvusDBSettings(BaseModel):
    """Settings for MilvusDBConnector."""

    collection: str
    partition: str = "_default"
    index_param: JSONDict = {"index_type": "AUTOINDEX", "metric_type": "IP"}
    search_param: JSONDict = {"metric_type": "IP"}
    sim_threshold_to_add: float | None = None
    embedding_field: str = "vector"
    more_similar_op: str = "gt"  # x (more_similar_op) threshold
    primary_key: str = "id"


class EmbedModel(Protocol):
    """Embedding model protocol."""

    embed_fn: Any
    embed_dim: int

    def __call__(self, texts: list[str]) -> list[NDArray[float]]:
        """Return embeddings for the input texts."""
        return self.embed_fn(texts)

    def inspect_embed_dim(self) -> int:
        """Measure the embedding dimension."""
        return self.embed_fn(["test"])[0].shape[0]
