File size: 5,874 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import os.path as osp
from typing import Optional
import json
from mmengine.dist import master_only
from .base_eval_dataset import BaseEvalDataset

from xtuner.registry import BUILDER
from mmengine.logging import print_log
from .utils import custom_data_process


def relaxed_correctness(prediction: str,
                        target: str,
                        max_relative_change: float = 0.05) -> bool:
    """Calculates relaxed correctness.

    The correctness tolerates certain error ratio defined by max_relative_change.
    See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
    “Following Methani et al. (2020), we use a relaxed accuracy measure for the
    numeric answers to allow a minor inaccuracy that may result from the automatic
    data extraction process. We consider an answer to be correct if it is within
    5% of the gold answer. For non-numeric answers, we still need an exact match
    to consider an answer to be correct.”

    Args:
      prediction: Predicted string.
      target: Target string.
      max_relative_change: Maximum relative change.

    Returns:
      Whether the prediction was correct given the specified tolerance.
    """

    def _to_float(text: str) -> Optional[float]:
        try:
            if text.endswith('%'):
                # Convert percentages to floats.
                return float(text.rstrip('%')) / 100.0
            else:
                return float(text)
        except ValueError:
            return None

    prediction_float = _to_float(prediction)
    target_float = _to_float(target)
    if prediction_float is not None and target_float:
        relative_change = abs(prediction_float -
                              target_float) / abs(target_float)
        return relative_change <= max_relative_change
    else:
        return prediction.lower() == target.lower()


def evaluate_relaxed_accuracy(entries):
    scores = []
    for elem in entries:
        if isinstance(elem['label'], str):
            elem['label'] = [elem['label']]
        score = max([
            relaxed_correctness(elem['prediction'].strip(), ann)
            for ann in elem['label']
        ])
        scores.append(score)
    return scores, sum(scores) / len(scores)


class ChartQADataset(BaseEvalDataset):
    METAINFO: dict = dict(name='chartqa')

    def __init__(
            self,
            data_file,
            image_folder,
            image_processor,
            pad_image_to_square=True,
            metainfo=None,
    ):
        super().__init__(metainfo)

        if isinstance(data_file, str):
            data_file = [data_file]
        self.raw_data = [json.load(open(f)) for f in data_file]
        # test_human, test_augmented
        self.name = [
            os.path.splitext(os.path.basename(f))[0] for f in data_file
        ]
        self.name_map = {name: i for i, name in enumerate(self.name)}
        self.revert_name_map = {i: name for i, name in enumerate(self.name)}

        self.image_folder = image_folder
        self.image_processor = BUILDER.build(image_processor)
        self.pad_image_to_square = pad_image_to_square
        self.data = self.load_data_list()

    def load_data_list(self):
        data_list = []
        idx = 0

        for data_idx in range(len(self.raw_data)):
            for sample_idx in range(len(self.raw_data[data_idx])):
                sample = self.raw_data[data_idx][sample_idx]
                image_path = sample['imgname']
                question = sample['query']
                answer = sample['label']
                category = self.name[data_idx]
                data = {
                    'img_id': idx,
                    'image_path': image_path,
                    'question': question,
                    'answer': answer,
                    'category': category
                }
                data_list.append(data)
                idx += 1
        return data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        data_dict = custom_data_process(self, data)
        return data_dict

    @master_only
    def evaluate(self, result, work_dir):
        orig_index = [x['img_id'] for x in self.data]
        results = [[] for _ in range(len(self.name))]
        for pred_dict in result:
            index = pred_dict['img_id']
            new_index = orig_index.index(index)
            filtered_rows = self.data[new_index]
            cur_result = {}
            cur_result['query'] = filtered_rows.get('question')
            cur_result['prediction'] = pred_dict['prediction']
            cur_result['label'] = filtered_rows.get('answer')

            index = self.name_map[filtered_rows['category']]
            results[index].append(cur_result)

        print_log('============================================', 'current')
        acc_list = []
        for i, result in enumerate(results):
            scores, _accuracy = evaluate_relaxed_accuracy(result)

            for res, score in zip(result, scores):
                res['score'] = score
            prediction_file = osp.join(work_dir, self.revert_name_map[i] + '.json')
            with open(prediction_file, 'w') as f:
                json.dump(result, f)

            print_log('Acc: {}, Category: {}, # samples: {}'.format(_accuracy, self.revert_name_map[i],
                                                                    len(result)), 'current')
            acc_list.append(_accuracy)

        print_log('============================================', 'current')
        acc = sum(acc_list) / len(acc_list)
        print_log('Overall Acc: {}'.format(acc), 'current')
        print_log('============================================', 'current')
        print_log('ChartQA successfully finished evaluating', 'current')

        return {'Acc': acc}