HarryLee commited on
Commit
c82bc0b
1 Parent(s): 19c6711

remove useless file

Browse files
tasks/mm_tasks/image_gen.py DELETED
@@ -1,329 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- from dataclasses import dataclass, field
7
- import json
8
- import logging
9
- import os
10
- import math
11
- import base64
12
- from typing import Optional
13
- from argparse import Namespace
14
- from omegaconf import DictConfig, OmegaConf
15
- from torchvision import transforms
16
- from PIL import Image
17
- from io import BytesIO
18
-
19
- import torch
20
- import numpy as np
21
- from fairseq import metrics
22
- from fairseq.tasks import register_task
23
- from fairseq.dataclass import ChoiceEnum
24
-
25
- from models import search, clip
26
- from models.taming.models.vqgan import GumbelVQ
27
- from data.mm_data.image_gen_dataset import ImageGenDataset
28
- from data.file_dataset import FileDataset
29
-
30
- from tasks.ofa_task import OFATask, OFAConfig
31
-
32
- logger = logging.getLogger(__name__)
33
-
34
-
35
- def custom_to_pil(x):
36
- x = x.detach().cpu()
37
- x = torch.clamp(x, -1., 1.)
38
- x = (x + 1.) / 2.
39
- x = x.permute(1, 2, 0).numpy()
40
- x = (255 * x).astype(np.uint8)
41
- x = Image.fromarray(x)
42
- if not x.mode == "RGB":
43
- x = x.convert("RGB")
44
- return x
45
-
46
-
47
- EVAL_CLIP_METHOD = ChoiceEnum(["ii_sim", "ti_sim"])
48
-
49
- @dataclass
50
- class ImageGenConfig(OFAConfig):
51
- sampling_times: int = field(
52
- default=1, metadata={"help": "sample times"}
53
- )
54
-
55
- code_image_size: int = field(
56
- default=256, metadata={"help": "code image size"}
57
- )
58
-
59
- # options for reporting CLIP score during validation
60
- eval_clip_method: EVAL_CLIP_METHOD = field(
61
- default='ti_sim',
62
- metadata={
63
- "help": "evaluation with CLIP scores. ii_sim means Similarity between generated Images and ref Images, ti_sim means Similarity between generated Images and input Text"}
64
- )
65
-
66
- eval_args: Optional[str] = field(
67
- default='{}',
68
- metadata={
69
- "help": 'generation args for clip scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string'
70
- },
71
- )
72
-
73
- scst: bool = field(
74
- default=False, metadata={"help": "Self-critical sequence training"}
75
- )
76
- scst_args: str = field(
77
- default='{}',
78
- metadata={
79
- "help": 'generation args for Self-critical sequence training, as JSON string'
80
- },
81
- )
82
-
83
- vqgan_model_path: Optional[str] = field(
84
- default=None,
85
- metadata={"help": "path of vqgan model"}
86
- )
87
- vqgan_config_path: Optional[str] = field(
88
- default=None,
89
- metadata={"help": "path of vqgan config"}
90
- )
91
- clip_model_path: Optional[str] = field(
92
- default=None,
93
- metadata={"help": "clip model path"}
94
- )
95
- gen_images_path: str = field(
96
- default='', metadata={"help": "where to store generated images during evalution. Don't dump images if None. "}
97
- )
98
-
99
-
100
- @register_task("image_gen", dataclass=ImageGenConfig)
101
- class ImageGenTask(OFATask):
102
- def __init__(self, cfg: ImageGenConfig, src_dict, tgt_dict):
103
- super().__init__(cfg, src_dict, tgt_dict)
104
-
105
- def load_dataset(self, split, epoch=1, combine=False, **kwargs):
106
- paths = self.cfg.data.split(',')
107
- assert len(paths) > 0
108
-
109
- if split == 'train':
110
- file_path = paths[(epoch - 1) % (len(paths) - 1)]
111
- else:
112
- file_path = paths[-1]
113
- dataset = FileDataset(file_path, self.cfg.selected_cols)
114
-
115
- self.datasets[split] = ImageGenDataset(
116
- split,
117
- dataset,
118
- self.bpe,
119
- self.src_dict,
120
- self.tgt_dict,
121
- max_src_length=self.cfg.max_src_length,
122
- code_dict_size=self.cfg.code_dict_size,
123
- code_image_size=self.cfg.code_image_size
124
- )
125
-
126
- def build_model(self, cfg):
127
- model = super().build_model(cfg)
128
-
129
- device = torch.cuda.current_device()
130
- clip_model, clip_preprocess = clip.load(self.cfg.clip_model_path, device=device)
131
- self.clip_model = clip_model
132
- self.clip_preprocess = clip_preprocess
133
- self.clip_model.to(device)
134
- self.clip_model.eval()
135
-
136
- vqgan_config = OmegaConf.load(self.cfg.vqgan_config_path)
137
- vqgan = GumbelVQ(**vqgan_config.model.params)
138
- sd = torch.load(self.cfg.vqgan_model_path, map_location="cpu")["state_dict"]
139
- missing, unexpected = vqgan.load_state_dict(sd, strict=False)
140
- for k, v in vqgan.named_parameters():
141
- v.requires_grad = False
142
- self.image_tokenizer = vqgan
143
- self.image_tokenizer.to(device)
144
- self.image_tokenizer.eval()
145
-
146
- gen_args = json.loads(self.cfg.eval_args)
147
- self.sequence_generator = self.build_generator(
148
- [model], Namespace(**gen_args)
149
- )
150
- if self.cfg.scst:
151
- scst_args = json.loads(self.cfg.scst_args)
152
- self.scst_generator = self.build_generator(
153
- [model], Namespace(**scst_args)
154
- )
155
-
156
- return model
157
-
158
- def build_generator(
159
- self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
160
- ):
161
- """
162
- Build a :class:`~fairseq.SequenceGenerator` instance for this
163
- task.
164
-
165
- Args:
166
- models (List[~fairseq.models.FairseqModel]): ensemble of models
167
- args (fairseq.dataclass.configs.GenerationConfig):
168
- configuration object (dataclass) for generation
169
- extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
170
- through to SequenceGenerator
171
- prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
172
- If provided, this function constrains the beam search to
173
- allowed tokens only at each step. The provided function
174
- should take 2 arguments: the batch ID (`batch_id: int`)
175
- and a unidimensional tensor of token ids (`inputs_ids:
176
- torch.Tensor`). It has to return a `List[int]` with the
177
- allowed tokens for the next generation step conditioned
178
- on the previously generated tokens (`inputs_ids`) and
179
- the batch ID (`batch_id`). This argument is useful for
180
- constrained generation conditioned on the prefix, as
181
- described in "Autoregressive Entity Retrieval"
182
- (https://arxiv.org/abs/2010.00904) and
183
- https://github.com/facebookresearch/GENRE.
184
- """
185
- from models.sequence_generator import SequenceGenerator
186
-
187
- # Choose search strategy. Defaults to Sampling.
188
- self.sampling_times = self.cfg.sampling_times
189
- sampling = True # we have to use sampling instead of beam search in image generation task
190
- sampling_topk = getattr(args, "sampling_topk", -1)
191
- sampling_topp = getattr(args, "sampling_topp", -1.0)
192
-
193
- assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
194
- assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
195
-
196
- search_strategy = search.Sampling(
197
- self.target_dictionary, sampling_topk, sampling_topp
198
- )
199
- extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
200
-
201
- return SequenceGenerator(
202
- models,
203
- self.target_dictionary,
204
- beam_size=getattr(args, "beam", 5),
205
- max_len_a=getattr(args, "max_len_a", 0),
206
- max_len_b=getattr(args, "max_len_b", 200),
207
- min_len=getattr(args, "min_len", 1),
208
- normalize_scores=(not getattr(args, "unnormalized", False)),
209
- len_penalty=getattr(args, "lenpen", 1),
210
- unk_penalty=getattr(args, "unkpen", 0),
211
- temperature=getattr(args, "temperature", 1.0),
212
- match_source_len=getattr(args, "match_source_len", False),
213
- no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
214
- search_strategy=search_strategy,
215
- constraint_range=self.cfg.constraint_range,
216
- gen_code=True,
217
- **extra_gen_cls_kwargs,
218
- )
219
-
220
- def compute_ref_image_similarity(self, hyps, ref, device):
221
- hyp_images = torch.stack(
222
- [self.clip_preprocess(hyp_image) for hyp_image in hyps], dim=0
223
- ).to(device)
224
-
225
- ref_images = self.clip_preprocess(ref).unsqueeze(0).to(device)
226
- with torch.no_grad():
227
- hyp_image_features = self.clip_model.encode_image(hyp_images)
228
- ref_image_features = self.clip_model.encode_image(ref_images)
229
- hyp_image_features /= hyp_image_features.norm(dim=-1, keepdim=True)
230
- ref_image_features /= ref_image_features.norm(dim=-1, keepdim=True)
231
- similarity = hyp_image_features @ ref_image_features.T
232
- # scores.append(similarity.max().item())
233
- sorted_score, indices = torch.sort(similarity.view(-1), descending=True)
234
- return sorted_score, indices
235
-
236
- def compute_text_similarity(self, hyps, text, device):
237
- hyp_images = torch.stack(
238
- [self.clip_preprocess(hyp_image) for hyp_image in hyps], dim=0
239
- ).to(device)
240
-
241
- clip_input = clip.tokenize([text]).to(device)
242
- with torch.no_grad():
243
- hyp_image_features = self.clip_model.encode_image(hyp_images)
244
- hyp_image_features /= hyp_image_features.norm(dim=-1, keepdim=True)
245
- text_features = self.clip_model.encode_text(clip_input)
246
- text_features /= text_features.norm(dim=-1, keepdim=True)
247
- ti_similarity = hyp_image_features @ text_features.T
248
- sorted_score, indices = torch.sort(ti_similarity.view(-1), descending=True)
249
- return sorted_score, indices
250
-
251
- def valid_step(self, sample, model, criterion):
252
- loss, sample_size, logging_output = criterion(model, sample)
253
-
254
- model.eval()
255
- device = sample['target'].device
256
-
257
- hyps, ref = self.inference_image(self.sequence_generator, sample, [model])
258
- scores = []
259
-
260
- tokens = sample['net_input']['src_tokens'][0].view(-1).tolist()
261
- caption = self.bpe.decode(self.tgt_dict.string([token for token in tokens if token >= 4]))[
262
- 38:].replace('/', '')
263
- if self.cfg.eval_clip_method == 'ii_sim':
264
- similarity_score, indices = self.compute_ref_image_similarity(hyps, ref, device)
265
- elif self.cfg.eval_clip_method == 'ti_sim':
266
- similarity_score, indices = self.compute_text_similarity(hyps, caption, device)
267
- else:
268
- raise ValueError("unsupported eval method.")
269
-
270
- scores.append(similarity_score.max().item())
271
- sorted_hyps = [hyps[indice] for indice in indices]
272
-
273
- if self.cfg.gen_images_path:
274
- caption_tokens = sample['net_input']['src_tokens'][0].view(-1).tolist()
275
- caption = self.bpe.decode(self.tgt_dict.string([token for token in caption_tokens if token >= 4]))[
276
- 38:].replace('/', '')
277
- self.dump_images(sorted_hyps, text=caption, path=os.path.join(self.cfg.gen_images_path, 'all_results'))
278
- self.dump_images(sorted_hyps, text=caption, path=os.path.join(self.cfg.gen_images_path, 'top1'), topk=1)
279
-
280
- logging_output["_score_sum"] = sum(scores)
281
- logging_output["_score_cnt"] = len(scores)
282
-
283
- return loss, sample_size, logging_output
284
-
285
- def reduce_metrics(self, logging_outputs, criterion):
286
- super().reduce_metrics(logging_outputs, criterion)
287
-
288
- def sum_logs(key):
289
- import torch
290
- result = sum(log.get(key, 0) for log in logging_outputs)
291
- if torch.is_tensor(result):
292
- result = result.cpu()
293
- return result
294
-
295
- def compute_score(meters):
296
- score = meters["_score_sum"].sum / meters["_score_cnt"].sum
297
- score = score if isinstance(score, float) else score.item()
298
- return round(score, 3)
299
-
300
- if sum_logs("_score_cnt") > 0:
301
- metrics.log_scalar("_score_sum", sum_logs("_score_sum"))
302
- metrics.log_scalar("_score_cnt", sum_logs("_score_cnt"))
303
- metrics.log_derived("score", compute_score)
304
-
305
- def inference_image(self, generator, sample, models):
306
- hyps, ref = [], None
307
- for j in range(self.sampling_times):
308
- gen_out = self.inference_step(generator, models, sample)
309
- for i in range(len(gen_out)):
310
- with torch.no_grad():
311
- tokens = torch.stack([item['tokens'][:-1] for item in gen_out[i]], dim=0)
312
- tokens += -len(self.src_dict) + self.cfg.code_dict_size + self.cfg.num_bins
313
- images = self.image_tokenizer.decode_code(
314
- tokens.view(-1, self.cfg.code_image_size // 8, self.cfg.code_image_size // 8)
315
- )
316
- images = [custom_to_pil(image) for image in images]
317
- hyps += images
318
- if 'code_images' in sample:
319
- ref = Image.open(BytesIO(base64.urlsafe_b64decode(sample['code_images'][0]))).convert('RGB')
320
-
321
- return hyps, ref
322
-
323
- def dump_images(self, images, text, path, topk=None):
324
- os.makedirs(path, exist_ok=True)
325
- if topk:
326
- images = images[:topk]
327
- for j, image in enumerate(images):
328
- save_path = os.path.join(path, f'{text}_{j}.png')
329
- image.save(save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tasks/mm_tasks/refcoco.py DELETED
@@ -1,160 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- from dataclasses import dataclass, field
7
- import json
8
- import logging
9
- from typing import Optional
10
- from argparse import Namespace
11
-
12
- import torch
13
- from fairseq import metrics
14
- from fairseq.tasks import register_task
15
-
16
- from tasks.ofa_task import OFATask, OFAConfig
17
- from data.mm_data.refcoco_dataset import RefcocoDataset
18
- from data.file_dataset import FileDataset
19
-
20
- logger = logging.getLogger(__name__)
21
-
22
-
23
- @dataclass
24
- class RefcocoConfig(OFAConfig):
25
- eval_acc: bool = field(
26
- default=False, metadata={"help": "evaluation with accuracy"}
27
- )
28
- eval_args: Optional[str] = field(
29
- default='{}',
30
- metadata={
31
- "help": 'generation args, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string'
32
- },
33
- )
34
- eval_print_samples: bool = field(
35
- default=False, metadata={"help": "print sample generations during validation"}
36
- )
37
-
38
- max_image_size: int = field(
39
- default=512, metadata={"help": "max image size for normalization"}
40
- )
41
- scst: bool = field(
42
- default=False, metadata={"help": "Self-critical sequence training"}
43
- )
44
- scst_args: str = field(
45
- default='{}',
46
- metadata={
47
- "help": 'generation args for Self-critical sequence training, as JSON string'
48
- },
49
- )
50
-
51
-
52
- @register_task("refcoco", dataclass=RefcocoConfig)
53
- class RefcocoTask(OFATask):
54
- def __init__(self, cfg: RefcocoConfig, src_dict, tgt_dict):
55
- super().__init__(cfg, src_dict, tgt_dict)
56
-
57
- def load_dataset(self, split, epoch=1, combine=False, **kwargs):
58
- paths = self.cfg.data.split(',')
59
- assert len(paths) > 0
60
-
61
- if split == 'train':
62
- file_path = paths[(epoch - 1) % (len(paths) - 1)]
63
- else:
64
- file_path = paths[-1]
65
- dataset = FileDataset(file_path, self.cfg.selected_cols)
66
-
67
- self.datasets[split] = RefcocoDataset(
68
- split,
69
- dataset,
70
- self.bpe,
71
- self.src_dict,
72
- self.tgt_dict,
73
- max_src_length=self.cfg.max_src_length,
74
- max_tgt_length=self.cfg.max_tgt_length,
75
- patch_image_size=self.cfg.patch_image_size,
76
- imagenet_default_mean_and_std=self.cfg.imagenet_default_mean_and_std,
77
- num_bins=self.cfg.num_bins,
78
- max_image_size=self.cfg.max_image_size
79
- )
80
-
81
- def build_model(self, cfg):
82
- model = super().build_model(cfg)
83
- if self.cfg.eval_acc:
84
- gen_args = json.loads(self.cfg.eval_args)
85
- self.sequence_generator = self.build_generator(
86
- [model], Namespace(**gen_args)
87
- )
88
- if self.cfg.scst:
89
- scst_args = json.loads(self.cfg.scst_args)
90
- self.scst_generator = self.build_generator(
91
- [model], Namespace(**scst_args)
92
- )
93
-
94
- return model
95
-
96
- def _calculate_ap_score(self, hyps, refs, thresh=0.5):
97
- interacts = torch.cat(
98
- [torch.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]),
99
- torch.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])],
100
- dim=1
101
- )
102
- area_predictions = (hyps[:, 2] - hyps[:, 0]) * (hyps[:, 3] - hyps[:, 1])
103
- area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
104
- interacts_w = interacts[:, 2] - interacts[:, 0]
105
- interacts_h = interacts[:, 3] - interacts[:, 1]
106
- area_interacts = interacts_w * interacts_h
107
- ious = area_interacts / (area_predictions + area_targets - area_interacts + 1e-6)
108
- return ((ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)).float()
109
-
110
- def valid_step(self, sample, model, criterion):
111
- loss, sample_size, logging_output = criterion(model, sample)
112
-
113
- model.eval()
114
- if self.cfg.eval_acc:
115
- hyps, refs = self._inference(self.sequence_generator, sample, model)
116
- hyps = hyps / (self.cfg.num_bins - 1) * self.cfg.max_image_size
117
- refs = refs / (self.cfg.num_bins - 1) * self.cfg.max_image_size
118
- hyps[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1)
119
- hyps[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1)
120
- refs[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1)
121
- refs[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1)
122
-
123
- # scores = self._calculate_ap_score(hyps, refs)
124
- scores = self._calculate_ap_score(hyps, sample['region_coords'].float())
125
- logging_output["_score_sum"] = scores.sum().item()
126
- logging_output["_score_cnt"] = scores.size(0)
127
-
128
- return loss, sample_size, logging_output
129
-
130
- def reduce_metrics(self, logging_outputs, criterion):
131
- super().reduce_metrics(logging_outputs, criterion)
132
-
133
- def sum_logs(key):
134
- import torch
135
- result = sum(log.get(key, 0) for log in logging_outputs)
136
- if torch.is_tensor(result):
137
- result = result.cpu()
138
- return result
139
-
140
- def compute_score(meters):
141
- score = meters["_score_sum"].sum / meters["_score_cnt"].sum
142
- score = score if isinstance(score, float) else score.item()
143
- return round(score, 4)
144
-
145
- if sum_logs("_score_cnt") > 0:
146
- metrics.log_scalar("_score_sum", sum_logs("_score_sum"))
147
- metrics.log_scalar("_score_cnt", sum_logs("_score_cnt"))
148
- metrics.log_derived("score", compute_score)
149
-
150
- def _inference(self, generator, sample, model):
151
- gen_out = self.inference_step(generator, [model], sample)
152
- hyps, refs = [], []
153
- for i in range(len(gen_out)):
154
- hyps.append(gen_out[i][0]["tokens"][:-1] - len(self.src_dict) + self.cfg.num_bins)
155
- refs.append(sample["target"][i][:-1] - len(self.src_dict) + self.cfg.num_bins)
156
- if self.cfg.eval_print_samples:
157
- logger.info("example hypothesis: ", hyps[0])
158
- logger.info("example reference: ", refs[0])
159
-
160
- return torch.stack(hyps, dim=0), torch.stack(refs, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tasks/mm_tasks/snli_ve.py DELETED
@@ -1,197 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import json
7
- import logging
8
- import math
9
- from dataclasses import dataclass, field
10
- from typing import Optional
11
-
12
- import torch
13
- from fairseq import metrics
14
- from fairseq.tasks import register_task
15
-
16
- from tasks.ofa_task import OFAConfig, OFATask
17
- from data.mm_data.snli_ve_dataset import SnliVeDataset
18
- from data.file_dataset import FileDataset
19
- from data import data_utils
20
- from utils.trie import Trie
21
-
22
- logger = logging.getLogger(__name__)
23
-
24
-
25
- @dataclass
26
- class SnliVeConfig(OFAConfig):
27
- ans2label_dict: Optional[str] = field(
28
- default='{"no": 0, "yes":1, "maybe": 2}',
29
- metadata={"help": 'answer to label dict'},
30
- )
31
- add_caption: bool = field(
32
- default=False,
33
- metadata={"help": "add caption to encoder"},
34
- )
35
- valid_batch_size: int = field(
36
- default=20,
37
- metadata={"help": "valid batch size per step"},
38
- )
39
- prompt_type: Optional[str] = field(
40
- default=None,
41
- metadata={"help": "prompt_type"},
42
- )
43
-
44
-
45
- @register_task("snli_ve", dataclass=SnliVeConfig)
46
- class SnliVeTask(OFATask):
47
- def __init__(self, cfg: SnliVeConfig, src_dict, tgt_dict):
48
- super().__init__(cfg, src_dict, tgt_dict)
49
- self.ans2label_dict = json.loads(self.cfg.ans2label_dict)
50
-
51
- def load_dataset(self, split, epoch=1, combine=False, **kwargs):
52
- paths = self.cfg.data.split(',')
53
- assert len(paths) > 0
54
-
55
- if split == 'train':
56
- file_path = paths[(epoch - 1) % (len(paths) - 1)]
57
- else:
58
- file_path = paths[-1]
59
- dataset = FileDataset(file_path, self.cfg.selected_cols)
60
-
61
- self.datasets[split] = SnliVeDataset(
62
- split,
63
- dataset,
64
- self.bpe,
65
- self.src_dict,
66
- self.tgt_dict,
67
- max_src_length=self.cfg.max_src_length,
68
- max_tgt_length=self.cfg.max_tgt_length,
69
- patch_image_size=self.cfg.patch_image_size,
70
- add_caption=self.cfg.add_caption,
71
- constraint_trie=self.constraint_trie,
72
- imagenet_default_mean_and_std=self.cfg.imagenet_default_mean_and_std,
73
- prompt_type=self.cfg.prompt_type
74
- )
75
-
76
- def build_model(self, cfg):
77
- model = super().build_model(cfg)
78
- answer_item_list = []
79
- self.index2ans = {}
80
- self.constraint_trie = Trie(self.tgt_dict.eos())
81
- for i, answer in enumerate(self.ans2label_dict.keys()):
82
- answer_item = self.tgt_dict.encode_line(
83
- line=self.bpe.encode(' ' + answer),
84
- add_if_not_exist=False,
85
- append_eos=False
86
- ).long()
87
- answer_item_list.append(answer_item)
88
- self.index2ans[i] = answer
89
- self.constraint_trie.insert([self.tgt_dict.bos()] + answer_item.tolist() + [self.tgt_dict.eos()])
90
-
91
- constraint_mask_list = []
92
- for answer_item in answer_item_list:
93
- constraint_mask = torch.zeros((len(answer_item)+1, len(self.tgt_dict))).bool()
94
- for i in range(len(answer_item)+1):
95
- constraint_prefix_token = [self.src_dict.bos()] + answer_item[:i].tolist()
96
- constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
97
- constraint_mask[i][constraint_nodes] = True
98
- constraint_mask_list.append(constraint_mask)
99
-
100
- self.valid_answers_list = []
101
- self.valid_constraint_masks_list = []
102
- for i in range(0, len(answer_item_list), self.cfg.valid_batch_size):
103
- self.valid_answers_list += [answer_item_list[i:i+self.cfg.valid_batch_size]]
104
- self.valid_constraint_masks_list += [constraint_mask_list[i:i+self.cfg.valid_batch_size]]
105
-
106
- return model
107
-
108
- def build_generator(
109
- self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
110
- ):
111
- seq_generator = super().build_generator(models, args, seq_gen_cls, extra_gen_cls_kwargs, prefix_allowed_tokens_fn)
112
- seq_generator.constraint_trie = self.constraint_trie
113
-
114
- return seq_generator
115
-
116
- def valid_step(self, sample, model, criterion, **extra_kwargs):
117
- loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
118
-
119
- model.eval()
120
- with torch.no_grad():
121
- encoder_out = model.encoder(
122
- sample["net_input"]["src_tokens"],
123
- src_lengths=sample["net_input"]["src_lengths"],
124
- patch_images=sample["net_input"]["patch_images"],
125
- patch_masks=sample["net_input"]["patch_masks"]
126
- )
127
- device = sample["net_input"]["src_tokens"].device
128
- eos_item = torch.tensor([self.src_dict.eos()])
129
- pad = self.src_dict.pad()
130
- valid_result = []
131
- for valid_answers, valid_constraint_masks in zip(self.valid_answers_list, self.valid_constraint_masks_list):
132
- valid_size = len(valid_answers)
133
- valid_tgt_items = [
134
- torch.cat([torch.tensor(decoder_prompt[1:]), valid_answer, eos_item])
135
- for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
136
- ]
137
- valid_prev_items = [
138
- torch.cat([torch.tensor(decoder_prompt), valid_answer])
139
- for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
140
- ]
141
- valid_constraint_mask_items = [
142
- torch.cat([torch.zeros(len(decoder_prompt)-1, valid_constraint_mask.size(1)).bool(), valid_constraint_mask], dim=0)
143
- for decoder_prompt in sample["decoder_prompts"] for valid_constraint_mask in valid_constraint_masks
144
- ]
145
- valid_tgt = data_utils.collate_tokens(valid_tgt_items, pad_idx=pad, left_pad=False).to(device)
146
- valid_prev_output = data_utils.collate_tokens(valid_prev_items, pad_idx=pad, left_pad=False).to(device)
147
- valid_constraint_masks = data_utils.collate_tokens(valid_constraint_mask_items, pad_idx=pad, left_pad=False).to(device)
148
-
149
- new_encoder_out = {}
150
- new_encoder_out["encoder_out"] = [
151
- encoder_out["encoder_out"][0].repeat_interleave(valid_size, dim=1)
152
- ]
153
- new_encoder_out["encoder_padding_mask"] = [
154
- encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_size, dim=0)
155
- ]
156
- new_encoder_out["position_embeddings"] = [
157
- encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0)
158
- ]
159
-
160
- decoder_out = model.decoder(valid_prev_output, encoder_out=new_encoder_out)
161
- decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
162
- lprobs = model.get_normalized_probs(decoder_out, log_probs=True)
163
- scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
164
- scores = scores.masked_fill(valid_tgt.eq(self.tgt_dict.pad()), 0)
165
- scores = scores.masked_fill((~valid_constraint_masks).all(2), 0)
166
- scores = scores.sum(1)
167
- scores = scores.view(-1, valid_size)
168
- valid_result.append(scores)
169
-
170
- valid_result = torch.cat(valid_result, dim=-1)
171
- predicts = valid_result.argmax(1).tolist()
172
- hyps = [self.index2ans[predict_index] for predict_index in predicts]
173
- scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)]
174
- logging_output["_snli_score_sum"] = sum(scores)
175
- logging_output["_snli_cnt"] = len(scores)
176
-
177
- return loss, sample_size, logging_output
178
-
179
- def reduce_metrics(self, logging_outputs, criterion):
180
- super().reduce_metrics(logging_outputs, criterion)
181
-
182
- def sum_logs(key):
183
- import torch
184
- result = sum(log.get(key, 0) for log in logging_outputs)
185
- if torch.is_tensor(result):
186
- result = result.cpu()
187
- return result
188
-
189
- def compute_score(meters):
190
- score = meters["_snli_score_sum"].sum / meters["_snli_cnt"].sum
191
- score = score if isinstance(score, float) else score.item()
192
- return round(score, 4)
193
-
194
- if sum_logs("_snli_cnt") > 0:
195
- metrics.log_scalar("_snli_score_sum", sum_logs("_snli_score_sum"))
196
- metrics.log_scalar("_snli_cnt", sum_logs("_snli_cnt"))
197
- metrics.log_derived("snli_score", compute_score)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tasks/mm_tasks/vqa_gen.py DELETED
@@ -1,278 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- from dataclasses import dataclass, field
7
- import json
8
- import logging
9
- import os
10
- import math
11
- import pickle
12
- from typing import Optional
13
- from argparse import Namespace
14
- from data.file_dataset import FileDataset
15
-
16
- import torch
17
- from fairseq import metrics
18
- from fairseq.tasks import register_task
19
-
20
- from models import search
21
- from data.mm_data.vqa_gen_dataset import VqaGenDataset
22
- from data import data_utils
23
- from tasks.ofa_task import OFAConfig, OFATask
24
- from utils.trie import Trie
25
-
26
- logger = logging.getLogger(__name__)
27
-
28
-
29
- def get_symbols_to_strip_from_output(generator):
30
- if hasattr(generator, "symbols_to_strip_from_output"):
31
- return generator.symbols_to_strip_from_output
32
- else:
33
- return {generator.bos, generator.eos}
34
-
35
-
36
- def decode_fn(x, tgt_dict, bpe, generator, tokenizer=None):
37
- x = tgt_dict.string(x.int().cpu(), extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator))
38
- if bpe is not None:
39
- x = bpe.decode(x)
40
- if tokenizer is not None:
41
- x = tokenizer.decode(x)
42
- return x
43
-
44
-
45
- @dataclass
46
- class VqaGenConfig(OFAConfig):
47
- max_object_length: int = field(
48
- default=30, metadata={"help": "the maximum object sequence length"}
49
- )
50
- ans2label_dict: Optional[str] = field(
51
- default='{"no": 0, "yes":1}',
52
- metadata={"help": 'answer to label dict'},
53
- )
54
- ans2label_file: Optional[str] = field(
55
- default=None,
56
- metadata={"help": "path to load ans2label file"},
57
- )
58
-
59
- add_object: bool = field(
60
- default=False,
61
- metadata={"help": "add object to encoder"},
62
- )
63
- valid_batch_size: int = field(
64
- default=20,
65
- metadata={"help": "valid batch size per step"},
66
- )
67
- prompt_type: Optional[str] = field(
68
- default=None,
69
- metadata={"help": "prompt_type"},
70
- )
71
- uses_ema: Optional[bool] = field(
72
- default=False,
73
- metadata={"help": "whether to use ema"},
74
- )
75
- val_inference_type: Optional[str] = field(
76
- default='allcand',
77
- metadata={"help": "inference type in validation (allcand or beamsearch), default to allcand"},
78
- )
79
- eval_args: Optional[str] = field(
80
- default='{"beam":5,"unnormalized":true,"temperature":1.0}',
81
- metadata={
82
- "help": 'generation args as JSON string for inference, only activated when --val-inference-type=beamsearch'
83
- },
84
- )
85
-
86
-
87
- @register_task("vqa_gen", dataclass=VqaGenConfig)
88
- class VqaGenTask(OFATask):
89
- def __init__(self, cfg: VqaGenConfig, src_dict, tgt_dict):
90
- super().__init__(cfg, src_dict, tgt_dict)
91
-
92
- self.ans2label_dict = None
93
- if self.cfg.ans2label_file is not None:
94
- self.ans2label_dict = pickle.load(open(self.cfg.ans2label_file, "rb"))
95
- else:
96
- self.ans2label_dict = json.loads(self.cfg.ans2label_dict)
97
-
98
- self.uses_ema = self.cfg.uses_ema
99
-
100
- assert self.cfg.val_inference_type in ["allcand", "beamsearch"], \
101
- "Unknown inference type encountered: {}, should be allcand or beamsearch.".format(self.cfg.val_inference_type)
102
-
103
- def load_dataset(self, split, epoch=1, combine=False, **kwargs):
104
- paths = self.cfg.data.split(',')
105
- assert len(paths) > 0
106
-
107
- if split == 'train':
108
- table_path = paths[(epoch - 1) % (len(paths) - 1)]
109
- else:
110
- table_path = paths[-1]
111
- dataset = FileDataset(table_path, self.cfg.selected_cols)
112
-
113
- self.datasets[split] = VqaGenDataset(
114
- split,
115
- dataset,
116
- self.bpe,
117
- self.src_dict,
118
- self.tgt_dict,
119
- max_src_length=self.cfg.max_src_length,
120
- max_object_length=self.cfg.max_object_length,
121
- max_tgt_length=self.cfg.max_tgt_length,
122
- patch_image_size=self.cfg.patch_image_size,
123
- add_object=self.cfg.add_object,
124
- constraint_trie=self.constraint_trie,
125
- imagenet_default_mean_and_std=self.cfg.imagenet_default_mean_and_std,
126
- prompt_type=self.cfg.prompt_type
127
- )
128
-
129
- def build_model(self, cfg):
130
- model = super().build_model(cfg)
131
- answer_item_list = []
132
- self.index2ans = {}
133
- self.constraint_trie = Trie(self.tgt_dict.eos())
134
- for i, answer in enumerate(self.ans2label_dict.keys()):
135
- answer_item = self.tgt_dict.encode_line(
136
- line=self.bpe.encode(' ' + answer),
137
- add_if_not_exist=False,
138
- append_eos=False
139
- ).long()
140
- answer_item_list.append(answer_item)
141
- self.index2ans[i] = answer
142
- self.constraint_trie.insert([self.tgt_dict.bos()] + answer_item.tolist() + [self.tgt_dict.eos()])
143
-
144
- constraint_mask_list = []
145
- for answer_item in answer_item_list:
146
- constraint_mask = torch.zeros((len(answer_item)+1, len(self.tgt_dict))).bool()
147
- for i in range(len(answer_item)+1):
148
- constraint_prefix_token = [self.src_dict.bos()] + answer_item[:i].tolist()
149
- constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
150
- constraint_mask[i][constraint_nodes] = True
151
- constraint_mask_list.append(constraint_mask)
152
-
153
- if self.cfg.val_inference_type == "allcand":
154
- self.valid_answers_list = []
155
- self.valid_constraint_masks_list = []
156
- for i in range(0, len(answer_item_list), self.cfg.valid_batch_size):
157
- self.valid_answers_list += [answer_item_list[i:i+self.cfg.valid_batch_size]]
158
- self.valid_constraint_masks_list += [constraint_mask_list[i:i+self.cfg.valid_batch_size]]
159
- elif self.cfg.val_inference_type == "beamsearch":
160
- gen_args = json.loads(self.cfg.eval_args)
161
- self.generator = self.build_generator(
162
- [model], Namespace(**gen_args)
163
- )
164
- else:
165
- raise NotImplementedError("Error: Unknown inference type encountered.")
166
-
167
- return model
168
-
169
- def build_generator(
170
- self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
171
- ):
172
- seq_generator = super().build_generator(models, args, seq_gen_cls, extra_gen_cls_kwargs, prefix_allowed_tokens_fn)
173
- seq_generator.constraint_trie = self.constraint_trie
174
-
175
- return seq_generator
176
-
177
- def valid_step(self, sample, model, criterion, **extra_kwargs):
178
- loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
179
-
180
- if self.uses_ema:
181
- assert 'ema_model' in extra_kwargs and extra_kwargs['ema_model'] is not None
182
- if self.uses_ema:
183
- eval_model = extra_kwargs['ema_model']
184
- else:
185
- eval_model = model
186
-
187
- eval_model.eval()
188
- with torch.no_grad():
189
- if self.cfg.val_inference_type == "allcand":
190
- encoder_out = eval_model.encoder(
191
- sample["net_input"]["src_tokens"],
192
- src_lengths=sample["net_input"]["src_lengths"],
193
- patch_images=sample["net_input"]["patch_images"],
194
- patch_masks=sample["net_input"]["patch_masks"]
195
- )
196
- device = sample["net_input"]["src_tokens"].device
197
- eos_item = torch.tensor([self.src_dict.eos()])
198
- pad = self.src_dict.pad()
199
- valid_result = []
200
- for valid_answers, valid_constraint_masks in zip(self.valid_answers_list, self.valid_constraint_masks_list):
201
- valid_size = len(valid_answers)
202
- valid_tgt_items = [
203
- torch.cat([torch.tensor(decoder_prompt[1:]), valid_answer, eos_item])
204
- for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
205
- ]
206
- valid_prev_items = [
207
- torch.cat([torch.tensor(decoder_prompt), valid_answer])
208
- for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
209
- ]
210
- valid_constraint_mask_items = [
211
- torch.cat([torch.zeros(len(decoder_prompt)-1, valid_constraint_mask.size(1)).bool(), valid_constraint_mask], dim=0)
212
- for decoder_prompt in sample["decoder_prompts"] for valid_constraint_mask in valid_constraint_masks
213
- ]
214
- valid_tgt = data_utils.collate_tokens(valid_tgt_items, pad_idx=pad, left_pad=False).to(device)
215
- valid_prev_output = data_utils.collate_tokens(valid_prev_items, pad_idx=pad, left_pad=False).to(device)
216
- valid_constraint_masks = data_utils.collate_tokens(valid_constraint_mask_items, pad_idx=pad, left_pad=False).to(device)
217
-
218
- new_encoder_out = {}
219
- new_encoder_out["encoder_out"] = [
220
- encoder_out["encoder_out"][0].repeat_interleave(valid_size, dim=1)
221
- ]
222
- new_encoder_out["encoder_padding_mask"] = [
223
- encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_size, dim=0)
224
- ]
225
- new_encoder_out["position_embeddings"] = [
226
- encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0)
227
- ]
228
-
229
- decoder_out = eval_model.decoder(valid_prev_output, encoder_out=new_encoder_out)
230
- decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
231
- lprobs = eval_model.get_normalized_probs(decoder_out, log_probs=True)
232
- scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
233
- scores = scores.masked_fill(valid_tgt.eq(self.tgt_dict.pad()), 0)
234
- scores = scores.masked_fill((~valid_constraint_masks).all(2), 0)
235
- scores = scores.sum(1)
236
- scores = scores.view(-1, valid_size)
237
- valid_result.append(scores)
238
-
239
- valid_result = torch.cat(valid_result, dim=-1)
240
- predicts = valid_result.argmax(1).tolist()
241
- hyps = [self.index2ans[predict_index] for predict_index in predicts]
242
-
243
- elif self.cfg.val_inference_type == "beamsearch":
244
- raw_hyps = self.inference_step(self.generator, [eval_model], sample, prefix_tokens=sample['prefix_tokens'])
245
- hyps = []
246
- for i, sample_id in enumerate(sample["id"].tolist()):
247
- prefix_len = sample['prefix_tokens'][i].ne(1).sum().item()
248
- detok_hypo_str = decode_fn(raw_hyps[i][0]["tokens"][prefix_len:], self.tgt_dict, self.bpe, self.generator)
249
- hyps.append(detok_hypo_str.strip())
250
-
251
- else:
252
- raise NotImplementedError("Error: Unknown inference type encountered.")
253
-
254
- scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)]
255
- logging_output["_vqa_score_sum"] = sum(scores)
256
- logging_output["_vqa_cnt"] = len(scores)
257
-
258
- return loss, sample_size, logging_output
259
-
260
- def reduce_metrics(self, logging_outputs, criterion):
261
- super().reduce_metrics(logging_outputs, criterion)
262
-
263
- def sum_logs(key):
264
- import torch
265
- result = sum(log.get(key, 0) for log in logging_outputs)
266
- if torch.is_tensor(result):
267
- result = result.cpu()
268
- return result
269
-
270
- def compute_score(meters):
271
- score = meters["_vqa_score_sum"].sum / meters["_vqa_cnt"].sum
272
- score = score if isinstance(score, float) else score.item()
273
- return round(score, 4)
274
-
275
- if sum_logs("_vqa_cnt") > 0:
276
- metrics.log_scalar("_vqa_score_sum", sum_logs("_vqa_score_sum"))
277
- metrics.log_scalar("_vqa_cnt", sum_logs("_vqa_cnt"))
278
- metrics.log_derived("vqa_score", compute_score)