File size: 898 Bytes
9022436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from torchvision import datasets, transforms

from .transforms import test_transforms, train_transforms


class Cifar10SearchDataset(datasets.CIFAR10):
    def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None):
        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label

def get_dataset():
    # train_data = Cifar10SearchDataset(
    #     root='./data/cifar10', train=True, download=True, transform=train_transforms)
    test_data = Cifar10SearchDataset(
        root='./data/cifar10', train=False, download=True, transform=test_transforms)

    return test_data