Spaces:
Runtime error
Runtime error
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |
# | |
# This work is licensed under the Creative Commons Attribution-NonCommercial | |
# 4.0 International License. To view a copy of this license, visit | |
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | |
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | |
"""Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN.""" | |
# pylint: disable=too-many-lines | |
import os | |
import sys | |
import glob | |
import argparse | |
import threading | |
import six.moves.queue as Queue # pylint: disable=import-error | |
import traceback | |
import numpy as np | |
import tensorflow as tf | |
import PIL.Image | |
import dnnlib.tflib as tflib | |
from training import dataset | |
#---------------------------------------------------------------------------- | |
def error(msg): | |
print('Error: ' + msg) | |
exit(1) | |
#---------------------------------------------------------------------------- | |
class TFRecordExporter: | |
def __init__(self, tfrecord_dir, expected_images, print_progress=True, progress_interval=10): | |
self.tfrecord_dir = tfrecord_dir | |
self.tfr_prefix = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir)) | |
self.expected_images = expected_images | |
self.cur_images = 0 | |
self.shape = None | |
self.resolution_log2 = None | |
self.tfr_writers = [] | |
self.print_progress = print_progress | |
self.progress_interval = progress_interval | |
if self.print_progress: | |
print('Creating dataset "%s"' % tfrecord_dir) | |
if not os.path.isdir(self.tfrecord_dir): | |
os.makedirs(self.tfrecord_dir) | |
assert os.path.isdir(self.tfrecord_dir) | |
def close(self): | |
if self.print_progress: | |
print('%-40s\r' % 'Flushing data...', end='', flush=True) | |
for tfr_writer in self.tfr_writers: | |
tfr_writer.close() | |
self.tfr_writers = [] | |
if self.print_progress: | |
print('%-40s\r' % '', end='', flush=True) | |
print('Added %d images.' % self.cur_images) | |
def choose_shuffled_order(self): # Note: Images and labels must be added in shuffled order. | |
order = np.arange(self.expected_images) | |
np.random.RandomState(123).shuffle(order) | |
return order | |
def add_image(self, img): | |
if self.print_progress and self.cur_images % self.progress_interval == 0: | |
print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True) | |
if self.shape is None: | |
self.shape = img.shape | |
self.resolution_log2 = int(np.log2(self.shape[1])) | |
assert self.shape[0] in [1, 3] | |
assert self.shape[1] == self.shape[2] | |
assert self.shape[1] == 2**self.resolution_log2 | |
tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) | |
for lod in range(self.resolution_log2 - 1): | |
tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod) | |
self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt)) | |
assert img.shape == self.shape | |
for lod, tfr_writer in enumerate(self.tfr_writers): | |
if lod: | |
img = img.astype(np.float32) | |
img = (img[:, 0::2, 0::2] + img[:, 0::2, 1::2] + img[:, 1::2, 0::2] + img[:, 1::2, 1::2]) * 0.25 | |
quant = np.rint(img).clip(0, 255).astype(np.uint8) | |
ex = tf.train.Example(features=tf.train.Features(feature={ | |
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=quant.shape)), | |
'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()]))})) | |
tfr_writer.write(ex.SerializeToString()) | |
self.cur_images += 1 | |
def add_labels(self, labels): | |
if self.print_progress: | |
print('%-40s\r' % 'Saving labels...', end='', flush=True) | |
assert labels.shape[0] == self.cur_images | |
with open(self.tfr_prefix + '-rxx.labels', 'wb') as f: | |
np.save(f, labels.astype(np.float32)) | |
def __enter__(self): | |
return self | |
def __exit__(self, *args): | |
self.close() | |
#---------------------------------------------------------------------------- | |
class ExceptionInfo(object): | |
def __init__(self): | |
self.value = sys.exc_info()[1] | |
self.traceback = traceback.format_exc() | |
#---------------------------------------------------------------------------- | |
class WorkerThread(threading.Thread): | |
def __init__(self, task_queue): | |
threading.Thread.__init__(self) | |
self.task_queue = task_queue | |
def run(self): | |
while True: | |
func, args, result_queue = self.task_queue.get() | |
if func is None: | |
break | |
try: | |
result = func(*args) | |
except: | |
result = ExceptionInfo() | |
result_queue.put((result, args)) | |
#---------------------------------------------------------------------------- | |
class ThreadPool(object): | |
def __init__(self, num_threads): | |
assert num_threads >= 1 | |
self.task_queue = Queue.Queue() | |
self.result_queues = dict() | |
self.num_threads = num_threads | |
for _idx in range(self.num_threads): | |
thread = WorkerThread(self.task_queue) | |
thread.daemon = True | |
thread.start() | |
def add_task(self, func, args=()): | |
assert hasattr(func, '__call__') # must be a function | |
if func not in self.result_queues: | |
self.result_queues[func] = Queue.Queue() | |
self.task_queue.put((func, args, self.result_queues[func])) | |
def get_result(self, func): # returns (result, args) | |
result, args = self.result_queues[func].get() | |
if isinstance(result, ExceptionInfo): | |
print('\n\nWorker thread caught an exception:\n' + result.traceback) | |
raise result.value | |
return result, args | |
def finish(self): | |
for _idx in range(self.num_threads): | |
self.task_queue.put((None, (), None)) | |
def __enter__(self): # for 'with' statement | |
return self | |
def __exit__(self, *excinfo): | |
self.finish() | |
def process_items_concurrently(self, item_iterator, process_func=lambda x: x, pre_func=lambda x: x, post_func=lambda x: x, max_items_in_flight=None): | |
if max_items_in_flight is None: max_items_in_flight = self.num_threads * 4 | |
assert max_items_in_flight >= 1 | |
results = [] | |
retire_idx = [0] | |
def task_func(prepared, _idx): | |
return process_func(prepared) | |
def retire_result(): | |
processed, (_prepared, idx) = self.get_result(task_func) | |
results[idx] = processed | |
while retire_idx[0] < len(results) and results[retire_idx[0]] is not None: | |
yield post_func(results[retire_idx[0]]) | |
results[retire_idx[0]] = None | |
retire_idx[0] += 1 | |
for idx, item in enumerate(item_iterator): | |
prepared = pre_func(item) | |
results.append(None) | |
self.add_task(func=task_func, args=(prepared, idx)) | |
while retire_idx[0] < idx - max_items_in_flight + 2: | |
for res in retire_result(): yield res | |
while retire_idx[0] < len(results): | |
for res in retire_result(): yield res | |
#---------------------------------------------------------------------------- | |
def display(tfrecord_dir): | |
print('Loading dataset "%s"' % tfrecord_dir) | |
tflib.init_tf({'gpu_options.allow_growth': True}) | |
dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size='full', repeat=False, shuffle_mb=0) | |
tflib.init_uninitialized_vars() | |
import cv2 # pip install opencv-python | |
idx = 0 | |
while True: | |
try: | |
images, labels = dset.get_minibatch_np(1) | |
except tf.errors.OutOfRangeError: | |
break | |
if idx == 0: | |
print('Displaying images') | |
cv2.namedWindow('dataset_tool') | |
print('Press SPACE or ENTER to advance, ESC to exit') | |
print('\nidx = %-8d\nlabel = %s' % (idx, labels[0].tolist())) | |
cv2.imshow('dataset_tool', images[0].transpose(1, 2, 0)[:, :, ::-1]) # CHW => HWC, RGB => BGR | |
idx += 1 | |
if cv2.waitKey() == 27: | |
break | |
print('\nDisplayed %d images.' % idx) | |
#---------------------------------------------------------------------------- | |
def extract(tfrecord_dir, output_dir): | |
print('Loading dataset "%s"' % tfrecord_dir) | |
tflib.init_tf({'gpu_options.allow_growth': True}) | |
dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size=0, repeat=False, shuffle_mb=0) | |
tflib.init_uninitialized_vars() | |
print('Extracting images to "%s"' % output_dir) | |
if not os.path.isdir(output_dir): | |
os.makedirs(output_dir) | |
idx = 0 | |
while True: | |
if idx % 10 == 0: | |
print('%d\r' % idx, end='', flush=True) | |
try: | |
images, _labels = dset.get_minibatch_np(1) | |
except tf.errors.OutOfRangeError: | |
break | |
if images.shape[1] == 1: | |
img = PIL.Image.fromarray(images[0][0], 'L') | |
else: | |
img = PIL.Image.fromarray(images[0].transpose(1, 2, 0), 'RGB') | |
img.save(os.path.join(output_dir, 'img%08d.png' % idx)) | |
idx += 1 | |
print('Extracted %d images.' % idx) | |
#---------------------------------------------------------------------------- | |
def compare(tfrecord_dir_a, tfrecord_dir_b, ignore_labels): | |
max_label_size = 0 if ignore_labels else 'full' | |
print('Loading dataset "%s"' % tfrecord_dir_a) | |
tflib.init_tf({'gpu_options.allow_growth': True}) | |
dset_a = dataset.TFRecordDataset(tfrecord_dir_a, max_label_size=max_label_size, repeat=False, shuffle_mb=0) | |
print('Loading dataset "%s"' % tfrecord_dir_b) | |
dset_b = dataset.TFRecordDataset(tfrecord_dir_b, max_label_size=max_label_size, repeat=False, shuffle_mb=0) | |
tflib.init_uninitialized_vars() | |
print('Comparing datasets') | |
idx = 0 | |
identical_images = 0 | |
identical_labels = 0 | |
while True: | |
if idx % 100 == 0: | |
print('%d\r' % idx, end='', flush=True) | |
try: | |
images_a, labels_a = dset_a.get_minibatch_np(1) | |
except tf.errors.OutOfRangeError: | |
images_a, labels_a = None, None | |
try: | |
images_b, labels_b = dset_b.get_minibatch_np(1) | |
except tf.errors.OutOfRangeError: | |
images_b, labels_b = None, None | |
if images_a is None or images_b is None: | |
if images_a is not None or images_b is not None: | |
print('Datasets contain different number of images') | |
break | |
if images_a.shape == images_b.shape and np.all(images_a == images_b): | |
identical_images += 1 | |
else: | |
print('Image %d is different' % idx) | |
if labels_a.shape == labels_b.shape and np.all(labels_a == labels_b): | |
identical_labels += 1 | |
else: | |
print('Label %d is different' % idx) | |
idx += 1 | |
print('Identical images: %d / %d' % (identical_images, idx)) | |
if not ignore_labels: | |
print('Identical labels: %d / %d' % (identical_labels, idx)) | |
#---------------------------------------------------------------------------- | |
def create_mnist(tfrecord_dir, mnist_dir): | |
print('Loading MNIST from "%s"' % mnist_dir) | |
import gzip | |
with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file: | |
images = np.frombuffer(file.read(), np.uint8, offset=16) | |
with gzip.open(os.path.join(mnist_dir, 'train-labels-idx1-ubyte.gz'), 'rb') as file: | |
labels = np.frombuffer(file.read(), np.uint8, offset=8) | |
images = images.reshape(-1, 1, 28, 28) | |
images = np.pad(images, [(0,0), (0,0), (2,2), (2,2)], 'constant', constant_values=0) | |
assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8 | |
assert labels.shape == (60000,) and labels.dtype == np.uint8 | |
assert np.min(images) == 0 and np.max(images) == 255 | |
assert np.min(labels) == 0 and np.max(labels) == 9 | |
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) | |
onehot[np.arange(labels.size), labels] = 1.0 | |
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: | |
order = tfr.choose_shuffled_order() | |
for idx in range(order.size): | |
tfr.add_image(images[order[idx]]) | |
tfr.add_labels(onehot[order]) | |
#---------------------------------------------------------------------------- | |
def create_mnistrgb(tfrecord_dir, mnist_dir, num_images=1000000, random_seed=123): | |
print('Loading MNIST from "%s"' % mnist_dir) | |
import gzip | |
with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file: | |
images = np.frombuffer(file.read(), np.uint8, offset=16) | |
images = images.reshape(-1, 28, 28) | |
images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) | |
assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 | |
assert np.min(images) == 0 and np.max(images) == 255 | |
with TFRecordExporter(tfrecord_dir, num_images) as tfr: | |
rnd = np.random.RandomState(random_seed) | |
for _idx in range(num_images): | |
tfr.add_image(images[rnd.randint(images.shape[0], size=3)]) | |
#---------------------------------------------------------------------------- | |
def create_cifar10(tfrecord_dir, cifar10_dir): | |
print('Loading CIFAR-10 from "%s"' % cifar10_dir) | |
import pickle | |
images = [] | |
labels = [] | |
for batch in range(1, 6): | |
with open(os.path.join(cifar10_dir, 'data_batch_%d' % batch), 'rb') as file: | |
data = pickle.load(file, encoding='latin1') | |
images.append(data['data'].reshape(-1, 3, 32, 32)) | |
labels.append(data['labels']) | |
images = np.concatenate(images) | |
labels = np.concatenate(labels) | |
assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8 | |
assert labels.shape == (50000,) and labels.dtype == np.int32 | |
assert np.min(images) == 0 and np.max(images) == 255 | |
assert np.min(labels) == 0 and np.max(labels) == 9 | |
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) | |
onehot[np.arange(labels.size), labels] = 1.0 | |
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: | |
order = tfr.choose_shuffled_order() | |
for idx in range(order.size): | |
tfr.add_image(images[order[idx]]) | |
tfr.add_labels(onehot[order]) | |
#---------------------------------------------------------------------------- | |
def create_cifar100(tfrecord_dir, cifar100_dir): | |
print('Loading CIFAR-100 from "%s"' % cifar100_dir) | |
import pickle | |
with open(os.path.join(cifar100_dir, 'train'), 'rb') as file: | |
data = pickle.load(file, encoding='latin1') | |
images = data['data'].reshape(-1, 3, 32, 32) | |
labels = np.array(data['fine_labels']) | |
assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8 | |
assert labels.shape == (50000,) and labels.dtype == np.int32 | |
assert np.min(images) == 0 and np.max(images) == 255 | |
assert np.min(labels) == 0 and np.max(labels) == 99 | |
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) | |
onehot[np.arange(labels.size), labels] = 1.0 | |
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: | |
order = tfr.choose_shuffled_order() | |
for idx in range(order.size): | |
tfr.add_image(images[order[idx]]) | |
tfr.add_labels(onehot[order]) | |
#---------------------------------------------------------------------------- | |
def create_svhn(tfrecord_dir, svhn_dir): | |
print('Loading SVHN from "%s"' % svhn_dir) | |
import pickle | |
images = [] | |
labels = [] | |
for batch in range(1, 4): | |
with open(os.path.join(svhn_dir, 'train_%d.pkl' % batch), 'rb') as file: | |
data = pickle.load(file, encoding='latin1') | |
images.append(data[0]) | |
labels.append(data[1]) | |
images = np.concatenate(images) | |
labels = np.concatenate(labels) | |
assert images.shape == (73257, 3, 32, 32) and images.dtype == np.uint8 | |
assert labels.shape == (73257,) and labels.dtype == np.uint8 | |
assert np.min(images) == 0 and np.max(images) == 255 | |
assert np.min(labels) == 0 and np.max(labels) == 9 | |
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) | |
onehot[np.arange(labels.size), labels] = 1.0 | |
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: | |
order = tfr.choose_shuffled_order() | |
for idx in range(order.size): | |
tfr.add_image(images[order[idx]]) | |
tfr.add_labels(onehot[order]) | |
#---------------------------------------------------------------------------- | |
def create_lsun(tfrecord_dir, lmdb_dir, resolution=256, max_images=None): | |
print('Loading LSUN dataset from "%s"' % lmdb_dir) | |
import lmdb # pip install lmdb # pylint: disable=import-error | |
import cv2 # pip install opencv-python | |
import io | |
with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn: | |
total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter | |
if max_images is None: | |
max_images = total_images | |
with TFRecordExporter(tfrecord_dir, max_images) as tfr: | |
for _idx, (_key, value) in enumerate(txn.cursor()): | |
try: | |
try: | |
img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1) | |
if img is None: | |
raise IOError('cv2.imdecode failed') | |
img = img[:, :, ::-1] # BGR => RGB | |
except IOError: | |
img = np.asarray(PIL.Image.open(io.BytesIO(value))) | |
crop = np.min(img.shape[:2]) | |
img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2] | |
img = PIL.Image.fromarray(img, 'RGB') | |
img = img.resize((resolution, resolution), PIL.Image.ANTIALIAS) | |
img = np.asarray(img) | |
img = img.transpose([2, 0, 1]) # HWC => CHW | |
tfr.add_image(img) | |
except: | |
print(sys.exc_info()[1]) | |
if tfr.cur_images == max_images: | |
break | |
#---------------------------------------------------------------------------- | |
def create_lsun_wide(tfrecord_dir, lmdb_dir, width=512, height=384, max_images=None): | |
assert width == 2 ** int(np.round(np.log2(width))) | |
assert height <= width | |
print('Loading LSUN dataset from "%s"' % lmdb_dir) | |
import lmdb # pip install lmdb # pylint: disable=import-error | |
import cv2 # pip install opencv-python | |
import io | |
with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn: | |
total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter | |
if max_images is None: | |
max_images = total_images | |
with TFRecordExporter(tfrecord_dir, max_images, print_progress=False) as tfr: | |
for idx, (_key, value) in enumerate(txn.cursor()): | |
try: | |
try: | |
img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1) | |
if img is None: | |
raise IOError('cv2.imdecode failed') | |
img = img[:, :, ::-1] # BGR => RGB | |
except IOError: | |
img = np.asarray(PIL.Image.open(io.BytesIO(value))) | |
ch = int(np.round(width * img.shape[0] / img.shape[1])) | |
if img.shape[1] < width or ch < height: | |
continue | |
img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2] | |
img = PIL.Image.fromarray(img, 'RGB') | |
img = img.resize((width, height), PIL.Image.ANTIALIAS) | |
img = np.asarray(img) | |
img = img.transpose([2, 0, 1]) # HWC => CHW | |
canvas = np.zeros([3, width, width], dtype=np.uint8) | |
canvas[:, (width - height) // 2 : (width + height) // 2] = img | |
tfr.add_image(canvas) | |
print('\r%d / %d => %d ' % (idx + 1, total_images, tfr.cur_images), end='') | |
except: | |
print(sys.exc_info()[1]) | |
if tfr.cur_images == max_images: | |
break | |
print() | |
#---------------------------------------------------------------------------- | |
def create_celeba(tfrecord_dir, celeba_dir, cx=89, cy=121): | |
print('Loading CelebA from "%s"' % celeba_dir) | |
glob_pattern = os.path.join(celeba_dir, 'img_align_celeba_png', '*.png') | |
image_filenames = sorted(glob.glob(glob_pattern)) | |
expected_images = 202599 | |
if len(image_filenames) != expected_images: | |
error('Expected to find %d images' % expected_images) | |
with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr: | |
order = tfr.choose_shuffled_order() | |
for idx in range(order.size): | |
img = np.asarray(PIL.Image.open(image_filenames[order[idx]])) | |
assert img.shape == (218, 178, 3) | |
img = img[cy - 64 : cy + 64, cx - 64 : cx + 64] | |
img = img.transpose(2, 0, 1) # HWC => CHW | |
tfr.add_image(img) | |
#---------------------------------------------------------------------------- | |
def create_from_images(tfrecord_dir, image_dir, shuffle): | |
print('Loading images from "%s"' % image_dir) | |
image_filenames = sorted(glob.glob(os.path.join(image_dir, '*'))) | |
if len(image_filenames) == 0: | |
error('No input images found') | |
img = np.asarray(PIL.Image.open(image_filenames[0])) | |
resolution = img.shape[0] | |
channels = img.shape[2] if img.ndim == 3 else 1 | |
if img.shape[1] != resolution: | |
error('Input images must have the same width and height') | |
if resolution != 2 ** int(np.floor(np.log2(resolution))): | |
error('Input image resolution must be a power-of-two') | |
if channels not in [1, 3]: | |
error('Input images must be stored as RGB or grayscale') | |
with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr: | |
order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames)) | |
for idx in range(order.size): | |
img = np.asarray(PIL.Image.open(image_filenames[order[idx]])) | |
if channels == 1: | |
img = img[np.newaxis, :, :] # HW => CHW | |
else: | |
img = img.transpose([2, 0, 1]) # HWC => CHW | |
tfr.add_image(img) | |
#---------------------------------------------------------------------------- | |
def create_from_hdf5(tfrecord_dir, hdf5_filename, shuffle): | |
print('Loading HDF5 archive from "%s"' % hdf5_filename) | |
import h5py # conda install h5py | |
with h5py.File(hdf5_filename, 'r') as hdf5_file: | |
hdf5_data = max([value for key, value in hdf5_file.items() if key.startswith('data')], key=lambda lod: lod.shape[3]) | |
with TFRecordExporter(tfrecord_dir, hdf5_data.shape[0]) as tfr: | |
order = tfr.choose_shuffled_order() if shuffle else np.arange(hdf5_data.shape[0]) | |
for idx in range(order.size): | |
tfr.add_image(hdf5_data[order[idx]]) | |
npy_filename = os.path.splitext(hdf5_filename)[0] + '-labels.npy' | |
if os.path.isfile(npy_filename): | |
tfr.add_labels(np.load(npy_filename)[order]) | |
#---------------------------------------------------------------------------- | |
def execute_cmdline(argv): | |
prog = argv[0] | |
parser = argparse.ArgumentParser( | |
prog = prog, | |
description = 'Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN.', | |
epilog = 'Type "%s <command> -h" for more information.' % prog) | |
subparsers = parser.add_subparsers(dest='command') | |
subparsers.required = True | |
def add_command(cmd, desc, example=None): | |
epilog = 'Example: %s %s' % (prog, example) if example is not None else None | |
return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog) | |
p = add_command( 'display', 'Display images in dataset.', | |
'display datasets/mnist') | |
p.add_argument( 'tfrecord_dir', help='Directory containing dataset') | |
p = add_command( 'extract', 'Extract images from dataset.', | |
'extract datasets/mnist mnist-images') | |
p.add_argument( 'tfrecord_dir', help='Directory containing dataset') | |
p.add_argument( 'output_dir', help='Directory to extract the images into') | |
p = add_command( 'compare', 'Compare two datasets.', | |
'compare datasets/mydataset datasets/mnist') | |
p.add_argument( 'tfrecord_dir_a', help='Directory containing first dataset') | |
p.add_argument( 'tfrecord_dir_b', help='Directory containing second dataset') | |
p.add_argument( '--ignore_labels', help='Ignore labels (default: 0)', type=int, default=0) | |
p = add_command( 'create_mnist', 'Create dataset for MNIST.', | |
'create_mnist datasets/mnist ~/downloads/mnist') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'mnist_dir', help='Directory containing MNIST') | |
p = add_command( 'create_mnistrgb', 'Create dataset for MNIST-RGB.', | |
'create_mnistrgb datasets/mnistrgb ~/downloads/mnist') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'mnist_dir', help='Directory containing MNIST') | |
p.add_argument( '--num_images', help='Number of composite images to create (default: 1000000)', type=int, default=1000000) | |
p.add_argument( '--random_seed', help='Random seed (default: 123)', type=int, default=123) | |
p = add_command( 'create_cifar10', 'Create dataset for CIFAR-10.', | |
'create_cifar10 datasets/cifar10 ~/downloads/cifar10') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'cifar10_dir', help='Directory containing CIFAR-10') | |
p = add_command( 'create_cifar100', 'Create dataset for CIFAR-100.', | |
'create_cifar100 datasets/cifar100 ~/downloads/cifar100') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'cifar100_dir', help='Directory containing CIFAR-100') | |
p = add_command( 'create_svhn', 'Create dataset for SVHN.', | |
'create_svhn datasets/svhn ~/downloads/svhn') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'svhn_dir', help='Directory containing SVHN') | |
p = add_command( 'create_lsun', 'Create dataset for single LSUN category.', | |
'create_lsun datasets/lsun-car-100k ~/downloads/lsun/car_lmdb --resolution 256 --max_images 100000') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'lmdb_dir', help='Directory containing LMDB database') | |
p.add_argument( '--resolution', help='Output resolution (default: 256)', type=int, default=256) | |
p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None) | |
p = add_command( 'create_lsun_wide', 'Create LSUN dataset with non-square aspect ratio.', | |
'create_lsun_wide datasets/lsun-car-512x384 ~/downloads/lsun/car_lmdb --width 512 --height 384') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'lmdb_dir', help='Directory containing LMDB database') | |
p.add_argument( '--width', help='Output width (default: 512)', type=int, default=512) | |
p.add_argument( '--height', help='Output height (default: 384)', type=int, default=384) | |
p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None) | |
p = add_command( 'create_celeba', 'Create dataset for CelebA.', | |
'create_celeba datasets/celeba ~/downloads/celeba') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'celeba_dir', help='Directory containing CelebA') | |
p.add_argument( '--cx', help='Center X coordinate (default: 89)', type=int, default=89) | |
p.add_argument( '--cy', help='Center Y coordinate (default: 121)', type=int, default=121) | |
p = add_command( 'create_from_images', 'Create dataset from a directory full of images.', | |
'create_from_images datasets/mydataset myimagedir') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'image_dir', help='Directory containing the images') | |
p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1) | |
p = add_command( 'create_from_hdf5', 'Create dataset from legacy HDF5 archive.', | |
'create_from_hdf5 datasets/celebahq ~/downloads/celeba-hq-1024x1024.h5') | |
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') | |
p.add_argument( 'hdf5_filename', help='HDF5 archive containing the images') | |
p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1) | |
args = parser.parse_args(argv[1:] if len(argv) > 1 else ['-h']) | |
func = globals()[args.command] | |
del args.command | |
func(**vars(args)) | |
#---------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
execute_cmdline(sys.argv) | |
#---------------------------------------------------------------------------- | |