|
import json |
|
from typing import Callable, Optional |
|
|
|
import torch |
|
from torchvision.datasets import VisionDataset |
|
from torchvision.io import ImageReadMode, read_image |
|
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
|
|
class Transform(torch.nn.Module): |
|
""" |
|
returns transformed version of the input image |
|
>>> preprocess = Transform(config.vision_config.image_size) |
|
>>> preprocess = torch.jit.script(preprocess) |
|
""" |
|
|
|
def __init__(self, image_size): |
|
super().__init__() |
|
self.transforms = torch.nn.Sequential( |
|
Resize([image_size], interpolation=InterpolationMode.BICUBIC), |
|
CenterCrop(image_size), |
|
ConvertImageDtype(torch.float), |
|
Normalize( |
|
(0.48145466, 0.4578275, 0.40821073), |
|
(0.26862954, 0.26130258, 0.27577711), |
|
), |
|
) |
|
|
|
@torch.no_grad() |
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.transforms(x) |
|
|
|
|
|
class ImageTextDataset(VisionDataset): |
|
""" |
|
Dtaset for loading image-text data for tasks like CLIP training, Image Captioning. |
|
Args: |
|
root: (string): The root path where the dataset is stored |
|
file_path: (string): Path to the file containing the image_paths and associated captions. |
|
The expected format is jsonlines where each line is a json object containing to keys. |
|
`image_path`: The path to the image. |
|
`captions`: An `array` of captions. |
|
transform (callable, optional): A function/transform that takes in an PIL image |
|
and returns a transformed version. E.g, ``transforms.ToTensor`` |
|
target_transform (callable, optional): A function/transform that takes in the |
|
target and transforms it. |
|
transforms (callable, optional): A function/transform that takes input sample and its target as entry |
|
and returns a transformed version. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
root: str, |
|
file_path: str, |
|
captions_per_image=5, |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
transforms: Optional[Callable] = None, |
|
): |
|
super().__init__(root, transforms, transform, target_transform) |
|
|
|
with open(file_path, "r") as f: |
|
examples = json.load(f) |
|
|
|
self.captions = [] |
|
self.image_paths = [] |
|
|
|
for example in examples: |
|
captions = example["captions"][:captions_per_image] |
|
self.captions.extend(captions) |
|
self.image_paths.extend([example["file_path"]] * len(captions)) |
|
|
|
def _load_image(self, idx: int): |
|
path = self.image_paths[idx] |
|
return read_image(path, mode=ImageReadMode.RGB) |
|
|
|
def _load_target(self, idx): |
|
return self.captions[idx] |
|
|
|
def __getitem__(self, index: int): |
|
image = self._load_image(index) |
|
target = self._load_target(index) |
|
|
|
if self.transforms is not None: |
|
image, target = self.transforms(image, target) |
|
|
|
return image, target |
|
|
|
def __len__(self) -> int: |
|
return len(self.captions) |
|
|