MR commited on
Commit
de36b67
1 Parent(s): d6ad5e3

Create new file

Browse files
Files changed (1) hide show
  1. app.py +238 -0
app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import ViltProcessor, ViltForQuestionAnswering
3
+ import torch
4
+
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import copy
9
+ import time
10
+ import requests
11
+ import io
12
+ import numpy as np
13
+ import re
14
+
15
+ import ipdb
16
+
17
+ from PIL import Image
18
+
19
+ from vilt.config import ex
20
+ from vilt.modules import ViLTransformerSS
21
+
22
+ from vilt.modules.objectives import cost_matrix_cosine, ipot
23
+ from vilt.transforms import pixelbert_transform
24
+ from vilt.datamodules.datamodule_base import get_pretrained_tokenizer
25
+
26
+
27
+ @ex.automain
28
+ def main(_config):
29
+ _config = copy.deepcopy(_config)
30
+
31
+ loss_names = {
32
+ "itm": 0,
33
+ "mlm": 0.5,
34
+ "mpp": 0,
35
+ "vqa": 0,
36
+ "imgcls": 0,
37
+ "nlvr2": 0,
38
+ "irtr": 0,
39
+ "arc": 0,
40
+ }
41
+ tokenizer = get_pretrained_tokenizer(_config["tokenizer"])
42
+
43
+ _config.update(
44
+ {
45
+ "loss_names": loss_names,
46
+ }
47
+ )
48
+
49
+ model = ViLTransformerSS(_config)
50
+ model.setup("test")
51
+ model.eval()
52
+
53
+ device = "cuda:0" if _config["num_gpus"] > 0 else "cpu"
54
+ model.to(device)
55
+
56
+ def infer(url, mp_text, hidx):
57
+ try:
58
+ res = requests.get(url)
59
+ image = Image.open(io.BytesIO(res.content)).convert("RGB")
60
+ img = pixelbert_transform(size=384)(image)
61
+ img = img.unsqueeze(0).to(device)
62
+ except:
63
+ return False
64
+
65
+ batch = {"text": [""], "image": [None]}
66
+ tl = len(re.findall("\[MASK\]", mp_text))
67
+ inferred_token = [mp_text]
68
+ batch["image"][0] = img
69
+
70
+ with torch.no_grad():
71
+ for i in range(tl):
72
+ batch["text"] = inferred_token
73
+ encoded = tokenizer(inferred_token)
74
+ batch["text_ids"] = torch.tensor(encoded["input_ids"]).to(device)
75
+ batch["text_labels"] = torch.tensor(encoded["input_ids"]).to(device)
76
+ batch["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device)
77
+ encoded = encoded["input_ids"][0][1:-1]
78
+ infer = model(batch)
79
+ mlm_logits = model.mlm_score(infer["text_feats"])[0, 1:-1]
80
+ mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
81
+ mlm_values[torch.tensor(encoded) != 103] = 0
82
+ select = mlm_values.argmax().item()
83
+ encoded[select] = mlm_ids[select].item()
84
+ inferred_token = [tokenizer.decode(encoded)]
85
+
86
+ selected_token = ""
87
+ encoded = tokenizer(inferred_token)
88
+
89
+ if hidx > 0 and hidx < len(encoded["input_ids"][0][:-1]):
90
+ with torch.no_grad():
91
+ batch["text"] = inferred_token
92
+ batch["text_ids"] = torch.tensor(encoded["input_ids"]).to(device)
93
+ batch["text_labels"] = torch.tensor(encoded["input_ids"]).to(device)
94
+ batch["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device)
95
+ infer = model(batch)
96
+ txt_emb, img_emb = infer["text_feats"], infer["image_feats"]
97
+ txt_mask, img_mask = (
98
+ infer["text_masks"].bool(),
99
+ infer["image_masks"].bool(),
100
+ )
101
+ for i, _len in enumerate(txt_mask.sum(dim=1)):
102
+ txt_mask[i, _len - 1] = False
103
+ txt_mask[:, 0] = False
104
+ img_mask[:, 0] = False
105
+ txt_pad, img_pad = ~txt_mask, ~img_mask
106
+
107
+ cost = cost_matrix_cosine(txt_emb.float(), img_emb.float())
108
+ joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2)
109
+ cost.masked_fill_(joint_pad, 0)
110
+
111
+ txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False)).to(
112
+ dtype=cost.dtype
113
+ )
114
+ img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False)).to(
115
+ dtype=cost.dtype
116
+ )
117
+ T = ipot(
118
+ cost.detach(),
119
+ txt_len,
120
+ txt_pad,
121
+ img_len,
122
+ img_pad,
123
+ joint_pad,
124
+ 0.1,
125
+ 1000,
126
+ 1,
127
+ )
128
+
129
+ plan = T[0]
130
+ plan_single = plan * len(txt_emb)
131
+ cost_ = plan_single.t()
132
+
133
+ cost_ = cost_[hidx][1:].cpu()
134
+
135
+ patch_index, (H, W) = infer["patch_index"]
136
+ heatmap = torch.zeros(H, W)
137
+ for i, pidx in enumerate(patch_index[0]):
138
+ h, w = pidx[0].item(), pidx[1].item()
139
+ heatmap[h, w] = cost_[i]
140
+
141
+ heatmap = (heatmap - heatmap.mean()) / heatmap.std()
142
+ heatmap = np.clip(heatmap, 1.0, 3.0)
143
+ heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
144
+
145
+ _w, _h = image.size
146
+ overlay = Image.fromarray(np.uint8(heatmap * 255), "L").resize(
147
+ (_w, _h), resample=Image.NEAREST
148
+ )
149
+ image_rgba = image.copy()
150
+ image_rgba.putalpha(overlay)
151
+ image = image_rgba
152
+
153
+ selected_token = tokenizer.convert_ids_to_tokens(
154
+ encoded["input_ids"][0][hidx]
155
+ )
156
+
157
+ return [np.array(image), inferred_token[0], selected_token]
158
+
159
+ inputs = [
160
+ gr.inputs.Textbox(
161
+ label="Url of an image.",
162
+ lines=5,
163
+ ),
164
+ gr.inputs.Textbox(label="Caption with [MASK] tokens to be filled.", lines=5),
165
+ gr.inputs.Slider(
166
+ minimum=0,
167
+ maximum=38,
168
+ step=1,
169
+ label="Index of token for heatmap visualization (ignored if zero)",
170
+ ),
171
+ ]
172
+ outputs = [
173
+ gr.outputs.Image(label="Image"),
174
+ gr.outputs.Textbox(label="description"),
175
+ gr.outputs.Textbox(label="selected token"),
176
+ ]
177
+
178
+ interface = gr.Interface(
179
+ fn=infer,
180
+ inputs=inputs,
181
+ outputs=outputs,
182
+ server_name="0.0.0.0",
183
+ server_port=8888,
184
+ examples=[
185
+ [
186
+ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
187
+ "a display of flowers growing out and over the [MASK] [MASK] in front of [MASK] on a [MASK] [MASK].",
188
+ 0,
189
+ ],
190
+ [
191
+ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
192
+ "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.",
193
+ 4,
194
+ ],
195
+ [
196
+ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
197
+ "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.",
198
+ 11,
199
+ ],
200
+ [
201
+ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
202
+ "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.",
203
+ 15,
204
+ ],
205
+ [
206
+ "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
207
+ "a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.",
208
+ 18,
209
+ ],
210
+ [
211
+ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
212
+ "a room with a [MASK], a [MASK], a [MASK], and a [MASK].",
213
+ 0,
214
+ ],
215
+ [
216
+ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
217
+ "a room with a rug, a chair, a painting, and a plant.",
218
+ 5,
219
+ ],
220
+ [
221
+ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
222
+ "a room with a rug, a chair, a painting, and a plant.",
223
+ 8,
224
+ ],
225
+ [
226
+ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
227
+ "a room with a rug, a chair, a painting, and a plant.",
228
+ 11,
229
+ ],
230
+ [
231
+ "https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
232
+ "a room with a rug, a chair, a painting, and a plant.",
233
+ 15,
234
+ ],
235
+ ],
236
+ )
237
+
238
+ interface.launch(debug=True)