File size: 6,287 Bytes
032e687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import os
import os.path as osp
from mmengine.dist import master_only
from .base_eval_dataset import BaseEvalDataset
from xtuner.registry import BUILDER
from mmengine.logging import print_log
import pandas as pd
from xtuner.dataset.utils import decode_base64_to_image
import numpy as np
from .utils import custom_data_process
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2 + 1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def anls_compute(groundtruth, prediction):
gt_answer = ' '.join(groundtruth.strip().lower().split())
det_answer = ' '.join(prediction.strip().lower().split())
dist = levenshtein_distance(gt_answer, det_answer)
length = max(len(groundtruth.upper()), len(prediction.upper()))
values = 0.0 if length == 0 else float(dist) / float(length)
return values
def hit_calculate(result, dataset_name, anls_threshold=0.5):
if 'DocVQA' in dataset_name or 'InfoVQA' in dataset_name:
# return [1 - np.min(x['match']) >= anls_threshold for x in result]
return [0.0 if 1 - np.min(x['match']) < anls_threshold else 1 - np.min(x['match']) for x in result]
elif 'OCRVQA' in dataset_name:
return [np.max(x['match']) for x in result]
else:
raise NotImplementedError(f"Dataset {dataset_name} not supported for hit calculation")
def istype(s, type):
if isinstance(s, type):
return True
try:
return isinstance(eval(s), type)
except Exception as _:
return False
class GeneralVQADataset(BaseEvalDataset):
METAINFO: dict = dict(name='gvqa')
def __init__(self, data_file, image_processor,
pad_image_to_square=True,
anls_threshold=0.5, metainfo=None,):
super().__init__(metainfo)
self.anls_threshold = anls_threshold
self.data_file = data_file
self.df = pd.read_csv(data_file, sep='\t')
self.ocr = False
if 'OCR' in data_file:
self.ocr = True
skip_noimg = True
if skip_noimg:
self.df = self.df[~pd.isna(self.df['image'])]
self.image_processor = BUILDER.build(image_processor)
self.pad_image_to_square = pad_image_to_square
self.name = os.path.splitext(os.path.basename(data_file))[0]
self.results_xlsx_path = os.path.splitext(os.path.basename(data_file))[0] + '-results.xlsx'
self.data = self.load_data_list()
def get_image(self, image):
while len(image) < 16:
if self.ocr:
image = self.df[self.df['index'] == image]['image'].values
else:
image = self.df[self.df['index'] == int(image)]['image'].values
assert len(image) == 1
image = image[0]
image = decode_base64_to_image(image)
return image
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
data = self.data[idx]
data_dict = custom_data_process(self, data)
return data_dict
def load_data_list(self):
data_list = []
for idx in range(len(self.df)):
index = self.df.iloc[idx]['index']
image = self.df.iloc[idx]['image']
question = self.df.iloc[idx]['question']
split = self.df.iloc[idx]['split'] if 'split' in self.df.iloc[
0].keys() else None
answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[
0].keys() else None
data = {
'img': image,
'question': question,
'answer': answer,
'index': index,
'img_id': idx
}
if split is not None:
data['split'] = split
data_list.append(data)
return data_list
@master_only
def evaluate(self, results, work_dir):
orig_index = [x['img_id'] for x in self.data]
new_results = []
for pred_dict in results:
index = pred_dict['img_id']
new_index = orig_index.index(index)
filtered_rows = self.data[new_index]
cur_result = {}
cur_result['question'] = filtered_rows.get('question')
cur_result['split'] = filtered_rows.get('split')
cur_result['prediction'] = pred_dict['prediction']
cur_result['index'] = filtered_rows.get('index')
cur_result['index'] = filtered_rows.get('answer')
answers = filtered_rows.get('answer')
if istype(answers, list):
answers = eval(answers)
else:
answers = [answers]
if 'OCRVQA' in self.name:
match = [(1.0 if (x.strip().lower() == cur_result['prediction'].strip().lower()) else 0.0) for x in
answers]
else:
match = [anls_compute(x, cur_result['prediction']) for x in answers]
cur_result['match'] = match
new_results.append(cur_result)
results_df = pd.DataFrame(new_results)
with pd.ExcelWriter(osp.join(work_dir, self.results_xlsx_path), engine='openpyxl') as writer:
results_df.to_excel(writer, index=False)
ret = dict()
if 'split' in results_df:
splits = list(set(results_df['split']))
for sp in splits:
sub = [new_results[i] for i, x in enumerate(new_results) if x['split'] == sp]
hit = hit_calculate(sub, self.name)
ret[sp] = np.mean(hit) * 100
else:
hit = hit_calculate(new_results, self.name)
ret['overall'] = np.mean(hit) * 100
print_log('============================================', 'current')
print_log(ret, 'current')
print_log('============================================', 'current')
print_log(f'{self.name} successfully finished evaluating', 'current')
return ret
|