DenseLabelDev / vlm /datasets /evaluation /general_vqa_dataset.py
zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import os
import os.path as osp
from mmengine.dist import master_only
from .base_eval_dataset import BaseEvalDataset
from xtuner.registry import BUILDER
from mmengine.logging import print_log
import pandas as pd
from xtuner.dataset.utils import decode_base64_to_image
import numpy as np
from .utils import custom_data_process
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2 + 1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def anls_compute(groundtruth, prediction):
gt_answer = ' '.join(groundtruth.strip().lower().split())
det_answer = ' '.join(prediction.strip().lower().split())
dist = levenshtein_distance(gt_answer, det_answer)
length = max(len(groundtruth.upper()), len(prediction.upper()))
values = 0.0 if length == 0 else float(dist) / float(length)
return values
def hit_calculate(result, dataset_name, anls_threshold=0.5):
if 'DocVQA' in dataset_name or 'InfoVQA' in dataset_name:
# return [1 - np.min(x['match']) >= anls_threshold for x in result]
return [0.0 if 1 - np.min(x['match']) < anls_threshold else 1 - np.min(x['match']) for x in result]
elif 'OCRVQA' in dataset_name:
return [np.max(x['match']) for x in result]
else:
raise NotImplementedError(f"Dataset {dataset_name} not supported for hit calculation")
def istype(s, type):
if isinstance(s, type):
return True
try:
return isinstance(eval(s), type)
except Exception as _:
return False
class GeneralVQADataset(BaseEvalDataset):
METAINFO: dict = dict(name='gvqa')
def __init__(self, data_file, image_processor,
pad_image_to_square=True,
anls_threshold=0.5, metainfo=None,):
super().__init__(metainfo)
self.anls_threshold = anls_threshold
self.data_file = data_file
self.df = pd.read_csv(data_file, sep='\t')
self.ocr = False
if 'OCR' in data_file:
self.ocr = True
skip_noimg = True
if skip_noimg:
self.df = self.df[~pd.isna(self.df['image'])]
self.image_processor = BUILDER.build(image_processor)
self.pad_image_to_square = pad_image_to_square
self.name = os.path.splitext(os.path.basename(data_file))[0]
self.results_xlsx_path = os.path.splitext(os.path.basename(data_file))[0] + '-results.xlsx'
self.data = self.load_data_list()
def get_image(self, image):
while len(image) < 16:
if self.ocr:
image = self.df[self.df['index'] == image]['image'].values
else:
image = self.df[self.df['index'] == int(image)]['image'].values
assert len(image) == 1
image = image[0]
image = decode_base64_to_image(image)
return image
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
data = self.data[idx]
data_dict = custom_data_process(self, data)
return data_dict
def load_data_list(self):
data_list = []
for idx in range(len(self.df)):
index = self.df.iloc[idx]['index']
image = self.df.iloc[idx]['image']
question = self.df.iloc[idx]['question']
split = self.df.iloc[idx]['split'] if 'split' in self.df.iloc[
0].keys() else None
answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[
0].keys() else None
data = {
'img': image,
'question': question,
'answer': answer,
'index': index,
'img_id': idx
}
if split is not None:
data['split'] = split
data_list.append(data)
return data_list
@master_only
def evaluate(self, results, work_dir):
orig_index = [x['img_id'] for x in self.data]
new_results = []
for pred_dict in results:
index = pred_dict['img_id']
new_index = orig_index.index(index)
filtered_rows = self.data[new_index]
cur_result = {}
cur_result['question'] = filtered_rows.get('question')
cur_result['split'] = filtered_rows.get('split')
cur_result['prediction'] = pred_dict['prediction']
cur_result['index'] = filtered_rows.get('index')
cur_result['index'] = filtered_rows.get('answer')
answers = filtered_rows.get('answer')
if istype(answers, list):
answers = eval(answers)
else:
answers = [answers]
if 'OCRVQA' in self.name:
match = [(1.0 if (x.strip().lower() == cur_result['prediction'].strip().lower()) else 0.0) for x in
answers]
else:
match = [anls_compute(x, cur_result['prediction']) for x in answers]
cur_result['match'] = match
new_results.append(cur_result)
results_df = pd.DataFrame(new_results)
with pd.ExcelWriter(osp.join(work_dir, self.results_xlsx_path), engine='openpyxl') as writer:
results_df.to_excel(writer, index=False)
ret = dict()
if 'split' in results_df:
splits = list(set(results_df['split']))
for sp in splits:
sub = [new_results[i] for i, x in enumerate(new_results) if x['split'] == sp]
hit = hit_calculate(sub, self.name)
ret[sp] = np.mean(hit) * 100
else:
hit = hit_calculate(new_results, self.name)
ret['overall'] = np.mean(hit) * 100
print_log('============================================', 'current')
print_log(ret, 'current')
print_log('============================================', 'current')
print_log(f'{self.name} successfully finished evaluating', 'current')
return ret