RobustViT / robustness_dataset_per_class.py
Hila's picture
init commit
7754b29
import json
from torchvision.datasets import ImageFolder
import torch
import os
from PIL import Image
import collections
import torchvision.transforms as transforms
from label_str_to_imagenet_classes import label_str_to_imagenet_classes
torch.manual_seed(0)
ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
class RobustnessDataset(ImageFolder):
def __init__(self, imagenet_path, folder, imagenet_classes_path='imagenet_classes.json', isV2=False, isSI=False):
self._isV2 = isV2
self._isSI = isSI
self._folder = folder
self._imagenet_path = imagenet_path
with open(imagenet_classes_path, 'r') as f:
self._imagenet_classes = json.load(f)
self._all_images = []
base_dir = os.path.join(self._imagenet_path, folder)
for i, file in enumerate(os.listdir(base_dir)):
self._all_images.append(ImageItem(file, folder))
def __getitem__(self, item):
image_item = self._all_images[item]
image_path = os.path.join(self._imagenet_path, image_item.tag, image_item.image_name)
image = Image.open(image_path)
image = image.convert('RGB')
image = transform(image)
if self._isV2:
class_name = int(image_item.tag)
elif self._isSI:
class_name = int(label_str_to_imagenet_classes[image_item.tag])
else:
class_name = int(self._imagenet_classes[image_item.tag])
return image, class_name
def __len__(self):
return len(self._all_images)
def get_classname(self):
if self._isV2:
class_name = int(self._folder)
elif self._isSI:
class_name = int(label_str_to_imagenet_classes[self._folder])
else:
class_name = int(self._imagenet_classes[self._folder])
return class_name