# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/ml.clip.ipynb.

# %% auto 0
__all__ = ['load_clip_preprocess', 'make_tfms_from_clip_preprocess', 'ClipClassificationHead', 'ClipZeroShotClassifier']

# %% ../../nbs/ml.clip.ipynb 3
import clip
import torch
import torch.nn as nn
from torchvision import transforms
from .vision import TorchVisionTransform

# %% ../../nbs/ml.clip.ipynb 4
def load_clip_preprocess(clip_model_name):
    from clip import clip
    return clip.load(clip_model_name, device='cpu')[1]


# %% ../../nbs/ml.clip.ipynb 5
def make_tfms_from_clip_preprocess(clip_preprocess):
    item_tfms = TorchVisionTransform(transforms.Compose(clip_preprocess.transforms[:-2]))
    batch_tfms = TorchVisionTransform(transforms.Compose(clip_preprocess.transforms[-2:]))
    return item_tfms, batch_tfms

# %% ../../nbs/ml.clip.ipynb 6
class ClipClassificationHead(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.logit_scale = nn.Parameter(clip_model.logit_scale.detach().clone(), requires_grad=True) 
    
    def forward(self, image_features, text_features):
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        logits = self.logit_scale.exp() * (image_features @ text_features.t())
        return logits

# %% ../../nbs/ml.clip.ipynb 7
class ClipZeroShotClassifier(nn.Module):
    def __init__(self, clip_model, class_descriptions):
        super().__init__()
        self.clip_model = clip_model
        self.head = ClipClassificationHead(clip_model)
        with torch.inference_mode():
            ctf = self.compute_text_features(class_descriptions)
        self.class_text_features = nn.Parameter(ctf, requires_grad=False)
    
    def forward(self, image):
        image_features = self.clip_model.encode_image(image)
        return self.head(image_features, self.class_text_features)

    def compute_text_features(self, texts):
        device = next(self.clip_model.parameters()).device
        text_tokens = clip.tokenize(texts)
        text_features = self.clip_model.encode_text(text_tokens.to(device)).float()
        return text_features

