MikailDuzenli commited on
Commit
d34cc3f
·
1 Parent(s): 43e5916

Add heatmap viz

Browse files
Files changed (1) hide show
  1. app.py +230 -41
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import gradio as gr
2
  import torch
 
3
  import requests
4
  import numpy as np
5
  import re
6
  import io
 
7
 
8
  from PIL import Image
9
  from transformers import ViltProcessor, ViltForMaskedLM
@@ -15,6 +17,7 @@ model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm")
15
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
16
  model.to(device)
17
 
 
18
  class MinMaxResize:
19
  def __init__(self, shorter=800, longer=1333):
20
  self.min = shorter
@@ -36,7 +39,8 @@ class MinMaxResize:
36
  newh, neww = int(newh + 0.5), int(neww + 0.5)
37
  newh, neww = newh // 32 * 32, neww // 32 * 32
38
 
39
- return x.resize((neww, newh), resample=Image.BICUBIC)
 
40
 
41
  def pixelbert_transform(size=800):
42
  longer = int((1333 / 800) * size)
@@ -44,16 +48,99 @@ def pixelbert_transform(size=800):
44
  [
45
  MinMaxResize(shorter=size, longer=longer),
46
  transforms.ToTensor(),
47
- transforms.Compose([transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]),
 
48
  ]
49
  )
50
 
51
 
52
- def infer(url, mp_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  try:
54
  res = requests.get(url)
55
  image = Image.open(io.BytesIO(res.content)).convert("RGB")
56
- img = pixelbert_transform(size=384)(image)
57
  img = img.unsqueeze(0).to(device)
58
  except:
59
  return False
@@ -67,69 +154,171 @@ def infer(url, mp_text):
67
  encoded = processor.tokenizer(inferred_token)
68
  input_ids = torch.tensor(encoded.input_ids)
69
  encoded = encoded["input_ids"][0][1:-1]
70
- outputs = model(input_ids=input_ids, pixel_values=encoding.pixel_values)
 
71
  mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
72
-
73
  # only take into account text features (minus CLS and SEP token)
74
- mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
75
  mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
76
-
77
  # only take into account text
78
  mlm_values[torch.tensor(encoded) != 103] = 0
79
  select = mlm_values.argmax().item()
80
  encoded[select] = mlm_ids[select].item()
81
  inferred_token = [processor.decode(encoded)]
82
-
83
  encoded = processor.tokenizer(inferred_token)
84
  output = processor.decode(encoded.input_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- return [np.array(image), output]
87
 
88
  title = "What's in the picture ?"
89
 
90
  description = """
91
  Can't find your words to describe an image ? The pre-trained
92
  ViLT model will help you. Give the url of an image and a caption with [MASK] tokens to be filled or play with the given examples !
 
93
  """
94
 
95
 
96
  inputs_interface = [
97
- gr.inputs.Textbox(
98
- label="Url of an image.",
99
- lines=5,
100
- ),
101
- gr.inputs.Textbox(label="Caption with [MASK] tokens to be filled.", lines=5),
102
- ]
 
 
 
 
 
 
 
103
  outputs_interface = [
104
- gr.outputs.Image(label="Image"),
105
- gr.outputs.Textbox(label="description"),
106
- ]
 
107
 
108
  interface = gr.Interface(
109
- fn=infer,
110
- inputs=inputs_interface,
111
- outputs=outputs_interface,
112
- title=title,
113
- description=description,
114
- server_name="0.0.0.0",
115
- server_port=8888,
116
- examples=[
117
- [
118
  "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
119
  "a display of flowers growing out and over the [MASK] [MASK] in front of [MASK] on a [MASK] [MASK].",
120
- ],
121
-
122
- [
123
- "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcT5W71UTcSBm3r5l9NzBemglq983bYvKOHRkw&usqp=CAU",
124
- "An [MASK] with the [MASK] in the [MASK].",
125
- ],
126
-
127
- [
128
- "https://www.referenseo.com/wp-content/uploads/2019/03/image-attractive-960x540.jpg",
129
- "An [MASK] is flying with a [MASK] over a [MASK].",
130
- ],
131
  ],
132
- )
133
 
134
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn.functional as F
4
  import requests
5
  import numpy as np
6
  import re
7
  import io
8
+ import matplotlib.pyplot as plt
9
 
10
  from PIL import Image
11
  from transformers import ViltProcessor, ViltForMaskedLM
 
17
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
  model.to(device)
19
 
20
+
21
  class MinMaxResize:
22
  def __init__(self, shorter=800, longer=1333):
23
  self.min = shorter
 
39
  newh, neww = int(newh + 0.5), int(neww + 0.5)
40
  newh, neww = newh // 32 * 32, neww // 32 * 32
41
 
42
+ return x.resize((neww, newh), resample=Image.Resampling.BICUBIC)
43
+
44
 
45
  def pixelbert_transform(size=800):
46
  longer = int((1333 / 800) * size)
 
48
  [
49
  MinMaxResize(shorter=size, longer=longer),
50
  transforms.ToTensor(),
51
+ transforms.Compose([transforms.Normalize(
52
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]),
53
  ]
54
  )
55
 
56
 
57
+ def cost_matrix_cosine(x, y, eps=1e-5):
58
+ """Compute cosine distnace across every pairs of x, y (batched)
59
+ [B, L_x, D] [B, L_y, D] -> [B, Lx, Ly]"""
60
+ assert x.dim() == y.dim()
61
+ assert x.size(0) == y.size(0)
62
+ assert x.size(2) == y.size(2)
63
+ x_norm = F.normalize(x, p=2, dim=-1, eps=eps)
64
+ y_norm = F.normalize(y, p=2, dim=-1, eps=eps)
65
+ cosine_sim = x_norm.matmul(y_norm.transpose(1, 2))
66
+ cosine_dist = 1 - cosine_sim
67
+ return cosine_dist
68
+
69
+
70
+ @torch.no_grad()
71
+ def ipot(C, x_len, x_pad, y_len, y_pad, joint_pad, beta, iteration, k):
72
+ """ [B, M, N], [B], [B, M], [B], [B, N], [B, M, N]"""
73
+ b, m, n = C.size()
74
+ sigma = torch.ones(b, m, dtype=C.dtype,
75
+ device=C.device) / x_len.unsqueeze(1)
76
+ T = torch.ones(b, n, m, dtype=C.dtype, device=C.device)
77
+ A = torch.exp(-C.transpose(1, 2) / beta)
78
+
79
+ # mask padded positions
80
+ sigma.masked_fill_(x_pad, 0)
81
+ joint_pad = joint_pad.transpose(1, 2)
82
+ T.masked_fill_(joint_pad, 0)
83
+ A.masked_fill_(joint_pad, 0)
84
+
85
+ # broadcastable lengths
86
+ x_len = x_len.unsqueeze(1).unsqueeze(2)
87
+ y_len = y_len.unsqueeze(1).unsqueeze(2)
88
+
89
+ # mask to zero out padding in delta and sigma
90
+ x_mask = (x_pad.to(C.dtype) * 1e4).unsqueeze(1)
91
+ y_mask = (y_pad.to(C.dtype) * 1e4).unsqueeze(1)
92
+
93
+ for _ in range(iteration):
94
+ Q = A * T # bs * n * m
95
+ sigma = sigma.view(b, m, 1)
96
+ for _ in range(k):
97
+ delta = 1 / (y_len * Q.matmul(sigma).view(b, 1, n) + y_mask)
98
+ sigma = 1 / (x_len * delta.matmul(Q) + x_mask)
99
+ T = delta.view(b, n, 1) * Q * sigma
100
+ T.masked_fill_(joint_pad, 0)
101
+ return T
102
+
103
+
104
+ def get_model_embedding_and_mask(model, input_ids, pixel_values):
105
+
106
+ input_shape = input_ids.size()
107
+
108
+ text_batch_size, seq_length = input_shape
109
+ device = input_ids.device
110
+ attention_mask = torch.ones(((text_batch_size, seq_length)), device=device)
111
+ image_batch_size = pixel_values.shape[0]
112
+ image_token_type_idx = 1
113
+
114
+ if image_batch_size != text_batch_size:
115
+ raise ValueError(
116
+ "The text inputs and image inputs need to have the same batch size")
117
+
118
+ pixel_mask = torch.ones((image_batch_size, model.vilt.config.image_size,
119
+ model.vilt.config.image_size), device=device)
120
+
121
+ text_embeds = model.vilt.embeddings.text_embeddings(
122
+ input_ids=input_ids, token_type_ids=None, inputs_embeds=None)
123
+
124
+ image_embeds, image_masks, patch_index = model.vilt.embeddings.visual_embed(
125
+ pixel_values=pixel_values, pixel_mask=pixel_mask, max_image_length=model.vilt.config.max_image_length
126
+ )
127
+ text_embeds = text_embeds + model.vilt.embeddings.token_type_embeddings(
128
+ torch.zeros_like(attention_mask, dtype=torch.long,
129
+ device=text_embeds.device)
130
+ )
131
+ image_embeds = image_embeds + model.vilt.embeddings.token_type_embeddings(
132
+ torch.full_like(image_masks, image_token_type_idx,
133
+ dtype=torch.long, device=text_embeds.device)
134
+ )
135
+
136
+ return text_embeds, image_embeds, attention_mask, image_masks, patch_index
137
+
138
+
139
+ def infer(url, mp_text, hidx):
140
  try:
141
  res = requests.get(url)
142
  image = Image.open(io.BytesIO(res.content)).convert("RGB")
143
+ img = pixelbert_transform(size=500)(image)
144
  img = img.unsqueeze(0).to(device)
145
  except:
146
  return False
 
154
  encoded = processor.tokenizer(inferred_token)
155
  input_ids = torch.tensor(encoded.input_ids)
156
  encoded = encoded["input_ids"][0][1:-1]
157
+ outputs = model(input_ids=input_ids,
158
+ pixel_values=encoding.pixel_values)
159
  mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
160
+
161
  # only take into account text features (minus CLS and SEP token)
162
+ mlm_logits = mlm_logits[1: input_ids.shape[1] - 1, :]
163
  mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
164
+
165
  # only take into account text
166
  mlm_values[torch.tensor(encoded) != 103] = 0
167
  select = mlm_values.argmax().item()
168
  encoded[select] = mlm_ids[select].item()
169
  inferred_token = [processor.decode(encoded)]
170
+
171
  encoded = processor.tokenizer(inferred_token)
172
  output = processor.decode(encoded.input_ids[0], skip_special_tokens=True)
173
+ selected_token = ''
174
+
175
+ if hidx > 0 and hidx < len(encoded["input_ids"][0][:-1]):
176
+ input_ids = torch.tensor(encoded.input_ids)
177
+ outputs = model(
178
+ input_ids=input_ids, pixel_values=encoding.pixel_values, output_hidden_states=True)
179
+
180
+ txt_emb, img_emb, text_masks, image_masks, patch_index = get_model_embedding_and_mask(
181
+ model, input_ids=input_ids, pixel_values=encoding.pixel_values)
182
+
183
+ embedding_output = torch.cat([txt_emb, img_emb], dim=1)
184
+ attention_mask = torch.cat([text_masks, image_masks], dim=1)
185
+
186
+ extended_attention_mask = model.vilt.get_extended_attention_mask(
187
+ attention_mask, input_ids.size(), device=device)
188
+
189
+ encoder_outputs = model.vilt.encoder(
190
+ embedding_output,
191
+ attention_mask=extended_attention_mask,
192
+ head_mask=None,
193
+ output_attentions=False,
194
+ output_hidden_states=True,
195
+ return_dict=True,
196
+ )
197
+
198
+ x = encoder_outputs.hidden_states[-1]
199
+ x = model.vilt.layernorm(x)
200
+
201
+ txt_emb, img_emb = (
202
+ x[:, :txt_emb.shape[1]],
203
+ x[:, txt_emb.shape[1]:],
204
+ )
205
+
206
+ txt_mask, img_mask = (
207
+ text_masks.bool(),
208
+ image_masks.bool(),
209
+ )
210
+
211
+ for i, _len in enumerate(txt_mask.sum(dim=1)):
212
+ txt_mask[i, _len - 1] = False
213
+ txt_mask[:, 0] = False
214
+ img_mask[:, 0] = False
215
+ txt_pad, img_pad = ~txt_mask, ~img_mask
216
+ cost = cost_matrix_cosine(txt_emb.float(), img_emb.float())
217
+ joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2)
218
+ cost.masked_fill_(joint_pad, 0)
219
+
220
+ txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1,
221
+ keepdim=False)).to(dtype=cost.dtype)
222
+ img_len = (img_pad.size(1) - img_pad.sum(dim=1,
223
+ keepdim=False)).to(dtype=cost.dtype)
224
+ T = ipot(cost.detach(),
225
+ txt_len,
226
+ txt_pad,
227
+ img_len,
228
+ img_pad,
229
+ joint_pad,
230
+ 0.1,
231
+ 1000,
232
+ 1,
233
+ )
234
+ plan = T[0]
235
+ plan_single = plan * len(txt_emb)
236
+ cost_ = plan_single.t()
237
+
238
+ cost_ = cost_[hidx][1:].cpu()
239
+
240
+ patch_index, (H, W) = patch_index
241
+ heatmap = torch.zeros(H, W)
242
+ for i, pidx in enumerate(patch_index[0]):
243
+ h, w = pidx[0].item(), pidx[1].item()
244
+ heatmap[h, w] = cost_[i]
245
+
246
+ heatmap = (heatmap - heatmap.mean()) / heatmap.std()
247
+ heatmap = np.clip(heatmap, 1.0, 3.0)
248
+ heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
249
+
250
+ _w, _h = image.size
251
+ overlay = Image.fromarray(np.uint8(heatmap * 255), "L").resize(
252
+ (_w, _h), resample=Image.Resampling.NEAREST
253
+ )
254
+ image_rgba = image.copy()
255
+ image_rgba.putalpha(overlay)
256
+ image = image_rgba
257
+
258
+ selected_token = processor.tokenizer.convert_ids_to_tokens(
259
+ encoded["input_ids"][0][hidx]
260
+ )
261
+
262
+ return [np.array(image), output, selected_token]
263
 
 
264
 
265
  title = "What's in the picture ?"
266
 
267
  description = """
268
  Can't find your words to describe an image ? The pre-trained
269
  ViLT model will help you. Give the url of an image and a caption with [MASK] tokens to be filled or play with the given examples !
270
+ You can even see where the model focused its attention for a given word : just choose the index of the selected word with the slider.
271
  """
272
 
273
 
274
  inputs_interface = [
275
+ gr.inputs.Textbox(
276
+ label="Url of an image.",
277
+ lines=5,
278
+ ),
279
+ gr.inputs.Textbox(
280
+ label="Caption with [MASK] tokens to be filled.", lines=5),
281
+ gr.inputs.Slider(
282
+ minimum=0,
283
+ maximum=38,
284
+ step=1,
285
+ label="Index of token for heatmap visualization (ignored if zero)",
286
+ ),
287
+ ]
288
  outputs_interface = [
289
+ gr.outputs.Image(label="Image"),
290
+ gr.outputs.Textbox(label="description"),
291
+ gr.outputs.Textbox(label="selected token"),
292
+ ]
293
 
294
  interface = gr.Interface(
295
+ fn=infer,
296
+ inputs=inputs_interface,
297
+ outputs=outputs_interface,
298
+ title=title,
299
+ description=description,
300
+ server_name="0.0.0.0",
301
+ server_port=8888,
302
+ examples=[
303
+ [
304
  "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
305
  "a display of flowers growing out and over the [MASK] [MASK] in front of [MASK] on a [MASK] [MASK].",
306
+ 0,
 
 
 
 
 
 
 
 
 
 
307
  ],
 
308
 
309
+ [
310
+ "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcT5W71UTcSBm3r5l9NzBemglq983bYvKOHRkw&usqp=CAU",
311
+ "An [MASK] with the [MASK] in the [MASK].",
312
+ 5,
313
+ ],
314
+
315
+ [
316
+ "https://www.referenseo.com/wp-content/uploads/2019/03/image-attractive-960x540.jpg",
317
+ "An [MASK] is flying with a [MASK] over a [MASK].",
318
+ 2,
319
+ ],
320
+ ],
321
+ )
322
+
323
 
324
+ interface.launch()