File size: 3,418 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import os.path as osp
import json
from mmengine.dist import master_only
from .base_eval_dataset import BaseEvalDataset

from xtuner.registry import BUILDER
from mmengine.logging import print_log
from .gqa_eval_utils import eval_gqa
from .utils import custom_data_process


class GQADataset(BaseEvalDataset):
    METAINFO: dict = dict(name='gqa')

    def __init__(
            self,
            data_file,
            ann_file,
            image_folder,
            image_processor,
            pad_image_to_square=True,
            metainfo=None,
    ):
        super().__init__(metainfo)
        self.data_file = data_file
        self.ann_file = ann_file
        # Save detailed information for easy viewing
        self.answer_file = 'answer_gqa_results.jsonl'
        # solely for evaluation purposes
        self.prediction_file = 'pred_gqa_results.jsonl'

        self.image_folder = image_folder

        self.image_processor = BUILDER.build(image_processor)
        self.pad_image_to_square = pad_image_to_square
        self.data = self.load_data_list()

    def load_data_list(self):
        question_data = [json.loads(q) for q in open(os.path.expanduser(self.data_file), "r")]
        data_list = []
        for idx in range(len(question_data)):
            sample = question_data[idx]
            index = sample['question_id']
            image_path = sample['image']
            question = sample['text']
            category = sample['category']

            data = {
                'img_id': idx,
                'index': index,
                'image_path': image_path,
                'question': question,
                'category': category,
            }
            data_list.append(data)
        return data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        data_dict = custom_data_process(self, data)
        return data_dict

    @master_only
    def evaluate(self, results, work_dir):
        answers_file = osp.join(work_dir, self.answer_file)
        ans_file = open(answers_file, "w")

        for pred_dict in results:
            idx = pred_dict["img_id"]
            gt_data = self.data[idx]

            ans_file.write(
                json.dumps(
                    {
                        "question_id": gt_data['index'],
                        "prompt": gt_data['question'],
                        "text": pred_dict['prediction'],
                        "metadata": {},
                    }
                )
                + "\n"
            )
        ans_file.close()

        all_preds = []
        for line_idx, line in enumerate(open(answers_file)):
            res = json.loads(line)
            question_id = res['question_id']
            text = res['text'].rstrip('.').lower()
            all_preds.append({"questionId": question_id, "prediction": text})

        prediction_file = osp.join(work_dir, self.prediction_file)
        with open(prediction_file, 'w') as f:
            json.dump(all_preds, f)

        evaluator = eval_gqa(questions=self.ann_file, predictions=prediction_file)
        print_log('============================================', 'current')
        scores = evaluator.forward()
        print_log('============================================', 'current')
        print_log(f'GQA successfully finished evaluating', 'current')
        return scores