HarryLee commited on
Commit
6710444
1 Parent(s): ec234f9

Update files

Browse files
Files changed (3) hide show
  1. utils/eval_utils.py +3 -313
  2. utils/transforms.py +0 -5
  3. utils/trie.py +1 -6
utils/eval_utils.py CHANGED
@@ -1,19 +1,9 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
  import string
7
  import math
8
- import json
9
- from itertools import chain
10
- import os
11
 
12
  import torch
13
- import torch.distributed as dist
14
 
15
  from data import data_utils
16
- from tasks.nlg_tasks.gigaword import fix_tokenization
17
 
18
 
19
  def get_symbols_to_strip_from_output(generator):
@@ -32,7 +22,7 @@ def decode_fn(x, tgt_dict, bpe, generator, tokenizer=None):
32
  return x
33
 
34
 
35
- def eval_caption(task, generator, models, sample, **kwargs):
36
  transtab = str.maketrans({key: None for key in string.punctuation})
37
  hypos = task.inference_step(generator, models, sample)
38
  results = []
@@ -42,308 +32,8 @@ def eval_caption(task, generator, models, sample, **kwargs):
42
  return results, None
43
 
44
 
45
- def eval_vqa_gen(task, generator, models, sample, **kwargs):
46
- if kwargs['beam_search_vqa_eval']:
47
- hypos = task.inference_step(generator, models, sample, prefix_tokens=sample['prefix_tokens'])
48
- results = []
49
- for i, sample_id in enumerate(sample["id"].tolist()):
50
- prefix_len = sample['prefix_tokens'][i].ne(1).sum().item()
51
- detok_hypo_str = decode_fn(hypos[i][0]["tokens"][prefix_len:], task.tgt_dict, task.bpe, generator)
52
- results.append({"question_id": int(sample_id), "answer": detok_hypo_str.strip()})
53
- scores = [ref_dict.get(result['answer'], 0) for ref_dict, result in zip(sample['ref_dict'], results)]
54
- return results, scores
55
-
56
- encoder_out = models[0].encoder(
57
- sample["net_input"]["src_tokens"],
58
- src_lengths=sample["net_input"]["src_lengths"],
59
- patch_images=sample["net_input"]["patch_images"],
60
- patch_masks=sample["net_input"]["patch_masks"]
61
- )
62
- device = sample["net_input"]["src_tokens"].device
63
- eos_item = torch.tensor([task.src_dict.eos()])
64
- pad = task.src_dict.pad()
65
- valid_result = []
66
- for valid_answers, valid_constraint_masks in zip(task.valid_answers_list, task.valid_constraint_masks_list):
67
- valid_size = len(valid_answers)
68
- valid_tgt_items = [
69
- torch.cat([torch.tensor(decoder_prompt[1:]), valid_answer, eos_item])
70
- for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
71
- ]
72
- valid_prev_items = [
73
- torch.cat([torch.tensor(decoder_prompt), valid_answer])
74
- for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
75
- ]
76
- valid_constraint_mask_items = [
77
- torch.cat(
78
- [torch.zeros(len(decoder_prompt) - 1, valid_constraint_mask.size(1)).bool(), valid_constraint_mask],
79
- dim=0
80
- )
81
- for decoder_prompt in sample["decoder_prompts"] for valid_constraint_mask in valid_constraint_masks
82
- ]
83
- valid_tgt = data_utils.collate_tokens(valid_tgt_items, pad_idx=pad).to(device)
84
- valid_prev_output = data_utils.collate_tokens(valid_prev_items, pad_idx=pad).to(device)
85
- valid_constraint_masks = data_utils.collate_tokens(valid_constraint_mask_items, pad_idx=pad).to(device)
86
-
87
- new_encoder_out = {}
88
- new_encoder_out["encoder_out"] = [
89
- encoder_out["encoder_out"][0].repeat_interleave(valid_size, dim=1)
90
- ]
91
- new_encoder_out["encoder_padding_mask"] = [
92
- encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_size, dim=0)
93
- ]
94
- new_encoder_out["position_embeddings"] = [
95
- encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0)
96
- ]
97
-
98
- decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out)
99
- decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
100
- lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True)
101
- scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
102
- scores = scores.masked_fill(valid_tgt.eq(task.tgt_dict.pad()), 0)
103
- scores = scores.masked_fill((~valid_constraint_masks).all(2), 0)
104
- scores = scores.sum(1)
105
- scores = scores.view(-1, valid_size)
106
- valid_result.append(scores)
107
- valid_result = torch.cat(valid_result, dim=-1)
108
- predicts = valid_result.argmax(1).tolist()
109
- hyps = [task.index2ans[predict_index] for predict_index in predicts]
110
- results = [{"question_id": int(id), "answer": hyp} for id, hyp in zip(sample["id"].tolist(), hyps)]
111
- scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)]
112
- return results, scores
113
-
114
-
115
- def eval_refcoco(task, generator, models, sample, **kwargs):
116
- def _calculate_ap_score(hyps, refs, thresh=0.5):
117
- interacts = torch.cat(
118
- [torch.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]),
119
- torch.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])],
120
- dim=1
121
- )
122
- area_predictions = (hyps[:, 2] - hyps[:, 0]) * (hyps[:, 3] - hyps[:, 1])
123
- area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
124
- interacts_w = interacts[:, 2] - interacts[:, 0]
125
- interacts_h = interacts[:, 3] - interacts[:, 1]
126
- area_interacts = interacts_w * interacts_h
127
- ious = area_interacts / (area_predictions + area_targets - area_interacts + 1e-6)
128
- return ((ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)).float()
129
-
130
- gen_out = task.inference_step(generator, models, sample)
131
- hyps = []
132
- for i in range(len(gen_out)):
133
- hyps.append(gen_out[i][0]["tokens"][:-1] - len(task.src_dict) + task.cfg.num_bins)
134
- hyps = torch.stack(hyps, dim=0)
135
- hyps = hyps / (task.cfg.num_bins - 1) * task.cfg.max_image_size
136
- hyps[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1)
137
- hyps[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1)
138
-
139
- results = [
140
- {"uniq_id": sample_id,
141
- "box": [hyps[i][0].item(), hyps[i][1].item(), hyps[i][2].item(), hyps[i][3].item()]}
142
- for i, sample_id in enumerate(sample["id"].tolist())
143
- ]
144
- scores = _calculate_ap_score(hyps, sample['region_coords'].float())
145
- return results, scores
146
-
147
-
148
- def eval_snli_ve(task, generator, models, sample, **kwargs):
149
- encoder_out = models[0].encoder(
150
- sample["net_input"]["src_tokens"],
151
- src_lengths=sample["net_input"]["src_lengths"],
152
- patch_images=sample["net_input"]["patch_images"],
153
- patch_masks=sample["net_input"]["patch_masks"]
154
- )
155
- device = sample["net_input"]["src_tokens"].device
156
- eos_item = torch.tensor([task.src_dict.eos()])
157
- pad = task.src_dict.pad()
158
- valid_result = []
159
- for valid_answers, valid_constraint_masks in zip(task.valid_answers_list, task.valid_constraint_masks_list):
160
- valid_size = len(valid_answers)
161
- valid_tgt_items = [
162
- torch.cat([torch.tensor(decoder_prompt[1:]), valid_answer, eos_item])
163
- for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
164
- ]
165
- valid_prev_items = [
166
- torch.cat([torch.tensor(decoder_prompt), valid_answer])
167
- for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
168
- ]
169
- valid_constraint_mask_items = [
170
- torch.cat(
171
- [torch.zeros(len(decoder_prompt) - 1, valid_constraint_mask.size(1)).bool(), valid_constraint_mask],
172
- dim=0
173
- )
174
- for decoder_prompt in sample["decoder_prompts"] for valid_constraint_mask in valid_constraint_masks
175
- ]
176
- valid_tgt = data_utils.collate_tokens(valid_tgt_items, pad_idx=pad).to(device)
177
- valid_prev_output = data_utils.collate_tokens(valid_prev_items, pad_idx=pad).to(device)
178
- valid_constraint_masks = data_utils.collate_tokens(valid_constraint_mask_items, pad_idx=pad).to(device)
179
-
180
- new_encoder_out = {}
181
- new_encoder_out["encoder_out"] = [
182
- encoder_out["encoder_out"][0].repeat_interleave(valid_size, dim=1)
183
- ]
184
- new_encoder_out["encoder_padding_mask"] = [
185
- encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_size, dim=0)
186
- ]
187
- new_encoder_out["position_embeddings"] = [
188
- encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0)
189
- ]
190
-
191
- decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out)
192
- decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
193
- lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True)
194
- scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
195
- scores = scores.masked_fill(valid_tgt.eq(task.tgt_dict.pad()), 0)
196
- scores = scores.masked_fill((~valid_constraint_masks).all(2), 0)
197
- scores = scores.sum(1)
198
- scores = scores.view(-1, valid_size)
199
- valid_result.append(scores)
200
- valid_result = torch.cat(valid_result, dim=-1)
201
- predicts = valid_result.argmax(1).tolist()
202
- hyps = [task.index2ans[predict_index] for predict_index in predicts]
203
- results = [{"uniq_id": id, "answer": hyp} for id, hyp in zip(sample["id"].tolist(), hyps)]
204
- scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)]
205
- return results, scores
206
-
207
-
208
- def eval_image_gen(task, generator, models, sample, **kwargs):
209
- hypos, _ = task.inference_image(generator, sample, models)
210
- tokens = sample['net_input']['src_tokens'][0].view(-1).tolist()
211
- caption = task.bpe.decode(task.tgt_dict.string([token for token in tokens if token >= 4]))[
212
- 38:].replace('/', '')
213
-
214
- text_similarity_score, indices = task.compute_text_similarity(hypos, caption,
215
- sample['net_input']['src_tokens'].device)
216
- results = []
217
- for i, indice in enumerate(indices):
218
- results.append({"sample_id": str(sample["id"][0]), "score": text_similarity_score[i], "image": hypos[indice]})
219
- scores = [max(text_similarity_score).item()]
220
- sorted_hyps = [hypos[indice] for indice in indices]
221
- # dump results
222
- if task.cfg.gen_images_path:
223
- caption_tokens = sample['net_input']['src_tokens'][0].view(-1).tolist()
224
- caption = task.bpe.decode(task.tgt_dict.string([token for token in caption_tokens if token >= 4]))[
225
- 38:].replace('/', '')
226
- task.dump_images(sorted_hyps, text=caption, path=os.path.join(task.cfg.gen_images_path, 'all_results'))
227
- task.dump_images(sorted_hyps, text=caption, path=os.path.join(task.cfg.gen_images_path, 'top1'), topk=1)
228
-
229
- return results, scores
230
-
231
-
232
- def eval_glue(task, generator, models, sample, **kwargs):
233
- net_output = models[0](**sample["net_input"])
234
- net_output[0].masked_fill_(~sample["constraint_masks"], -math.inf)
235
- last_token_ids = sample["net_input"]["prev_output_tokens"].ne(task.src_dict.pad()).sum(1, keepdim=True) - 1
236
- logits = net_output[0].gather(1, last_token_ids.unsqueeze(2).expand(-1, -1, net_output[0].size(2)))
237
- logits = logits.squeeze(1)
238
- predicts = logits.argmax(1).tolist()
239
- hyps = [task.bpe.decode(task.src_dict[predict]).strip() for predict in predicts]
240
- results = [{"hyp": hyp, "ref": ref_dict.keys()[0]} for hyp, ref_dict in zip(hyps, sample['ref_dict'])]
241
- return results, None
242
-
243
-
244
- def eval_gigaword(task, generator, models, sample, **kwargs):
245
- gen_out = task.inference_step(generator, models, sample)
246
- hyps, refs = [], []
247
- results = []
248
- for i in range(len(gen_out)):
249
- hyp = decode_fn(gen_out[i][0]["tokens"], task.tgt_dict, task.bpe, generator).lower().strip()
250
- hyp = fix_tokenization(hyp).replace('1', '#')
251
- ref = sample['target_strs'][i]
252
- hyps.append(hyp)
253
- refs.append(ref)
254
- results.append({"hyp": hyp, "ref": ref})
255
- return results, None
256
-
257
-
258
- def eval_image_classify(task, generator, models, sample, **kwargs):
259
- batch_size = sample["net_input"]["src_tokens"].size(0)
260
- encoder_out = models[0].encoder(
261
- sample["net_input"]["src_tokens"],
262
- src_lengths=sample["net_input"]["src_lengths"],
263
- patch_images=sample["net_input"]["patch_images"],
264
- patch_masks=sample["net_input"]["patch_masks"]
265
- )
266
- device = sample["net_input"]["src_tokens"].device
267
- valid_result = []
268
- for valid_tgt, valid_prev_output, valid_constraint_masks in zip(task.valid_tgt_list,
269
- task.valid_prev_output_list,
270
- task.valid_constraint_masks_list):
271
- valid_tgt_size = valid_tgt.size(0)
272
- valid_tgt = valid_tgt.repeat(batch_size, 1).to(device)
273
- valid_prev_output = valid_prev_output.repeat(batch_size, 1).to(device)
274
- valid_constraint_masks = valid_constraint_masks.repeat(batch_size, 1, 1).to(device)
275
- new_encoder_out = {}
276
- new_encoder_out["encoder_out"] = [
277
- encoder_out["encoder_out"][0].repeat_interleave(valid_tgt_size, dim=1)
278
- ]
279
- new_encoder_out["encoder_padding_mask"] = [
280
- encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_tgt_size, dim=0)
281
- ]
282
- new_encoder_out["position_embeddings"] = [
283
- encoder_out["position_embeddings"][0].repeat_interleave(valid_tgt_size, dim=0)
284
- ]
285
-
286
- decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out)
287
- decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
288
- lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True)
289
- scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
290
- scores = scores.masked_fill(valid_tgt.eq(task.tgt_dict.pad()), 0)
291
- scores = scores.sum(1)
292
- scores = scores.view(-1, valid_tgt_size)
293
- valid_result.append(scores)
294
- valid_result = torch.cat(valid_result, dim=-1)
295
- predicts = valid_result.argmax(1).tolist()
296
- hyps = [task.index2ans[predict_index] for predict_index in predicts]
297
- scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)]
298
- results = [{"uniq_id": id, "answer": hyp} for id, hyp in zip(sample["id"].tolist(), hyps)]
299
- return results, scores
300
-
301
-
302
- def eval_step(task, generator, models, sample, **kwargs):
303
  if task.cfg._name == 'caption':
304
- return eval_caption(task, generator, models, sample, **kwargs)
305
- elif task.cfg._name == 'vqa_gen':
306
- return eval_vqa_gen(task, generator, models, sample, **kwargs)
307
- elif task.cfg._name == 'refcoco':
308
- return eval_refcoco(task, generator, models, sample, **kwargs)
309
- elif task.cfg._name == 'snli_ve':
310
- return eval_snli_ve(task, generator, models, sample, **kwargs)
311
- elif task.cfg._name == 'image_gen':
312
- return eval_image_gen(task, generator, models, sample, **kwargs)
313
- elif task.cfg._name in {'cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2'}:
314
- return eval_glue(task, generator, models, sample, **kwargs)
315
- elif task.cfg._name == 'gigaword':
316
- return eval_gigaword(task, generator, models, sample, **kwargs)
317
- elif task.cfg._name == 'image_classify':
318
- return eval_image_classify(task, generator, models, sample, **kwargs)
319
  else:
320
  raise NotImplementedError
321
-
322
-
323
- def merge_results(task, cfg, logger, score_cnt, score_sum, results):
324
- if task.cfg._name == 'image_gen':
325
- if cfg.distributed_training.distributed_world_size > 1:
326
- dist.all_reduce(score_sum.data)
327
- dist.all_reduce(score_cnt.data)
328
- if score_cnt.item() > 0:
329
- logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
330
- score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
331
- ))
332
- else:
333
- gather_results = None
334
- if cfg.distributed_training.distributed_world_size > 1:
335
- gather_results = [None for _ in range(dist.get_world_size())]
336
- dist.all_gather_object(gather_results, results)
337
- dist.all_reduce(score_sum.data)
338
- dist.all_reduce(score_cnt.data)
339
- if score_cnt.item() > 0:
340
- logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
341
- score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
342
- ))
343
-
344
- if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0:
345
- os.makedirs(cfg.common_eval.results_path, exist_ok=True)
346
- output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset))
347
- gather_results = list(chain(*gather_results)) if gather_results is not None else results
348
- with open(output_path, 'w') as fw:
349
- json.dump(gather_results, fw)
 
 
 
 
 
1
  import string
2
  import math
 
 
 
3
 
4
  import torch
 
5
 
6
  from data import data_utils
 
7
 
8
 
9
  def get_symbols_to_strip_from_output(generator):
22
  return x
23
 
24
 
25
+ def eval_caption(task, generator, models, sample):
26
  transtab = str.maketrans({key: None for key in string.punctuation})
27
  hypos = task.inference_step(generator, models, sample)
28
  results = []
32
  return results, None
33
 
34
 
35
+ def eval_step(task, generator, models, sample):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  if task.cfg._name == 'caption':
37
+ return eval_caption(task, generator, models, sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  else:
39
  raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/transforms.py CHANGED
@@ -1,8 +1,3 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
  import random
7
 
8
  import torch
 
 
 
 
 
1
  import random
2
 
3
  import torch
utils/trie.py CHANGED
@@ -1,8 +1,3 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
  from collections import defaultdict
7
 
8
 
@@ -27,4 +22,4 @@ class Trie:
27
  cur = cur.child.get(c)
28
  if cur is None:
29
  return [self.eos]
30
- return list(cur.child.keys())
 
 
 
 
 
1
  from collections import defaultdict
2
 
3
 
22
  cur = cur.child.get(c)
23
  if cur is None:
24
  return [self.eos]
25
+ return list(cur.child.keys())