Multimodal-GPT / mmgpt /datasets /gqa_dataset.py
devroop's picture
Upload folder using huggingface_hub
03561be verified
import json
import os
import random
from collections import defaultdict
from PIL import Image
from .vqa_dataset import VQADataset
class GQADataset(VQADataset):
"""Visual Reasoning Dataset."""
def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs)
self.annotation = self.load_annotations(ann_paths)
if self.sample_image:
print("randomly sample one annotation for each image")
self.annotation = self.parse_annotation(self.annotation)
self._add_instance_ids()
self.answer_prob = 1.0
@staticmethod
def load_annotations(ann_paths):
annotation = []
for ann_path in ann_paths:
ann = json.load(open(ann_path, "r"))
for k, v in ann.items():
v["question_id"] = k
annotation.append(v)
return annotation
def parse_annotation(self, annotation):
image_list = defaultdict(list)
for ann in annotation:
image_list[ann["imageId"]].append(ann)
annotation = []
for ann_list in image_list.values():
annotation.append(random.choice(ann_list))
return annotation
def process_text(self, ann):
question = ann["question"]
answer = ann["answer"]
full_answer = ann["fullAnswer"]
# TODO: check which one is better
# Random select answer or full_answer
if random.random() < self.answer_prob:
select_answer = full_answer
else:
select_answer = answer
instruction = self.prompter(question)
return dict(instruction=instruction, answer=select_answer)
def process_image(self, ann):
image_path = os.path.join(self.vis_root, ann["imageId"] + ".jpg")
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
return image
def build_gqa_dataset(
tokenizer,
vis_processor,
vis_root="data/gqa/images",
ann_paths=[
"data/gqa/questions/train_all_questions/train_all_questions_0.json",
"data/gqa/questions/val_all_questions.json",
],
sample_image=False,
):
return GQADataset(
tokenizer=tokenizer,
vis_processor=vis_processor,
vis_root=vis_root,
ann_paths=ann_paths,
sample_image=sample_image,
)