import copy
import random
from functools import wraps

import torch
from torch import nn
import torch.nn.functional as F

from torchvision import transforms as T
from einops import rearrange

# augmentations

class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p
    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

def get_default_aug(image_size, channels = 3):
    is_rgb = channels == 3
    is_greyscale = channels == 1
    rgb_or_greyscale = is_rgb or is_greyscale

    return torch.nn.Sequential(
        RandomApply(
            T.ColorJitter(0.8, 0.8, 0.8, 0.2),
            p = 0.3
        ) if rgb_or_greyscale else nn.Identity(),
        T.RandomGrayscale(p = 0.2) if is_rgb else nn.Identity(),
        T.RandomHorizontalFlip(),
        RandomApply(
            T.GaussianBlur((3, 3), (1.0, 2.0)),
            p = 0.2
        ),
        T.RandomResizedCrop((image_size, image_size)),
        T.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225])
        ) if is_rgb else nn.Identity(),
    )

# helper functions

def default(val, def_val):
    return def_val if val is None else val

def flatten(t):
    return t.reshape(t.shape[0], -1)

def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance

            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

def get_module_device(module):
    return next(module.parameters()).device

def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

def l2norm(t):
    return F.normalize(t, p = 2, dim = -1)

# simclr loss fn

def contrastive_loss(queries, keys, temperature = 0.1):
    b, device = queries.shape[0], queries.device
    logits = queries @ keys.t()
    logits = logits - logits.max(dim=-1, keepdim=True).values
    logits /= temperature
    return F.cross_entropy(logits, torch.arange(b, device=device))

def nt_xent_loss(queries, keys, temperature = 0.1):
    b, device = queries.shape[0], queries.device

    n = b * 2
    projs = torch.cat((queries, keys))
    logits = projs @ projs.t()

    mask = torch.eye(n, device=device).bool()
    logits = logits[~mask].reshape(n, n - 1)
    logits /= temperature

    labels = torch.cat(((torch.arange(b, device = device) + b - 1), torch.arange(b, device=device)), dim=0)
    loss = F.cross_entropy(logits, labels, reduction = 'sum')
    loss /= n
    return loss

# loss fn

def loss_fn(x, y):
    x = l2norm(x)
    y = l2norm(y)
    return 2 - 2 * (x * y).sum(dim=-1)

# MLP class for projector and predictor

def MLP(dim, projection_size, hidden_size = None):
    hidden_size = default(hidden_size, dim)

    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace = True),
        nn.Linear(hidden_size, projection_size)
    )

def SimSiamMLP(dim, projection_size, hidden_size = 4096):
    hidden_size = default(hidden_size, projection_size * 2)

    return nn.Sequential(
        nn.Linear(dim, hidden_size, bias = False),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace = True),
        nn.Linear(hidden_size, hidden_size, bias = False),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace = True),
        nn.Linear(hidden_size, projection_size, bias = False),
        nn.BatchNorm1d(projection_size, affine = False)
    )

# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets

class NetWrapper(nn.Module):
    def __init__(self, net, projection_size, projection_hidden_size = 4096, layer = -2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.projector = None
        self.projection_size = projection_size
        self.projection_hidden_size = projection_hidden_size

        self.hidden = {}
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    def _hook(self, _, input, output):
        device = input[0].device
        self.hidden[device] = flatten(output)

    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    @singleton('projector')
    def _get_projector(self, hidden):
        _, dim = hidden.shape
        projector = SimSiamMLP(dim, self.projection_size, self.projection_hidden_size)
        return projector.to(hidden)

    def get_representation(self, x):
        if self.layer == -1:
            return self.net(x)

        if not self.hook_registered:
            self._register_hook()

        self.hidden.clear()
        _ = self.net(x)
        hidden = self.hidden[x.device]
        self.hidden.clear()

        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

    def forward(self, x, return_projection = True):
        representation = self.get_representation(x)

        if not return_projection:
            return representation

        flattened_representation = rearrange(representation, 'b n d -> (b n) d')
        projector = self._get_projector(flattened_representation)
        projection = projector(flattened_representation)
        return projection, representation

# main class

class SimSiam(nn.Module):
    def __init__(
        self,
        net,
        image_size,
        channels = 3,
        hidden_layer = -2,
        projection_size = 256,
        projection_hidden_size = 4096,
        augment_fn = None,
        augment_fn2 = None
    ):
        super().__init__()
        self.net = net

        # default SimCLR augmentation

        self.augment1 = default(augment_fn, get_default_aug(image_size, channels))
        self.augment2 = default(augment_fn2, self.augment1)

        self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, channels, image_size, image_size, device=device))

    def forward(self, x):
        assert not (self.training and x.shape[0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer'

        image_one, image_two = self.augment1(x), self.augment2(x)

        online_proj_one, _ = self.online_encoder(image_one)
        online_proj_two, _ = self.online_encoder(image_two)

        online_pred_one = self.online_predictor(online_proj_one)
        online_pred_two = self.online_predictor(online_proj_two)

        with torch.no_grad():
            target_encoder = self.online_encoder
            target_proj_one, _ = target_encoder(image_one)
            target_proj_two, _ = target_encoder(image_two)
            target_proj_one.detach_()
            target_proj_two.detach_()

        loss_one = loss_fn(online_pred_one, target_proj_two)
        loss_two = loss_fn(online_pred_two, target_proj_one)

        loss = loss_one + loss_two
        return loss.mean()

# SimCLR

class SimCLR(nn.Module):
    def __init__(
        self,
        net,
        image_size,
        channels = 3,
        hidden_layer = -2,
        project_hidden = True,
        project_dim = 128,
        augment_both = True,
        use_nt_xent_loss = False,
        augment_fn = None,
        temperature = 0.1
    ):
        super().__init__()
        self.net = NetWrapper(net, project_dim, layer = hidden_layer)
        self.augment = default(augment_fn, get_default_aug(image_size, channels))
        self.augment_both = augment_both
        self.temperature = temperature

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate parameters
        self.forward(torch.randn(1, channels, image_size, image_size))

    def forward(self, x):
        b, c, h, w, device = *x.shape, x.device
        transform_fn = self.augment if self.augment_both else noop

        queries, _ = self.net(transform_fn(x))
        keys, _    = self.net(self.augment(x))

        queries, keys = map(flatten, (queries, keys))
        loss = nt_xent_loss(queries, keys, temperature = self.temperature)
        return loss
