|
import numpy as np |
|
|
|
import torch.nn as nn |
|
import torch |
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__( |
|
self, |
|
image_shape: (int, int, int), |
|
latent_space_dimension: int, |
|
use_cuda: bool = False, |
|
saved_model: str or None = None |
|
): |
|
super(Generator, self).__init__() |
|
|
|
self.image_shape = image_shape |
|
|
|
def block(in_feat, out_feat, normalize=True): |
|
layers = [nn.Linear(in_feat, out_feat)] |
|
if normalize: |
|
layers.append(nn.BatchNorm1d(out_feat, 0.8)) |
|
layers.append(nn.LeakyReLU(0.2, inplace=True)) |
|
return layers |
|
|
|
self.model = nn.Sequential( |
|
*block(latent_space_dimension, 128, normalize=False), |
|
*block(128, 256), |
|
*block(256, 512), |
|
*block(512, 1024), |
|
nn.Linear(1024, int(np.prod(image_shape))), |
|
nn.Tanh() |
|
) |
|
if saved_model is not None: |
|
self.model.load_state_dict( |
|
torch.load( |
|
saved_model, |
|
map_location=torch.device('cuda' if use_cuda else 'cpu') |
|
) |
|
) |
|
|
|
def forward(self, z): |
|
img = self.model(z) |
|
img = img.view(img.shape[0], *self.image_shape) |
|
return img |
|
|
|
def save(self, to): |
|
torch.save(self.model.state_dict(), to) |
|
|