meeww commited on
Commit
146cf41
1 Parent(s): 09fccfd

Upload generator.py

Browse files
Files changed (1) hide show
  1. generator.py +48 -0
generator.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+
7
+ class Generator(nn.Module):
8
+ def __init__(
9
+ self,
10
+ image_shape: (int, int, int),
11
+ latent_space_dimension: int,
12
+ use_cuda: bool = False,
13
+ saved_model: str or None = None
14
+ ):
15
+ super(Generator, self).__init__()
16
+
17
+ self.image_shape = image_shape
18
+
19
+ def block(in_feat, out_feat, normalize=True):
20
+ layers = [nn.Linear(in_feat, out_feat)]
21
+ if normalize:
22
+ layers.append(nn.BatchNorm1d(out_feat, 0.8))
23
+ layers.append(nn.LeakyReLU(0.2, inplace=True))
24
+ return layers
25
+
26
+ self.model = nn.Sequential(
27
+ *block(latent_space_dimension, 128, normalize=False),
28
+ *block(128, 256),
29
+ *block(256, 512),
30
+ *block(512, 1024),
31
+ nn.Linear(1024, int(np.prod(image_shape))),
32
+ nn.Tanh()
33
+ )
34
+ if saved_model is not None:
35
+ self.model.load_state_dict(
36
+ torch.load(
37
+ saved_model,
38
+ map_location=torch.device('cuda' if use_cuda else 'cpu')
39
+ )
40
+ )
41
+
42
+ def forward(self, z):
43
+ img = self.model(z)
44
+ img = img.view(img.shape[0], *self.image_shape)
45
+ return img
46
+
47
+ def save(self, to):
48
+ torch.save(self.model.state_dict(), to)