File size: 2,136 Bytes
03561be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import json
import os
import random

import numpy as np
from PIL import Image
from transformers import LlamaTokenizer

from .vqa_dataset import VQADataset, VQAPrompter


class TextOCRDataset(VQADataset):
    def __init__(

        self, tokenizer, vis_processor=None, vis_root=None, ann_paths=[], add_eos=True, ignore_instruction=True

    ):
        """

        vis_root (string): Root directory of images (e.g. coco/images/)

        ann_root (string): directory to store the annotation file

        """
        assert tokenizer.add_eos_token is False, "tokenizer should not add eos token by default"
        self.tokenizer: LlamaTokenizer = tokenizer
        self.vis_root = vis_root

        self.annotation = []
        for ann_path in ann_paths:
            self.annotation.extend(json.load(open(ann_path, "r"))["data"])

        self.vis_processor = vis_processor

        self._add_instance_ids()
        self.option_prob = 0.5
        self.prompter = VQAPrompter()
        self.add_eos = add_eos
        self.ignore_instruction = ignore_instruction

    def process_image(self, ann):
        image_path = os.path.join(self.vis_root, ann["image_id"] + ".jpg")
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)
        return image

    def process_text(self, ann):
        question = ann["question"]

        answer_weight = {}
        for answer in ann["answers"]:
            if answer in answer_weight.keys():
                answer_weight[answer] += 1 / len(ann["answers"])
            else:
                answer_weight[answer] = 1 / len(ann["answers"])

        answers = list(answer_weight.keys())
        weights = list(answer_weight.values())

        # create instruction
        true_answer = answers[np.argmax(weights)]
        is_option = random.random() < self.option_prob and len(answers) > 1
        if is_option:
            instruction = self.prompter(question, answers)
        else:
            instruction = self.prompter(question)

        return dict(instruction=instruction, answer=true_answer)