s12 / utils /dataset.py
srikanthp07's picture
Upload 27 files
9022436
import torch
from torchvision import datasets, transforms
from .transforms import test_transforms, train_transforms
class Cifar10SearchDataset(datasets.CIFAR10):
def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None):
super().__init__(root=root, train=train, download=download, transform=transform)
def __getitem__(self, index):
image, label = self.data[index], self.targets[index]
if self.transform is not None:
transformed = self.transform(image=image)
image = transformed["image"]
return image, label
def get_dataset():
# train_data = Cifar10SearchDataset(
# root='./data/cifar10', train=True, download=True, transform=train_transforms)
test_data = Cifar10SearchDataset(
root='./data/cifar10', train=False, download=True, transform=test_transforms)
return test_data