import random
import numpy as np
from sympy import primerange
import pandas as pd
from rich.progress import Progress

from src.utils.utils import generate_hash_functions, display_results

class privateCMSClient:
    """
    A class to represent the privatized Count-Min Sketch (privateCMS) Client.

    Attributes:
        df: DataFrame containing the dataset.
        epsilon: Privacy parameter for the privatization process.
        k: Number of hash functions used in the CMS.
        m: Size of the sketch matrix.
        dataset: List of values from the dataset.
        domain: Unique values in the dataset.
        N: Total number of elements in the dataset.
        M: Count-Min Sketch matrix.
        client_matrix: List of privatized matrices generated by the client.
        H: List of hash functions.
    
    Methods:
        bernoulli_vector():
            Generates a Bernoulli vector for privatization based on the epsilon value.
        client(d):
            Simulates the client side of the private CMS, returning a privatized sketch vector.
        update_sketch_matrix(v, j):
            Updates the sketch matrix based on the privatized sketch vector.
        estimate_client(d):
            Estimates the frequency of an element using the private CMS sketch matrix.
        execute_client():
            Simulates the client side of the private CMS for all elements in the dataset.
        server_simulator(privatized_data):
            Simulates the server side of the private CMS, processes the privatized data, and estimates frequencies.
    """
    def __init__(self, epsilon, k, m, df):
        """
        Initializes the privateCMSClient with the given parameters.

        Args:
            epsilon (float): Privacy parameter for the privatization process.
            k (int): Number of hash functions.
            m (int): Size of the sketch matrix.
            df (DataFrame): Dataset to be processed.
        """
        self.df = df
        self.epsilon = epsilon
        self.k = k
        self.m = m
        self.dataset = self.df['value'].tolist()
        self.domain = self.df['value'].unique().tolist()
        self.N = len(self.dataset)

        # Creation of the sketch matrix
        self.M = np.zeros((self.k, self.m))

        # List to store the privatized matrices
        self.client_matrix = []

        # Definition of the hash family 3 by 3
        primes = list(primerange(10**6, 10**7))
        p = primes[random.randint(0, len(primes)-1)]
        self.H = generate_hash_functions(self.k,p, 3,self.m)

    
    def bernoulli_vector(self):
        """
        Generates a Bernoulli vector for privatization based on the epsilon value.

        Returns:
            numpy.ndarray: A Bernoulli vector with values -1 and 1.
        """
        b = np.random.binomial(1, (np.exp(self.epsilon/2)) / ((np.exp(self.epsilon/2)) + 1), self.m)
        b = 2 * b - 1  # Convert 0 to -1
        return b

    def client(self, d):
        """
        Simulates the client side of the privatized Count-Min Sketch.

        Args:
            d (element): The element for which the privatized sketch vector is generated.

        Returns:
            tuple: A tuple containing the privatized sketch vector and the index of the chosen hash function.
        """
        j = random.randint(0, self.k-1)
        v = np.full(self.m, -1)
        selected_hash = self.H[j]
        v[selected_hash(d)] = 1
        b = self.bernoulli_vector()
        v_aux = v*b
        # Store the privatized matrix
        self.client_matrix.append((v_aux,j))
        return v_aux,j

    def update_sketch_matrix(self,v,j):
        """
        Updates the sketch matrix based on the given privatized sketch vector.

        Args:
            v (numpy.ndarray): The privatized sketch vector.
            j (int): The index of the selected hash function.
        """
        c_e = (np.exp(self.epsilon/2)+1) / ((np.exp(self.epsilon/2))-1)
        x = self.k * ((c_e/2) * v + (1/2) * np.ones_like(v))
        for i in range (self.m):
            self.M[j,i] += x[i]

    def estimate_client(self,d):
        """
        Estimates the frequency of an element based on the private CMS sketch matrix.

        Args:
            d (element): The element whose frequency is estimated.

        Returns:
            float: The estimated frequency of the element.
        """
        sum_aux = 0
        for i in range(self.k):
            selected_hash = self.H[i]
            sum_aux += self.M[i, selected_hash(d)]

        f_estimated = (self.m/(self.m-1))*((sum_aux/self.k)-(self.N/self.m))
        return f_estimated
    
    def execute_client(self):
        """
        Simulates the client side of the privatized Count-Min Sketch for all elements in the dataset.

        Returns:
            list: A list of privatized sketch vectors for all elements in the dataset.
        """
        with Progress() as progress:
            bar = progress.add_task("Processing client data", total=len(self.dataset))
            
            privatized_data = []
            for d in self.dataset:
                v_i, j_i = self.client(d)
                privatized_data.append((v_i,j_i))
                progress.update(bar, advance=1)
        
        return privatized_data
    
    def server_simulator(self,privatized_data):
        """
        Simulates the server side of the privatized Count-Min Sketch, processes the privatized data, and estimates frequencies.

        Args:
            privatized_data (list): List of privatized sketch vectors.

        Returns:
            tuple: A tuple containing the estimated frequencies and the hash functions used.
        """
        with Progress() as progress:
            bar = progress.add_task('Update sketch matrix', total=len(privatized_data))
            
            for data in privatized_data:
                self.update_sketch_matrix(data[0],data[1])
                progress.update(bar, advance=1)

            bar = progress.add_task('Estimate frequencies', total=len(self.domain))
            F_estimated = {}
            for x in self.domain:
                F_estimated[x] = self.estimate_client(x)
                progress.update(bar, advance=1)

        return F_estimated, self.H

def run_private_cms_client(k, m, e, df):
    """
    Runs the privatized Count-Min Sketch algorithm and displays the results.

    Args:
        k (int): Number of hash functions.
        m (int): Size of the sketch matrix.
        e (float): Privacy parameter.
        df (DataFrame): Dataset to be processed.

    Returns:
        tuple: A tuple containing the hash functions, the results table, the error table, the privatized data, and the estimated frequency DataFrame.
    """
    # Initialize the private Count-Mean Sketch
    PCMS = privateCMSClient(e, k, m, df)

    # Client side: process the private data
    privatized_data = PCMS.execute_client()

    # Simulate the server side
    f_estimated, H = PCMS.server_simulator(privatized_data)

    # Save f_estimated to a file
    df_estimated = pd.DataFrame(list(f_estimated.items()), columns=['Element', 'Frequency'])

    # Show the results
    data_table, error_table = display_results(df, f_estimated)
   
    return H, data_table, error_table, privatized_data, df_estimated