Spaces:
Runtime error
Runtime error
import json | |
import os | |
import random | |
import numpy as np | |
from PIL import Image | |
from transformers import LlamaTokenizer | |
from .vqa_dataset import VQADataset, VQAPrompter | |
class TextOCRDataset(VQADataset): | |
def __init__( | |
self, tokenizer, vis_processor=None, vis_root=None, ann_paths=[], add_eos=True, ignore_instruction=True | |
): | |
""" | |
vis_root (string): Root directory of images (e.g. coco/images/) | |
ann_root (string): directory to store the annotation file | |
""" | |
assert tokenizer.add_eos_token is False, "tokenizer should not add eos token by default" | |
self.tokenizer: LlamaTokenizer = tokenizer | |
self.vis_root = vis_root | |
self.annotation = [] | |
for ann_path in ann_paths: | |
self.annotation.extend(json.load(open(ann_path, "r"))["data"]) | |
self.vis_processor = vis_processor | |
self._add_instance_ids() | |
self.option_prob = 0.5 | |
self.prompter = VQAPrompter() | |
self.add_eos = add_eos | |
self.ignore_instruction = ignore_instruction | |
def process_image(self, ann): | |
image_path = os.path.join(self.vis_root, ann["image_id"] + ".jpg") | |
image = Image.open(image_path).convert("RGB") | |
image = self.vis_processor(image) | |
return image | |
def process_text(self, ann): | |
question = ann["question"] | |
answer_weight = {} | |
for answer in ann["answers"]: | |
if answer in answer_weight.keys(): | |
answer_weight[answer] += 1 / len(ann["answers"]) | |
else: | |
answer_weight[answer] = 1 / len(ann["answers"]) | |
answers = list(answer_weight.keys()) | |
weights = list(answer_weight.values()) | |
# create instruction | |
true_answer = answers[np.argmax(weights)] | |
is_option = random.random() < self.option_prob and len(answers) > 1 | |
if is_option: | |
instruction = self.prompter(question, answers) | |
else: | |
instruction = self.prompter(question) | |
return dict(instruction=instruction, answer=true_answer) | |