Spaces:
Runtime error
Runtime error
import copy | |
import json | |
import os | |
import random | |
from collections import defaultdict | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from .vqa_dataset import VQADataset | |
QUESTIONS = [ | |
"Is this true?", | |
"Is this right?", | |
"Can you confirm this information?" "Do you agree with this statement?", | |
"Does this align with your understanding?", | |
"How do you interpret this information?", | |
"Does this align with your understanding?", | |
"Can you confirm this?", | |
"Is this statement correct?", | |
"Could you verify this information?", | |
"Do you agree with this?", | |
"Is this accurate?", | |
"Can you validate this claim?", | |
"Are these details valid?", | |
"Is this factually correct?", | |
"Is the following information correct?", | |
"Could you please verify this fact?", | |
"Do you agree with this assertion?", | |
"Are these details accurate?", | |
"Does this claim hold true?", | |
] | |
class NLVRv1Dataset(VQADataset): | |
"""Visual Reasoning Dataset.""" | |
def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs): | |
super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs) | |
self.annotation = self.load_annotations(ann_paths) | |
if self.sample_image: | |
print("randomly sample one annotation for each image") | |
self.annotation = self.parse_annotation(self.annotation) | |
self._add_instance_ids() | |
def load_annotations(ann_paths): | |
annotation = [] | |
for ann_path in ann_paths: | |
if "train.json" in ann_path: | |
split = "train" | |
elif "dev.json" in ann_path: | |
split = "dev" | |
elif "test.json" in ann_path: | |
split = "test" | |
else: | |
raise ValueError(f"Unknown split for {ann_path}") | |
with open(ann_path, "r") as f: | |
for line in f.readlines(): | |
line = line.strip() | |
if len(line) != 0: | |
ann = json.loads(line) | |
ann["split"] = split | |
annotation.append(ann) | |
return annotation | |
def parse_annotation(self, annotation): | |
image_list = defaultdict(list) | |
for ann in annotation: | |
img_key = f"{ann['split']}-{ann['identifier']}" | |
image_list[img_key].append(ann) | |
annotation = [] | |
for ann_list in image_list.values(): | |
annotation.append(random.choice(ann_list)) | |
return annotation | |
def process_text(self, ann): | |
question = ann["sentence"] + " " + random.choice(QUESTIONS) | |
true_answer = ann["label"] | |
if random.random() < self.option_prob: | |
instruction = self.prompter(question, ["true", "false"]) | |
else: | |
instruction = self.prompter(question) | |
return dict(instruction=instruction, answer=true_answer) | |
def process_image(self, ann): | |
# each question have 6 images, we can random select one of them. | |
# TODO: check whether using all 6 images? | |
random_id = random.randint(0, 5) | |
image_name = f"{ann['split']}-{ann['identifier']}-{random_id}.png" | |
image_path = os.path.join(self.vis_root, ann["split"], "images", ann["directory"], image_name) | |
image = Image.open(image_path).convert("RGB") | |
image = self.vis_processor(image) | |
return image | |
class NLVRv2Dataset(VQADataset): | |
"""Visual Reasoning Dataset.""" | |
def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs): | |
super().__init__(tokenizer, vis_processor, vis_root, ann_paths, **kwargs) | |
self.flip_prob = 0.5 | |
def parse_annotation(self, annotation): | |
image_list = defaultdict(list) | |
for ann in annotation: | |
image_list[ann["images"][0]].append(ann) | |
# image_name_list = list(image_list.keys()) | |
annotation = [] | |
for ann_list in image_list.values(): | |
annotation.append(random.choice(ann_list)) | |
return annotation | |
def process_text(self, ann): | |
question = ann["sentence"] + " " + random.choice(QUESTIONS) | |
true_answer = ann["label"] | |
if random.random() < self.option_prob: | |
instruction = self.prompter(question, ["true", "false"]) | |
else: | |
instruction = self.prompter(question) | |
return dict(instruction=instruction, answer=true_answer) | |
def process_image(self, ann): | |
image_0_path = os.path.join(self.vis_root, ann["images"][0]) | |
image_1_path = os.path.join(self.vis_root, ann["images"][1]) | |
image_0 = Image.open(image_0_path).convert("RGB") | |
image_1 = Image.open(image_1_path).convert("RGB") | |
image_0 = self.vis_processor(image_0) | |
image_1 = self.vis_processor(image_1) | |
return image_0, image_1 | |
def _flip(samples): | |
sentence = samples["sentence"] | |
image0, image1 = samples["image0"], samples["image1"] | |
if "left" not in sentence and "right" not in sentence: | |
if random.random() < 0.5: | |
image0, image1 = image1, image0 | |
else: | |
if random.random() < 0.5: | |
sentence = sentence.replace("left", "[TEMP_TOKEN]") | |
sentence = sentence.replace("right", "left") | |
sentence = sentence.replace("[TEMP_TOKEN]", "right") | |
image0, image1 = image1, image0 | |
samples["sentence"] = sentence | |
samples["image0"] = image0 | |
samples["image1"] = image1 | |
return samples | |
def __getitem__(self, index): | |
ann = copy.deepcopy(self.annotation[index]) | |
image_0, image_1 = self.process_image(ann) | |
if random.random() < self.flip_prob: | |
samples = self._flip({"sentence": ann["sentence"], "image0": image_0, "image1": image_1}) | |
image_0, image_1 = samples["image0"], samples["image1"] | |
ann["sentence"] = samples["sentence"] | |
# concat | |
# TODO: https://github.com/salesforce/LAVIS/blob/main/lavis/models/blip_models/blip_nlvr.py | |
# model logic need update if using nlvr2 | |
image = torch.cat([image_0, image_1], dim=2) | |
image = F.interpolate(image[None, ...], size=(image_0.shape[1], image_0.shape[2]))[0] | |
text = self.process_text(ann) | |
res = self.tokenize(text) | |
res.update(image=image) | |
res.update(text) | |
return res | |
def build_nlvrv1_dataset( | |
tokenizer, | |
vis_processor, | |
vis_root="data/nlvr", | |
ann_paths=["data/nlvr//train/train.json"], | |
sample_image=False, | |
): | |
return NLVRv1Dataset( | |
tokenizer=tokenizer, | |
vis_processor=vis_processor, | |
vis_root=vis_root, | |
ann_paths=ann_paths, | |
sample_image=sample_image, | |
) | |
def build_nlvrv2_dataset( | |
tokenizer, | |
vis_processor, | |
vis_root="data/nlvr2", | |
ann_paths=["data/nlvr2/annotations/nlvr_train.json"], | |
sample_image=False, | |
): | |
return NLVRv2Dataset( | |
tokenizer=tokenizer, | |
vis_processor=vis_processor, | |
vis_root=vis_root, | |
ann_paths=ann_paths, | |
sample_image=sample_image, | |
) | |