|
import os |
|
import numpy as np |
|
import os.path as osp |
|
import argparse |
|
from PIL import Image |
|
from scipy.io import loadmat |
|
|
|
|
|
def mkdir_if_missing(directory): |
|
if not osp.exists(directory): |
|
os.makedirs(directory) |
|
|
|
|
|
def extract_and_save(data, label, save_dir): |
|
for i, (x, y) in enumerate(zip(data, label)): |
|
if x.shape[2] == 1: |
|
x = np.repeat(x, 3, axis=2) |
|
if y == 10: |
|
y = 0 |
|
x = Image.fromarray(x, mode="RGB") |
|
save_path = osp.join( |
|
save_dir, |
|
str(i + 1).zfill(6) + "_" + str(y) + ".jpg" |
|
) |
|
x.save(save_path) |
|
|
|
|
|
def load_mnist(data_dir, raw_data_dir): |
|
filepath = osp.join(raw_data_dir, "mnist_data.mat") |
|
data = loadmat(filepath) |
|
|
|
train_data = np.reshape(data["train_32"], (55000, 32, 32, 1)) |
|
test_data = np.reshape(data["test_32"], (10000, 32, 32, 1)) |
|
|
|
train_label = np.nonzero(data["label_train"])[1] |
|
test_label = np.nonzero(data["label_test"])[1] |
|
|
|
return train_data, test_data, train_label, test_label |
|
|
|
|
|
def load_mnist_m(data_dir, raw_data_dir): |
|
filepath = osp.join(raw_data_dir, "mnistm_with_label.mat") |
|
data = loadmat(filepath) |
|
|
|
train_data = data["train"] |
|
test_data = data["test"] |
|
|
|
train_label = np.nonzero(data["label_train"])[1] |
|
test_label = np.nonzero(data["label_test"])[1] |
|
|
|
return train_data, test_data, train_label, test_label |
|
|
|
|
|
def load_svhn(data_dir, raw_data_dir): |
|
train = loadmat(osp.join(raw_data_dir, "svhn_train_32x32.mat")) |
|
train_data = train["X"].transpose(3, 0, 1, 2) |
|
train_label = train["y"][:, 0] |
|
|
|
test = loadmat(osp.join(raw_data_dir, "svhn_test_32x32.mat")) |
|
test_data = test["X"].transpose(3, 0, 1, 2) |
|
test_label = test["y"][:, 0] |
|
|
|
return train_data, test_data, train_label, test_label |
|
|
|
|
|
def load_syn(data_dir, raw_data_dir): |
|
filepath = osp.join(raw_data_dir, "syn_number.mat") |
|
data = loadmat(filepath) |
|
|
|
train_data = data["train_data"] |
|
test_data = data["test_data"] |
|
|
|
train_label = data["train_label"][:, 0] |
|
test_label = data["test_label"][:, 0] |
|
|
|
return train_data, test_data, train_label, test_label |
|
|
|
|
|
def load_usps(data_dir, raw_data_dir): |
|
filepath = osp.join(raw_data_dir, "usps_28x28.mat") |
|
data = loadmat(filepath)["dataset"] |
|
|
|
train_data = data[0][0].transpose(0, 2, 3, 1) |
|
test_data = data[1][0].transpose(0, 2, 3, 1) |
|
|
|
train_data *= 255 |
|
test_data *= 255 |
|
|
|
train_data = train_data.astype(np.uint8) |
|
test_data = test_data.astype(np.uint8) |
|
|
|
train_label = data[0][1][:, 0] |
|
test_label = data[1][1][:, 0] |
|
|
|
return train_data, test_data, train_label, test_label |
|
|
|
|
|
def main(data_dir): |
|
data_dir = osp.abspath(osp.expanduser(data_dir)) |
|
raw_data_dir = osp.join(data_dir, "Digit-Five") |
|
|
|
if not osp.exists(data_dir): |
|
raise FileNotFoundError('"{}" does not exist'.format(data_dir)) |
|
|
|
datasets = ["mnist", "mnist_m", "svhn", "syn", "usps"] |
|
|
|
for name in datasets: |
|
print("Creating {}".format(name)) |
|
|
|
output = eval("load_" + name)(data_dir, raw_data_dir) |
|
train_data, test_data, train_label, test_label = output |
|
|
|
print("# train: {}".format(train_data.shape[0])) |
|
print("# test: {}".format(test_data.shape[0])) |
|
|
|
train_dir = osp.join(data_dir, name, "train_images") |
|
mkdir_if_missing(train_dir) |
|
test_dir = osp.join(data_dir, name, "test_images") |
|
mkdir_if_missing(test_dir) |
|
|
|
extract_and_save(train_data, train_label, train_dir) |
|
extract_and_save(test_data, test_label, test_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"data_dir", type=str, help="directory containing Digit-Five/" |
|
) |
|
args = parser.parse_args() |
|
main(args.data_dir) |
|
|