|
import os |
|
import os.path as osp |
|
import re |
|
|
|
from .base_eval_dataset import BaseEvalDataset |
|
|
|
from xtuner.registry import BUILDER |
|
import json |
|
from mmengine.dist import (master_only) |
|
from .textvqa_utils import TextVQAAccuracyEvaluator |
|
from mmengine.logging import print_log |
|
from .utils import custom_data_process |
|
|
|
def prompt_processor(prompt): |
|
if prompt.startswith('OCR tokens: '): |
|
pattern = r"Question: (.*?) Short answer:" |
|
match = re.search(pattern, prompt, re.DOTALL) |
|
question = match.group(1) |
|
elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: |
|
if prompt.startswith('Reference OCR token:'): |
|
question = prompt.split('\n')[1] |
|
else: |
|
question = prompt.split('\n')[0] |
|
elif len(prompt.split('\n')) == 2: |
|
question = prompt.split('\n')[0] |
|
else: |
|
assert False |
|
|
|
return question.lower() |
|
|
|
|
|
class TextVQADataset(BaseEvalDataset): |
|
METAINFO: dict = dict(name='textvqa') |
|
|
|
def __init__(self, data_file, ann_file, |
|
image_folder, image_processor, |
|
pad_image_to_square=True, metainfo=None,): |
|
super().__init__(metainfo) |
|
self.data_file = data_file |
|
self.ann_file = ann_file |
|
self.image_folder = image_folder |
|
|
|
self.image_processor = BUILDER.build(image_processor) |
|
self.pad_image_to_square = pad_image_to_square |
|
self.name = os.path.splitext(os.path.basename(data_file))[0] |
|
self.results_path = os.path.splitext(os.path.basename(data_file))[0] + '-results.jsonl' |
|
self.data = self.load_data_list() |
|
|
|
def load_data_list(self): |
|
data = [json.loads(q) for q in open(os.path.expanduser(self.data_file), "r")] |
|
for i, d in enumerate(data): |
|
d['img_id'] = i |
|
d['image_path'] = d['image'] |
|
d['question'] = d['text'] |
|
return data |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
data = self.data[idx] |
|
data_dict = custom_data_process(self, data) |
|
return data_dict |
|
|
|
@master_only |
|
def evaluate(self, result, work_dir, show=True): |
|
answers_file = osp.join(work_dir, self.results_path) |
|
ans_file = open(answers_file, "w") |
|
|
|
for pred_dict in result: |
|
idx = pred_dict["img_id"] |
|
gt_data = self.data[idx] |
|
|
|
ans_file.write(json.dumps({"question_id": gt_data['question_id'], |
|
"prompt": gt_data['text'], |
|
"text": pred_dict['prediction'], |
|
"metadata": {}}) + "\n") |
|
ans_file.close() |
|
|
|
annotations = json.load(open(self.ann_file))['data'] |
|
annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in |
|
annotations} |
|
results = [json.loads(line) for line in open(answers_file)] |
|
|
|
pred_list = [] |
|
for result in results: |
|
annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] |
|
pred_list.append({ |
|
"pred_answer": result['text'], |
|
"gt_answers": annotation['answers'], |
|
}) |
|
|
|
evaluator = TextVQAAccuracyEvaluator() |
|
acc = 100. * evaluator.eval_pred_list(pred_list) |
|
print_log('============================================', 'current') |
|
print_log('Samples: {}, Accuracy: {:.2f}%'.format(len(pred_list), acc), 'current') |
|
print_log('============================================', 'current') |
|
print_log(f'TextVQA successfully finished evaluating', 'current') |
|
return {'acc': acc} |
|
|