Anash commited on
Commit
007be14
·
1 Parent(s): 121245f

adding main file

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as transforms
5
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoConfig
6
+ import clip
7
+ from PIL import Image
8
+ import re
9
+ import numpy as np
10
+ import cv2
11
+ import gradio as gr
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ def proc_ques(ques):
15
+ words = re.sub(r"([.,'!?\"()*#:;])",'',ques.lower()).replace('-', ' ').replace('/', ' ')
16
+ return words
17
+
18
+
19
+ def change_requires_grad(model, req_grad):
20
+ for p in model.parameters():
21
+ p.requires_grad = req_grad
22
+
23
+
24
+ def load_checkpoint(ckpt_path, epoch):
25
+
26
+ model_name = 'nle_model_{}'.format(str(epoch))
27
+ tokenizer_name = 'nle_gpt2_tokenizer_0'
28
+ tokenizer = GPT2Tokenizer.from_pretrained(ckpt_path + tokenizer_name) # load tokenizer
29
+ model = GPT2LMHeadModel.from_pretrained(ckpt_path + model_name).to(device) # load model with config
30
+
31
+ return tokenizer, model
32
+
33
+
34
+ class ImageEncoder(nn.Module):
35
+
36
+ def __init__(self):
37
+ super(ImageEncoder, self).__init__()
38
+
39
+ self.encoder, _ = clip.load("ViT-B/16", device=device) # loads already in eval mode
40
+
41
+ def forward(self, x):
42
+ """
43
+ Expects a tensor of size (batch_size, 3, 224, 224)
44
+ """
45
+ with torch.no_grad():
46
+ x = x.type(self.encoder.visual.conv1.weight.dtype)
47
+ x = self.encoder.visual.conv1(x) # shape = [*, width, grid, grid]
48
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
49
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
50
+ x = torch.cat([self.encoder.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
51
+ x = x + self.encoder.visual.positional_embedding.to(x.dtype)
52
+ x = self.encoder.visual.ln_pre(x)
53
+ x = x.permute(1, 0, 2) # NLD -> LND
54
+ x = self.encoder.visual.transformer(x)
55
+ grid_feats = x.permute(1, 0, 2) # LND -> NLD (N, 197, 768)
56
+ grid_feats = self.encoder.visual.ln_post(grid_feats[:,1:])
57
+
58
+ return grid_feats.float()
59
+
60
+
61
+ def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):
62
+
63
+ assert logits.dim() == 1 # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
64
+ top_k = min(top_k, logits.size(-1))
65
+ if top_k > 0:
66
+ # Remove all tokens with a probability less than the last token in the top-k tokens
67
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
68
+ logits[indices_to_remove] = filter_value
69
+
70
+ if top_p > 0.0:
71
+ # Compute cumulative probabilities of sorted tokens
72
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
73
+ cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
74
+
75
+ # Remove tokens with cumulative probability above the threshold
76
+ sorted_indices_to_remove = cumulative_probabilities > top_p
77
+ # Shift the indices to the right to keep also the first token above the threshold
78
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
79
+ sorted_indices_to_remove[..., 0] = 0
80
+
81
+ # Back to unsorted indices and set them to -infinity
82
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
83
+ logits[indices_to_remove] = filter_value
84
+
85
+ indices_to_remove = logits < threshold
86
+ logits[indices_to_remove] = filter_value
87
+
88
+ return logits
89
+
90
+
91
+ def sample_sequences(img, model, input_ids, segment_ids, tokenizer):
92
+
93
+ SPECIAL_TOKENS = ['<|endoftext|>', '<pad>', '<question>', '<answer>', '<explanation>']
94
+ special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
95
+ because_token = tokenizer.convert_tokens_to_ids('Ġbecause')
96
+ max_len = 20
97
+ current_output = []
98
+ img_embeddings = image_encoder(img)
99
+ always_exp = False
100
+
101
+ with torch.no_grad():
102
+
103
+ for step in range(max_len + 1):
104
+
105
+ if step == max_len:
106
+ break
107
+
108
+ outputs = model(input_ids=input_ids,
109
+ past_key_values=None,
110
+ attention_mask=None,
111
+ token_type_ids=segment_ids,
112
+ position_ids=None,
113
+ encoder_hidden_states=img_embeddings,
114
+ encoder_attention_mask=None,
115
+ labels=None,
116
+ use_cache=False,
117
+ output_attentions=True,
118
+ return_dict=True)
119
+
120
+ lm_logits = outputs.logits
121
+ xa_maps = outputs.cross_attentions
122
+ logits = lm_logits[0, -1, :] / temperature
123
+ logits = top_filtering(logits, top_k=top_k, top_p=top_p)
124
+ probs = F.softmax(logits, dim=-1)
125
+ prev = torch.topk(probs, 1)[1] if no_sample else torch.multinomial(probs, 1)
126
+
127
+ if prev.item() in special_tokens_ids:
128
+ break
129
+
130
+ # take care of when to start the <explanation> token. Nasty code in here (i hate lots of ifs)
131
+ if not always_exp:
132
+
133
+ if prev.item() != because_token:
134
+ new_segment = special_tokens_ids[-2] # answer segment
135
+ else:
136
+ new_segment = special_tokens_ids[-1] # explanation segment
137
+ always_exp = True
138
+ else:
139
+ new_segment = special_tokens_ids[-1] # explanation segment
140
+
141
+ new_segment = torch.LongTensor([new_segment]).to(device)
142
+ current_output.append(prev.item())
143
+ input_ids = torch.cat((input_ids, prev.unsqueeze(0)), dim = 1)
144
+ segment_ids = torch.cat((segment_ids, new_segment.unsqueeze(0)), dim = 1)
145
+
146
+ decoded_sequences = tokenizer.decode(current_output, skip_special_tokens=True).lstrip()
147
+
148
+ return decoded_sequences, xa_maps
149
+
150
+ img_size = 224
151
+ ckpt_path = 'VQAX_p/'
152
+ max_seq_len = 40
153
+ load_from_epoch = 11
154
+ no_sample = True # setting this to False will greatly reduce the evaluation scores, be careful!
155
+ top_k = 0
156
+ top_p = 0.9
157
+ temperature = 1
158
+
159
+ image_encoder = ImageEncoder().to(device)
160
+ change_requires_grad(image_encoder, False)
161
+ tokenizer, model = load_checkpoint(ckpt_path, load_from_epoch)
162
+ model.eval()
163
+
164
+
165
+ img_transform = transforms.Compose([transforms.Resize((img_size,img_size)),
166
+ transforms.ToTensor(),
167
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
168
+
169
+ def get_inputs(text, tokenizer):
170
+ q_segment_id, a_segment_id, e_segment_id = tokenizer.convert_tokens_to_ids(['<question>', '<answer>', '<explanation>'])
171
+ tokens = tokenizer.tokenize(text)
172
+ segment_ids = [q_segment_id] * len(tokens)
173
+ answer = [tokenizer.bos_token] + tokenizer.tokenize(" the answer is")
174
+ answer_len = len(answer)
175
+ tokens += answer
176
+ segment_ids += [a_segment_id] * answer_len
177
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
178
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
179
+ segment_ids = torch.tensor(segment_ids, dtype=torch.long)
180
+ return input_ids.unsqueeze(0).to(device), segment_ids.unsqueeze(0).to(device)
181
+
182
+ def inference(raw_image, question):
183
+
184
+ oimg = raw_image.convert('RGB').resize((224,224))
185
+ img = img_transform(oimg).unsqueeze(0).to(device)
186
+ text = proc_ques(question)
187
+ input_ids, segment_ids = get_inputs(text, tokenizer)
188
+ question_len = len(tokenizer.convert_ids_to_tokens(input_ids[0]))
189
+ seq, xa_maps = sample_sequences(img, model, input_ids, segment_ids, tokenizer)
190
+ last_am = xa_maps[-1].mean(1)[0, question_len:]
191
+ mask = last_am[0, :].reshape(14,14).cpu().numpy()
192
+ mask = cv2.resize(mask / mask.max(), oimg.size)[..., np.newaxis]
193
+ attention_map = (mask * oimg).astype("uint8")
194
+ splitted_seq = seq.split("because")
195
+ return splitted_seq[0].strip(), "because " + splitted_seq[-1].strip(), Image.fromarray(attention_map)
196
+
197
+ inputs = [gr.inputs.Image(type='pil', label="Load the image of your interest"), gr.inputs.Textbox(label="Ask a question on this image")]
198
+ outputs = [gr.outputs.Textbox(label="Answer"), gr.outputs.Textbox(label="Textual Explanation"), gr.outputs.Image(type='pil', label="Visual Explanation")]
199
+
200
+ title = "NLX-GPT: Explanations with Natural Text (Visual Question Answering Demo)"
201
+ gr.Interface(inference, inputs, outputs, title=title).launch()