File size: 4,457 Bytes
d1b8c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
import os
from typing import Callable, List, Tuple, Union

import torch

from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset


class VQADataset(Dataset):
    """
    Create the dataset for VQA task.

    Args:
        ann_file (List[str]): The paths to annotation json files.
        vqa_root (str): The path to vqa data directory.
        vg_root (str): The path to vg data directory.
        image_transform (Callable[[Image.Image], Tensor]): image data transform.
        question_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for questions.
        answer_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for answers.
        split (str): Indicates train or test. Default is train.
        answer_list (str): The path to the answers list. Required for test split.

    Dataset Outputs:
        if split is train:
            image (Tensor): Transformed image input tensor of shape (C, W, H).
            question (Tensor): Transformed question token input ids.
            answers (List[Tensor]): List of transformed answers token input ids.
            answer_weights (List[float]): List of answer weights.
                answer_weights[i] is proportional to the number of occurences of answers[i]
        if split is test:
            image (Tensor): Transformed image input tensor of shape (C, W, H).
            question (Tensor): Transformed text token input ids.
            question_id (int): The question sample id.
    """

    def __init__(
        self,
        ann_file: List[str],
        vqa_root: str,
        vg_root: str,
        image_transform: Callable[[Image.Image], Tensor],
        question_transform: Callable[[Union[List[str], str]], Tensor],
        answer_transform: Callable[[Union[List[str], str]], Tensor],
        split: str = "train",
        answer_list: str = None,
    ) -> None:
        self.ann = []
        for f in ann_file:
            self.ann += json.load(open(f, "r"))

        self.vqa_root = vqa_root
        self.vg_root = vg_root
        self.image_transform = image_transform
        self.question_transform = question_transform
        self.answer_transform = answer_transform
        self.split = split

        if split == "test":
            self.answer_list = json.load(open(answer_list, "r"))
            self.answer_input_ids = self.answer_transform(self.answer_list)
            self.answer_attention_mask = (self.answer_input_ids != 0).type(torch.long)

    def __len__(self) -> int:
        return len(self.ann)

    def __getitem__(
        self, index: int
    ) -> Union[
        Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, List[Tensor], List[float]]
    ]:
        ann = self.ann[index]

        image_root = self.vqa_root if ann["dataset"] == "vqa" else self.vg_root
        image_path = os.path.join(image_root, ann["image"])
        image = Image.open(image_path).convert("RGB")
        image = self.image_transform(image)
        question = self.question_transform(ann["question"])

        if self.split == "test":
            return image, question, ann["question_id"]

        elif self.split == "train":
            if ann["dataset"] == "vqa":
                # Each VQA sample question has a list of answers (with potential repeats)
                # answer_weight[answer] = count(answer) / len(answers for the question)
                answer_weights = {}
                for answer in ann["answer"]:
                    if answer in answer_weights.keys():
                        answer_weights[answer] += 1 / len(ann["answer"])
                    else:
                        answer_weights[answer] = 1 / len(ann["answer"])

                answers = list(answer_weights.keys())
                answer_weights = list(answer_weights.values())

            elif ann["dataset"] == "vg":
                # A VG sample question has one answer so assign it a constant weight (0.5)
                answers = [ann["answer"]]
                answer_weights = [0.5]

            answers = list(self.answer_transform(answers))

            return image, question, answers, answer_weights

        else:
            raise ValueError("dataset split should be train or test")