osanseviero HF staff commited on
Commit
b6236ba
1 Parent(s): 4d36fab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -36
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import os
2
  os.system('pip install git+https://github.com/huggingface/transformers.git --upgrade')
 
 
3
 
 
4
  import gradio as gr
5
  from transformers import ViTFeatureExtractor, ViTModel
6
  import torch
@@ -8,7 +11,16 @@ import torch.nn as nn
8
  import torchvision
9
  import matplotlib.pyplot as plt
10
 
11
- def get_attention_maps(pixel_values, attentions, nh):
 
 
 
 
 
 
 
 
 
12
  threshold = 0.6
13
  w_featmap = pixel_values.shape[-2] // model.config.patch_size
14
  h_featmap = pixel_values.shape[-1] // model.config.patch_size
@@ -25,57 +37,120 @@ def get_attention_maps(pixel_values, attentions, nh):
25
 
26
  # interpolate
27
  th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu().numpy()
 
28
  attentions = attentions.reshape(nh, w_featmap, h_featmap)
29
  attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu()
30
  attentions = attentions.detach().numpy()
31
 
32
- # save attentions heatmaps and return list of filenames
33
- output_dir = '.'
34
- os.makedirs(output_dir, exist_ok=True)
35
- attention_maps = []
36
- for j in range(nh):
37
- fname = os.path.join(output_dir, "attn-head" + str(j) + ".png")
 
 
 
 
 
38
 
39
- # save the attention map
40
- plt.imsave(fname=fname, arr=attentions[j], format='png')
41
 
42
- # append file name
43
- attention_maps.append(fname)
44
 
45
- return attention_maps
 
 
 
 
 
 
 
46
 
 
 
47
 
48
- def visualize_attention(video):
49
- return video
50
- """
51
- # normalize channels
52
- pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
53
 
54
- # forward pass
55
- outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True)
56
 
57
- # get attentions of last layer
58
- attentions = outputs.attentions[-1]
59
- nh = attentions.shape[1] # number of heads
60
 
61
- # we keep only the output patch attention
62
- attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
63
- attention_maps = get_attention_maps(pixel_values, attentions, nh)
64
-
65
- return attention_maps"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits8", do_resize=False)
68
- model = ViTModel.from_pretrained("facebook/dino-vits8", add_pooling_layer=False)
69
 
70
  title = "Interactive demo: DINO"
71
  description = "Demo for Facebook AI's DINO, a new method for self-supervised training of Vision Transformers. Using this method, they are capable of segmenting objects within an image without having ever been trained to do so. This can be observed by displaying the self-attention of the heads from the last layer for the [CLS] token query. This demo uses a ViT-S/8 trained with DINO. To use it, simply upload an image or use the example image below. Results will show up in a few seconds."
72
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.14294'>Emerging Properties in Self-Supervised Vision Transformers</a> | <a href='https://github.com/facebookresearch/dino'>Github Repo</a></p>"
73
- examples =[['video.mp4']]
74
- iface = gr.Interface(fn=visualize_attention,
75
- inputs=gr.inputs.Video(gr.inputs.Video()),
76
- outputs=[gr.outputs.Video(label=f'result_video')],
77
  title=title,
78
  description=description,
79
- article=article,
80
- examples=examples)
81
- iface.launch()
 
1
  import os
2
  os.system('pip install git+https://github.com/huggingface/transformers.git --upgrade')
3
+ os.system('pip install gradio --upgrade')
4
+ os.system('pip freeze')
5
 
6
+ import os
7
  import gradio as gr
8
  from transformers import ViTFeatureExtractor, ViTModel
9
  import torch
 
11
  import torchvision
12
  import matplotlib.pyplot as plt
13
 
14
+ import cv2
15
+ import numpy as np
16
+ from tqdm import tqdm
17
+ import glob
18
+ from PIL import Image
19
+
20
+ feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits8", do_resize=True, padding=True)
21
+ model = ViTModel.from_pretrained("facebook/dino-vits8", add_pooling_layer=False)
22
+
23
+ def get_attention_maps(pixel_values, attentions, nh, out, img_path):
24
  threshold = 0.6
25
  w_featmap = pixel_values.shape[-2] // model.config.patch_size
26
  h_featmap = pixel_values.shape[-1] // model.config.patch_size
 
37
 
38
  # interpolate
39
  th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu().numpy()
40
+
41
  attentions = attentions.reshape(nh, w_featmap, h_featmap)
42
  attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0].cpu()
43
  attentions = attentions.detach().numpy()
44
 
45
+ # sum all attentions
46
+ fname = os.path.join(out, "attn-" + os.path.basename(img_path))
47
+ plt.imsave(
48
+ fname=fname,
49
+ arr=sum(
50
+ attentions[i] * 1 / attentions.shape[0]
51
+ for i in range(attentions.shape[0])
52
+ ),
53
+ cmap="inferno",
54
+ format="png",
55
+ )
56
 
 
 
57
 
58
+ def inference(inp: str, out: str):
59
+ print(f"Generating attention images to {out}")
60
 
61
+ # I had to process one at a time since colab was crashing...
62
+ for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))):
63
+ with open(img_path, "rb") as f:
64
+ img = Image.open(f)
65
+ img = img.convert("RGB")
66
+
67
+ # normalize channels
68
+ pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values
69
 
70
+ # forward pass
71
+ outputs = model(pixel_values, output_attentions=True, interpolate_pos_encoding=True)
72
 
73
+ # get attentions of last layer
74
+ attentions = outputs.attentions[-1]
75
+ nh = attentions.shape[1] # number of heads
 
 
76
 
77
+ # we keep only the output patch attention
78
+ attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
79
 
80
+ # sum and save attention maps
81
+ get_attention_maps(pixel_values, attentions, nh, out, img_path)
 
82
 
83
+ def extract_frames_from_video(inp: str, out: str):
84
+ vidcap = cv2.VideoCapture(inp)
85
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
86
+
87
+ print(f"Video: {inp} ({fps} fps)")
88
+ print(f"Extracting frames to {out}")
89
+
90
+ success, image = vidcap.read()
91
+ count = 0
92
+ while success:
93
+ cv2.imwrite(
94
+ os.path.join(out, f"frame-{count:04}.jpg"),
95
+ image,
96
+ )
97
+ success, image = vidcap.read()
98
+ count += 1
99
+ return fps
100
+
101
+ def generate_video_from_images(inp: str, out_name: str, fps: int):
102
+ img_array = []
103
+ attention_images_list = sorted(glob.glob(os.path.join(inp, "attn-*.jpg")))
104
+
105
+ # Get size of the first image
106
+ with open(attention_images_list[0], "rb") as f:
107
+ img = Image.open(f)
108
+ img = img.convert("RGB")
109
+ size = (400, 400)
110
+ img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
111
+
112
+ print(f"Generating video {size} to {out_name}")
113
+
114
+ for filename in tqdm(attention_images_list[1:]):
115
+ with open(filename, "rb") as f:
116
+ img = Image.open(f)
117
+ img = img.convert("RGB")
118
+ img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
119
+
120
+ out = cv2.VideoWriter(
121
+ out_name,
122
+ cv2.VideoWriter_fourcc(*"MP4V"),
123
+ fps,
124
+ size,
125
+ )
126
+
127
+ for i in range(len(img_array)):
128
+ out.write(img_array[i])
129
+ out.release()
130
+ print("Done")
131
+ return
132
+
133
+ def func(video):
134
+ frames_folder = os.path.join("output", "frames")
135
+ attention_folder = os.path.join("output", "attention")
136
+
137
+ os.makedirs(frames_folder, exist_ok=True)
138
+ os.makedirs(attention_folder, exist_ok=True)
139
+
140
+ fps = extract_frames_from_video(video, frames_folder)
141
+
142
+ inference(frames_folder,attention_folder)
143
+ generate_video_from_images(attention_folder, "video.mp4", fps)
144
+
145
+ return "video.mp4"
146
 
 
 
147
 
148
  title = "Interactive demo: DINO"
149
  description = "Demo for Facebook AI's DINO, a new method for self-supervised training of Vision Transformers. Using this method, they are capable of segmenting objects within an image without having ever been trained to do so. This can be observed by displaying the self-attention of the heads from the last layer for the [CLS] token query. This demo uses a ViT-S/8 trained with DINO. To use it, simply upload an image or use the example image below. Results will show up in a few seconds."
150
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.14294'>Emerging Properties in Self-Supervised Vision Transformers</a> | <a href='https://github.com/facebookresearch/dino'>Github Repo</a></p>"
151
+ iface = gr.Interface(fn=func,
152
+ inputs=gr.inputs.Video(type=None),
153
+ outputs="video",
 
154
  title=title,
155
  description=description,
156
+ article=article)