RobustViT / robustness_dataset_per_class.py
Hila's picture
init commit
7754b29
raw history blame
No virus
2.11 kB
import json
from torchvision.datasets import ImageFolder
import torch
import os
from PIL import Image
import collections
import torchvision.transforms as transforms
from label_str_to_imagenet_classes import label_str_to_imagenet_classes
torch.manual_seed(0)
ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
class RobustnessDataset(ImageFolder):
def __init__(self, imagenet_path, folder, imagenet_classes_path='imagenet_classes.json', isV2=False, isSI=False):
self._isV2 = isV2
self._isSI = isSI
self._folder = folder
self._imagenet_path = imagenet_path
with open(imagenet_classes_path, 'r') as f:
self._imagenet_classes = json.load(f)
self._all_images = []
base_dir = os.path.join(self._imagenet_path, folder)
for i, file in enumerate(os.listdir(base_dir)):
self._all_images.append(ImageItem(file, folder))
def __getitem__(self, item):
image_item = self._all_images[item]
image_path = os.path.join(self._imagenet_path, image_item.tag, image_item.image_name)
image = Image.open(image_path)
image = image.convert('RGB')
image = transform(image)
if self._isV2:
class_name = int(image_item.tag)
elif self._isSI:
class_name = int(label_str_to_imagenet_classes[image_item.tag])
else:
class_name = int(self._imagenet_classes[image_item.tag])
return image, class_name
def __len__(self):
return len(self._all_images)
def get_classname(self):
if self._isV2:
class_name = int(self._folder)
elif self._isSI:
class_name = int(label_str_to_imagenet_classes[self._folder])
else:
class_name = int(self._imagenet_classes[self._folder])
return class_name