tongyujun's picture
Upload 641 files
8c6b5ee verified
import sys
import os.path as osp
from torchvision.datasets import STL10
from dassl.utils import mkdir_if_missing
def extract_and_save_image(dataset, save_dir):
if osp.exists(save_dir):
print('Folder "{}" already exists'.format(save_dir))
return
print('Extracting images to "{}" ...'.format(save_dir))
mkdir_if_missing(save_dir)
for i in range(len(dataset)):
img, label = dataset[i]
if label == -1:
label_name = "none"
else:
label_name = str(label)
imname = str(i).zfill(6) + "_" + label_name + ".jpg"
impath = osp.join(save_dir, imname)
img.save(impath)
def download_and_prepare(root):
train = STL10(root, split="train", download=True)
test = STL10(root, split="test")
unlabeled = STL10(root, split="unlabeled")
train_dir = osp.join(root, "train")
test_dir = osp.join(root, "test")
unlabeled_dir = osp.join(root, "unlabeled")
extract_and_save_image(train, train_dir)
extract_and_save_image(test, test_dir)
extract_and_save_image(unlabeled, unlabeled_dir)
if __name__ == "__main__":
download_and_prepare(sys.argv[1])