|
import os |
|
import os.path as osp |
|
from typing import Optional |
|
import json |
|
from mmengine.dist import master_only |
|
from .base_eval_dataset import BaseEvalDataset |
|
|
|
from xtuner.registry import BUILDER |
|
from mmengine.logging import print_log |
|
from .utils import custom_data_process |
|
|
|
|
|
def relaxed_correctness(prediction: str, |
|
target: str, |
|
max_relative_change: float = 0.05) -> bool: |
|
"""Calculates relaxed correctness. |
|
|
|
The correctness tolerates certain error ratio defined by max_relative_change. |
|
See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1: |
|
“Following Methani et al. (2020), we use a relaxed accuracy measure for the |
|
numeric answers to allow a minor inaccuracy that may result from the automatic |
|
data extraction process. We consider an answer to be correct if it is within |
|
5% of the gold answer. For non-numeric answers, we still need an exact match |
|
to consider an answer to be correct.” |
|
|
|
Args: |
|
prediction: Predicted string. |
|
target: Target string. |
|
max_relative_change: Maximum relative change. |
|
|
|
Returns: |
|
Whether the prediction was correct given the specified tolerance. |
|
""" |
|
|
|
def _to_float(text: str) -> Optional[float]: |
|
try: |
|
if text.endswith('%'): |
|
|
|
return float(text.rstrip('%')) / 100.0 |
|
else: |
|
return float(text) |
|
except ValueError: |
|
return None |
|
|
|
prediction_float = _to_float(prediction) |
|
target_float = _to_float(target) |
|
if prediction_float is not None and target_float: |
|
relative_change = abs(prediction_float - |
|
target_float) / abs(target_float) |
|
return relative_change <= max_relative_change |
|
else: |
|
return prediction.lower() == target.lower() |
|
|
|
|
|
def evaluate_relaxed_accuracy(entries): |
|
scores = [] |
|
for elem in entries: |
|
if isinstance(elem['label'], str): |
|
elem['label'] = [elem['label']] |
|
score = max([ |
|
relaxed_correctness(elem['prediction'].strip(), ann) |
|
for ann in elem['label'] |
|
]) |
|
scores.append(score) |
|
return scores, sum(scores) / len(scores) |
|
|
|
|
|
class ChartQADataset(BaseEvalDataset): |
|
METAINFO: dict = dict(name='chartqa') |
|
|
|
def __init__( |
|
self, |
|
data_file, |
|
image_folder, |
|
image_processor, |
|
pad_image_to_square=True, |
|
metainfo=None, |
|
): |
|
super().__init__(metainfo) |
|
|
|
if isinstance(data_file, str): |
|
data_file = [data_file] |
|
self.raw_data = [json.load(open(f)) for f in data_file] |
|
|
|
self.name = [ |
|
os.path.splitext(os.path.basename(f))[0] for f in data_file |
|
] |
|
self.name_map = {name: i for i, name in enumerate(self.name)} |
|
self.revert_name_map = {i: name for i, name in enumerate(self.name)} |
|
|
|
self.image_folder = image_folder |
|
self.image_processor = BUILDER.build(image_processor) |
|
self.pad_image_to_square = pad_image_to_square |
|
self.data = self.load_data_list() |
|
|
|
def load_data_list(self): |
|
data_list = [] |
|
idx = 0 |
|
|
|
for data_idx in range(len(self.raw_data)): |
|
for sample_idx in range(len(self.raw_data[data_idx])): |
|
sample = self.raw_data[data_idx][sample_idx] |
|
image_path = sample['imgname'] |
|
question = sample['query'] |
|
answer = sample['label'] |
|
category = self.name[data_idx] |
|
data = { |
|
'img_id': idx, |
|
'image_path': image_path, |
|
'question': question, |
|
'answer': answer, |
|
'category': category |
|
} |
|
data_list.append(data) |
|
idx += 1 |
|
return data_list |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
data = self.data[idx] |
|
data_dict = custom_data_process(self, data) |
|
return data_dict |
|
|
|
@master_only |
|
def evaluate(self, result, work_dir): |
|
orig_index = [x['img_id'] for x in self.data] |
|
results = [[] for _ in range(len(self.name))] |
|
for pred_dict in result: |
|
index = pred_dict['img_id'] |
|
new_index = orig_index.index(index) |
|
filtered_rows = self.data[new_index] |
|
cur_result = {} |
|
cur_result['query'] = filtered_rows.get('question') |
|
cur_result['prediction'] = pred_dict['prediction'] |
|
cur_result['label'] = filtered_rows.get('answer') |
|
|
|
index = self.name_map[filtered_rows['category']] |
|
results[index].append(cur_result) |
|
|
|
print_log('============================================', 'current') |
|
acc_list = [] |
|
for i, result in enumerate(results): |
|
scores, _accuracy = evaluate_relaxed_accuracy(result) |
|
|
|
for res, score in zip(result, scores): |
|
res['score'] = score |
|
prediction_file = osp.join(work_dir, self.revert_name_map[i] + '.json') |
|
with open(prediction_file, 'w') as f: |
|
json.dump(result, f) |
|
|
|
print_log('Acc: {}, Category: {}, # samples: {}'.format(_accuracy, self.revert_name_map[i], |
|
len(result)), 'current') |
|
acc_list.append(_accuracy) |
|
|
|
print_log('============================================', 'current') |
|
acc = sum(acc_list) / len(acc_list) |
|
print_log('Overall Acc: {}'.format(acc), 'current') |
|
print_log('============================================', 'current') |
|
print_log('ChartQA successfully finished evaluating', 'current') |
|
|
|
return {'Acc': acc} |
|
|