ldkong commited on
Commit
a797152
β€’
1 Parent(s): 52b383d

Upload dcgan_64.py

Browse files
Files changed (1) hide show
  1. dcgan_64.py +140 -0
dcgan_64.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class dcgan_conv(nn.Module):
5
+ def __init__(self, nin, nout):
6
+ super(dcgan_conv, self).__init__()
7
+ self.main = nn.Sequential(nn.Conv2d(nin, nout, 4, 2, 1), nn.BatchNorm2d(nout), nn.LeakyReLU(0.2, inplace=True))
8
+
9
+ def forward(self, input):
10
+ return self.main(input)
11
+
12
+
13
+ class dcgan_upconv(nn.Module):
14
+ def __init__(self, nin, nout):
15
+ super(dcgan_upconv, self).__init__()
16
+ self.main = nn.Sequential(nn.ConvTranspose2d(nin, nout, 4, 2, 1), nn.BatchNorm2d(nout), nn.LeakyReLU(0.2, inplace=True))
17
+
18
+ def forward(self, input):
19
+ return self.main(input)
20
+
21
+
22
+ class encoder(nn.Module):
23
+ def __init__(self, dim, nc=1):
24
+ super(encoder, self).__init__()
25
+ self.dim = dim
26
+ nf = 64
27
+ self.c1 = dcgan_conv(nc, nf)
28
+ self.c2 = dcgan_conv(nf, nf * 2)
29
+ self.c3 = dcgan_conv(nf * 2, nf * 4)
30
+ self.c4 = dcgan_conv(nf * 4, nf * 8)
31
+ self.c5 = nn.Sequential(nn.Conv2d(nf * 8, dim, 4, 1, 0), nn.BatchNorm2d(dim), nn.Tanh())
32
+
33
+ def forward(self, input):
34
+ h1 = self.c1(input)
35
+ h2 = self.c2(h1)
36
+ h3 = self.c3(h2)
37
+ h4 = self.c4(h3)
38
+ h5 = self.c5(h4)
39
+ return h5.view(-1, self.dim), [h1, h2, h3, h4]
40
+
41
+
42
+ class decoder_convT(nn.Module):
43
+ def __init__(self, dim, nc=1):
44
+ super(decoder_convT, self).__init__()
45
+ self.dim = dim
46
+ nf = 64
47
+ self.upc1 = nn.Sequential(
48
+ nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0),
49
+ nn.BatchNorm2d(nf * 8),
50
+ nn.LeakyReLU(0.2, inplace=True)
51
+ )
52
+ self.upc2 = dcgan_upconv(nf * 8, nf * 4)
53
+ self.upc3 = dcgan_upconv(nf * 4, nf * 2)
54
+ self.upc4 = dcgan_upconv(nf * 2, nf)
55
+ self.upc5 = nn.Sequential(
56
+ nn.ConvTranspose2d(nf, nc, 4, 2, 1),
57
+ nn.Sigmoid()
58
+ )
59
+
60
+ def forward(self, input):
61
+ d1 = self.upc1(input.view(-1, self.dim, 1, 1))
62
+ d2 = self.upc2(d1)
63
+ d3 = self.upc3(d2)
64
+ d4 = self.upc4(d3)
65
+ output = self.upc5(d4)
66
+ output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3])
67
+ return output
68
+
69
+
70
+ class decoder_woSkip(nn.Module):
71
+ def __init__(self, dim, nc=1):
72
+ super(decoder_woSkip, self).__init__()
73
+ self.dim = dim
74
+ nf = 64
75
+ self.upc1 = nn.Sequential(
76
+ nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0),
77
+ nn.BatchNorm2d(nf * 8),
78
+ nn.LeakyReLU(0.2, inplace=True)
79
+ )
80
+ self.upc2 = dcgan_upconv(nf * 8, nf * 4)
81
+ self.upc3 = dcgan_upconv(nf * 4, nf * 2)
82
+ self.upc4 = dcgan_upconv(nf * 2, nf)
83
+ self.upc5 = nn.Sequential(
84
+ nn.ConvTranspose2d(nf, nc, 4, 2, 1),
85
+ nn.Sigmoid()
86
+ )
87
+
88
+ def forward(self, input):
89
+ d1 = self.upc1(input.view(-1, self.dim, 1, 1))
90
+ d2 = self.upc2(d1)
91
+ d3 = self.upc3(d2)
92
+ d4 = self.upc4(d3)
93
+ output = self.upc5(d4)
94
+ output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3])
95
+ return output
96
+
97
+
98
+ """
99
+ # Using Convolution and up_resize as the block to up-sample
100
+ """
101
+ import torch.nn.functional as F
102
+ class upconv(nn.Module):
103
+ def __init__(self, nc_in, nc_out):
104
+ super().__init__()
105
+ self.conv = nn.Conv2d(nc_in, nc_out, 3, 1, 1)
106
+ self.norm = nn.BatchNorm2d(nc_out)
107
+
108
+ def forward(self, input):
109
+ out = F.interpolate(input, scale_factor=2, mode='bilinear', align_corners=False)
110
+ return F.relu(self.norm(self.conv(out)))
111
+
112
+ class decoder_conv(nn.Module):
113
+ def __init__(self, dim, nc=1):
114
+ super(decoder_conv, self).__init__()
115
+ self.dim = dim
116
+ nf = 64
117
+
118
+ self.main = nn.Sequential(
119
+ nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0),
120
+ nn.BatchNorm2d(nf * 8),
121
+ nn.ReLU(),
122
+ # state size. (nf*8) x 4 x 4
123
+ upconv(nf * 8, nf * 4),
124
+ # state size. (nf*4) x 8 x 8
125
+ upconv(nf * 4, nf * 2),
126
+ # state size. (nf*2) x 16 x 16
127
+ upconv(nf * 2, nf * 2),
128
+ # state size. (nf*2) x 32 x 32
129
+ upconv(nf * 2, nf),
130
+ # state size. (nf) x 64 x 64
131
+ nn.Conv2d(nf, nc, 1, 1, 0),
132
+ nn.Sigmoid()
133
+ )
134
+
135
+
136
+ def forward(self, input):
137
+ output = self.main(input.view(-1, self.dim, 1, 1))
138
+ output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3])
139
+
140
+ return output