dalle-mini / dalle_mini /dataset.py
Pedro Cuenca
* dalle_mini package with models and utilities:
150ed18
raw
history blame
4.78 kB
"""
An image-caption dataset dataloader.
Luke Melas-Kyriazi, 2021
"""
import warnings
from typing import Optional, Callable
from pathlib import Path
import numpy as np
import torch
import pandas as pd
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
from PIL import ImageFile
from PIL.Image import DecompressionBombWarning
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DecompressionBombWarning)
class CaptionDataset(Dataset):
"""
A PyTorch Dataset class for (image, texts) tasks. Note that this dataset
returns the raw text rather than tokens. This is done on purpose, because
it's easy to tokenize a batch of text after loading it from this dataset.
"""
def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None,
image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',
include_captions: bool = True):
"""
:param images_root: folder where images are stored
:param captions_path: path to csv that maps image filenames to captions
:param image_transform: image transform pipeline
:param text_transform: image transform pipeline
:param image_transform_type: image transform type, either `torchvision` or `albumentations`
:param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.
"""
# Base path for images
self.images_root = Path(images_root)
# Load captions as DataFrame
self.captions = pd.read_csv(captions_path, delimiter='\t', header=0)
self.captions['image_file'] = self.captions['image_file'].astype(str)
# PyTorch transformation pipeline for the image (normalizing, etc.)
self.text_transform = text_transform
self.image_transform = image_transform
self.image_transform_type = image_transform_type.lower()
assert self.image_transform_type in ['torchvision', 'albumentations']
# Total number of datapoints
self.size = len(self.captions)
# Return image+captions or just images
self.include_captions = include_captions
def verify_that_all_images_exist(self):
for image_file in self.captions['image_file']:
p = self.images_root / image_file
if not p.is_file():
print(f'file does not exist: {p}')
def _get_raw_image(self, i):
image_file = self.captions.iloc[i]['image_file']
image_path = self.images_root / image_file
image = default_loader(image_path)
return image
def _get_raw_text(self, i):
return self.captions.iloc[i]['caption']
def __getitem__(self, i):
image = self._get_raw_image(i)
caption = self._get_raw_text(i)
if self.image_transform is not None:
if self.image_transform_type == 'torchvision':
image = self.image_transform(image)
elif self.image_transform_type == 'albumentations':
image = self.image_transform(image=np.array(image))['image']
else:
raise NotImplementedError(f"{self.image_transform_type=}")
return {'image': image, 'text': caption} if self.include_captions else image
def __len__(self):
return self.size
if __name__ == "__main__":
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import AutoTokenizer
# Paths
images_root = './images'
captions_path = './images-list-clean.tsv'
# Create transforms
tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
def tokenize(text):
return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length')
image_transform = A.Compose([
A.Resize(256, 256), A.CenterCrop(256, 256),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
# Create dataset
dataset = CaptionDataset(
images_root=images_root,
captions_path=captions_path,
image_transform=image_transform,
text_transform=tokenize,
image_transform_type='albumentations')
# Create dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
batch = next(iter(dataloader))
print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()})
# # (Optional) Check that all the images exist
# dataset = CaptionDataset(images_root=images_root, captions_path=captions_path)
# dataset.verify_that_all_images_exist()
# print('Done')