JustinLin610 commited on
Commit
9eb2477
1 Parent(s): 08374eb

remove unnecessary eval functions

Browse files
Files changed (1) hide show
  1. utils/eval_utils.py +1 -331
utils/eval_utils.py CHANGED
@@ -33,32 +33,6 @@ def decode_fn(x, tgt_dict, bpe, generator, tokenizer=None):
33
  return x
34
 
35
 
36
- def eval_caption(task, generator, models, sample, **kwargs):
37
- transtab = str.maketrans({key: None for key in string.punctuation})
38
- hypos = task.inference_step(generator, models, sample)
39
- results = []
40
- for i, sample_id in enumerate(sample["id"].tolist()):
41
- detok_hypo_str = decode_fn(hypos[i][0]["tokens"], task.tgt_dict, task.bpe, generator)
42
- results.append({"image_id": str(sample_id), "caption": detok_hypo_str.translate(transtab).strip()})
43
- return results, None
44
-
45
-
46
- def eval_caption_cn(task, generator, models, sample, **kwargs):
47
- hypos = task.inference_step(generator, models, sample)
48
- results = []
49
- for i, sample_id in enumerate(sample["id"].tolist()):
50
- detok_hypo_str = decode_fn(
51
- hypos[i][0]["tokens"], task.tgt_dict, task.bpe, generator
52
- )
53
- results.append(
54
- {
55
- "image_id": str(sample_id),
56
- "caption": detok_hypo_str.strip(),
57
- }
58
- )
59
- return results, None
60
-
61
-
62
  def eval_ocr(task, generator, models, sample, **kwargs):
63
  gen_out = task.inference_step(generator, models, sample)
64
  hyps, refs, results = [], [], []
@@ -88,312 +62,8 @@ def eval_ocr(task, generator, models, sample, **kwargs):
88
  return results, acc
89
 
90
 
91
- def eval_vqa_gen(task, generator, models, sample, **kwargs):
92
- if kwargs['beam_search_vqa_eval']:
93
- hypos = task.inference_step(generator, models, sample, prefix_tokens=sample['prefix_tokens'])
94
- results = []
95
- for i, sample_id in enumerate(sample["id"].tolist()):
96
- prefix_len = sample['prefix_tokens'][i].ne(1).sum().item()
97
- detok_hypo_str = decode_fn(hypos[i][0]["tokens"][prefix_len:], task.tgt_dict, task.bpe, generator)
98
- results.append({"question_id": int(sample_id), "answer": detok_hypo_str.strip()})
99
- scores = [ref_dict.get(result['answer'], 0) for ref_dict, result in zip(sample['ref_dict'], results)]
100
- return results, scores
101
-
102
- encoder_out = models[0].encoder(
103
- sample["net_input"]["src_tokens"],
104
- src_lengths=sample["net_input"]["src_lengths"],
105
- patch_images=sample["net_input"]["patch_images"],
106
- patch_masks=sample["net_input"]["patch_masks"]
107
- )
108
- device = sample["net_input"]["src_tokens"].device
109
- eos_item = torch.tensor([task.src_dict.eos()])
110
- pad = task.src_dict.pad()
111
- valid_result = []
112
- for valid_answers, valid_constraint_masks in zip(task.valid_answers_list, task.valid_constraint_masks_list):
113
- valid_size = len(valid_answers)
114
- valid_tgt_items = [
115
- torch.cat([torch.tensor(decoder_prompt[1:]), valid_answer, eos_item])
116
- for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
117
- ]
118
- valid_prev_items = [
119
- torch.cat([torch.tensor(decoder_prompt), valid_answer])
120
- for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
121
- ]
122
- valid_constraint_mask_items = [
123
- torch.cat(
124
- [torch.zeros(len(decoder_prompt) - 1, valid_constraint_mask.size(1)).bool(), valid_constraint_mask],
125
- dim=0
126
- )
127
- for decoder_prompt in sample["decoder_prompts"] for valid_constraint_mask in valid_constraint_masks
128
- ]
129
- valid_tgt = data_utils.collate_tokens(valid_tgt_items, pad_idx=pad).to(device)
130
- valid_prev_output = data_utils.collate_tokens(valid_prev_items, pad_idx=pad).to(device)
131
- valid_constraint_masks = data_utils.collate_tokens(valid_constraint_mask_items, pad_idx=pad).to(device)
132
-
133
- new_encoder_out = {}
134
- new_encoder_out["encoder_out"] = [
135
- encoder_out["encoder_out"][0].repeat_interleave(valid_size, dim=1)
136
- ]
137
- new_encoder_out["encoder_padding_mask"] = [
138
- encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_size, dim=0)
139
- ]
140
- new_encoder_out["position_embeddings"] = [
141
- encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0)
142
- ]
143
-
144
- decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out)
145
- decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
146
- lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True)
147
- scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
148
- scores = scores.masked_fill(valid_tgt.eq(task.tgt_dict.pad()), 0)
149
- scores = scores.masked_fill((~valid_constraint_masks).all(2), 0)
150
- scores = scores.sum(1)
151
- scores = scores.view(-1, valid_size)
152
- valid_result.append(scores)
153
- valid_result = torch.cat(valid_result, dim=-1)
154
- predicts = valid_result.argmax(1).tolist()
155
- hyps = [task.index2ans[predict_index] for predict_index in predicts]
156
- results = [{"question_id": int(id), "answer": hyp} for id, hyp in zip(sample["id"].tolist(), hyps)]
157
- scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)]
158
- return results, scores
159
-
160
-
161
- def eval_refcoco(task, generator, models, sample, **kwargs):
162
- def _calculate_ap_score(hyps, refs, thresh=0.5):
163
- interacts = torch.cat(
164
- [torch.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]),
165
- torch.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])],
166
- dim=1
167
- )
168
- area_predictions = (hyps[:, 2] - hyps[:, 0]) * (hyps[:, 3] - hyps[:, 1])
169
- area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
170
- interacts_w = interacts[:, 2] - interacts[:, 0]
171
- interacts_h = interacts[:, 3] - interacts[:, 1]
172
- area_interacts = interacts_w * interacts_h
173
- ious = area_interacts / (area_predictions + area_targets - area_interacts + 1e-6)
174
- return ((ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)).float()
175
-
176
- gen_out = task.inference_step(generator, models, sample)
177
- hyps = []
178
- for i in range(len(gen_out)):
179
- hyps.append(gen_out[i][0]["tokens"][:-1] - len(task.src_dict) + task.cfg.num_bins)
180
- hyps = torch.stack(hyps, dim=0)
181
- hyps = hyps / (task.cfg.num_bins - 1) * task.cfg.max_image_size
182
- hyps[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1)
183
- hyps[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1)
184
-
185
- results = [
186
- {"uniq_id": sample_id,
187
- "box": [hyps[i][0].item(), hyps[i][1].item(), hyps[i][2].item(), hyps[i][3].item()]}
188
- for i, sample_id in enumerate(sample["id"].tolist())
189
- ]
190
- scores = _calculate_ap_score(hyps, sample['region_coords'].float())
191
- return results, scores
192
-
193
-
194
- def eval_snli_ve(task, generator, models, sample, **kwargs):
195
- encoder_out = models[0].encoder(
196
- sample["net_input"]["src_tokens"],
197
- src_lengths=sample["net_input"]["src_lengths"],
198
- patch_images=sample["net_input"]["patch_images"],
199
- patch_masks=sample["net_input"]["patch_masks"]
200
- )
201
- device = sample["net_input"]["src_tokens"].device
202
- eos_item = torch.tensor([task.src_dict.eos()])
203
- pad = task.src_dict.pad()
204
- valid_result = []
205
- for valid_answers, valid_constraint_masks in zip(task.valid_answers_list, task.valid_constraint_masks_list):
206
- valid_size = len(valid_answers)
207
- valid_tgt_items = [
208
- torch.cat([torch.tensor(decoder_prompt[1:]), valid_answer, eos_item])
209
- for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
210
- ]
211
- valid_prev_items = [
212
- torch.cat([torch.tensor(decoder_prompt), valid_answer])
213
- for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
214
- ]
215
- valid_constraint_mask_items = [
216
- torch.cat(
217
- [torch.zeros(len(decoder_prompt) - 1, valid_constraint_mask.size(1)).bool(), valid_constraint_mask],
218
- dim=0
219
- )
220
- for decoder_prompt in sample["decoder_prompts"] for valid_constraint_mask in valid_constraint_masks
221
- ]
222
- valid_tgt = data_utils.collate_tokens(valid_tgt_items, pad_idx=pad).to(device)
223
- valid_prev_output = data_utils.collate_tokens(valid_prev_items, pad_idx=pad).to(device)
224
- valid_constraint_masks = data_utils.collate_tokens(valid_constraint_mask_items, pad_idx=pad).to(device)
225
-
226
- new_encoder_out = {}
227
- new_encoder_out["encoder_out"] = [
228
- encoder_out["encoder_out"][0].repeat_interleave(valid_size, dim=1)
229
- ]
230
- new_encoder_out["encoder_padding_mask"] = [
231
- encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_size, dim=0)
232
- ]
233
- new_encoder_out["position_embeddings"] = [
234
- encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0)
235
- ]
236
-
237
- decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out)
238
- decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
239
- lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True)
240
- scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
241
- scores = scores.masked_fill(valid_tgt.eq(task.tgt_dict.pad()), 0)
242
- scores = scores.masked_fill((~valid_constraint_masks).all(2), 0)
243
- scores = scores.sum(1)
244
- scores = scores.view(-1, valid_size)
245
- valid_result.append(scores)
246
- valid_result = torch.cat(valid_result, dim=-1)
247
- predicts = valid_result.argmax(1).tolist()
248
- hyps = [task.index2ans[predict_index] for predict_index in predicts]
249
- results = [{"uniq_id": id, "answer": hyp} for id, hyp in zip(sample["id"].tolist(), hyps)]
250
- scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)]
251
- return results, scores
252
-
253
-
254
- def eval_image_gen(task, generator, models, sample, **kwargs):
255
- hypos, _ = task.inference_image(generator, sample, models)
256
- tokens = sample['net_input']['src_tokens'][0].view(-1).tolist()
257
- caption = task.bpe.decode(task.tgt_dict.string([token for token in tokens if token >= 4]))[
258
- 38:].replace('/', '')
259
-
260
- text_similarity_score, indices = task.compute_text_similarity(hypos, caption,
261
- sample['net_input']['src_tokens'].device)
262
- results = []
263
- for i, indice in enumerate(indices):
264
- results.append({"sample_id": str(sample["id"][0]), "score": text_similarity_score[i], "image": hypos[indice]})
265
- scores = [max(text_similarity_score).item()]
266
- sorted_hyps = [hypos[indice] for indice in indices]
267
- # dump results
268
- if task.cfg.gen_images_path:
269
- caption_tokens = sample['net_input']['src_tokens'][0].view(-1).tolist()
270
- caption = task.bpe.decode(task.tgt_dict.string([token for token in caption_tokens if token >= 4]))[
271
- 38:].replace('/', '')
272
- task.dump_images(sorted_hyps, text=caption, path=os.path.join(task.cfg.gen_images_path, 'all_results'))
273
- task.dump_images(sorted_hyps, text=caption, path=os.path.join(task.cfg.gen_images_path, 'top1'), topk=1)
274
-
275
- return results, scores
276
-
277
-
278
- def eval_glue(task, generator, models, sample, **kwargs):
279
- net_output = models[0](**sample["net_input"])
280
- net_output[0].masked_fill_(~sample["constraint_masks"], -math.inf)
281
- last_token_ids = sample["net_input"]["prev_output_tokens"].ne(task.src_dict.pad()).sum(1, keepdim=True) - 1
282
- logits = net_output[0].gather(1, last_token_ids.unsqueeze(2).expand(-1, -1, net_output[0].size(2)))
283
- logits = logits.squeeze(1)
284
- predicts = logits.argmax(1).tolist()
285
- hyps = [task.bpe.decode(task.src_dict[predict]).strip() for predict in predicts]
286
- results = [{"hyp": hyp, "ref": ref_dict.keys()[0]} for hyp, ref_dict in zip(hyps, sample['ref_dict'])]
287
- return results, None
288
-
289
-
290
- def eval_gigaword(task, generator, models, sample, **kwargs):
291
- gen_out = task.inference_step(generator, models, sample)
292
- hyps, refs = [], []
293
- results = []
294
- for i in range(len(gen_out)):
295
- hyp = decode_fn(gen_out[i][0]["tokens"], task.tgt_dict, task.bpe, generator).lower().strip()
296
- hyp = fix_tokenization(hyp).replace('1', '#')
297
- ref = sample['target_strs'][i]
298
- hyps.append(hyp)
299
- refs.append(ref)
300
- results.append({"hyp": hyp, "ref": ref})
301
- return results, None
302
-
303
-
304
- def eval_image_classify(task, generator, models, sample, **kwargs):
305
- batch_size = sample["net_input"]["src_tokens"].size(0)
306
- encoder_out = models[0].encoder(
307
- sample["net_input"]["src_tokens"],
308
- src_lengths=sample["net_input"]["src_lengths"],
309
- patch_images=sample["net_input"]["patch_images"],
310
- patch_masks=sample["net_input"]["patch_masks"]
311
- )
312
- device = sample["net_input"]["src_tokens"].device
313
- valid_result = []
314
- for valid_tgt, valid_prev_output, valid_constraint_masks in zip(task.valid_tgt_list,
315
- task.valid_prev_output_list,
316
- task.valid_constraint_masks_list):
317
- valid_tgt_size = valid_tgt.size(0)
318
- valid_tgt = valid_tgt.repeat(batch_size, 1).to(device)
319
- valid_prev_output = valid_prev_output.repeat(batch_size, 1).to(device)
320
- valid_constraint_masks = valid_constraint_masks.repeat(batch_size, 1, 1).to(device)
321
- new_encoder_out = {}
322
- new_encoder_out["encoder_out"] = [
323
- encoder_out["encoder_out"][0].repeat_interleave(valid_tgt_size, dim=1)
324
- ]
325
- new_encoder_out["encoder_padding_mask"] = [
326
- encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_tgt_size, dim=0)
327
- ]
328
- new_encoder_out["position_embeddings"] = [
329
- encoder_out["position_embeddings"][0].repeat_interleave(valid_tgt_size, dim=0)
330
- ]
331
-
332
- decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out)
333
- decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
334
- lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True)
335
- scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
336
- scores = scores.masked_fill(valid_tgt.eq(task.tgt_dict.pad()), 0)
337
- scores = scores.sum(1)
338
- scores = scores.view(-1, valid_tgt_size)
339
- valid_result.append(scores)
340
- valid_result = torch.cat(valid_result, dim=-1)
341
- predicts = valid_result.argmax(1).tolist()
342
- hyps = [task.index2ans[predict_index] for predict_index in predicts]
343
- scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)]
344
- results = [{"uniq_id": id, "answer": hyp} for id, hyp in zip(sample["id"].tolist(), hyps)]
345
- return results, scores
346
-
347
-
348
  def eval_step(task, generator, models, sample, **kwargs):
349
- if task.cfg._name == 'caption':
350
- return eval_caption(task, generator, models, sample, **kwargs)
351
- elif task.cfg._name == "caption_cn":
352
- return eval_caption_cn(task, generator, models, sample, **kwargs)
353
- elif task.cfg._name == "ocr":
354
  return eval_ocr(task, generator, models, sample, **kwargs)
355
- elif task.cfg._name == 'vqa_gen':
356
- return eval_vqa_gen(task, generator, models, sample, **kwargs)
357
- elif task.cfg._name == 'refcoco':
358
- return eval_refcoco(task, generator, models, sample, **kwargs)
359
- elif task.cfg._name == 'snli_ve':
360
- return eval_snli_ve(task, generator, models, sample, **kwargs)
361
- elif task.cfg._name == 'image_gen':
362
- return eval_image_gen(task, generator, models, sample, **kwargs)
363
- elif task.cfg._name in {'cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2'}:
364
- return eval_glue(task, generator, models, sample, **kwargs)
365
- elif task.cfg._name == 'gigaword':
366
- return eval_gigaword(task, generator, models, sample, **kwargs)
367
- elif task.cfg._name == 'image_classify':
368
- return eval_image_classify(task, generator, models, sample, **kwargs)
369
  else:
370
  raise NotImplementedError
371
-
372
-
373
- def merge_results(task, cfg, logger, score_cnt, score_sum, results):
374
- if task.cfg._name == 'image_gen':
375
- if cfg.distributed_training.distributed_world_size > 1:
376
- dist.all_reduce(score_sum.data)
377
- dist.all_reduce(score_cnt.data)
378
- if score_cnt.item() > 0:
379
- logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
380
- score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
381
- ))
382
- else:
383
- gather_results = None
384
- if cfg.distributed_training.distributed_world_size > 1:
385
- gather_results = [None for _ in range(dist.get_world_size())]
386
- dist.all_gather_object(gather_results, results)
387
- dist.all_reduce(score_sum.data)
388
- dist.all_reduce(score_cnt.data)
389
- if score_cnt.item() > 0:
390
- logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
391
- score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
392
- ))
393
-
394
- if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0:
395
- os.makedirs(cfg.common_eval.results_path, exist_ok=True)
396
- output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset))
397
- gather_results = list(chain(*gather_results)) if gather_results is not None else results
398
- with open(output_path, 'w') as fw:
399
- json.dump(gather_results, fw)
 
33
  return x
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def eval_ocr(task, generator, models, sample, **kwargs):
37
  gen_out = task.inference_step(generator, models, sample)
38
  hyps, refs, results = [], [], []
 
62
  return results, acc
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def eval_step(task, generator, models, sample, **kwargs):
66
+ if task.cfg._name == "ocr":
 
 
 
 
67
  return eval_ocr(task, generator, models, sample, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  else:
69
  raise NotImplementedError