Spaces:
Running
Running
# Copyright 2017 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Download MNIST, Omniglot datasets for Rebar.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import urllib | |
import gzip | |
import os | |
import config | |
import struct | |
import numpy as np | |
import cPickle as pickle | |
import datasets | |
MNIST_URL = 'see README' | |
MNIST_BINARIZED_URL = 'see README' | |
OMNIGLOT_URL = 'see README' | |
MNIST_FLOAT_TRAIN = 'train-images-idx3-ubyte' | |
def load_mnist_float(local_filename): | |
with open(local_filename, 'rb') as f: | |
f.seek(4) | |
nimages, rows, cols = struct.unpack('>iii', f.read(12)) | |
dim = rows*cols | |
images = np.fromfile(f, dtype=np.dtype(np.ubyte)) | |
images = (images/255.0).astype('float32').reshape((nimages, dim)) | |
return images | |
if __name__ == '__main__': | |
if not os.path.exists(config.DATA_DIR): | |
os.makedirs(config.DATA_DIR) | |
# Get MNIST and convert to npy file | |
local_filename = os.path.join(config.DATA_DIR, MNIST_FLOAT_TRAIN) | |
if not os.path.exists(local_filename): | |
urllib.urlretrieve("%s/%s.gz" % (MNIST_URL, MNIST_FLOAT_TRAIN), local_filename+'.gz') | |
with gzip.open(local_filename+'.gz', 'rb') as f: | |
file_content = f.read() | |
with open(local_filename, 'wb') as f: | |
f.write(file_content) | |
os.remove(local_filename+'.gz') | |
mnist_float_train = load_mnist_float(local_filename)[:-10000] | |
# save in a nice format | |
np.save(os.path.join(config.DATA_DIR, config.MNIST_FLOAT), mnist_float_train) | |
# Get binarized MNIST | |
splits = ['train', 'valid', 'test'] | |
mnist_binarized = [] | |
for split in splits: | |
filename = 'binarized_mnist_%s.amat' % split | |
url = '%s/binarized_mnist_%s.amat' % (MNIST_BINARIZED_URL, split) | |
local_filename = os.path.join(config.DATA_DIR, filename) | |
if not os.path.exists(local_filename): | |
urllib.urlretrieve(url, local_filename) | |
with open(local_filename, 'rb') as f: | |
mnist_binarized.append((np.array([map(int, line.split()) for line in f.readlines()]).astype('float32'), None)) | |
# save in a nice format | |
with open(os.path.join(config.DATA_DIR, config.MNIST_BINARIZED), 'w') as out: | |
pickle.dump(mnist_binarized, out) | |
# Get Omniglot | |
local_filename = os.path.join(config.DATA_DIR, config.OMNIGLOT) | |
if not os.path.exists(local_filename): | |
urllib.urlretrieve(OMNIGLOT_URL, | |
local_filename) | |