# Modified from OFA code. # Copyright 2022 The OFA-Sys Team. # All rights reserved. # This source code is licensed under the Apache 2.0 license # found in the LICENSE file in the root directory. import string import math import json from itertools import chain import os import torch import torch.distributed as dist from data import data_utils from functools import partial def get_symbols_to_strip_from_output(generator): if hasattr(generator, "symbols_to_strip_from_output"): return generator.symbols_to_strip_from_output else: return {generator.bos, generator.eos} def decode_fn(x, tgt_dict, bpe, generator, tokenizer=None): x = tgt_dict.string(x.int().cpu(), extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator)) if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x def eval_caption(task, generator, models, sample, **kwargs): transtab = str.maketrans({key: None for key in string.punctuation}) hypos = task.inference_step(generator, models, sample) results = [] for i, sample_id in enumerate(sample["id"].tolist()): detok_hypo_str = decode_fn(hypos[i][0]["tokens"], task.tgt_dict, task.bpe, generator) results.append({"image_id": str(sample_id), "caption": detok_hypo_str.translate(transtab).strip()}) return results, None def eval_vqa_gen(task, generator, models, sample, **kwargs): if kwargs['beam_search_vqa_eval']: hypos = task.inference_step(generator, models, sample, prefix_tokens=sample['prefix_tokens']) results = [] for i, sample_id in enumerate(sample["id"].tolist()): prefix_len = sample['prefix_tokens'][i].ne(1).sum().item() detok_hypo_str = decode_fn(hypos[i][0]["tokens"][prefix_len:], task.tgt_dict, task.bpe, generator) results.append({"question_id": int(sample_id), "answer": detok_hypo_str.strip()}) scores = [ref_dict.get(result['answer'], 0) for ref_dict, result in zip(sample['ref_dict'], results)] return results, scores encoder_out = models[0].encoder( sample["net_input"]["src_tokens"], src_lengths=sample["net_input"]["src_lengths"], patch_images=sample["net_input"]["patch_images"], patch_masks=sample["net_input"]["patch_masks"] ) device = sample["net_input"]["src_tokens"].device eos_item = torch.tensor([task.src_dict.eos()]) pad = task.src_dict.pad() valid_result = [] for valid_answers, valid_constraint_masks in zip(task.valid_answers_list, task.valid_constraint_masks_list): valid_size = len(valid_answers) valid_tgt_items = [ torch.cat([torch.tensor(decoder_prompt[1:]), valid_answer, eos_item]) for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers ] valid_prev_items = [ torch.cat([torch.tensor(decoder_prompt), valid_answer]) for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers ] valid_constraint_mask_items = [ torch.cat( [torch.zeros(len(decoder_prompt) - 1, valid_constraint_mask.size(1)).bool(), valid_constraint_mask], dim=0 ) for decoder_prompt in sample["decoder_prompts"] for valid_constraint_mask in valid_constraint_masks ] valid_tgt = data_utils.collate_tokens(valid_tgt_items, pad_idx=pad).to(device) valid_prev_output = data_utils.collate_tokens(valid_prev_items, pad_idx=pad).to(device) valid_constraint_masks = data_utils.collate_tokens(valid_constraint_mask_items, pad_idx=pad).to(device) new_encoder_out = {} new_encoder_out["encoder_out"] = [ encoder_out["encoder_out"][0].repeat_interleave(valid_size, dim=1) ] new_encoder_out["encoder_padding_mask"] = [ encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_size, dim=0) ] new_encoder_out["position_embeddings"] = [ encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0) ] decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out) decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf) lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True) scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1) scores = scores.masked_fill(valid_tgt.eq(task.tgt_dict.pad()), 0) scores = scores.masked_fill((~valid_constraint_masks).all(2), 0) scores = scores.sum(1) scores = scores.view(-1, valid_size) valid_result.append(scores) valid_result = torch.cat(valid_result, dim=-1) predicts = valid_result.argmax(1).tolist() hyps = [task.index2ans[predict_index] for predict_index in predicts] results = [{"question_id": int(id), "answer": hyp} for id, hyp in zip(sample["id"].tolist(), hyps)] scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)] return results, scores def eval_refcoco(task, generator, models, sample, **kwargs): def _calculate_ap_score(hyps, refs, thresh=0.5, min_area_size=None, max_area_size=None): interacts = torch.cat( [torch.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]), torch.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])], dim=1 ) area_predictions = (hyps[:, 2] - hyps[:, 0]) * (hyps[:, 3] - hyps[:, 1]) area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1]) interacts_w = interacts[:, 2] - interacts[:, 0] interacts_h = interacts[:, 3] - interacts[:, 1] area_interacts = interacts_w * interacts_h ious = area_interacts / (area_predictions + area_targets - area_interacts + 1e-6) if max_area_size is not None and min_area_size is not None: ious = ious * (area_targets > max_area_size).float() * (area_targets < min_area_size).float() elif min_area_size is not None: ious = ious * (area_targets > min_area_size).float() elif max_area_size is not None: ious = ious * (area_targets < max_area_size).float() if thresh is None: return ious else: return ((ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)).float() gen_out = task.inference_step(generator, models, sample) hyps_ = [] refs_ = [] for i in range(len(gen_out)): hyps_.append(gen_out[i][0]["tokens"][:-1] - len(task.src_dict) + task.cfg.num_bins) refs_.append(sample["target"][i][:-1] - len(task.src_dict) + task.cfg.num_bins) refs_ = torch.stack(refs_, dim=0) hyps_ = torch.stack(hyps_, dim=0) hyps = hyps_ / (task.cfg.num_bins - 1) * task.cfg.max_image_size hyps[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1) hyps[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1) results = [ {"uniq_id": sample_id, "box": [hyps[i][0].item(), hyps[i][1].item(), hyps[i][2].item(), hyps[i][3].item()]} for i, sample_id in enumerate(sample["id"].tolist()) ] scores_list = [] names = [] evaluate_cfg = kwargs['evaluate_cfg'] # task.cfg threshs = evaluate_cfg.acc_thresh if threshs is not None: if ',' in threshs: threshs = threshs.split(',') if not isinstance(threshs, list): threshs = [threshs] threshs = [float(t) for t in threshs] for thresh in threshs: scores = _calculate_ap_score(hyps, sample['region_coords'].float(), thresh=thresh) names.append(str(thresh)+'acc') scores_list.append(scores) if evaluate_cfg.min_area_size is not None: large_scores = _calculate_ap_score(hyps, sample['region_coords'].float(), thresh=thresh, min_area_size=evaluate_cfg.min_area_size) scores_list.append(large_scores) names.append(str(thresh)+'large_acc') if evaluate_cfg.max_area_size is not None: small_scores = _calculate_ap_score(hyps, sample['region_coords'].float(), thresh=thresh, max_area_size=evaluate_cfg.max_area_size) scores_list.append(small_scores) names.append(str(thresh)+'small_acc') if evaluate_cfg.max_area_size is not None and evaluate_cfg.min_area_size is not None: medium_scores = _calculate_ap_score(hyps, sample['region_coords'].float(), thresh=thresh, max_area_size=evaluate_cfg.max_area_size, min_area_size=evaluate_cfg.min_area_size) scores_list.append(medium_scores) names.append(str(thresh)+'medium_acc') if len(scores_list) > 0: scores = scores_list #[scores] + scores_list results = [names, results] return results, scores def eval_snli_ve(task, generator, models, sample, **kwargs): encoder_out = models[0].encoder( sample["net_input"]["src_tokens"], src_lengths=sample["net_input"]["src_lengths"], patch_images=sample["net_input"]["patch_images"], patch_masks=sample["net_input"]["patch_masks"] ) device = sample["net_input"]["src_tokens"].device eos_item = torch.tensor([task.src_dict.eos()]) pad = task.src_dict.pad() valid_result = [] for valid_answers, valid_constraint_masks in zip(task.valid_answers_list, task.valid_constraint_masks_list): valid_size = len(valid_answers) valid_tgt_items = [ torch.cat([torch.tensor(decoder_prompt[1:]), valid_answer, eos_item]) for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers ] valid_prev_items = [ torch.cat([torch.tensor(decoder_prompt), valid_answer]) for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers ] valid_constraint_mask_items = [ torch.cat( [torch.zeros(len(decoder_prompt) - 1, valid_constraint_mask.size(1)).bool(), valid_constraint_mask], dim=0 ) for decoder_prompt in sample["decoder_prompts"] for valid_constraint_mask in valid_constraint_masks ] valid_tgt = data_utils.collate_tokens(valid_tgt_items, pad_idx=pad).to(device) valid_prev_output = data_utils.collate_tokens(valid_prev_items, pad_idx=pad).to(device) valid_constraint_masks = data_utils.collate_tokens(valid_constraint_mask_items, pad_idx=pad).to(device) new_encoder_out = {} new_encoder_out["encoder_out"] = [ encoder_out["encoder_out"][0].repeat_interleave(valid_size, dim=1) ] new_encoder_out["encoder_padding_mask"] = [ encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_size, dim=0) ] new_encoder_out["position_embeddings"] = [ encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0) ] decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out) decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf) lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True) scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1) scores = scores.masked_fill(valid_tgt.eq(task.tgt_dict.pad()), 0) scores = scores.masked_fill((~valid_constraint_masks).all(2), 0) scores = scores.sum(1) scores = scores.view(-1, valid_size) valid_result.append(scores) valid_result = torch.cat(valid_result, dim=-1) predicts = valid_result.argmax(1).tolist() hyps = [task.index2ans[predict_index] for predict_index in predicts] results = [{"uniq_id": id, "answer": hyp} for id, hyp in zip(sample["id"].tolist(), hyps)] scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)] return results, scores def eval_image_gen(task, generator, models, sample, **kwargs): hypos, _ = task.inference_image(generator, sample, models) tokens = sample['net_input']['src_tokens'][0].view(-1).tolist() caption = task.bpe.decode(task.tgt_dict.string([token for token in tokens if token >= 4]))[ 38:].replace('/', '') text_similarity_score, indices = task.compute_text_similarity(hypos, caption, sample['net_input']['src_tokens'].device) results = [] for i, indice in enumerate(indices): results.append({"sample_id": str(sample["id"][0]), "score": text_similarity_score[i], "image": hypos[indice]}) scores = [max(text_similarity_score).item()] sorted_hyps = [hypos[indice] for indice in indices] # dump results if task.cfg.gen_images_path: caption_tokens = sample['net_input']['src_tokens'][0].view(-1).tolist() caption = task.bpe.decode(task.tgt_dict.string([token for token in caption_tokens if token >= 4]))[ 38:].replace('/', '') task.dump_images(sorted_hyps, text=caption, path=os.path.join(task.cfg.gen_images_path, 'all_results')) task.dump_images(sorted_hyps, text=caption, path=os.path.join(task.cfg.gen_images_path, 'top1'), topk=1) return results, scores def eval_image_classify(task, generator, models, sample, **kwargs): batch_size = sample["net_input"]["src_tokens"].size(0) encoder_out = models[0].encoder( sample["net_input"]["src_tokens"], src_lengths=sample["net_input"]["src_lengths"], patch_images=sample["net_input"]["patch_images"], patch_masks=sample["net_input"]["patch_masks"] ) device = sample["net_input"]["src_tokens"].device valid_result = [] for valid_tgt, valid_prev_output, valid_constraint_masks in zip(task.valid_tgt_list, task.valid_prev_output_list, task.valid_constraint_masks_list): valid_tgt_size = valid_tgt.size(0) valid_tgt = valid_tgt.repeat(batch_size, 1).to(device) valid_prev_output = valid_prev_output.repeat(batch_size, 1).to(device) valid_constraint_masks = valid_constraint_masks.repeat(batch_size, 1, 1).to(device) new_encoder_out = {} new_encoder_out["encoder_out"] = [ encoder_out["encoder_out"][0].repeat_interleave(valid_tgt_size, dim=1) ] new_encoder_out["encoder_padding_mask"] = [ encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_tgt_size, dim=0) ] new_encoder_out["position_embeddings"] = [ encoder_out["position_embeddings"][0].repeat_interleave(valid_tgt_size, dim=0) ] decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out) decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf) lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True) scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1) scores = scores.masked_fill(valid_tgt.eq(task.tgt_dict.pad()), 0) scores = scores.sum(1) scores = scores.view(-1, valid_tgt_size) valid_result.append(scores) valid_result = torch.cat(valid_result, dim=-1) predicts = valid_result.argmax(1).tolist() hyps = [task.index2ans[predict_index] for predict_index in predicts] scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)] results = [{"uniq_id": id, "answer": hyp} for id, hyp in zip(sample["id"].tolist(), hyps)] return results, scores def eval_step(task, generator, models, sample, **kwargs): if 'caption' in task.cfg._name: return eval_caption(task, generator, models, sample, **kwargs) elif 'vqa_gen' in task.cfg._name: return eval_vqa_gen(task, generator, models, sample, **kwargs) elif task.cfg._name == 'refcoco': return eval_refcoco(task, generator, models, sample, **kwargs) elif task.cfg._name == 'snli_ve': return eval_snli_ve(task, generator, models, sample, **kwargs) elif task.cfg._name == 'image_gen': return eval_image_gen(task, generator, models, sample, **kwargs) else: raise NotImplementedError def merge_results(task, cfg, logger, score_cnt, score_sum, results): if task.cfg._name == 'image_gen': if cfg.distributed_training.distributed_world_size > 1: dist.all_reduce(score_sum.data) dist.all_reduce(score_cnt.data) if score_cnt.item() > 0: logger.info("score_sum: {}, score_cnt: {}, score: {}".format( score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4) )) else: gather_results = None if cfg.distributed_training.distributed_world_size > 1: gather_results = [None for _ in range(dist.get_world_size())] dist.all_gather_object(gather_results, results) dist.all_reduce(score_sum.data) dist.all_reduce(score_cnt.data) if score_cnt.item() > 0: logger.info("score_sum: {}, score_cnt: {}, score: {}".format( score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4) )) if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0: os.makedirs(cfg.common_eval.results_path, exist_ok=True) output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset)) gather_results = list(chain(*gather_results)) if gather_results is not None else results with open(output_path, 'w') as fw: json.dump(gather_results, fw)