Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch.nn as nn | |
import torch | |
class Generator(nn.Module): | |
def __init__( | |
self, | |
image_shape: (int, int, int), | |
latent_space_dimension: int, | |
use_cuda: bool = False, | |
saved_model: str or None = None | |
): | |
super(Generator, self).__init__() | |
self.image_shape = image_shape | |
def block(in_feat, out_feat, normalize=True): | |
layers = [nn.Linear(in_feat, out_feat)] | |
if normalize: | |
layers.append(nn.BatchNorm1d(out_feat, 0.8)) | |
layers.append(nn.LeakyReLU(0.2, inplace=True)) | |
return layers | |
self.model = nn.Sequential( | |
*block(latent_space_dimension, 128, normalize=False), | |
*block(128, 256), | |
*block(256, 512), | |
*block(512, 1024), | |
nn.Linear(1024, int(np.prod(image_shape))), | |
nn.Tanh() | |
) | |
if saved_model is not None: | |
self.model.load_state_dict( | |
torch.load( | |
saved_model, | |
map_location=torch.device('cuda' if use_cuda else 'cpu') | |
) | |
) | |
def forward(self, z): | |
img = self.model(z) | |
img = img.view(img.shape[0], *self.image_shape) | |
return img | |
def save(self, to): | |
torch.save(self.model.state_dict(), to) | |