import json import os import random from collections import defaultdict from PIL import Image from .vqa_dataset import VQADataset class CLEVRDataset(VQADataset): """Visual Reasoning Dataset. It also contains Dialog. Note: The image is a little bit simple. with several objects and simple background. """ def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs): super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs) self.annotation = self.load_annotations(ann_paths) if self.sample_image: print("randomly sample one annotation for each image") self.annotation = self.parse_annotation(self.annotation) self._add_instance_ids() @staticmethod def load_annotations(ann_paths): annotation = [] for ann_path in ann_paths: ann = json.load(open(ann_path, "r")) annotation.extend(ann["questions"]) return annotation def parse_annotation(self, annotation): image_list = defaultdict(list) for ann in annotation: image_list[ann["image_filename"]].append(ann) annotation = [] for ann_list in image_list.values(): annotation.append(random.choice(ann_list)) return annotation def process_text(self, ann): question = ann["question"] answer = ann["answer"] instruction = self.prompter(question) return dict(instruction=instruction, answer=answer) def process_image(self, ann): split = ann["split"] image_path = os.path.join(self.vis_root, split, ann["image_filename"]) image = Image.open(image_path).convert("RGB") image = self.vis_processor(image) return image def build_clevr_dataset( tokenizer, vis_processor, vis_root="data/clevr/CLEVR_v1.0/images", ann_paths=[ "data/clevr/CLEVR_v1.0/questions/CLEVR_train_questions.json", "data/clevr/CLEVR_v1.0/questions/CLEVR_val_questions.json", ], sample_image=False, ): return CLEVRDataset( tokenizer=tokenizer, vis_processor=vis_processor, vis_root=vis_root, ann_paths=ann_paths, sample_image=sample_image, )