File size: 4,778 Bytes
150ed18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
An image-caption dataset dataloader.
Luke Melas-Kyriazi, 2021
"""
import warnings
from typing import Optional, Callable
from pathlib import Path
import numpy as np
import torch
import pandas as pd
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
from PIL import ImageFile
from PIL.Image import DecompressionBombWarning
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DecompressionBombWarning)


class CaptionDataset(Dataset):
    """
    A PyTorch Dataset class for (image, texts) tasks. Note that this dataset 
    returns the raw text rather than tokens. This is done on purpose, because
    it's easy to tokenize a batch of text after loading it from this dataset.
    """

    def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None, 
                 image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',
                 include_captions: bool = True):
        """
        :param images_root: folder where images are stored
        :param captions_path: path to csv that maps image filenames to captions
        :param image_transform: image transform pipeline
        :param text_transform: image transform pipeline
        :param image_transform_type: image transform type, either `torchvision` or `albumentations`
        :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.
        """

        # Base path for images
        self.images_root = Path(images_root)

        # Load captions as DataFrame
        self.captions = pd.read_csv(captions_path, delimiter='\t', header=0)
        self.captions['image_file'] = self.captions['image_file'].astype(str)

        # PyTorch transformation pipeline for the image (normalizing, etc.)
        self.text_transform = text_transform
        self.image_transform = image_transform
        self.image_transform_type = image_transform_type.lower()
        assert self.image_transform_type in ['torchvision', 'albumentations']

        # Total number of datapoints
        self.size = len(self.captions)

        # Return image+captions or just images
        self.include_captions = include_captions

    def verify_that_all_images_exist(self):
        for image_file in self.captions['image_file']:
            p = self.images_root / image_file
            if not p.is_file():
                print(f'file does not exist: {p}')

    def _get_raw_image(self, i):
        image_file = self.captions.iloc[i]['image_file']
        image_path = self.images_root / image_file
        image = default_loader(image_path)
        return image

    def _get_raw_text(self, i):
        return self.captions.iloc[i]['caption']

    def __getitem__(self, i):
        image = self._get_raw_image(i)
        caption = self._get_raw_text(i)
        if self.image_transform is not None:
            if self.image_transform_type == 'torchvision':
                image = self.image_transform(image)
            elif self.image_transform_type == 'albumentations':
                image = self.image_transform(image=np.array(image))['image']
            else:
                raise NotImplementedError(f"{self.image_transform_type=}")
        return {'image': image, 'text': caption} if self.include_captions else image

    def __len__(self):
        return self.size


if __name__ == "__main__":
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    from transformers import AutoTokenizer
    
    # Paths
    images_root = './images'
    captions_path = './images-list-clean.tsv'

    # Create transforms
    tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
    def tokenize(text):
        return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length')
    image_transform = A.Compose([
        A.Resize(256, 256), A.CenterCrop(256, 256),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
    
    # Create dataset
    dataset = CaptionDataset(
        images_root=images_root,
        captions_path=captions_path,
        image_transform=image_transform,
        text_transform=tokenize,
        image_transform_type='albumentations')

    # Create dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
    batch = next(iter(dataloader))
    print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()})

    # # (Optional) Check that all the images exist
    # dataset = CaptionDataset(images_root=images_root, captions_path=captions_path)
    # dataset.verify_that_all_images_exist()
    # print('Done')