Spaces:
Running
Running
Upload data/utils.py
Browse files- data/utils.py +112 -0
data/utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
|
8 |
+
import utils
|
9 |
+
|
10 |
+
def pre_caption(caption,max_words=50):
|
11 |
+
caption = re.sub(
|
12 |
+
r"([.!\"()*#:;~])",
|
13 |
+
' ',
|
14 |
+
caption.lower(),
|
15 |
+
)
|
16 |
+
caption = re.sub(
|
17 |
+
r"\s{2,}",
|
18 |
+
' ',
|
19 |
+
caption,
|
20 |
+
)
|
21 |
+
caption = caption.rstrip('\n')
|
22 |
+
caption = caption.strip(' ')
|
23 |
+
|
24 |
+
#truncate caption
|
25 |
+
caption_words = caption.split(' ')
|
26 |
+
if len(caption_words)>max_words:
|
27 |
+
caption = ' '.join(caption_words[:max_words])
|
28 |
+
|
29 |
+
return caption
|
30 |
+
|
31 |
+
def pre_question(question,max_ques_words=50):
|
32 |
+
question = re.sub(
|
33 |
+
r"([.!\"()*#:;~])",
|
34 |
+
'',
|
35 |
+
question.lower(),
|
36 |
+
)
|
37 |
+
question = question.rstrip(' ')
|
38 |
+
|
39 |
+
#truncate question
|
40 |
+
question_words = question.split(' ')
|
41 |
+
if len(question_words)>max_ques_words:
|
42 |
+
question = ' '.join(question_words[:max_ques_words])
|
43 |
+
|
44 |
+
return question
|
45 |
+
|
46 |
+
|
47 |
+
def save_result(result, result_dir, filename, remove_duplicate=''):
|
48 |
+
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
|
49 |
+
final_result_file = os.path.join(result_dir, '%s.json'%filename)
|
50 |
+
|
51 |
+
json.dump(result,open(result_file,'w'))
|
52 |
+
|
53 |
+
dist.barrier()
|
54 |
+
|
55 |
+
if utils.is_main_process():
|
56 |
+
# combine results from all processes
|
57 |
+
result = []
|
58 |
+
|
59 |
+
for rank in range(utils.get_world_size()):
|
60 |
+
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
|
61 |
+
res = json.load(open(result_file,'r'))
|
62 |
+
result += res
|
63 |
+
|
64 |
+
if remove_duplicate:
|
65 |
+
result_new = []
|
66 |
+
id_list = []
|
67 |
+
for res in result:
|
68 |
+
if res[remove_duplicate] not in id_list:
|
69 |
+
id_list.append(res[remove_duplicate])
|
70 |
+
result_new.append(res)
|
71 |
+
result = result_new
|
72 |
+
|
73 |
+
json.dump(result,open(final_result_file,'w'))
|
74 |
+
print('result file saved to %s'%final_result_file)
|
75 |
+
|
76 |
+
return final_result_file
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
from pycocotools.coco import COCO
|
81 |
+
from pycocoevalcap.eval import COCOEvalCap
|
82 |
+
from torchvision.datasets.utils import download_url
|
83 |
+
|
84 |
+
def coco_caption_eval(coco_gt_root, results_file, split):
|
85 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
|
86 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
|
87 |
+
filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
|
88 |
+
|
89 |
+
download_url(urls[split],coco_gt_root)
|
90 |
+
annotation_file = os.path.join(coco_gt_root,filenames[split])
|
91 |
+
|
92 |
+
# create coco object and coco_result object
|
93 |
+
coco = COCO(annotation_file)
|
94 |
+
coco_result = coco.loadRes(results_file)
|
95 |
+
|
96 |
+
# create coco_eval object by taking coco and coco_result
|
97 |
+
coco_eval = COCOEvalCap(coco, coco_result)
|
98 |
+
|
99 |
+
# evaluate on a subset of images by setting
|
100 |
+
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
101 |
+
# please remove this line when evaluating the full validation set
|
102 |
+
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
103 |
+
|
104 |
+
# evaluate results
|
105 |
+
# SPICE will take a few minutes the first time, but speeds up due to caching
|
106 |
+
coco_eval.evaluate()
|
107 |
+
|
108 |
+
# print output evaluation scores
|
109 |
+
for metric, score in coco_eval.eval.items():
|
110 |
+
print(f'{metric}: {score:.3f}')
|
111 |
+
|
112 |
+
return coco_eval
|