ldkong commited on
Commit
3a316b0
β€’
1 Parent(s): a797152

Update dcgan_64.py

Browse files
Files changed (1) hide show
  1. dcgan_64.py +1 -10
dcgan_64.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch.nn as nn
 
2
 
3
 
4
  class dcgan_conv(nn.Module):
@@ -95,10 +96,6 @@ class decoder_woSkip(nn.Module):
95
  return output
96
 
97
 
98
- """
99
- # Using Convolution and up_resize as the block to up-sample
100
- """
101
- import torch.nn.functional as F
102
  class upconv(nn.Module):
103
  def __init__(self, nc_in, nc_out):
104
  super().__init__()
@@ -119,15 +116,10 @@ class decoder_conv(nn.Module):
119
  nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0),
120
  nn.BatchNorm2d(nf * 8),
121
  nn.ReLU(),
122
- # state size. (nf*8) x 4 x 4
123
  upconv(nf * 8, nf * 4),
124
- # state size. (nf*4) x 8 x 8
125
  upconv(nf * 4, nf * 2),
126
- # state size. (nf*2) x 16 x 16
127
  upconv(nf * 2, nf * 2),
128
- # state size. (nf*2) x 32 x 32
129
  upconv(nf * 2, nf),
130
- # state size. (nf) x 64 x 64
131
  nn.Conv2d(nf, nc, 1, 1, 0),
132
  nn.Sigmoid()
133
  )
@@ -136,5 +128,4 @@ class decoder_conv(nn.Module):
136
  def forward(self, input):
137
  output = self.main(input.view(-1, self.dim, 1, 1))
138
  output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3])
139
-
140
  return output
1
  import torch.nn as nn
2
+ import torch.nn.functional as F
3
 
4
 
5
  class dcgan_conv(nn.Module):
96
  return output
97
 
98
 
 
 
 
 
99
  class upconv(nn.Module):
100
  def __init__(self, nc_in, nc_out):
101
  super().__init__()
116
  nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0),
117
  nn.BatchNorm2d(nf * 8),
118
  nn.ReLU(),
 
119
  upconv(nf * 8, nf * 4),
 
120
  upconv(nf * 4, nf * 2),
 
121
  upconv(nf * 2, nf * 2),
 
122
  upconv(nf * 2, nf),
 
123
  nn.Conv2d(nf, nc, 1, 1, 0),
124
  nn.Sigmoid()
125
  )
128
  def forward(self, input):
129
  output = self.main(input.view(-1, self.dim, 1, 1))
130
  output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3])
 
131
  return output