from PIL import Image | |
from torch.utils.data import Dataset | |
# We have to make a custom dataset class to load them with the torch DataLoader | |
# Custom dataset class for CIFAR-10 images | |
class CustomCIFAR10Dataset(Dataset): | |
def __init__(self, images, labels, transform=None): | |
self.images = images | |
self.labels = labels | |
self.transform = transform | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, index): | |
image = self.images[index] | |
label = self.labels[index] | |
# Apply the transformations (if any) | |
if self.transform is not None: | |
image = self.transform(image) | |
return image, label | |