# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/ml.coop.ipynb.

# %% auto 0
__all__ = ['TextEncoder', 'PromptLearner', 'ClipVisualEncoder', 'PromptLearningTextEncoder', 'PromptLearningClip',
           'make_prompt_learning_clip', 'prepare_prompt_learning_clip']

# %% ../../nbs/ml.coop.ipynb 3
import torch
import torch.nn as nn

from clip import clip
from .clip import ClipClassificationHead

# %% ../../nbs/ml.coop.ipynb 4
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)
        
        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eos_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        return x


class PromptLearner(nn.Module):
    def __init__(
        self, 
        clip_model, 
        tokenizer, 
        class_names, 
        ctx_init=None, 
        n_ctx=None, 
        class_specific_contexts=False, 
        class_token_position='end', 
        **kwargs,
    ):
        super().__init__()
        assert not (ctx_init is None and n_ctx is None), "Either of ctx_init or n_ctx must be specified"
        n_cls = len(class_names)
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            tokens = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(tokens).type(dtype)
            # taking only the part of context corresponding to the given context, i.e. excluding special tokens or padding
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # random initialization
            if class_specific_contexts:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized

        class_names = [name.replace("_", " ") for name in class_names]
        name_lens = [len(tokenizer.encode(name)) for name in class_names]
        prompts = [prompt_prefix + " " + name + "." for name in class_names]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx 
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens
        self.class_token_position = class_token_position

    def forward(self):
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix
        suffix = self.token_suffix

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )
        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)
        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)
        else:
            raise ValueError

        return prompts

# %% ../../nbs/ml.coop.ipynb 5
class ClipVisualEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.dtype = clip_model.dtype
        self.image_encoder = clip_model.visual
    
    def forward(self, image):
        return self.image_encoder(image.type(self.dtype))

class PromptLearningTextEncoder(nn.Module):
    def __init__(self, clip_model, tokenizer, class_names, **kwargs):
        super().__init__()
        self.prompt_learner = PromptLearner(clip_model, tokenizer, class_names, **kwargs)
        self.text_encoder = TextEncoder(clip_model)

    def forward(self):
        prompts = self.prompt_learner()
        text_features = self.text_encoder(prompts, self.prompt_learner.tokenized_prompts)
        return text_features

class PromptLearningClip(nn.Module):
    def __init__(self, clip_model, tokenizer, class_names, **kwargs):
        super().__init__()
        self.visual_encoder = ClipVisualEncoder(clip_model)
        self.text_encoder = PromptLearningTextEncoder(clip_model, tokenizer, class_names, **kwargs)
        self.head = ClipClassificationHead(clip_model)

    def forward(self, image):
        image_features = self.visual_encoder(image)
        text_features = self.text_encoder()
        logits = self.head(image_features, text_features)
        return logits


# %% ../../nbs/ml.coop.ipynb 6
def make_prompt_learning_clip(class_names, clip_model_name="ViT-B/32", prec='fp32', **kwargs):
    from clip.simple_tokenizer import SimpleTokenizer

    clip_model = clip.load(clip_model_name, device='cpu')[0]
    if prec == "fp32" or prec == "amp":
        # CLIP's default precision is fp16
        clip_model.float()
    
    print("Building prompt learning CLIP")
    tokenizer = SimpleTokenizer()
    return PromptLearningClip(clip_model, tokenizer, class_names, **kwargs)

def prepare_prompt_learning_clip(model):
    print("Turning off gradients for all except prompt_learner")
    for name, param in model.named_parameters():
        if "prompt_learner" not in name:
            param.requires_grad_(False)
    return model
