File size: 4,226 Bytes
aa86478
 
 
 
 
 
 
 
 
 
 
46f48ca
aa86478
 
 
 
 
 
 
 
 
5ded884
 
 
46f48ca
5ded884
 
aa86478
5ded884
aa86478
 
 
 
 
 
 
 
 
 
 
5ded884
aa86478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import cv2
from PIL import Image
import numpy as np
import torch

import PIL

def overlay_attn(original_image,mask):
    # Colormap and alpha for attention mask
    # COLORMAP_OCEAN
    # COLORMAP_OCEAN
    colormap_attn, alpha_attn = cv2.COLORMAP_VIRIDIS, 1 #0.85
    
    # Resize mask to original image size
    w, h = original_image.shape[0], original_image.shape[1]
    mask = cv2.resize(mask / mask.max(), (h, w))[..., np.newaxis]
    
    # Apply colormap to mask
    cmap = cv2.applyColorMap(np.uint8(255 * mask), colormap_attn)

    # Blend mask and original image
    # grayscale_img =  cv2.cvtColor(np.uint8(original_image), cv2.COLOR_RGB2GRAY)
    # grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
    # alpha_blended = cv2.addWeighted(np.uint8(original_image),1, cmap, alpha_attn, 0)
    alpha_blended = cv2.addWeighted(np.uint8(original_image),0.4, cmap, 0.6, 0)


    # alpha_blended = cmap
    

    # Save image
    final_im = Image.fromarray(alpha_blended)
    # final_im = final_im.crop((0,0,250,250))
    return final_im



class VITAttentionGradRollout:
    '''
        Expects timm ViT transformer model
        Adapted from https://github.com/samiraabnar/attention_flow
    '''
    def __init__(self, model, head_fusion='min', discard_ratio=0):
        self.model = model
        self.head_fusion = head_fusion
        self.discard_ratio = discard_ratio
        
        self.attentions = {}
        for idx, module in enumerate(list(model.blocks.children())):
            module.attn.register_forward_hook(self.get_attention(f"attn{idx}"))


    def get_attention(self, name):
        def hook(module, input, output):
            with torch.no_grad():
                input = input[0]
                B, N, C = input.shape
                qkv = (
                    module.qkv(input)
                    .detach()
                    .reshape(B, N, 3, module.num_heads, C // module.num_heads)
                    .permute(2, 0, 3, 1, 4)
                )
                q, k, _ = (
                    qkv[0],
                    qkv[1],
                    qkv[2],
                )  # make torchscript happy (cannot use tensor as tuple)
                attn = (q @ k.transpose(-2, -1)) * module.scale
                attn = attn.softmax(dim=-1)
                self.attentions[name] = attn
        return hook

    def get_attn_mask(self,k=0):
        attn_key = "attn" + str()
        result = torch.eye(self.attentions['attn0'].size(-1)).to(self.attentions['attn0'].device)

        # result = torch.eye(self.attentions['attn2'].size(-1)).to(self.attentions['attn2'].device)
        with torch.no_grad():
            # for attention in self.attentions.values():
            for k in range(11, len(self.attentions.keys())):
                attention = self.attentions[f'attn{k}']
                if self.head_fusion == "mean":
                    attention_heads_fused = attention.mean(axis=1)
                elif self.head_fusion == "max":
                    attention_heads_fused = attention.max(axis=1)[0]
                elif self.head_fusion == "min":
                    attention_heads_fused = attention.min(axis=1)[0]
                else:
                    raise "Attention head fusion type Not supported"

                # Drop the lowest attentions, but
                # don't drop the class token
                flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
                _, indices = flat.topk(int(flat.size(-1)*self.discard_ratio), -1, False)
                indices = indices[indices != 0]
                flat[0, indices] = 0
                I = torch.eye(attention_heads_fused.size(-1)).to(attention_heads_fused.device)
                a = (attention_heads_fused + 1.0*I)/2
                a = a / a.sum(dim=-1).unsqueeze(-1)

                result = torch.matmul(a, result)

        # Look at the total attention between the class token,
        # and the image patches
        mask = result[0, 0 , 1 :]
        # In case of 224x224 image, this brings us from 196 to 14
        width = int(mask.size(-1)**0.5)
        mask = mask.reshape(width, width).detach().cpu().numpy()
        mask = mask / np.max(mask)
        return mask