import json from typing import Callable, Optional import torch from torchvision.datasets import VisionDataset from torchvision.io import ImageReadMode, read_image from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize from torchvision.transforms.functional import InterpolationMode class Transform(torch.nn.Module): """ returns transformed version of the input image >>> preprocess = Transform(config.vision_config.image_size) >>> preprocess = torch.jit.script(preprocess) """ def __init__(self, image_size): super().__init__() self.transforms = torch.nn.Sequential( Resize([image_size], interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), ConvertImageDtype(torch.float), Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711), ), ) @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.transforms(x) class ImageTextDataset(VisionDataset): """ Dtaset for loading image-text data for tasks like CLIP training, Image Captioning. Args: root: (string): The root path where the dataset is stored file_path: (string): Path to the file containing the image_paths and associated captions. The expected format is jsonlines where each line is a json object containing to keys. `image_path`: The path to the image. `captions`: An `array` of captions. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.ToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. """ def __init__( self, root: str, file_path: str, captions_per_image=5, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, transforms: Optional[Callable] = None, ): super().__init__(root, transforms, transform, target_transform) with open(file_path, "r") as f: examples = json.load(f) self.captions = [] self.image_paths = [] for example in examples: captions = example["captions"][:captions_per_image] self.captions.extend(captions) self.image_paths.extend([example["file_path"]] * len(captions)) def _load_image(self, idx: int): path = self.image_paths[idx] return read_image(path, mode=ImageReadMode.RGB) def _load_target(self, idx): return self.captions[idx] def __getitem__(self, index: int): image = self._load_image(index) target = self._load_target(index) if self.transforms is not None: image, target = self.transforms(image, target) return image, target def __len__(self) -> int: return len(self.captions)