import json import os import random from collections import defaultdict from PIL import Image from .vqa_dataset import VQADataset class GQADataset(VQADataset): """Visual Reasoning Dataset.""" def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs): super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs) self.annotation = self.load_annotations(ann_paths) if self.sample_image: print("randomly sample one annotation for each image") self.annotation = self.parse_annotation(self.annotation) self._add_instance_ids() self.answer_prob = 1.0 @staticmethod def load_annotations(ann_paths): annotation = [] for ann_path in ann_paths: ann = json.load(open(ann_path, "r")) for k, v in ann.items(): v["question_id"] = k annotation.append(v) return annotation def parse_annotation(self, annotation): image_list = defaultdict(list) for ann in annotation: image_list[ann["imageId"]].append(ann) annotation = [] for ann_list in image_list.values(): annotation.append(random.choice(ann_list)) return annotation def process_text(self, ann): question = ann["question"] answer = ann["answer"] full_answer = ann["fullAnswer"] # TODO: check which one is better # Random select answer or full_answer if random.random() < self.answer_prob: select_answer = full_answer else: select_answer = answer instruction = self.prompter(question) return dict(instruction=instruction, answer=select_answer) def process_image(self, ann): image_path = os.path.join(self.vis_root, ann["imageId"] + ".jpg") image = Image.open(image_path).convert("RGB") image = self.vis_processor(image) return image def build_gqa_dataset( tokenizer, vis_processor, vis_root="data/gqa/images", ann_paths=[ "data/gqa/questions/train_all_questions/train_all_questions_0.json", "data/gqa/questions/val_all_questions.json", ], sample_image=False, ): return GQADataset( tokenizer=tokenizer, vis_processor=vis_processor, vis_root=vis_root, ann_paths=ann_paths, sample_image=sample_image, )