ae_gen / cli.py
mehdidc's picture
minor bug + better defaults in test()
383cba8
import os
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from functools import partial
from clize import run
import numpy as np
from skimage.io import imsave
from viz import grid_of_images_default
import torch.nn as nn
import torch
from model import DenseAE
from model import ConvAE
from model import DeepConvAE
from model import SimpleConvAE
from model import ZAE
from model import KAE
from data import load_dataset
device = "cuda" if torch.cuda.is_available() else "cpu"
def plot_dataset(code_2d, categories):
colors = [
'r',
'b',
'g',
'crimson',
'gold',
'yellow',
'maroon',
'm',
'c',
'orange'
]
for cat in range(0, 10):
g = (categories == cat)
plt.scatter(
code_2d[g, 0],
code_2d[g, 1],
marker='+',
c=colors[cat],
s=40,
alpha=0.7,
label="digit {}".format(cat)
)
def plot_generated(code_2d, categories):
g = (categories < 0)
plt.scatter(
code_2d[g, 0],
code_2d[g, 1],
marker='+',
c='gray',
s=30
)
def grid_embedding(h):
from lapjv import lapjv
from scipy.spatial.distance import cdist
assert int(np.sqrt(h.shape[0])) ** 2 == h.shape[0], 'Nb of examples must be a square number'
size = int(np.sqrt(h.shape[0]))
grid = np.dstack(np.meshgrid(np.linspace(0, 1, size), np.linspace(0, 1, size))).reshape(-1, 2)
cost_matrix = cdist(grid, h, "sqeuclidean").astype('float32')
cost_matrix = cost_matrix * (100000 / cost_matrix.max())
_, rows, cols = lapjv(cost_matrix)
return rows
def save_weights(m, folder='.'):
if isinstance(m, nn.Linear):
w = m.weight.data
if w.size(1) == 28*28 or w.size(0) == 28*28:
w0, w1 = w.size(0), w.size(1)
if w0 == 28*28:
w = w.transpose(0, 1)
w = w.contiguous()
w = w.view(w.size(0), 1, 28, 28)
gr = grid_of_images_default(np.array(w.tolist()), normalize=True)
imsave('{}/feat_{}.png'.format(folder, w0), gr)
elif isinstance(m, nn.ConvTranspose2d):
w = m.weight.data
if w.size(0) in (32, 64, 128, 256, 512) and w.size(1) in (1, 3):
gr = grid_of_images_default(np.array(w.tolist()), normalize=True)
imsave('{}/feat.png'.format(folder), gr)
@torch.no_grad()
def iterative_refinement(ae, nb_examples=1, nb_iter=10, w=28, h=28, c=1, batch_size=None, binarize_threshold=None):
if batch_size is None:
batch_size = nb_examples
x = torch.rand(nb_iter, nb_examples, c, w, h)
for i in range(1, nb_iter):
for j in range(0, nb_examples, batch_size):
oldv = x[i-1][j:j + batch_size].to(device)
newv = ae(oldv)
if binarize_threshold is not None:
newv = (newv>binarize_threshold).float()
newv = newv.data.cpu()
x[i][j:j + batch_size] = newv
return x
def build_model(name, w, h, c):
if name == 'convae':
ae = ConvAE(
w=w, h=h, c=c,
nb_filters=128,
spatial=True,
channel=True,
channel_stride=4,
)
elif name == 'zae':
ae = ZAE(
w=w, h=h, c=c,
theta=3,
nb_hidden=1000,
)
elif name == 'kae':
ae = KAE(
w=w, h=h, c=c,
nb_active=1000,
nb_hidden=1000,
)
elif name == 'denseae':
ae = DenseAE(
w=w, h=h, c=c,
encode_hidden=[1000],
decode_hidden=[],
ksparse=True,
nb_active=50,
)
elif name == 'simple_convae':
ae = SimpleConvAE(
w=w, h=h, c=c,
nb_filters=128,
)
elif name == 'deep_convae':
ae = DeepConvAE(
w=w, h=h, c=c,
nb_filters=128,
spatial=True,
channel=True,
channel_stride=4,
nb_layers=3,
)
else:
raise ValueError('Unknown model')
return ae
def salt_and_pepper(X, proba=0.5):
a = (torch.rand(X.size()).to(device) <= (1 - proba)).float()
b = (torch.rand(X.size()).to(device) <= 0.5).float()
c = ((a == 0).float() * b)
return X * a + c
def train(*, dataset='mnist', folder='mnist', resume=False, model='convae', walkback=False, denoise=False, epochs=100, batch_size=64, log_interval=100):
gamma = 0.99
dataset = load_dataset(dataset, split='train')
x0, _ = dataset[0]
c, h, w = x0.size()
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4
)
if resume:
ae = torch.load('{}/model.th'.format(folder))
ae = ae.to(device)
else:
ae = build_model(model, w=w, h=h, c=c)
ae = ae.to(device)
optim = torch.optim.Adadelta(ae.parameters(), lr=0.1, eps=1e-7, rho=0.95, weight_decay=0)
avg_loss = 0.
nb_updates = 0
_save_weights = partial(save_weights, folder=folder)
for epoch in range(epochs):
for X, y in dataloader:
ae.zero_grad()
X = X.to(device)
if hasattr(ae, 'nb_active'):
ae.nb_active = max(ae.nb_active - 1, 32)
# walkback + denoise
if walkback:
loss = 0.
x = X.data
nb = 5
for _ in range(nb):
x = salt_and_pepper(x, proba=0.3) # denoise
x = x.to(device)
x = ae(x) # reconstruct
Xr = x
loss += (((x - X) ** 2).view(X.size(0), -1).sum(1).mean()) / nb
x = (torch.rand(x.size()).to(device) <= x.data).float() # sample
# denoise only
elif denoise:
Xc = salt_and_pepper(X.data, proba=0.3)
Xr = ae(Xc)
loss = ((Xr - X) ** 2).view(X.size(0), -1).sum(1).mean()
# normal training
else:
Xr = ae(X)
loss = ((Xr - X) ** 2).view(X.size(0), -1).sum(1).mean()
loss.backward()
optim.step()
avg_loss = avg_loss * gamma + loss.item() * (1 - gamma)
if nb_updates % log_interval == 0:
print('Epoch : {:05d} AvgTrainLoss: {:.6f}, Batch Loss : {:.6f}'.format(epoch, avg_loss, loss.item() ))
gr = grid_of_images_default(np.array(Xr.data.tolist()))
imsave('{}/rec.png'.format(folder), gr)
ae.apply(_save_weights)
torch.save(ae, '{}/model.th'.format(folder))
nb_updates += 1
def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=25, nb_generate=100, nb_active=160, tsne=False):
if not os.path.exists(folder):
os.makedirs(folder, exist_ok=True)
dataset = load_dataset(dataset, split='train')
x0, _ = dataset[0]
c, h, w = x0.size()
nb = nb_generate
print('Load model...')
if model_path is None:
model_path = os.path.join(folder, "model.th")
ae = torch.load(model_path, map_location="cpu")
ae = ae.to(device)
ae.nb_active = nb_active # for fc_sparse.th only
def enc(X):
batch_size = 64
h_list = []
for i in range(0, X.size(0), batch_size):
x = X[i:i + batch_size]
x = x.to(device)
name = ae.__class__.__name__
if name in ('ConvAE',):
h = ae.encode(x)
h, _ = h.max(2)
h = h.view((h.size(0), -1))
elif name in ('DenseAE',):
x = x.view(x.size(0), -1)
h = x
#h = ae.encode(x)
else:
h = x.view(x.size(0), -1)
h = h.data.cpu()
h_list.append(h)
return torch.cat(h_list, 0)
print('iterative refinement...')
g = iterative_refinement(
ae,
nb_iter=nb_iter,
nb_examples=nb,
w=w, h=h, c=c,
batch_size=64
)
np.savez('{}/generated.npz'.format(folder), X=g.numpy())
g_subset = g[:, 0:100]
gr = grid_of_images_default(g_subset.reshape((g_subset.shape[0]*g_subset.shape[1], h, w, 1)).numpy(), shape=(g_subset.shape[0], g_subset.shape[1]))
imsave('{}/gen_full_iters.png'.format(folder), (gr*255).astype("uint8") )
g = g[-1] # last iter
print(g.shape)
gr = grid_of_images_default(g.numpy())
imsave('{}/gen_full.png'.format(folder), (gr*255).astype("uint8") )
if tsne:
from sklearn.manifold import TSNE
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=nb,
shuffle=True,
num_workers=1
)
print('Load data...')
X, y = next(iter(dataloader))
print('Encode data...')
xh = enc(X)
print('Encode generated...')
gh = enc(g)
X = X.numpy()
g = g.numpy()
xh = xh.numpy()
gh = gh.numpy()
a = np.concatenate((X, g), axis=0)
ah = np.concatenate((xh, gh), axis=0)
labels = np.array(y.tolist() + [-1] * len(g))
sne = TSNE()
print('fit tsne...')
ah = sne.fit_transform(ah)
print('grid embedding...')
assert nb_generate >= 450
asmall = np.concatenate((a[0:450], a[nb:nb + 450]), axis=0)
ahsmall = np.concatenate((ah[0:450], ah[nb:nb + 450]), axis=0)
rows = grid_embedding(ahsmall)
asmall = asmall[rows]
gr = grid_of_images_default(asmall)
imsave('{}/sne_grid.png'.format(folder), (gr*255).astype("uint8") )
fig = plt.figure(figsize=(10, 10))
plot_dataset(ah, labels)
plot_generated(ah, labels)
plt.legend(loc='best')
plt.axis('off')
plt.savefig('{}/sne.png'.format(folder))
plt.close(fig)
if __name__ == '__main__':
run([train, test])