import torch
import torch.nn as nn
import torch.nn.functional as F
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from fedbiomed.common.constants import ProcessTypes
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import pandas as pd
class MyTrainingPlan(TorchTrainingPlan):
    
    # Model
    def init_model(self):
        model_args = self.model_args()
        model = self.Net(model_args)
        return model
        
    # Dependencies
    def init_dependencies(self):
        deps = ["from torch.utils.data import Dataset",
                "import pandas as pd"]
        return deps
    
    # network
    class Net(nn.Module):
        def __init__(self, model_args):
            super().__init__()
            self.in_features = model_args['in_features']
            self.out_features = model_args['out_features']
            self.fc1 = nn.Linear(self.in_features, 5)
            self.fc2 = nn.Linear(5, self.out_features)

        def forward(self, x):
            x = self.fc1(x)
            x = F.relu(x)
            x = self.fc2(x)
            return x

    def training_step(self, data, target):
        output = self.model().forward(data).float()
        criterion = torch.nn.MSELoss()
        loss   = torch.sqrt(criterion(output, target.unsqueeze(1)))
        return loss

    class csv_Dataset(Dataset):
    # Here we define a custom Dataset class inherited from the general torch Dataset class
    # This class takes as argument a .csv file path and creates a torch Dataset 
        def __init__(self, dataset_path, x_dim):
            self.input_file = pd.read_csv(dataset_path,sep=',',index_col=False)
            x_train = self.input_file.loc[:,('year','transmission','mileage','tax','mpg','engineSize')].values
            y_train = self.input_file.loc[:,'price'].values
            self.X_train = torch.from_numpy(x_train).float()
            self.Y_train = torch.from_numpy(y_train).float()

        def __len__(self):            
            return len(self.Y_train)

        def __getitem__(self, idx):

            return (self.X_train[idx], self.Y_train[idx])
        
    def training_data(self):
    # The training_data creates the Dataloader to be used for training in the general class TorchTrainingPlan of fedbiomed
        dataset = self.csv_Dataset(self.dataset_path, self.model_args()["in_features"])
        train_kwargs = {'shuffle': True}
        return DataManager(dataset=dataset , **train_kwargs)
