Ali Mohammad commited on
Commit
e7eee8e
1 Parent(s): d8b02aa

add app file

Browse files
Files changed (1) hide show
  1. demo_search.py +189 -0
demo_search.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import copy
4
+ import time
5
+ import requests
6
+ import io
7
+ import numpy as np
8
+ import re
9
+ from einops import rearrange
10
+
11
+ import ipdb
12
+
13
+ from PIL import Image
14
+
15
+ from vilt.config import ex
16
+ from vilt.modules import ViLTransformerSS
17
+
18
+ from vilt.modules.objectives import cost_matrix_cosine, ipot
19
+ from vilt.transforms import pixelbert_transform
20
+ from vilt.datamodules.datamodule_base import get_pretrained_tokenizer
21
+
22
+
23
+ @ex.automain
24
+ def main(_config):
25
+ _config = copy.deepcopy(_config)
26
+
27
+ loss_names = {
28
+ "itm": 1,
29
+ "mlm": 0.5,
30
+ "mpp": 0,
31
+ "vqa": 0,
32
+ "imgcls": 0,
33
+ "nlvr2": 0,
34
+ "irtr": 1,
35
+ "arc": 0,
36
+ }
37
+ tokenizer = get_pretrained_tokenizer(_config["tokenizer"])
38
+
39
+ _config.update(
40
+ {
41
+ "loss_names": loss_names,
42
+ }
43
+ )
44
+
45
+ model = ViLTransformerSS(_config)
46
+ model.setup("test")
47
+ model.eval()
48
+
49
+ device = "cuda:0" if _config["num_gpus"] > 0 else "cpu"
50
+ model.to(device)
51
+ lst_imgs = [f"C:\\Users\\alimh\\PycharmProjects\\ViLT\\assets\\database\\{i}.jpg" for i in range(1,10)]
52
+
53
+
54
+ def infer( mp_text, hidx =0 ):
55
+ def get_image(path):
56
+ image = Image.open(path).convert("RGB")
57
+ img = pixelbert_transform(size=384)(image)
58
+ return img.unsqueeze(0).to(device)
59
+
60
+ imgs = [get_image(pth) for pth in lst_imgs]
61
+
62
+ batch = []
63
+ for img in imgs:
64
+ batch.append({"text": [mp_text], "image": [img]})
65
+
66
+ for dic in batch:
67
+ encoded = tokenizer(dic["text"])
68
+
69
+ dic["text_ids"] = torch.tensor(encoded["input_ids"]).to(device)
70
+ dic["text_labels"] = torch.tensor(encoded["input_ids"]).to(device)
71
+ dic["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device)
72
+
73
+ scores = []
74
+ with torch.no_grad():
75
+
76
+ for dic in batch:
77
+ s = time.time()
78
+ infer = model(dic)
79
+
80
+ e = time.time()
81
+ print("time ", round(e - s, 2))
82
+
83
+ score = model.rank_output(infer["cls_feats"])
84
+ scores.append(score.item())
85
+ print(scores)
86
+ img_idx =np.argmax(scores)
87
+ print(np.argmax(scores) + 1 )
88
+ selected_image = Image.open(lst_imgs[img_idx]).convert("RGB")
89
+ selected_image = np.asarray(selected_image)
90
+ print(selected_image.shape)
91
+ selected_token =""
92
+ if hidx > 0 and hidx < len(encoded["input_ids"][0][:-1]):
93
+ image = Image.open(lst_imgs[img_idx]).convert("RGB")
94
+ selected_batch = batch[img_idx]
95
+ with torch.no_grad():
96
+ infer = model(selected_batch)
97
+ txt_emb, img_emb = infer["text_feats"], infer["image_feats"]
98
+ txt_mask, img_mask = (
99
+ infer["text_masks"].bool(),
100
+ infer["image_masks"].bool(),
101
+ )
102
+ for i, _len in enumerate(txt_mask.sum(dim=1)):
103
+ txt_mask[i, _len - 1] = False
104
+ txt_mask[:, 0] = False
105
+ img_mask[:, 0] = False
106
+ txt_pad, img_pad = ~txt_mask, ~img_mask
107
+
108
+ cost = cost_matrix_cosine(txt_emb.float(), img_emb.float())
109
+ joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2)
110
+ cost.masked_fill_(joint_pad, 0)
111
+
112
+ txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False)).to(
113
+ dtype=cost.dtype
114
+ )
115
+ img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False)).to(
116
+ dtype=cost.dtype
117
+ )
118
+ T = ipot(
119
+ cost.detach(),
120
+ txt_len,
121
+ txt_pad,
122
+ img_len,
123
+ img_pad,
124
+ joint_pad,
125
+ 0.1,
126
+ 1000,
127
+ 1,
128
+ )
129
+
130
+ plan = T[0]
131
+ plan_single = plan * len(txt_emb)
132
+ cost_ = plan_single.t()
133
+
134
+ cost_ = cost_[hidx][1:].cpu()
135
+
136
+ patch_index, (H, W) = infer["patch_index"]
137
+ heatmap = torch.zeros(H, W)
138
+ for i, pidx in enumerate(patch_index[0]):
139
+ h, w = pidx[0].item(), pidx[1].item()
140
+ heatmap[h, w] = cost_[i]
141
+
142
+ heatmap = (heatmap - heatmap.mean()) / heatmap.std()
143
+ heatmap = np.clip(heatmap, 1.0, 3.0)
144
+ heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
145
+
146
+ _w, _h = image.size
147
+ overlay = Image.fromarray(np.uint8(heatmap * 255), "L").resize(
148
+ (_w, _h), resample=Image.NEAREST
149
+ )
150
+ image_rgba = image.copy()
151
+ image_rgba.putalpha(overlay)
152
+ selected_image = image_rgba
153
+
154
+ selected_token = tokenizer.convert_ids_to_tokens(
155
+ encoded["input_ids"][0][hidx]
156
+ )
157
+
158
+
159
+ return [selected_image,hidx]
160
+
161
+ imgs = [Image.open(pth).convert("RGB") for pth in lst_imgs]
162
+ inputs = [
163
+
164
+ gr.inputs.Textbox(label="Caption with [MASK] tokens to be filled.", lines=5),
165
+ gr.inputs.Slider(
166
+ minimum=0,
167
+ maximum=38,
168
+ step=1,
169
+ label="Index of token for heatmap visualization (ignored if zero)",
170
+ ),
171
+ ]
172
+ outputs = [
173
+ gr.outputs.Image(label="Image"),
174
+
175
+
176
+ gr.outputs.Textbox(label="matching index "),
177
+ ]
178
+
179
+
180
+ interface = gr.Interface(
181
+ fn=infer,
182
+ inputs=inputs,
183
+ outputs=outputs,
184
+ server_name="localhost",
185
+ server_port=8888,
186
+
187
+ )
188
+
189
+ interface.launch(debug=True,share=False)