RobustViT / tokencut_image_dataset.py
Hila's picture
init commit
7754b29
from torchvision.datasets import ImageFolder
import torch
import os
import collections
torch.manual_seed(0)
ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))
class RobustnessDataset(ImageFolder):
def __init__(self, dataset_path):
self._dataset_path = dataset_path
self._tag_list = [tag for tag in os.listdir(self._dataset_path)]
self._all_images = []
for tag in self._tag_list:
base_dir = os.path.join(self._dataset_path, tag)
for i, file in enumerate(os.listdir(base_dir)):
self._all_images.append(ImageItem(file, tag))
def __getitem__(self, item):
image_item = self._all_images[item]
image_path = os.path.join(self._dataset_path, image_item.tag, image_item.image_name)
return image_path
def __len__(self):
return len(self._all_images)