Spaces:
Runtime error
Runtime error
File size: 2,136 Bytes
03561be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import json
import os
import random
import numpy as np
from PIL import Image
from transformers import LlamaTokenizer
from .vqa_dataset import VQADataset, VQAPrompter
class TextOCRDataset(VQADataset):
def __init__(
self, tokenizer, vis_processor=None, vis_root=None, ann_paths=[], add_eos=True, ignore_instruction=True
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
assert tokenizer.add_eos_token is False, "tokenizer should not add eos token by default"
self.tokenizer: LlamaTokenizer = tokenizer
self.vis_root = vis_root
self.annotation = []
for ann_path in ann_paths:
self.annotation.extend(json.load(open(ann_path, "r"))["data"])
self.vis_processor = vis_processor
self._add_instance_ids()
self.option_prob = 0.5
self.prompter = VQAPrompter()
self.add_eos = add_eos
self.ignore_instruction = ignore_instruction
def process_image(self, ann):
image_path = os.path.join(self.vis_root, ann["image_id"] + ".jpg")
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
return image
def process_text(self, ann):
question = ann["question"]
answer_weight = {}
for answer in ann["answers"]:
if answer in answer_weight.keys():
answer_weight[answer] += 1 / len(ann["answers"])
else:
answer_weight[answer] = 1 / len(ann["answers"])
answers = list(answer_weight.keys())
weights = list(answer_weight.values())
# create instruction
true_answer = answers[np.argmax(weights)]
is_option = random.random() < self.option_prob and len(answers) > 1
if is_option:
instruction = self.prompter(question, answers)
else:
instruction = self.prompter(question)
return dict(instruction=instruction, answer=true_answer)
|