import torch
import torch.nn as nn
import torch.nn.functional as F
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from fedbiomed.common.constants import ProcessTypes
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import Dataset
from torchvision import transforms
import pandas as pd
from PIL import Image
import os
import numpy as np
class CelebaTrainingPlan(TorchTrainingPlan):
         
    # Defines model 
    def init_model(self):
        model = self.Net()
        return model 
    
    # Here we define the custom dependencies that will be needed by our custom Dataloader
    def init_dependencies(self):
        deps = ["from torch.utils.data import Dataset",
                "from torchvision import transforms",
                "import pandas as pd",
                "from PIL import Image",
                "import os",
                "import numpy as np"]
        return deps

    # Torch modules class
    class Net(nn.Module):
        
        def __init__(self):
            super().__init__()
            #convolution layers
            self.conv1 = nn.Conv2d(3, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 32, 3, 1)
            self.conv3 = nn.Conv2d(32, 32, 3, 1)
            self.conv4 = nn.Conv2d(32, 32, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            # classifier
            self.fc1 = nn.Linear(3168, 128)
            self.fc2 = nn.Linear(128, 2)

        def forward(self, x):
            x = self.conv1(x)
            x = F.max_pool2d(x, 2)
            x = F.relu(x)

            x = self.conv2(x)
            x = F.max_pool2d(x, 2)
            x = F.relu(x)

            x = self.conv3(x)
            x = F.max_pool2d(x, 2)
            x = F.relu(x)

            x = self.conv4(x)
            x = F.max_pool2d(x, 2)
            x = F.relu(x)

            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)

            x = self.dropout2(x)
            x = self.fc2(x)
            output = F.log_softmax(x, dim=1)
            return output


    class CelebaDataset(Dataset):
        """Custom Dataset for loading CelebA face images"""
        
        # we dont load the full data of the images, we retrieve the image with the get item. 
        # in our case, each image is 218*178 * 3colors. there is 67533 images. this take at leas 7G of ram
        # loading images when needed takes more time during training but it wont impact the ram usage as much as loading everything
        def __init__(self, txt_path, img_dir, transform=None):
            df = pd.read_csv(txt_path, sep="\t", index_col=0)
            self.img_dir = img_dir
            self.txt_path = txt_path
            self.img_names = df.index.values
            self.y = df['Smiling'].values
            self.transform = transform
            print("celeba dataset finished")

        def __getitem__(self, index):
            img = np.asarray(Image.open(os.path.join(self.img_dir,
                                        self.img_names[index])))
            img = transforms.ToTensor()(img)
            label = self.y[index]
            return img, label

        def __len__(self):
            return self.y.shape[0]
    
    # The training_data creates the Dataloader to be used for training in the 
    # general class Torchnn of fedbiomed
    def training_data(self):
        dataset = self.CelebaDataset(self.dataset_path + "/target.csv", self.dataset_path + "/data/")
        train_kwargs = {'shuffle': True}
        return DataManager(dataset, **train_kwargs)
    
    # This function must return the loss to backward it 
    def training_step(self, data, target):
        
        output = self.model().forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss
