DenseLabelDev / vlm /datasets /evaluation /chartqa_dataset.py
zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import os
import os.path as osp
from typing import Optional
import json
from mmengine.dist import master_only
from .base_eval_dataset import BaseEvalDataset
from xtuner.registry import BUILDER
from mmengine.logging import print_log
from .utils import custom_data_process
def relaxed_correctness(prediction: str,
target: str,
max_relative_change: float = 0.05) -> bool:
"""Calculates relaxed correctness.
The correctness tolerates certain error ratio defined by max_relative_change.
See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
“Following Methani et al. (2020), we use a relaxed accuracy measure for the
numeric answers to allow a minor inaccuracy that may result from the automatic
data extraction process. We consider an answer to be correct if it is within
5% of the gold answer. For non-numeric answers, we still need an exact match
to consider an answer to be correct.”
Args:
prediction: Predicted string.
target: Target string.
max_relative_change: Maximum relative change.
Returns:
Whether the prediction was correct given the specified tolerance.
"""
def _to_float(text: str) -> Optional[float]:
try:
if text.endswith('%'):
# Convert percentages to floats.
return float(text.rstrip('%')) / 100.0
else:
return float(text)
except ValueError:
return None
prediction_float = _to_float(prediction)
target_float = _to_float(target)
if prediction_float is not None and target_float:
relative_change = abs(prediction_float -
target_float) / abs(target_float)
return relative_change <= max_relative_change
else:
return prediction.lower() == target.lower()
def evaluate_relaxed_accuracy(entries):
scores = []
for elem in entries:
if isinstance(elem['label'], str):
elem['label'] = [elem['label']]
score = max([
relaxed_correctness(elem['prediction'].strip(), ann)
for ann in elem['label']
])
scores.append(score)
return scores, sum(scores) / len(scores)
class ChartQADataset(BaseEvalDataset):
METAINFO: dict = dict(name='chartqa')
def __init__(
self,
data_file,
image_folder,
image_processor,
pad_image_to_square=True,
metainfo=None,
):
super().__init__(metainfo)
if isinstance(data_file, str):
data_file = [data_file]
self.raw_data = [json.load(open(f)) for f in data_file]
# test_human, test_augmented
self.name = [
os.path.splitext(os.path.basename(f))[0] for f in data_file
]
self.name_map = {name: i for i, name in enumerate(self.name)}
self.revert_name_map = {i: name for i, name in enumerate(self.name)}
self.image_folder = image_folder
self.image_processor = BUILDER.build(image_processor)
self.pad_image_to_square = pad_image_to_square
self.data = self.load_data_list()
def load_data_list(self):
data_list = []
idx = 0
for data_idx in range(len(self.raw_data)):
for sample_idx in range(len(self.raw_data[data_idx])):
sample = self.raw_data[data_idx][sample_idx]
image_path = sample['imgname']
question = sample['query']
answer = sample['label']
category = self.name[data_idx]
data = {
'img_id': idx,
'image_path': image_path,
'question': question,
'answer': answer,
'category': category
}
data_list.append(data)
idx += 1
return data_list
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
data_dict = custom_data_process(self, data)
return data_dict
@master_only
def evaluate(self, result, work_dir):
orig_index = [x['img_id'] for x in self.data]
results = [[] for _ in range(len(self.name))]
for pred_dict in result:
index = pred_dict['img_id']
new_index = orig_index.index(index)
filtered_rows = self.data[new_index]
cur_result = {}
cur_result['query'] = filtered_rows.get('question')
cur_result['prediction'] = pred_dict['prediction']
cur_result['label'] = filtered_rows.get('answer')
index = self.name_map[filtered_rows['category']]
results[index].append(cur_result)
print_log('============================================', 'current')
acc_list = []
for i, result in enumerate(results):
scores, _accuracy = evaluate_relaxed_accuracy(result)
for res, score in zip(result, scores):
res['score'] = score
prediction_file = osp.join(work_dir, self.revert_name_map[i] + '.json')
with open(prediction_file, 'w') as f:
json.dump(result, f)
print_log('Acc: {}, Category: {}, # samples: {}'.format(_accuracy, self.revert_name_map[i],
len(result)), 'current')
acc_list.append(_accuracy)
print_log('============================================', 'current')
acc = sum(acc_list) / len(acc_list)
print_log('Overall Acc: {}'.format(acc), 'current')
print_log('============================================', 'current')
print_log('ChartQA successfully finished evaluating', 'current')
return {'Acc': acc}