|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Convert dataset to HDF5 |
|
This script preprocesses a dataset and saves it (images and labels) to |
|
an HDF5 file for improved I/O. """ |
|
import os |
|
import sys |
|
from argparse import ArgumentParser |
|
from tqdm import tqdm, trange |
|
import h5py as h5 |
|
|
|
import numpy as np |
|
import torch |
|
import torchvision.datasets as dset |
|
import torchvision.transforms as transforms |
|
from torchvision.utils import save_image |
|
import torchvision.transforms as transforms |
|
from torch.utils.data import DataLoader |
|
|
|
import utils |
|
|
|
|
|
def prepare_parser(): |
|
usage = "Parser for ImageNet HDF5 scripts." |
|
parser = ArgumentParser(description=usage) |
|
parser.add_argument( |
|
"--resolution", |
|
type=int, |
|
default=128, |
|
help="Which Dataset resolution to train on, out of 64, 128, 256, 512 (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--split", |
|
type=str, |
|
default="train", |
|
help="Which Dataset to convert: train, val (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--data_root", |
|
type=str, |
|
default="data", |
|
help="Default location where data is stored (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--out_path", |
|
type=str, |
|
default="data", |
|
help="Default location where data in hdf5 format will be stored (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--longtail", |
|
action="store_true", |
|
default=False, |
|
help="Use long-tail version of the dataset", |
|
) |
|
parser.add_argument( |
|
"--batch_size", |
|
type=int, |
|
default=256, |
|
help="Default overall batchsize (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--num_workers", |
|
type=int, |
|
default=16, |
|
help="Number of dataloader workers (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--chunk_size", |
|
type=int, |
|
default=500, |
|
help="Default overall batchsize (default: %(default)s)", |
|
) |
|
parser.add_argument( |
|
"--compression", |
|
action="store_true", |
|
default=False, |
|
help="Use LZF compression? (default: %(default)s)", |
|
) |
|
return parser |
|
|
|
|
|
def run(config): |
|
|
|
|
|
|
|
config["compression"] = ( |
|
"lzf" if config["compression"] else None |
|
) |
|
|
|
|
|
kwargs = { |
|
"num_workers": config["num_workers"], |
|
"pin_memory": False, |
|
"drop_last": False, |
|
} |
|
dataset = utils.get_dataset_images( |
|
config["resolution"], |
|
data_path=os.path.join(config["data_root"], config["split"]), |
|
longtail=config["longtail"], |
|
) |
|
train_loader = utils.get_dataloader( |
|
dataset, config["batch_size"], shuffle=False, **kwargs |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print( |
|
"Starting to load dataset into an HDF5 file with chunk size %i and compression %s..." |
|
% (config["chunk_size"], config["compression"]) |
|
) |
|
|
|
for i, (x, y) in enumerate(tqdm(train_loader)): |
|
|
|
x = (255 * ((x + 1) / 2.0)).byte().numpy() |
|
|
|
y = y.numpy() |
|
|
|
if i == 0: |
|
with h5.File( |
|
config["out_path"] |
|
+ "/ILSVRC%i%s_xy.hdf5" |
|
% (config["resolution"], "" if not config["longtail"] else "longtail"), |
|
"w", |
|
) as f: |
|
print("Producing dataset of len %d" % len(train_loader.dataset)) |
|
imgs_dset = f.create_dataset( |
|
"imgs", |
|
x.shape, |
|
dtype="uint8", |
|
maxshape=( |
|
len(train_loader.dataset), |
|
3, |
|
config["resolution"], |
|
config["resolution"], |
|
), |
|
chunks=( |
|
config["chunk_size"], |
|
3, |
|
config["resolution"], |
|
config["resolution"], |
|
), |
|
compression=config["compression"], |
|
) |
|
print("Image chunks chosen as " + str(imgs_dset.chunks)) |
|
imgs_dset[...] = x |
|
labels_dset = f.create_dataset( |
|
"labels", |
|
y.shape, |
|
dtype="int64", |
|
maxshape=(len(train_loader.dataset),), |
|
chunks=(config["chunk_size"],), |
|
compression=config["compression"], |
|
) |
|
print("Label chunks chosen as " + str(labels_dset.chunks)) |
|
labels_dset[...] = y |
|
|
|
else: |
|
with h5.File( |
|
config["out_path"] |
|
+ "/ILSVRC%i%s_xy.hdf5" |
|
% (config["resolution"], "" if not config["longtail"] else "longtail"), |
|
"a", |
|
) as f: |
|
f["imgs"].resize(f["imgs"].shape[0] + x.shape[0], axis=0) |
|
f["imgs"][-x.shape[0] :] = x |
|
f["labels"].resize(f["labels"].shape[0] + y.shape[0], axis=0) |
|
f["labels"][-y.shape[0] :] = y |
|
|
|
|
|
def main(): |
|
|
|
parser = prepare_parser() |
|
config = vars(parser.parse_args()) |
|
print(config) |
|
run(config) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|