|
import os, gzip, torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import scipy.misc |
|
import imageio |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
from torchvision import datasets, transforms |
|
import visdom |
|
import random |
|
|
|
def save_wiggle(images, rows=1, name="test"): |
|
|
|
|
|
width = images[0].shape[1] |
|
height = images[0].shape[2] |
|
columns = int(len(images)/rows) |
|
rows = int(rows) |
|
margin = 4 |
|
|
|
total_width = (width + margin) * columns |
|
total_height = (height + margin) * rows |
|
|
|
new_im = Image.new('RGB', (total_width, total_height)) |
|
|
|
transToPil = transforms.ToPILImage() |
|
|
|
x_offset = 3 |
|
y_offset = 3 |
|
for y in range(rows): |
|
for x in range(columns): |
|
im = images[x+y*columns] |
|
im = transToPil((im+1)/2) |
|
new_im.paste(im, (x_offset, y_offset)) |
|
x_offset += width + margin |
|
x_offset = 3 |
|
y_offset += height + margin |
|
|
|
new_im.save('./WiggleResults/' + name + '.jpg') |
|
|
|
def load_mnist(dataset): |
|
data_dir = os.path.join("./data", dataset) |
|
|
|
def extract_data(filename, num_data, head_size, data_size): |
|
with gzip.open(filename) as bytestream: |
|
bytestream.read(head_size) |
|
buf = bytestream.read(data_size * num_data) |
|
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float) |
|
return data |
|
|
|
data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28) |
|
trX = data.reshape((60000, 28, 28, 1)) |
|
|
|
data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1) |
|
trY = data.reshape((60000)) |
|
|
|
data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28) |
|
teX = data.reshape((10000, 28, 28, 1)) |
|
|
|
data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1) |
|
teY = data.reshape((10000)) |
|
|
|
trY = np.asarray(trY).astype(np.int) |
|
teY = np.asarray(teY) |
|
|
|
X = np.concatenate((trX, teX), axis=0) |
|
y = np.concatenate((trY, teY), axis=0).astype(np.int) |
|
|
|
seed = 547 |
|
np.random.seed(seed) |
|
np.random.shuffle(X) |
|
np.random.seed(seed) |
|
np.random.shuffle(y) |
|
|
|
y_vec = np.zeros((len(y), 10), dtype=np.float) |
|
for i, label in enumerate(y): |
|
y_vec[i, y[i]] = 1 |
|
|
|
X = X.transpose(0, 3, 1, 2) / 255. |
|
|
|
|
|
X = torch.from_numpy(X).type(torch.FloatTensor) |
|
y_vec = torch.from_numpy(y_vec).type(torch.FloatTensor) |
|
return X, y_vec |
|
|
|
def load_celebA(dir, transform, batch_size, shuffle): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dset = datasets.ImageFolder(dir, transform) |
|
data_loader = torch.utils.data.DataLoader(dset, batch_size, shuffle) |
|
|
|
return data_loader |
|
|
|
|
|
def print_network(net): |
|
num_params = 0 |
|
for param in net.parameters(): |
|
num_params += param.numel() |
|
print(net) |
|
print('Total number of parameters: %d' % num_params) |
|
|
|
def save_images(images, size, image_path): |
|
return imsave(images, size, image_path) |
|
|
|
def imsave(images, size, path): |
|
image = np.squeeze(merge(images, size)) |
|
return scipy.misc.imsave(path, image) |
|
|
|
def merge(images, size): |
|
|
|
h, w = images.shape[1], images.shape[2] |
|
if (images.shape[3] in (3,4)): |
|
c = images.shape[3] |
|
img = np.zeros((h * size[0], w * size[1], c)) |
|
for idx, image in enumerate(images): |
|
i = idx % size[1] |
|
j = idx // size[1] |
|
img[j * h:j * h + h, i * w:i * w + w, :] = image |
|
return img |
|
elif images.shape[3]== 1: |
|
img = np.zeros((h * size[0], w * size[1])) |
|
for idx, image in enumerate(images): |
|
|
|
i = idx % size[1] |
|
j = idx // size[1] |
|
img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] |
|
return img |
|
else: |
|
raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') |
|
|
|
def generate_animation(path, num): |
|
images = [] |
|
for e in range(num): |
|
img_name = path + '_epoch%04d' % (e+1) + '.png' |
|
images.append(imageio.imread(img_name)) |
|
imageio.mimsave(path + '_generate_animation.gif', images, fps=5) |
|
|
|
def loss_plot(hist, path = 'Train_hist.png', model_name = ''): |
|
x1 = range(len(hist['D_loss_train'])) |
|
x2 = range(len(hist['G_loss_train'])) |
|
|
|
y1 = hist['D_loss_train'] |
|
y2 = hist['G_loss_train'] |
|
|
|
if (x1 != x2): |
|
y1 = [0.0] * (len(y2) - len(y1)) + y1 |
|
x1 = x2 |
|
|
|
plt.plot(x1, y1, label='D_loss_train') |
|
|
|
plt.plot(x2, y2, label='G_loss_train') |
|
|
|
plt.xlabel('Iter') |
|
plt.ylabel('Loss') |
|
|
|
plt.legend(loc=4) |
|
plt.grid(True) |
|
plt.tight_layout() |
|
|
|
path = os.path.join(path, model_name + '_loss.png') |
|
|
|
plt.savefig(path) |
|
|
|
plt.close() |
|
|
|
def initialize_weights(net): |
|
for m in net.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
m.weight.data.normal_(0, 0.02) |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.ConvTranspose2d): |
|
m.weight.data.normal_(0, 0.02) |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.Linear): |
|
m.weight.data.normal_(0, 0.02) |
|
m.bias.data.zero_() |
|
|
|
class VisdomLinePlotter(object): |
|
"""Plots to Visdom""" |
|
def __init__(self, env_name='main'): |
|
self.viz = visdom.Visdom() |
|
self.env = env_name |
|
self.ini = False |
|
self.count = 1 |
|
def plot(self, var_name,names, split_name, hist): |
|
|
|
|
|
|
|
x = [] |
|
y = [] |
|
for i, name in enumerate(names): |
|
x.append(self.count) |
|
y.append(hist[name]) |
|
self.count+=1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
np.array(x) |
|
|
|
|
|
for i,n in enumerate(names): |
|
x[i] = np.arange(1, x[i]+1) |
|
|
|
if not self.ini: |
|
for i, name in enumerate(names): |
|
if i == 0: |
|
self.win = self.viz.line(X=x[i], Y=np.array(y[i]), env=self.env,name = name,opts=dict( |
|
title=var_name + '_'+split_name, showlegend = True |
|
)) |
|
else: |
|
self.viz.line(X=x[i], Y=np.array(y[i]), env=self.env,win=self.win, name=name, update='append') |
|
self.ini = True |
|
else: |
|
x[0] = np.array([x[0][-2], x[0][-1]]) |
|
|
|
for i,n in enumerate(names): |
|
y[i] = np.array([y[i][-2], y[i][-1]]) |
|
self.viz.line(X=x[0], Y=np.array(y[i]), env=self.env, win=self.win, name=n, update='append') |
|
|
|
|
|
class VisdomLineTwoPlotter(VisdomLinePlotter): |
|
|
|
def plot(self, var_name, epoch,names, hist): |
|
|
|
x1 = epoch |
|
y1 = hist[names[0]] |
|
y2 = hist[names[1]] |
|
y3 = hist[names[2]] |
|
y4 = hist[names[3]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.ini: |
|
self.win = self.viz.line(X=np.array([x1]), Y=np.array(y1), env=self.env,name = names[0],opts=dict( |
|
title=var_name, |
|
showlegend = True, |
|
linecolor = np.array([[0, 0, 255]]) |
|
)) |
|
self.viz.line(X=np.array([x1]), Y=np.array(y2), env=self.env,win=self.win, name=names[1], |
|
update='append', opts=dict( |
|
linecolor=np.array([[255, 153, 51]]) |
|
)) |
|
self.viz.line(X=np.array([x1]), Y=np.array(y3), env=self.env, win=self.win, name=names[2], |
|
update='append', opts=dict( |
|
linecolor=np.array([[0, 51, 153]]) |
|
)) |
|
self.viz.line(X=np.array([x1]), Y=np.array(y4), env=self.env, win=self.win, name=names[3], |
|
update='append', opts=dict( |
|
linecolor=np.array([[204, 51, 0]]) |
|
)) |
|
self.ini = True |
|
else: |
|
|
|
y4 = np.array([y4[-2], y4[-1]]) |
|
y3 = np.array([y3[-2], y3[-1]]) |
|
y2 = np.array([y2[-2], y2[-1]]) |
|
y1 = np.array([y1[-2], y1[-1]]) |
|
x1 = np.array([x1 - 1, x1]) |
|
self.viz.line(X=x1, Y=np.array(y1), env=self.env, win=self.win, name=names[0], update='append') |
|
self.viz.line(X=x1, Y=np.array(y2), env=self.env, win=self.win, name=names[1], update='append') |
|
self.viz.line(X=x1, Y=np.array(y3), env=self.env, win=self.win, name=names[2], |
|
update='append') |
|
self.viz.line(X=x1, Y=np.array(y4), env=self.env, win=self.win, name=names[3], |
|
update='append') |
|
|
|
class VisdomImagePlotter(object): |
|
"""Plots to Visdom""" |
|
def __init__(self, env_name='main'): |
|
self.viz = visdom.Visdom() |
|
self.env = env_name |
|
def plot(self, epoch,images,rows): |
|
|
|
list_images = [] |
|
for image in images: |
|
|
|
image = (image + 1)/2 |
|
image = image.detach().numpy() * 255 |
|
list_images.append(image) |
|
self.viz.images( |
|
list_images, |
|
padding=2, |
|
nrow =rows, |
|
opts=dict(title="epoch: " + str(epoch)), |
|
env=self.env |
|
) |
|
|
|
|
|
def augmentData(x,y, randomness = 1, percent_noise = 0.1): |
|
""" |
|
:param x: image X |
|
:param y: image Y |
|
:param randomness: Value of randomness (between 1 and 0) |
|
:return: data x,y augmented |
|
""" |
|
|
|
|
|
sampleX = torch.tensor([]) |
|
sampleY = torch.tensor([]) |
|
|
|
for aumX, aumY in zip(x,y): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
aumX = (aumX + 1) / 2 |
|
aumY = (aumY + 1) / 2 |
|
|
|
imgX = transforms.ToPILImage()(aumX) |
|
imgY = transforms.ToPILImage()(aumY) |
|
|
|
|
|
brighness = random.uniform(0.7, 1.2)* randomness + (1-randomness) |
|
saturation = random.uniform(0, 2)* randomness + (1-randomness) |
|
contrast = random.uniform(0.4, 2)* randomness + (1-randomness) |
|
gamma = random.uniform(0.7, 1.3)* randomness + (1-randomness) |
|
hue = random.uniform(-0.3, 0.3)* randomness |
|
|
|
imgX = transforms.functional.adjust_gamma(imgX, gamma) |
|
imgX = transforms.functional.adjust_brightness(imgX, brighness) |
|
imgX = transforms.functional.adjust_contrast(imgX, contrast) |
|
imgX = transforms.functional.adjust_saturation(imgX, saturation) |
|
imgX = transforms.functional.adjust_hue(imgX, hue) |
|
|
|
|
|
imgY = transforms.functional.adjust_gamma(imgY, gamma) |
|
imgY = transforms.functional.adjust_brightness(imgY, brighness) |
|
imgY = transforms.functional.adjust_contrast(imgY, contrast) |
|
imgY = transforms.functional.adjust_saturation(imgY, saturation) |
|
imgY = transforms.functional.adjust_hue(imgY, hue) |
|
|
|
|
|
sx = transforms.ToTensor()(imgX) |
|
sx = (sx * 2)-1 |
|
|
|
sy = transforms.ToTensor()(imgY) |
|
sy = (sy * 2)-1 |
|
|
|
sampleX = torch.cat((sampleX, sx.unsqueeze_(0)), 0) |
|
sampleY = torch.cat((sampleY, sy.unsqueeze_(0)), 0) |
|
return sampleX,sampleY |
|
|
|
def RGBtoL (x): |
|
|
|
return x[:,0,:,:].unsqueeze(0).transpose(0,1) |
|
|
|
def LtoRGB (x): |
|
|
|
return x.repeat(1, 3, 1, 1) |
|
|