import torch
import torch.nn as nn
import torch.nn.functional as F
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from fedbiomed.common.constants import ProcessTypes
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from monai.transforms import (Compose, NormalizeIntensity, AddChannel, Resize, AsDiscrete)
from fedbiomed.common.data import MedicalFolderDataset
import numpy as np
from torch.optim import AdamW
from unet import UNet
class UNetTrainingPlan(TorchTrainingPlan):

    def init_model(self, model_args):
        model = self.Net(model_args)
        return model


    def init_optimizer(self):
        optimizer = AdamW(self.model().parameters())
        return optimizer

    def init_dependencies(self):
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = ["from monai.transforms import (Compose, NormalizeIntensity, AddChannel, Resize, AsDiscrete)",
               "import torch.nn as nn",
               'import torch.nn.functional as F',
               "from fedbiomed.common.data import MedicalFolderDataset",
               'import numpy as np',
               'from torch.optim import AdamW',
               'from unet import UNet']
        return deps


    class Net(nn.Module):
        # Init of UNetTrainingPlan
        def __init__(self, model_args: dict = {}):
            super().__init__()
            self.CHANNELS_DIMENSION = 1

            self.unet = UNet(
                in_channels = model_args.get('in_channels',1),
                out_classes = model_args.get('out_classes',2),
                dimensions = model_args.get('dimensions',2),
                num_encoding_blocks = model_args.get('num_encoding_blocks',5),
                out_channels_first_layer = model_args.get('out_channels_first_layer',64),
                normalization = model_args.get('normalization', None),
                pooling_type = model_args.get('pooling_type', 'max'),
                upsampling_type = model_args.get('upsampling_type','conv'),
                preactivation = model_args.get('preactivation',False),
                residual = model_args.get('residual',False),
                padding = model_args.get('padding',0),
                padding_mode = model_args.get('padding_mode','zeros'),
                activation = model_args.get('activation','ReLU'),
                initial_dilation = model_args.get('initial_dilation',None),
                dropout = model_args.get('dropout',0),
                monte_carlo_dropout = model_args.get('monte_carlo_dropout',0)
            )

        def forward(self, x):
            x = self.unet.forward(x)
            x = F.softmax(x, dim=self.CHANNELS_DIMENSION)
            return x

    @staticmethod
    def get_dice_loss(output, target, epsilon=1e-9):
        SPATIAL_DIMENSIONS = 2, 3, 4
        p0 = output
        g0 = target
        p1 = 1 - p0
        g1 = 1 - g0
        tp = (p0 * g0).sum(dim=SPATIAL_DIMENSIONS)
        fp = (p0 * g1).sum(dim=SPATIAL_DIMENSIONS)
        fn = (p1 * g0).sum(dim=SPATIAL_DIMENSIONS)
        num = 2 * tp
        denom = 2 * tp + fp + fn + epsilon
        dice_score = num / denom
        return 1. - dice_score

    @staticmethod
    def demographics_transform(demographics: dict):
        """Transforms dict of demographics into data type for ML.

        This function is provided for demonstration purposes, but
        note that if you intend to use demographics data as part
        of your model's input, you **must** provide a
        `demographics_transform` function which at the very least
        converts the demographics dict into a torch.Tensor.

        Must return either a torch Tensor or something Tensor-like
        that can be easily converted through the torch.as_tensor()
        function."""

        if isinstance(demographics, dict) and len(demographics) == 0:
            # when input is empty dict, we don't want to transform anything
            return demographics

        # simple example: keep only some keys
        keys_to_keep = ['HEIGHT', 'WEIGHT']
        out = np.array([float(val) for key, val in demographics.items() if key in keys_to_keep])

        # more complex: generate dummy variables for site name
        # not ideal as it requires knowing the site names in advance
        # could be better implemented with some preprocess
        site_names = ['Guys', 'IOP', 'HH']
        len_dummy_vars = len(site_names) + 1
        dummy_vars = np.zeros(shape=(len_dummy_vars,))
        site_name = demographics['SITE_NAME']
        if site_name in site_names:
            site_idx = site_names.index(site_name)
        else:
            site_idx = len_dummy_vars - 1
        dummy_vars[site_idx] = 1.

        return np.concatenate((out, dummy_vars))


    def training_data(self):
    # The training_data creates the Dataloader to be used for training in the general class Torchnn of fedbiomed
        common_shape = (48, 60, 48)
        training_transform = Compose([AddChannel(), Resize(common_shape), NormalizeIntensity(),])
        target_transform = Compose([AddChannel(), Resize(common_shape), AsDiscrete(to_onehot=2)])

        dataset = MedicalFolderDataset(
            root=self.dataset_path,
            data_modalities='T1',
            target_modalities='label',
            transform=training_transform,
            target_transform=target_transform,
            demographics_transform=UNetTrainingPlan.demographics_transform)
        loader_arguments = {'shuffle': True}
        return DataManager(dataset, **loader_arguments)


    def training_step(self, data, target):
        #this function must return the loss to backward it
        img = data[0]['T1']
        demographics = data[1]
        output = self.model().forward(img)
        loss = UNetTrainingPlan.get_dice_loss(output, target['label'])
        avg_loss = loss.mean()
        return avg_loss

    def testing_step(self, data, target):
        img = data[0]['T1']
        demographics = data[1]
        target = target['label']
        prediction = self.model().forward(img)
        loss = UNetTrainingPlan.get_dice_loss(prediction, target)
        avg_loss = loss.mean()  # average per batch
        return avg_loss
