Spaces:
Runtime error
Runtime error
File size: 4,457 Bytes
d1b8c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
from typing import Callable, List, Tuple, Union
import torch
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
class VQADataset(Dataset):
"""
Create the dataset for VQA task.
Args:
ann_file (List[str]): The paths to annotation json files.
vqa_root (str): The path to vqa data directory.
vg_root (str): The path to vg data directory.
image_transform (Callable[[Image.Image], Tensor]): image data transform.
question_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for questions.
answer_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for answers.
split (str): Indicates train or test. Default is train.
answer_list (str): The path to the answers list. Required for test split.
Dataset Outputs:
if split is train:
image (Tensor): Transformed image input tensor of shape (C, W, H).
question (Tensor): Transformed question token input ids.
answers (List[Tensor]): List of transformed answers token input ids.
answer_weights (List[float]): List of answer weights.
answer_weights[i] is proportional to the number of occurences of answers[i]
if split is test:
image (Tensor): Transformed image input tensor of shape (C, W, H).
question (Tensor): Transformed text token input ids.
question_id (int): The question sample id.
"""
def __init__(
self,
ann_file: List[str],
vqa_root: str,
vg_root: str,
image_transform: Callable[[Image.Image], Tensor],
question_transform: Callable[[Union[List[str], str]], Tensor],
answer_transform: Callable[[Union[List[str], str]], Tensor],
split: str = "train",
answer_list: str = None,
) -> None:
self.ann = []
for f in ann_file:
self.ann += json.load(open(f, "r"))
self.vqa_root = vqa_root
self.vg_root = vg_root
self.image_transform = image_transform
self.question_transform = question_transform
self.answer_transform = answer_transform
self.split = split
if split == "test":
self.answer_list = json.load(open(answer_list, "r"))
self.answer_input_ids = self.answer_transform(self.answer_list)
self.answer_attention_mask = (self.answer_input_ids != 0).type(torch.long)
def __len__(self) -> int:
return len(self.ann)
def __getitem__(
self, index: int
) -> Union[
Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, List[Tensor], List[float]]
]:
ann = self.ann[index]
image_root = self.vqa_root if ann["dataset"] == "vqa" else self.vg_root
image_path = os.path.join(image_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.image_transform(image)
question = self.question_transform(ann["question"])
if self.split == "test":
return image, question, ann["question_id"]
elif self.split == "train":
if ann["dataset"] == "vqa":
# Each VQA sample question has a list of answers (with potential repeats)
# answer_weight[answer] = count(answer) / len(answers for the question)
answer_weights = {}
for answer in ann["answer"]:
if answer in answer_weights.keys():
answer_weights[answer] += 1 / len(ann["answer"])
else:
answer_weights[answer] = 1 / len(ann["answer"])
answers = list(answer_weights.keys())
answer_weights = list(answer_weights.values())
elif ann["dataset"] == "vg":
# A VG sample question has one answer so assign it a constant weight (0.5)
answers = [ann["answer"]]
answer_weights = [0.5]
answers = list(self.answer_transform(answers))
return image, question, answers, answer_weights
else:
raise ValueError("dataset split should be train or test")
|