import json from torchvision.datasets import ImageFolder import torch import os from PIL import Image import collections import torchvision.transforms as transforms from label_str_to_imagenet_classes import label_str_to_imagenet_classes torch.manual_seed(0) ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag')) normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ]) class RobustnessDataset(ImageFolder): def __init__(self, imagenet_path, folder, imagenet_classes_path='imagenet_classes.json', isV2=False, isSI=False): self._isV2 = isV2 self._isSI = isSI self._folder = folder self._imagenet_path = imagenet_path with open(imagenet_classes_path, 'r') as f: self._imagenet_classes = json.load(f) self._all_images = [] base_dir = os.path.join(self._imagenet_path, folder) for i, file in enumerate(os.listdir(base_dir)): self._all_images.append(ImageItem(file, folder)) def __getitem__(self, item): image_item = self._all_images[item] image_path = os.path.join(self._imagenet_path, image_item.tag, image_item.image_name) image = Image.open(image_path) image = image.convert('RGB') image = transform(image) if self._isV2: class_name = int(image_item.tag) elif self._isSI: class_name = int(label_str_to_imagenet_classes[image_item.tag]) else: class_name = int(self._imagenet_classes[image_item.tag]) return image, class_name def __len__(self): return len(self._all_images) def get_classname(self): if self._isV2: class_name = int(self._folder) elif self._isSI: class_name = int(label_str_to_imagenet_classes[self._folder]) else: class_name = int(self._imagenet_classes[self._folder]) return class_name