nielsr HF staff commited on
Commit
d0b4edc
β€’
1 Parent(s): 4b621a2

First commit

Browse files
Files changed (2) hide show
  1. app.py +71 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip install gradio --upgrade')
3
+ os.system('pip install git+https://github.com/NielsRogge/transformers.git@add_dino --upgrade')
4
+
5
+ import gradio as gr
6
+ from transformers import ViTFeatureExtractor, ViTModel
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchvision
10
+ import matplotlib.pyplot as plt
11
+
12
+ def get_attention_maps(pixel_values, attentions, nh):
13
+ threshold = 0.6
14
+ w_featmap = pixel_values.shape[-2] // model.config.patch_size
15
+ h_featmap = pixel_values.shape[-1] // model.config.patch_size
16
+
17
+ # we keep only a certain percentage of the mass
18
+ val, idx = torch.sort(attentions)
19
+ val /= torch.sum(val, dim=1, keepdim=True)
20
+ cumval = torch.cumsum(val, dim=1)
21
+ th_attn = cumval > (1 - threshold)
22
+ idx2 = torch.argsort(idx)
23
+ for head in range(nh):
24
+ th_attn[head] = th_attn[head][idx2[head]]
25
+ th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
26
+ # interpolate
27
+ th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu().numpy()
28
+
29
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
30
+ attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu()
31
+ attentions = attentions.detach().numpy()
32
+
33
+ # save attentions heatmaps and return list of filenames
34
+ output_dir = '.'
35
+ os.makedirs(output_dir, exist_ok=True)
36
+ attention_maps = []
37
+ print("Number of heads:", nh)
38
+ for j in range(nh):
39
+ fname = os.path.join(output_dir, "attn-head" + str(j) + ".png")
40
+ # save the attention map
41
+ plt.imsave(fname=fname, arr=attentions[j], format='png')
42
+ # append file name
43
+ attention_maps.append(fname)
44
+
45
+ return attention_maps
46
+
47
+ feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits8", do_resize=False)
48
+ model = ViTModel.from_pretrained("facebook/dino-vits8", add_pooling_layer=False)
49
+
50
+ def visualize_attention(image):
51
+ # normalize channels
52
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
53
+
54
+ # forward pass
55
+ outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True)
56
+
57
+ # get attentions of last layer
58
+ attentions = outputs.attentions[-1]
59
+ nh = attentions.shape[1] # number of heads
60
+
61
+ # we keep only the output patch attention
62
+ attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
63
+
64
+ attention_maps = get_attention_maps(pixel_values, attentions, nh)
65
+
66
+ return attention_maps
67
+
68
+ iface = gr.Interface(fn=visualize_attention,
69
+ inputs=gr.inputs.Image(shape=(480, 480), type="pil"),
70
+ outputs=[gr.outputs.Image(type='file', label=f'attention_head_{i}') for i in range(6)])
71
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Pillow
4
+ matplotlib