import torch
import torch.nn as nn
import torch.nn.functional as F
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from fedbiomed.common.constants import ProcessTypes
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class VariationalAutoencoderPlan(TorchTrainingPlan):
    """ Declaration of two encoding layers and 2 decoding layers
    """
    def init_model(self):
        return self.Net()
    
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(784, 400)
            self.fc21 = nn.Linear(400, 20)
            self.fc22 = nn.Linear(400, 20)
            self.fc3 = nn.Linear(20, 400)
            self.fc4 = nn.Linear(400, 784)


        def encode(self, x):
            h1 = F.relu(self.fc1(x))
            return self.fc21(h1), self.fc22(h1)


        def decode(self, z):
            h3 = F.relu(self.fc3(z))
            return torch.sigmoid(self.fc4(h3))


        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return mu + eps*std
    
        """ Forward step in variational autoencoders is done in three steps, encoding
        reparametrizing and decoding.
        """
        def forward(self, x):
            mu, logvar = self.encode(x.view(-1, 784))
            z = self.reparameterize(mu, logvar)
            return self.decode(z), mu, logvar
    
    """ We will work on MNIST data. This is the pytorch wrapper of this data.
    """
    def training_data(self):
        # The training_data creates the Dataloader to be used for training in the general class Torchnn of fedbiomed
        mnist_transform = transforms.Compose([
                transforms.ToTensor(),
        ])
        train_dataset = datasets.MNIST(self.dataset_path, transform=mnist_transform, train=True, download=True)
        return DataManager(train_dataset,shuffle=True)
    
    """ Computed loss for variational autoencoders.
    """
    def final_loss(self,bce_loss, mu, logvar):
        BCE = bce_loss 
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD
    
    """ At each federated learning round, this code will be executed
    in every node making part of the federation.
    """
    def training_step(self, data, target):
       
        criterion = nn.BCELoss(reduction='sum')
        reconstruction, mu, logvar = self.model().forward(data)
        
        bce_loss = criterion(reconstruction, data.view(48,-1))
        loss = self.final_loss(bce_loss, mu, logvar)
        return loss
