Spaces:
Runtime error
Runtime error
import os | |
import tempfile | |
import torchvision | |
from tqdm.auto import tqdm | |
CLASSES = ( | |
"plane", | |
"car", | |
"bird", | |
"cat", | |
"deer", | |
"dog", | |
"frog", | |
"horse", | |
"ship", | |
"truck", | |
) | |
def main(): | |
for split in ["train", "test"]: | |
out_dir = f"cifar_{split}" | |
if os.path.exists(out_dir): | |
print(f"skipping split {split} since {out_dir} already exists.") | |
continue | |
print("downloading...") | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
dataset = torchvision.datasets.CIFAR10( | |
root=tmp_dir, train=split == "train", download=True | |
) | |
print("dumping images...") | |
os.mkdir(out_dir) | |
for i in tqdm(range(len(dataset))): | |
image, label = dataset[i] | |
filename = os.path.join(out_dir, f"{CLASSES[label]}_{i:05d}.png") | |
image.save(filename) | |
if __name__ == "__main__": | |
main() | |