Spaces:
Runtime error
Runtime error
import json | |
import os | |
import random | |
from collections import defaultdict | |
from PIL import Image | |
from .vqa_dataset import VQADataset | |
class GQADataset(VQADataset): | |
"""Visual Reasoning Dataset.""" | |
def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs): | |
super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs) | |
self.annotation = self.load_annotations(ann_paths) | |
if self.sample_image: | |
print("randomly sample one annotation for each image") | |
self.annotation = self.parse_annotation(self.annotation) | |
self._add_instance_ids() | |
self.answer_prob = 1.0 | |
def load_annotations(ann_paths): | |
annotation = [] | |
for ann_path in ann_paths: | |
ann = json.load(open(ann_path, "r")) | |
for k, v in ann.items(): | |
v["question_id"] = k | |
annotation.append(v) | |
return annotation | |
def parse_annotation(self, annotation): | |
image_list = defaultdict(list) | |
for ann in annotation: | |
image_list[ann["imageId"]].append(ann) | |
annotation = [] | |
for ann_list in image_list.values(): | |
annotation.append(random.choice(ann_list)) | |
return annotation | |
def process_text(self, ann): | |
question = ann["question"] | |
answer = ann["answer"] | |
full_answer = ann["fullAnswer"] | |
# TODO: check which one is better | |
# Random select answer or full_answer | |
if random.random() < self.answer_prob: | |
select_answer = full_answer | |
else: | |
select_answer = answer | |
instruction = self.prompter(question) | |
return dict(instruction=instruction, answer=select_answer) | |
def process_image(self, ann): | |
image_path = os.path.join(self.vis_root, ann["imageId"] + ".jpg") | |
image = Image.open(image_path).convert("RGB") | |
image = self.vis_processor(image) | |
return image | |
def build_gqa_dataset( | |
tokenizer, | |
vis_processor, | |
vis_root="data/gqa/images", | |
ann_paths=[ | |
"data/gqa/questions/train_all_questions/train_all_questions_0.json", | |
"data/gqa/questions/val_all_questions.json", | |
], | |
sample_image=False, | |
): | |
return GQADataset( | |
tokenizer=tokenizer, | |
vis_processor=vis_processor, | |
vis_root=vis_root, | |
ann_paths=ann_paths, | |
sample_image=sample_image, | |
) | |