import sys import os.path as osp from torchvision.datasets import STL10 from dassl.utils import mkdir_if_missing def extract_and_save_image(dataset, save_dir): if osp.exists(save_dir): print('Folder "{}" already exists'.format(save_dir)) return print('Extracting images to "{}" ...'.format(save_dir)) mkdir_if_missing(save_dir) for i in range(len(dataset)): img, label = dataset[i] if label == -1: label_name = "none" else: label_name = str(label) imname = str(i).zfill(6) + "_" + label_name + ".jpg" impath = osp.join(save_dir, imname) img.save(impath) def download_and_prepare(root): train = STL10(root, split="train", download=True) test = STL10(root, split="test") unlabeled = STL10(root, split="unlabeled") train_dir = osp.join(root, "train") test_dir = osp.join(root, "test") unlabeled_dir = osp.join(root, "unlabeled") extract_and_save_image(train, train_dir) extract_and_save_image(test, test_dir) extract_and_save_image(unlabeled, unlabeled_dir) if __name__ == "__main__": download_and_prepare(sys.argv[1])