File size: 4,119 Bytes
3b40f46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from torchvision import transforms as T
import numpy as np
from CLIP import clip_explainability as clip

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# https://github.com/hila-chefer/Transformer-MM-Explainability/blob/main/CLIP_explainability.ipynb
class ClipRelevancy(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        # TODO it would make more sense not to load ths model again (already done in the extractor)
        self.model = clip.load("ViT-B/32", device=device, jit=False)[0]
        clip_input_size = 224
        self.preprocess = T.Compose(
            [
                T.Resize((clip_input_size, clip_input_size)),
                T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
            ]
        )
        input_prompts = cfg["bootstrap_text"]
        if type(input_prompts) == str:
            input_prompts = [input_prompts]
        self.text = clip.tokenize(input_prompts).to(cfg["device"])

        if self.cfg["use_negative_bootstrap"]:
            input_negative_prompts = cfg["bootstrap_negative_text"]
            if type(input_negative_prompts) == str:
                input_negative_prompts = [input_negative_prompts]
            self.bootstrap_negative_text = clip.tokenize(input_negative_prompts).to(cfg["device"])

    def image_relevance(self, image_relevance):
        patch_size = 32  # hardcoded for ViT-B/32 which we use
        h = w = 224
        image_relevance = image_relevance.reshape(1, 1, h // patch_size, w // patch_size)
        image_relevance = torch.nn.functional.interpolate(image_relevance, size=(h, w), mode="bilinear")
        image_relevance = image_relevance.reshape(h, w).to(device)
        image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
        return image_relevance

    def interpret(self, image, negative=False):
        text = self.text if not negative else self.bootstrap_negative_text
        batch_size = text.shape[0]
        images = image.repeat(batch_size, 1, 1, 1)
        # TODO this is pretty inefficient, we can calculate the text embeddings instead of recomputing at each call
        logits_per_image, logits_per_text = self.model(images, text)
        probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
        index = [i for i in range(batch_size)]
        one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32)
        one_hot[torch.arange(logits_per_image.shape[0]), index] = 1
        one_hot = torch.from_numpy(one_hot).requires_grad_(True)
        one_hot = torch.sum(one_hot.to(device) * logits_per_image)
        self.model.zero_grad()

        image_attn_blocks = list(dict(self.model.visual.transformer.resblocks.named_children()).values())
        num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
        R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
        R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
        for i, blk in enumerate(image_attn_blocks):
            if i <= self.cfg["relevancy_num_layers"]:
                continue
            grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
            cam = blk.attn_probs.detach()
            cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
            grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
            cam = grad * cam
            cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
            cam = cam.clamp(min=0).mean(dim=1)
            R = R + torch.bmm(cam, R)
        image_relevance = R[:, 0, 1:]

        return image_relevance

    def forward(self, img, preprocess=True, negative=False):
        if preprocess:
            img = self.preprocess(img)
        R_image = self.interpret(img, negative=negative)
        res = []
        for el in R_image:
            res.append(self.image_relevance(el).float())
        res = torch.stack(res, dim=0)
        return res