WiggleGAN / utils.py
Rodrigo_Cobo
added thesis
cc6c676
raw
history blame
11.9 kB
import os, gzip, torch
import torch.nn as nn
import numpy as np
import scipy.misc
import imageio
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import datasets, transforms
import visdom
import random
def save_wiggle(images, rows=1, name="test"):
width = images[0].shape[1]
height = images[0].shape[2]
columns = int(len(images)/rows)
rows = int(rows)
margin = 4
total_width = (width + margin) * columns
total_height = (height + margin) * rows
new_im = Image.new('RGB', (total_width, total_height))
transToPil = transforms.ToPILImage()
x_offset = 3
y_offset = 3
for y in range(rows):
for x in range(columns):
im = images[x+y*columns]
im = transToPil((im+1)/2)
new_im.paste(im, (x_offset, y_offset))
x_offset += width + margin
x_offset = 3
y_offset += height + margin
new_im.save('./WiggleResults/' + name + '.jpg')
def load_mnist(dataset):
data_dir = os.path.join("./data", dataset)
def extract_data(filename, num_data, head_size, data_size):
with gzip.open(filename) as bytestream:
bytestream.read(head_size)
buf = bytestream.read(data_size * num_data)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
return data
data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
trX = data.reshape((60000, 28, 28, 1))
data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1)
trY = data.reshape((60000))
data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
teX = data.reshape((10000, 28, 28, 1))
data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1)
teY = data.reshape((10000))
trY = np.asarray(trY).astype(np.int)
teY = np.asarray(teY)
X = np.concatenate((trX, teX), axis=0)
y = np.concatenate((trY, teY), axis=0).astype(np.int)
seed = 547
np.random.seed(seed)
np.random.shuffle(X)
np.random.seed(seed)
np.random.shuffle(y)
y_vec = np.zeros((len(y), 10), dtype=np.float)
for i, label in enumerate(y):
y_vec[i, y[i]] = 1
X = X.transpose(0, 3, 1, 2) / 255.
# y_vec = y_vec.transpose(0, 3, 1, 2)
X = torch.from_numpy(X).type(torch.FloatTensor)
y_vec = torch.from_numpy(y_vec).type(torch.FloatTensor)
return X, y_vec
def load_celebA(dir, transform, batch_size, shuffle):
# transform = transforms.Compose([
# transforms.CenterCrop(160),
# transform.Scale(64)
# transforms.ToTensor(),
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
# ])
# data_dir = 'data/celebA' # this path depends on your computer
dset = datasets.ImageFolder(dir, transform)
data_loader = torch.utils.data.DataLoader(dset, batch_size, shuffle)
return data_loader
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
def save_images(images, size, image_path):
return imsave(images, size, image_path)
def imsave(images, size, path):
image = np.squeeze(merge(images, size))
return scipy.misc.imsave(path, image)
def merge(images, size):
#print ("shape", images.shape)
h, w = images.shape[1], images.shape[2]
if (images.shape[3] in (3,4)):
c = images.shape[3]
img = np.zeros((h * size[0], w * size[1], c))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j * h:j * h + h, i * w:i * w + w, :] = image
return img
elif images.shape[3]== 1:
img = np.zeros((h * size[0], w * size[1]))
for idx, image in enumerate(images):
#print("indez ",idx)
i = idx % size[1]
j = idx // size[1]
img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
return img
else:
raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')
def generate_animation(path, num):
images = []
for e in range(num):
img_name = path + '_epoch%04d' % (e+1) + '.png'
images.append(imageio.imread(img_name))
imageio.mimsave(path + '_generate_animation.gif', images, fps=5)
def loss_plot(hist, path = 'Train_hist.png', model_name = ''):
x1 = range(len(hist['D_loss_train']))
x2 = range(len(hist['G_loss_train']))
y1 = hist['D_loss_train']
y2 = hist['G_loss_train']
if (x1 != x2):
y1 = [0.0] * (len(y2) - len(y1)) + y1
x1 = x2
plt.plot(x1, y1, label='D_loss_train')
plt.plot(x2, y2, label='G_loss_train')
plt.xlabel('Iter')
plt.ylabel('Loss')
plt.legend(loc=4)
plt.grid(True)
plt.tight_layout()
path = os.path.join(path, model_name + '_loss.png')
plt.savefig(path)
plt.close()
def initialize_weights(net):
for m in net.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
class VisdomLinePlotter(object):
"""Plots to Visdom"""
def __init__(self, env_name='main'):
self.viz = visdom.Visdom()
self.env = env_name
self.ini = False
self.count = 1
def plot(self, var_name,names, split_name, hist):
x = []
y = []
for i, name in enumerate(names):
x.append(self.count)
y.append(hist[name])
self.count+=1
#x1 = (len(hist['D_loss_' +split_name]))
#x2 = (len(hist['G_loss_' +split_name]))
#y1 = hist['D_loss_'+split_name]
#y2 = hist['G_loss_'+split_name]
np.array(x)
for i,n in enumerate(names):
x[i] = np.arange(1, x[i]+1)
if not self.ini:
for i, name in enumerate(names):
if i == 0:
self.win = self.viz.line(X=x[i], Y=np.array(y[i]), env=self.env,name = name,opts=dict(
title=var_name + '_'+split_name, showlegend = True
))
else:
self.viz.line(X=x[i], Y=np.array(y[i]), env=self.env,win=self.win, name=name, update='append')
self.ini = True
else:
x[0] = np.array([x[0][-2], x[0][-1]])
for i,n in enumerate(names):
y[i] = np.array([y[i][-2], y[i][-1]])
self.viz.line(X=x[0], Y=np.array(y[i]), env=self.env, win=self.win, name=n, update='append')
class VisdomLineTwoPlotter(VisdomLinePlotter):
def plot(self, var_name, epoch,names, hist):
x1 = epoch
y1 = hist[names[0]]
y2 = hist[names[1]]
y3 = hist[names[2]]
y4 = hist[names[3]]
#y1 = hist['D_loss_' + split_name]
#y2 = hist['G_loss_' + split_name]
#y3 = hist['D_loss_' + split_name2]
#y4 = hist['G_loss_' + split_name2]
#x1 = np.arange(1, x1+1)
if not self.ini:
self.win = self.viz.line(X=np.array([x1]), Y=np.array(y1), env=self.env,name = names[0],opts=dict(
title=var_name,
showlegend = True,
linecolor = np.array([[0, 0, 255]])
))
self.viz.line(X=np.array([x1]), Y=np.array(y2), env=self.env,win=self.win, name=names[1],
update='append', opts=dict(
linecolor=np.array([[255, 153, 51]])
))
self.viz.line(X=np.array([x1]), Y=np.array(y3), env=self.env, win=self.win, name=names[2],
update='append', opts=dict(
linecolor=np.array([[0, 51, 153]])
))
self.viz.line(X=np.array([x1]), Y=np.array(y4), env=self.env, win=self.win, name=names[3],
update='append', opts=dict(
linecolor=np.array([[204, 51, 0]])
))
self.ini = True
else:
y4 = np.array([y4[-2], y4[-1]])
y3 = np.array([y3[-2], y3[-1]])
y2 = np.array([y2[-2], y2[-1]])
y1 = np.array([y1[-2], y1[-1]])
x1 = np.array([x1 - 1, x1])
self.viz.line(X=x1, Y=np.array(y1), env=self.env, win=self.win, name=names[0], update='append')
self.viz.line(X=x1, Y=np.array(y2), env=self.env, win=self.win, name=names[1], update='append')
self.viz.line(X=x1, Y=np.array(y3), env=self.env, win=self.win, name=names[2],
update='append')
self.viz.line(X=x1, Y=np.array(y4), env=self.env, win=self.win, name=names[3],
update='append')
class VisdomImagePlotter(object):
"""Plots to Visdom"""
def __init__(self, env_name='main'):
self.viz = visdom.Visdom()
self.env = env_name
def plot(self, epoch,images,rows):
list_images = []
for image in images:
#transforms.ToPILImage()(image)
image = (image + 1)/2
image = image.detach().numpy() * 255
list_images.append(image)
self.viz.images(
list_images,
padding=2,
nrow =rows,
opts=dict(title="epoch: " + str(epoch)),
env=self.env
)
def augmentData(x,y, randomness = 1, percent_noise = 0.1):
"""
:param x: image X
:param y: image Y
:param randomness: Value of randomness (between 1 and 0)
:return: data x,y augmented
"""
sampleX = torch.tensor([])
sampleY = torch.tensor([])
for aumX, aumY in zip(x,y):
# Preparing to get image # transforms.ToPILImage()(pil_to_tensor.squeeze_(0))
#percent_noise = percent_noise
#noise = torch.randn(aumX.shape)
#aumX = noise * percent_noise + aumX * (1 - percent_noise)
#aumY = noise * percent_noise + aumY * (1 - percent_noise)
aumX = (aumX + 1) / 2
aumY = (aumY + 1) / 2
imgX = transforms.ToPILImage()(aumX)
imgY = transforms.ToPILImage()(aumY)
# Values for augmentation #
brighness = random.uniform(0.7, 1.2)* randomness + (1-randomness)
saturation = random.uniform(0, 2)* randomness + (1-randomness)
contrast = random.uniform(0.4, 2)* randomness + (1-randomness)
gamma = random.uniform(0.7, 1.3)* randomness + (1-randomness)
hue = random.uniform(-0.3, 0.3)* randomness #0.01
imgX = transforms.functional.adjust_gamma(imgX, gamma)
imgX = transforms.functional.adjust_brightness(imgX, brighness)
imgX = transforms.functional.adjust_contrast(imgX, contrast)
imgX = transforms.functional.adjust_saturation(imgX, saturation)
imgX = transforms.functional.adjust_hue(imgX, hue)
#imgX.show()
imgY = transforms.functional.adjust_gamma(imgY, gamma)
imgY = transforms.functional.adjust_brightness(imgY, brighness)
imgY = transforms.functional.adjust_contrast(imgY, contrast)
imgY = transforms.functional.adjust_saturation(imgY, saturation)
imgY = transforms.functional.adjust_hue(imgY, hue)
#imgY.show()
sx = transforms.ToTensor()(imgX)
sx = (sx * 2)-1
sy = transforms.ToTensor()(imgY)
sy = (sy * 2)-1
sampleX = torch.cat((sampleX, sx.unsqueeze_(0)), 0)
sampleY = torch.cat((sampleY, sy.unsqueeze_(0)), 0)
return sampleX,sampleY
def RGBtoL (x):
return x[:,0,:,:].unsqueeze(0).transpose(0,1)
def LtoRGB (x):
return x.repeat(1, 3, 1, 1)