import faiss
import numpy as np
from typing import List, Optional, Tuple
from evoagent.vector_stores import VectorStore

class FaissVectorStore(VectorStore):
    """FAISS-based vector store implementation."""
    
    def __init__(self, embedding_dim: int):
        """
        Initialize FAISS index.
        
        Args:
            embedding_dim (int): Dimension of the embedding vectors
        """
        if embedding_dim <= 0:
            raise ValueError("embedding_dim must be positive")
            
        self.embedding_dim = embedding_dim
        self.index = faiss.IndexFlatL2(embedding_dim)
        self.texts: List[str] = []
        
    def add(self, text: str, embedding: List[float]) -> None:
        """
        Add a text and its embedding to FAISS.
        
        Args:
            text (str): Text to store
            embedding (List[float]): Embedding vector for the text
            
        Raises:
            ValueError: If embedding dimension doesn't match initialization
        """
        if len(embedding) != self.embedding_dim:
            raise ValueError(
                f"Embedding dimension {len(embedding)} does not match "
                f"expected dimension {self.embedding_dim}"
            )
            
        vector = np.array([embedding], dtype=np.float32)
        self.index.add(vector)
        self.texts.append(text)
        
    def add_batch(self, texts: List[str], embeddings: List[List[float]]) -> None:
        """
        Add multiple texts and their embeddings in batch.
        
        Args:
            texts (List[str]): List of texts to store
            embeddings (List[List[float]]): List of embedding vectors
            
        Raises:
            ValueError: If dimensions don't match or inputs have different lengths
        """
        if len(texts) != len(embeddings):
            raise ValueError("Number of texts and embeddings must match")
            
        if not all(len(emb) == self.embedding_dim for emb in embeddings):
            raise ValueError(f"All embeddings must have dimension {self.embedding_dim}")
            
        vectors = np.array(embeddings, dtype=np.float32)
        self.index.add(vectors)
        self.texts.extend(texts)
        
    def query(self, query_embedding: List[float], k: int = 10) -> List[str]:
        """
        Query FAISS for similar texts.
        
        Args:
            query_embedding (List[float]): Query embedding vector
            k (int): Number of results to return
            
        Returns:
            List[str]: List of similar texts
            
        Raises:
            ValueError: If query embedding dimension doesn't match
        """
        if len(query_embedding) != self.embedding_dim:
            raise ValueError(
                f"Query embedding dimension {len(query_embedding)} does not match "
                f"expected dimension {self.embedding_dim}"
            )
            
        k = min(k, len(self.texts))
        if k == 0:
            return []
            
        query_vector = np.array([query_embedding], dtype=np.float32)
        distances, indices = self.index.search(query_vector, k)
        
        return [self.texts[i] for i in indices[0]]
        
    def query_with_scores(self, query_embedding: List[float], k: int = 10) -> List[Tuple[str, float]]:
        """
        Query FAISS for similar texts and return similarity scores.
        
        Args:
            query_embedding (List[float]): Query embedding vector
            k (int): Number of results to return
            
        Returns:
            List[Tuple[str, float]]: List of (text, similarity_score) tuples
        """
        if len(query_embedding) != self.embedding_dim:
            raise ValueError(
                f"Query embedding dimension {len(query_embedding)} does not match "
                f"expected dimension {self.embedding_dim}"
            )
            
        k = min(k, len(self.texts))
        if k == 0:
            return []
            
        query_vector = np.array([query_embedding], dtype=np.float32)
        distances, indices = self.index.search(query_vector, k)
        
        # Convert L2 distances to similarity scores (negative distance)
        return [(self.texts[i], -dist) for i, dist in zip(indices[0], distances[0])]
        
    def save(self, path: str) -> None:
        """
        Save the vector store to disk.
        
        Args:
            path (str): Path to save the index and texts
        """
        faiss.write_index(self.index, f"{path}.faiss")
        np.save(f"{path}_texts.npy", np.array(self.texts, dtype=object))
        
    @classmethod
    def load(cls, path: str, embedding_dim: int) -> 'FaissVectorStore':
        """
        Load a vector store from disk.
        
        Args:
            path (str): Path to load the index and texts from
            embedding_dim (int): Dimension of the embedding vectors
            
        Returns:
            FaissVectorStore: Loaded vector store
        """
        instance = cls(embedding_dim)
        instance.index = faiss.read_index(f"{path}.faiss")
        instance.texts = np.load(f"{path}_texts.npy", allow_pickle=True).tolist()
        return instance