din0s's picture
Add code
d4ab5ac unverified
from .image_classification import CIFAR10DataModule
from argparse import ArgumentParser
from functools import partial
from torch import LongTensor
from torch.utils.data import default_collate, random_split, Sampler
from torchvision import transforms
from torchvision.datasets import VisionDataset
from typing import Iterator, Optional
import itertools
import random
import torch
class CIFAR10QADataModule(CIFAR10DataModule):
@staticmethod
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = parent_parser.add_argument_group("Visual QA")
parser.add_argument(
"--class_idx",
type=int,
default=3,
help="The class (index) to count.",
)
parser.add_argument(
"--grid_size",
type=int,
default=3,
help="The number of images per row in the grid.",
)
return parent_parser
def __init__(
self,
class_idx: int,
grid_size: int = 3,
feature_extractor: callable = None,
data_dir: str = "data/",
batch_size: int = 32,
add_noise: bool = False,
add_rotation: bool = False,
add_blur: bool = False,
num_workers: int = 4,
):
"""A datamodule for a modified CIFAR10 dataset that is used for Question Answering.
More specifically, the task is to count the number of images of a certain class in a grid.
Args:
class_idx (int): the class (index) to count
grid_size (int): the number of images per row in the grid
feature_extractor (callable): a callable feature extractor instance
data_dir (str): the directory to store the dataset
batch_size (int): the batch size for the train/val/test dataloaders
add_noise (bool): whether to add noise to the images
add_rotation (bool): whether to add rotation augmentation
add_blur (bool): whether to add blur augmentation
num_workers (int): the number of workers to use for data loading
"""
super().__init__(
feature_extractor,
data_dir,
(grid_size**2) * batch_size,
add_noise,
add_rotation,
add_blur,
num_workers,
)
# Store hyperparameters
self.class_idx = class_idx
self.grid_size = grid_size
# Save the existing transformations to be applied after creating the grid
self.post_transform = self.transform
# Set the pre-batch transformation to be the conversion from PIL to tensor
self.transform = transforms.PILToTensor()
# Specify the custom collate function and samplers
self.collate_fn = self.custom_collate_fn
self.shuffled_sampler = partial(
FairGridSampler,
class_idx=class_idx,
grid_size=grid_size,
shuffle=True,
)
self.sequential_sampler = partial(
FairGridSampler,
class_idx=class_idx,
grid_size=grid_size,
shuffle=False,
)
def custom_collate_fn(self, batch):
# Split the batch into groups of grid_size**2
idx = range(len(batch))
grids = zip(*(iter(idx),) * (self.grid_size**2))
new_batch = []
for grid in grids:
# Create a grid of images from the indices in the batch
img = torch.hstack(
[
torch.dstack(
[batch[i][0] for i in grid[idx : idx + self.grid_size]]
)
for idx in range(
0, self.grid_size**2 - self.grid_size + 1, self.grid_size
)
]
)
# Apply the post transformations to the grid
img = self.post_transform(img)
# Define the target as the number of images that have the class_idx
targets = [batch[i][1] for i in grid]
target = targets.count(self.class_idx)
# Append grid and target to the batch
new_batch += [(img, target)]
return default_collate(new_batch)
class ToyQADataModule(CIFAR10QADataModule):
"""A datamodule for the toy dataset as described in the paper."""
def prepare_data(self):
# No need to download anything for the toy task
pass
def setup(self, stage: Optional[str] = None):
img_size = 16
samples = []
# Generate 6000 samples based on 6 different colors
for r, g, b in itertools.product((0, 1), (0, 1), (0, 1)):
if r == g == b:
# We do not want black/white patches
continue
for _ in range(1000):
patch = torch.vstack(
[
r * torch.ones(1, img_size, img_size),
g * torch.ones(1, img_size, img_size),
b * torch.ones(1, img_size, img_size),
]
)
# Assign a unique id to each color
target = int(f"{r}{g}{b}", 2) - 1
# Append the patch and target to the samples
samples += [(patch, target)]
# Split the data to 90% train, 5% validation and 5% test
train_size = int(len(samples) * 0.9)
val_size = (len(samples) - train_size) // 2
test_size = len(samples) - train_size - val_size
self.train_data, self.val_data, self.test_data = random_split(
samples,
[
train_size,
val_size,
test_size,
],
)
class FairGridSampler(Sampler[int]):
def __init__(
self,
dataset: VisionDataset,
class_idx: int,
grid_size: int,
shuffle: bool = False,
):
"""A sampler that returns a grid of images from the dataset, with a uniformly random
amount of appearances for a specific class of interest.
Args:
dataset (VisionDataset): the dataset to sample from
class_idx(int): the class (index) to treat as the class of interest
grid_size (int): the number of images per row in the grid
shuffle (bool): whether to shuffle the dataset before sampling
"""
super().__init__(dataset)
# Save the hyperparameters
self.dataset = dataset
self.grid_size = grid_size
self.n_images = grid_size**2
# Get the indices of the class of interest
self.class_indices = LongTensor(
[i for i, x in enumerate(dataset) if x[1] == class_idx]
)
# Get the indices of all other classes
self.other_indices = LongTensor(
[i for i, x in enumerate(dataset) if x[1] != class_idx]
)
# Fix the seed if shuffle is False
self.seed = None if shuffle else self._get_seed()
@staticmethod
def _get_seed() -> int:
"""Utility function for generating a random seed."""
return int(torch.empty((), dtype=torch.int64).random_().item())
def __iter__(self) -> Iterator[int]:
# Create a torch Generator object
seed = self.seed if self.seed is not None else self._get_seed()
gen = torch.Generator()
gen.manual_seed(seed)
# Sample the batches
for _ in range(len(self.dataset) // self.n_images):
# Pick the number of instances for the class of interest
n_samples = torch.randint(self.n_images + 1, (), generator=gen).item()
# Sample the indices from the class of interest
idx_from_class = torch.randperm(
len(self.class_indices),
generator=gen,
)[:n_samples]
# Sample the indices from the other classes
idx_from_other = torch.randperm(
len(self.other_indices),
generator=gen,
)[: self.n_images - n_samples]
# Concatenate the corresponding lists of patches to form a grid
grid = (
self.class_indices[idx_from_class].tolist()
+ self.other_indices[idx_from_other].tolist()
)
# Shuffle the order of the patches within the grid
random.shuffle(grid)
yield from grid
def __len__(self) -> int:
return len(self.dataset)