Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Module containing wrapper classes for PyTorch Datasets | |
Author: Shilpaj Bhalerao | |
Date: Jun 25, 2023 | |
""" | |
# Standard Library Imports | |
from typing import Tuple | |
# Third-Party Imports | |
from torchvision import datasets, transforms | |
class AlbumDataset(datasets.CIFAR10): | |
""" | |
Wrapper class to use albumentations library with PyTorch Dataset | |
""" | |
def __init__(self, root: str = "./data", train: bool = True, download: bool = True, transform: list = None): | |
""" | |
Constructor | |
:param root: Directory at which data is stored | |
:param train: Param to distinguish if data is training or test | |
:param download: Param to download the dataset from source | |
:param transform: List of transformation to be performed on the dataset | |
""" | |
super().__init__(root=root, train=train, download=download, transform=transform) | |
def __getitem__(self, index: int) -> Tuple: | |
""" | |
Method to return image and its label | |
:param index: Index of image and label in the dataset | |
""" | |
image, label = self.data[index], self.targets[index] | |
if self.transform: | |
transformed = self.transform(image=image) | |
image = transformed["image"] | |
return image, label | |