|
from torchvision.datasets import ImageFolder |
|
import torch |
|
import os |
|
import collections |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag')) |
|
|
|
class RobustnessDataset(ImageFolder): |
|
def __init__(self, dataset_path): |
|
self._dataset_path = dataset_path |
|
self._tag_list = [tag for tag in os.listdir(self._dataset_path)] |
|
self._all_images = [] |
|
for tag in self._tag_list: |
|
base_dir = os.path.join(self._dataset_path, tag) |
|
for i, file in enumerate(os.listdir(base_dir)): |
|
self._all_images.append(ImageItem(file, tag)) |
|
|
|
def __getitem__(self, item): |
|
image_item = self._all_images[item] |
|
image_path = os.path.join(self._dataset_path, image_item.tag, image_item.image_name) |
|
return image_path |
|
|
|
def __len__(self): |
|
return len(self._all_images) |