|
import re |
|
import json |
|
import os |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
import utils |
|
|
|
def pre_caption(caption,max_words=50): |
|
caption = re.sub( |
|
r"([.!\"()*#:;~])", |
|
' ', |
|
caption.lower(), |
|
) |
|
caption = re.sub( |
|
r"\s{2,}", |
|
' ', |
|
caption, |
|
) |
|
caption = caption.rstrip('\n') |
|
caption = caption.strip(' ') |
|
|
|
|
|
caption_words = caption.split(' ') |
|
if len(caption_words)>max_words: |
|
caption = ' '.join(caption_words[:max_words]) |
|
|
|
return caption |
|
|
|
def pre_question(question,max_ques_words=50): |
|
question = re.sub( |
|
r"([.!\"()*#:;~])", |
|
'', |
|
question.lower(), |
|
) |
|
question = question.rstrip(' ') |
|
|
|
|
|
question_words = question.split(' ') |
|
if len(question_words)>max_ques_words: |
|
question = ' '.join(question_words[:max_ques_words]) |
|
|
|
return question |
|
|
|
|
|
def save_result(result, result_dir, filename, remove_duplicate=''): |
|
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) |
|
final_result_file = os.path.join(result_dir, '%s.json'%filename) |
|
|
|
json.dump(result,open(result_file,'w')) |
|
|
|
dist.barrier() |
|
|
|
if utils.is_main_process(): |
|
|
|
result = [] |
|
|
|
for rank in range(utils.get_world_size()): |
|
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) |
|
res = json.load(open(result_file,'r')) |
|
result += res |
|
|
|
if remove_duplicate: |
|
result_new = [] |
|
id_list = [] |
|
for res in result: |
|
if res[remove_duplicate] not in id_list: |
|
id_list.append(res[remove_duplicate]) |
|
result_new.append(res) |
|
result = result_new |
|
|
|
json.dump(result,open(final_result_file,'w')) |
|
print('result file saved to %s'%final_result_file) |
|
|
|
return final_result_file |
|
|
|
|
|
|
|
from pycocotools.coco import COCO |
|
from pycocoevalcap.eval import COCOEvalCap |
|
from torchvision.datasets.utils import download_url |
|
|
|
def coco_caption_eval(coco_gt_root, results_file, split): |
|
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', |
|
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'} |
|
filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'} |
|
|
|
download_url(urls[split],coco_gt_root) |
|
annotation_file = os.path.join(coco_gt_root,filenames[split]) |
|
|
|
|
|
coco = COCO(annotation_file) |
|
coco_result = coco.loadRes(results_file) |
|
|
|
|
|
coco_eval = COCOEvalCap(coco, coco_result) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
coco_eval.evaluate() |
|
|
|
|
|
for metric, score in coco_eval.eval.items(): |
|
print(f'{metric}: {score:.3f}') |
|
|
|
return coco_eval |