File size: 3,693 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import os.path as osp
import re

from .base_eval_dataset import BaseEvalDataset

from xtuner.registry import BUILDER
import json
from mmengine.dist import (master_only)
from .textvqa_utils import TextVQAAccuracyEvaluator
from mmengine.logging import print_log
from .utils import custom_data_process

def prompt_processor(prompt):
    if prompt.startswith('OCR tokens: '):
        pattern = r"Question: (.*?) Short answer:"
        match = re.search(pattern, prompt, re.DOTALL)
        question = match.group(1)
    elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
        if prompt.startswith('Reference OCR token:'):
            question = prompt.split('\n')[1]
        else:
            question = prompt.split('\n')[0]
    elif len(prompt.split('\n')) == 2:
        question = prompt.split('\n')[0]
    else:
        assert False

    return question.lower()


class TextVQADataset(BaseEvalDataset):
    METAINFO: dict = dict(name='textvqa')

    def __init__(self, data_file, ann_file,
                 image_folder, image_processor,
                 pad_image_to_square=True, metainfo=None,):
        super().__init__(metainfo)
        self.data_file = data_file
        self.ann_file = ann_file
        self.image_folder = image_folder

        self.image_processor = BUILDER.build(image_processor)
        self.pad_image_to_square = pad_image_to_square
        self.name = os.path.splitext(os.path.basename(data_file))[0]
        self.results_path = os.path.splitext(os.path.basename(data_file))[0] + '-results.jsonl'
        self.data = self.load_data_list()

    def load_data_list(self):
        data = [json.loads(q) for q in open(os.path.expanduser(self.data_file), "r")]
        for i, d in enumerate(data):
            d['img_id'] = i
            d['image_path'] = d['image']
            d['question'] = d['text']
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        data_dict = custom_data_process(self, data)
        return data_dict

    @master_only
    def evaluate(self, result, work_dir, show=True):
        answers_file = osp.join(work_dir, self.results_path)
        ans_file = open(answers_file, "w")

        for pred_dict in result:
            idx = pred_dict["img_id"]
            gt_data = self.data[idx]

            ans_file.write(json.dumps({"question_id": gt_data['question_id'],
                                       "prompt": gt_data['text'],
                                       "text": pred_dict['prediction'],
                                       "metadata": {}}) + "\n")
        ans_file.close()

        annotations = json.load(open(self.ann_file))['data']
        annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in
                       annotations}
        results = [json.loads(line) for line in open(answers_file)]

        pred_list = []
        for result in results:
            annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
            pred_list.append({
                "pred_answer": result['text'],
                "gt_answers": annotation['answers'],
            })

        evaluator = TextVQAAccuracyEvaluator()
        acc = 100. * evaluator.eval_pred_list(pred_list)
        print_log('============================================', 'current')
        print_log('Samples: {}, Accuracy: {:.2f}%'.format(len(pred_list), acc), 'current')
        print_log('============================================', 'current')
        print_log(f'TextVQA successfully finished evaluating', 'current')
        return {'acc': acc}