opetrova commited on
Commit
ebcdb64
1 Parent(s): 6d89473

Upload network.py

Browse files
Files changed (1) hide show
  1. network.py +101 -0
network.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.parallel
4
+ import torch.optim as optim
5
+ from torch.autograd import Variable
6
+
7
+
8
+ def weights_init(m):
9
+ classname = m.__class__.__name__
10
+
11
+ if classname.find('Conv') != -1:
12
+ m.weight.data.normal_(0.0, 0.02)
13
+
14
+ elif classname.find('BatchNorm') != -1:
15
+ m.weight.data.normal_(1.0, 0.02)
16
+ m.bias.data.fill_(0)
17
+
18
+
19
+ ''' Generator network for 128x128 RGB images '''
20
+ class G(nn.Module):
21
+
22
+ def __init__(self):
23
+ super(G, self).__init__()
24
+
25
+ self.main = nn.Sequential(
26
+ # Input HxW = 128x128
27
+ nn.Conv2d(3, 16, 4, 2, 1), # Output HxW = 64x64
28
+ nn.BatchNorm2d(16),
29
+ nn.ReLU(True),
30
+ nn.Conv2d(16, 32, 4, 2, 1), # Output HxW = 32x32
31
+ nn.BatchNorm2d(32),
32
+ nn.ReLU(True),
33
+ nn.Conv2d(32, 64, 4, 2, 1), # Output HxW = 16x16
34
+ nn.BatchNorm2d(64),
35
+ nn.ReLU(True),
36
+ nn.Conv2d(64, 128, 4, 2, 1), # Output HxW = 8x8
37
+ nn.BatchNorm2d(128),
38
+ nn.ReLU(True),
39
+ nn.Conv2d(128, 256, 4, 2, 1), # Output HxW = 4x4
40
+ nn.BatchNorm2d(256),
41
+ nn.ReLU(True),
42
+ nn.Conv2d(256, 512, 4, 2, 1), # Output HxW = 2x2
43
+ nn.MaxPool2d((2,2)),
44
+ # At this point, we arrive at our low D representation vector, which is 512 dimensional.
45
+
46
+ nn.ConvTranspose2d(512, 256, 4, 1, 0, bias = False), # Output HxW = 4x4
47
+ nn.BatchNorm2d(256),
48
+ nn.ReLU(True),
49
+ nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False), # Output HxW = 8x8
50
+ nn.BatchNorm2d(128),
51
+ nn.ReLU(True),
52
+ nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False), # Output HxW = 16x16
53
+ nn.BatchNorm2d(64),
54
+ nn.ReLU(True),
55
+ nn.ConvTranspose2d(64, 32, 4, 2, 1, bias = False), # Output HxW = 32x32
56
+ nn.BatchNorm2d(32),
57
+ nn.ReLU(True),
58
+ nn.ConvTranspose2d(32, 16, 4, 2, 1, bias = False), # Output HxW = 64x64
59
+ nn.BatchNorm2d(16),
60
+ nn.ReLU(True),
61
+ nn.ConvTranspose2d(16, 3, 4, 2, 1, bias = False), # Output HxW = 128x128
62
+ nn.Tanh()
63
+ )
64
+
65
+
66
+ def forward(self, input):
67
+ output = self.main(input)
68
+ return output
69
+
70
+
71
+ ''' Discriminator network for 128x128 RGB images '''
72
+ class D(nn.Module):
73
+
74
+ def __init__(self):
75
+ super(D, self).__init__()
76
+ self.main = nn.Sequential(
77
+ nn.Conv2d(3, 16, 4, 2, 1),
78
+ nn.LeakyReLU(0.2, inplace = True),
79
+ nn.Conv2d(16, 32, 4, 2, 1),
80
+ nn.BatchNorm2d(32),
81
+ nn.LeakyReLU(0.2, inplace = True),
82
+ nn.Conv2d(32, 64, 4, 2, 1),
83
+ nn.BatchNorm2d(64),
84
+ nn.LeakyReLU(0.2, inplace = True),
85
+ nn.Conv2d(64, 128, 4, 2, 1),
86
+ nn.BatchNorm2d(128),
87
+ nn.LeakyReLU(0.2, inplace = True),
88
+ nn.Conv2d(128, 256, 4, 2, 1),
89
+ nn.BatchNorm2d(256),
90
+ nn.LeakyReLU(0.2, inplace = True),
91
+ nn.Conv2d(256, 512, 4, 2, 1),
92
+ nn.BatchNorm2d(512),
93
+ nn.LeakyReLU(0.2, inplace = True),
94
+ nn.Conv2d(512, 1, 4, 2, 1, bias = False),
95
+ nn.Sigmoid()
96
+ )
97
+
98
+
99
+ def forward(self, input):
100
+ output = self.main(input)
101
+ return output.view(-1)