| import re
|
| import json
|
| import os
|
|
|
| import torch
|
| import torch.distributed as dist
|
|
|
| import utils
|
|
|
| def pre_caption(caption,max_words=50):
|
| caption = re.sub(
|
| r"([.!\"()*#:;~])",
|
| ' ',
|
| caption.lower(),
|
| )
|
| caption = re.sub(
|
| r"\s{2,}",
|
| ' ',
|
| caption,
|
| )
|
| caption = caption.rstrip('\n')
|
| caption = caption.strip(' ')
|
|
|
|
|
| caption_words = caption.split(' ')
|
| if len(caption_words)>max_words:
|
| caption = ' '.join(caption_words[:max_words])
|
|
|
| return caption
|
|
|
| def pre_question(question,max_ques_words=50):
|
| question = re.sub(
|
| r"([.!\"()*#:;~])",
|
| '',
|
| question.lower(),
|
| )
|
| question = question.rstrip(' ')
|
|
|
|
|
| question_words = question.split(' ')
|
| if len(question_words)>max_ques_words:
|
| question = ' '.join(question_words[:max_ques_words])
|
|
|
| return question
|
|
|
|
|
| def save_result(result, result_dir, filename, remove_duplicate=''):
|
| result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
|
| final_result_file = os.path.join(result_dir, '%s.json'%filename)
|
|
|
| json.dump(result,open(result_file,'w'))
|
|
|
| dist.barrier()
|
|
|
| if utils.is_main_process():
|
|
|
| result = []
|
|
|
| for rank in range(utils.get_world_size()):
|
| result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
|
| res = json.load(open(result_file,'r'))
|
| result += res
|
|
|
| if remove_duplicate:
|
| result_new = []
|
| id_list = []
|
| for res in result:
|
| if res[remove_duplicate] not in id_list:
|
| id_list.append(res[remove_duplicate])
|
| result_new.append(res)
|
| result = result_new
|
|
|
| json.dump(result,open(final_result_file,'w'))
|
| print('result file saved to %s'%final_result_file)
|
|
|
| return final_result_file
|
|
|
|
|
|
|
| from pycocotools.coco import COCO
|
| from pycocoevalcap.eval import COCOEvalCap
|
| from torchvision.datasets.utils import download_url
|
|
|
| def coco_caption_eval(coco_gt_root, results_file, split):
|
| urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
|
| 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
|
| filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
|
|
|
| download_url(urls[split],coco_gt_root)
|
| annotation_file = os.path.join(coco_gt_root,filenames[split])
|
|
|
|
|
| coco = COCO(annotation_file)
|
| coco_result = coco.loadRes(results_file)
|
|
|
|
|
| coco_eval = COCOEvalCap(coco, coco_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| coco_eval.evaluate()
|
|
|
|
|
| for metric, score in coco_eval.eval.items():
|
| print(f'{metric}: {score:.3f}')
|
|
|
| return coco_eval |