File size: 3,137 Bytes
8895b4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
def get_noise(n_samples, z_dim, device='cpu'):
return torch.randn((n_samples, z_dim), device=device)
def get_random_labels(n_samples, device='cpu'):
return torch.randint(0, 10, (n_samples,), device=device).type(torch.long)
def get_generator_block(input_dim, output_dim):
return nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.BatchNorm1d(output_dim),
nn.ReLU(inplace=True)
)
class Generator(nn.Module):
def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
super(Generator, self).__init__()
# input is of shape (batch_size, z_dim + 10)
self.gen = nn.Sequential(
get_generator_block(z_dim + 10, hidden_dim), # 128
get_generator_block(hidden_dim, hidden_dim*2), # 256
get_generator_block(hidden_dim*2, hidden_dim*4), # 512
get_generator_block(hidden_dim*4, hidden_dim*8), # 1024
nn.Linear(hidden_dim*8, im_dim), # 784
nn.Sigmoid(), # output between 0 and 1
)
def forward(self, noise, classes):
'''
noise (batch_size, z_dim) noise vector for each image in a batch
classes:long (batch_size) condition class for each image in a batch
'''
# classes = classes.type(torch.long)
# one-hot encode condition_class e.g. 3 -> [0,0,0,1,0,0,0,0,0,0]
one_hot_vec = F.one_hot(classes, num_classes=10).type(torch.float32) # (batch_size, 10)
conditioned_noise = torch.concat((noise, one_hot_vec), dim=1) # (batch_size, z_dim + 10)
return self.gen(conditioned_noise)
def get_discriminator_block(input_dim, output_dim):
return nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.LeakyReLU(0.2, inplace=True)
)
class Discriminator(nn.Module):
def __init__(self, im_dim=784, hidden_dim=128):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
get_discriminator_block(im_dim + 10, hidden_dim*4), # 512
get_discriminator_block(hidden_dim * 4, hidden_dim * 2), # 256
get_discriminator_block(hidden_dim * 2, hidden_dim), # 128
nn.Linear(hidden_dim, 1),
# nn.Sigmoid(),
# using a sigmoid followed by BCE is less numerically stable than BCEWithLogitsLoss alone
# https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss:~:text=This%20loss%20combines%20a%20Sigmoid%20layer%20and%20the%20BCELoss%20in%20one%20single%20class.%20This%20version%20is%20more%20numerically%20stable%20than%20using%20a%20plain%20Sigmoid%20followed%20by%20a%20BCELoss%20as%2C%20by%20combining%20the%20operations%20into%20one%20layer%2C%20we%20take%20advantage%20of%20the%20log%2Dsum%2Dexp%20trick%20for%20numerical%20stability.
)
def forward(self, image_batch):
'''image_batch (batch_size, 784+10)'''
return self.disc(image_batch) |