Update dcgan_64.py
Browse files- dcgan_64.py +1 -10
dcgan_64.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import torch.nn as nn
|
|
|
2 |
|
3 |
|
4 |
class dcgan_conv(nn.Module):
|
@@ -95,10 +96,6 @@ class decoder_woSkip(nn.Module):
|
|
95 |
return output
|
96 |
|
97 |
|
98 |
-
"""
|
99 |
-
# Using Convolution and up_resize as the block to up-sample
|
100 |
-
"""
|
101 |
-
import torch.nn.functional as F
|
102 |
class upconv(nn.Module):
|
103 |
def __init__(self, nc_in, nc_out):
|
104 |
super().__init__()
|
@@ -119,15 +116,10 @@ class decoder_conv(nn.Module):
|
|
119 |
nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0),
|
120 |
nn.BatchNorm2d(nf * 8),
|
121 |
nn.ReLU(),
|
122 |
-
# state size. (nf*8) x 4 x 4
|
123 |
upconv(nf * 8, nf * 4),
|
124 |
-
# state size. (nf*4) x 8 x 8
|
125 |
upconv(nf * 4, nf * 2),
|
126 |
-
# state size. (nf*2) x 16 x 16
|
127 |
upconv(nf * 2, nf * 2),
|
128 |
-
# state size. (nf*2) x 32 x 32
|
129 |
upconv(nf * 2, nf),
|
130 |
-
# state size. (nf) x 64 x 64
|
131 |
nn.Conv2d(nf, nc, 1, 1, 0),
|
132 |
nn.Sigmoid()
|
133 |
)
|
@@ -136,5 +128,4 @@ class decoder_conv(nn.Module):
|
|
136 |
def forward(self, input):
|
137 |
output = self.main(input.view(-1, self.dim, 1, 1))
|
138 |
output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3])
|
139 |
-
|
140 |
return output
|
|
|
1 |
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
|
4 |
|
5 |
class dcgan_conv(nn.Module):
|
|
|
96 |
return output
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
99 |
class upconv(nn.Module):
|
100 |
def __init__(self, nc_in, nc_out):
|
101 |
super().__init__()
|
|
|
116 |
nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0),
|
117 |
nn.BatchNorm2d(nf * 8),
|
118 |
nn.ReLU(),
|
|
|
119 |
upconv(nf * 8, nf * 4),
|
|
|
120 |
upconv(nf * 4, nf * 2),
|
|
|
121 |
upconv(nf * 2, nf * 2),
|
|
|
122 |
upconv(nf * 2, nf),
|
|
|
123 |
nn.Conv2d(nf, nc, 1, 1, 0),
|
124 |
nn.Sigmoid()
|
125 |
)
|
|
|
128 |
def forward(self, input):
|
129 |
output = self.main(input.view(-1, self.dim, 1, 1))
|
130 |
output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3])
|
|
|
131 |
return output
|