File size: 4,053 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import os.path as osp
import json
from mmengine.dist import master_only
from .base_eval_dataset import BaseEvalDataset

from xtuner.registry import BUILDER
from mmengine.logging import print_log
from .vqav2_utils import EvalAIAnswerProcessor
from .utils import custom_data_process


class VQAv2Dataset(BaseEvalDataset):

    METAINFO: dict = dict(name='vqa_v2')

    def __init__(
        self,
        data_file,
        test_file,
        image_folder,
        image_processor,
        pad_image_to_square=True,
        metainfo=None,
    ):
        super().__init__(metainfo)
        self.data_file = data_file
        self.test_file = test_file
        self.image_folder = image_folder
        # Save detailed information for easy viewing
        self.answer_file = 'answer_vqav2_results.json'
        # solely for evaluation purposes
        self.prediction_file = 'pred_vqav2_results.json'
        self.answer_processor = EvalAIAnswerProcessor()

        self.pad_image_to_square = pad_image_to_square

        self.image_processor = BUILDER.build(image_processor)

        self.data = self.load_data_list()

    def load_data_list(self):
        question_data = [json.loads(q) for q in open(os.path.expanduser(self.data_file), "r")]
        data_list = []
        for idx in range(len(question_data)):
            sample = question_data[idx]
            index = sample['question_id']
            image_path = sample['image']
            question = sample['text']
            category = sample['category']

            data = {
                'img_id': idx,
                'index': index,
                'image_path': image_path,
                'question': question,
                'category': category,
            }
            data_list.append(data)

        return data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        data_dict = custom_data_process(self, data)
        return data_dict

    @master_only
    def evaluate(self, results, work_dir):
        answers_file = osp.join(work_dir, self.answer_file)
        ans_file = open(answers_file, "w")

        for pred_dict in results:
            idx = pred_dict["img_id"]
            gt_data = self.data[idx]

            ans_file.write(
                json.dumps(
                    {
                        "question_id": gt_data['index'],
                        "prompt": gt_data['question'],
                        "text": pred_dict['prediction'],
                        "metadata": {},
                    }
                )
                + "\n"
            )
        ans_file.close()

        results = []
        error_line = 0
        for line_idx, line in enumerate(open(answers_file)):
            try:
                results.append(json.loads(line))
            except:
                error_line += 1

        results = {x['question_id']: x['text'] for x in results}
        test_split = [json.loads(line) for line in open(self.test_file)]

        all_answers = []

        for x in test_split:
            if x['question_id'] not in results:
                all_answers.append({
                    'question_id': x['question_id'],
                    'answer': ''
                })
            else:
                all_answers.append({
                    'question_id': x['question_id'],
                    'answer': self.answer_processor(results[x['question_id']])
                })

        prediction_file = osp.join(work_dir, self.prediction_file)
        with open(prediction_file, 'w') as f:
            json.dump(all_answers, f)

        print_log('============================================', 'current')
        print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}')
        print_log(f'Please submit the generated {prediction_file} file to the official server for evaluation.',
                  'current')
        print_log('============================================', 'current')
        return {'acc': 0}