|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class dcgan_conv(nn.Module): |
|
def __init__(self, nin, nout): |
|
super(dcgan_conv, self).__init__() |
|
self.main = nn.Sequential(nn.Conv2d(nin, nout, 4, 2, 1), nn.BatchNorm2d(nout), nn.LeakyReLU(0.2, inplace=True)) |
|
|
|
def forward(self, input): |
|
return self.main(input) |
|
|
|
|
|
class dcgan_upconv(nn.Module): |
|
def __init__(self, nin, nout): |
|
super(dcgan_upconv, self).__init__() |
|
self.main = nn.Sequential(nn.ConvTranspose2d(nin, nout, 4, 2, 1), nn.BatchNorm2d(nout), nn.LeakyReLU(0.2, inplace=True)) |
|
|
|
def forward(self, input): |
|
return self.main(input) |
|
|
|
|
|
class encoder(nn.Module): |
|
def __init__(self, dim, nc=1): |
|
super(encoder, self).__init__() |
|
self.dim = dim |
|
nf = 64 |
|
self.c1 = dcgan_conv(nc, nf) |
|
self.c2 = dcgan_conv(nf, nf * 2) |
|
self.c3 = dcgan_conv(nf * 2, nf * 4) |
|
self.c4 = dcgan_conv(nf * 4, nf * 8) |
|
self.c5 = nn.Sequential(nn.Conv2d(nf * 8, dim, 4, 1, 0), nn.BatchNorm2d(dim), nn.Tanh()) |
|
|
|
def forward(self, input): |
|
h1 = self.c1(input) |
|
h2 = self.c2(h1) |
|
h3 = self.c3(h2) |
|
h4 = self.c4(h3) |
|
h5 = self.c5(h4) |
|
return h5.view(-1, self.dim), [h1, h2, h3, h4] |
|
|
|
|
|
class decoder_convT(nn.Module): |
|
def __init__(self, dim, nc=1): |
|
super(decoder_convT, self).__init__() |
|
self.dim = dim |
|
nf = 64 |
|
self.upc1 = nn.Sequential( |
|
nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0), |
|
nn.BatchNorm2d(nf * 8), |
|
nn.LeakyReLU(0.2, inplace=True) |
|
) |
|
self.upc2 = dcgan_upconv(nf * 8, nf * 4) |
|
self.upc3 = dcgan_upconv(nf * 4, nf * 2) |
|
self.upc4 = dcgan_upconv(nf * 2, nf) |
|
self.upc5 = nn.Sequential( |
|
nn.ConvTranspose2d(nf, nc, 4, 2, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, input): |
|
d1 = self.upc1(input.view(-1, self.dim, 1, 1)) |
|
d2 = self.upc2(d1) |
|
d3 = self.upc3(d2) |
|
d4 = self.upc4(d3) |
|
output = self.upc5(d4) |
|
output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3]) |
|
return output |
|
|
|
|
|
class decoder_woSkip(nn.Module): |
|
def __init__(self, dim, nc=1): |
|
super(decoder_woSkip, self).__init__() |
|
self.dim = dim |
|
nf = 64 |
|
self.upc1 = nn.Sequential( |
|
nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0), |
|
nn.BatchNorm2d(nf * 8), |
|
nn.LeakyReLU(0.2, inplace=True) |
|
) |
|
self.upc2 = dcgan_upconv(nf * 8, nf * 4) |
|
self.upc3 = dcgan_upconv(nf * 4, nf * 2) |
|
self.upc4 = dcgan_upconv(nf * 2, nf) |
|
self.upc5 = nn.Sequential( |
|
nn.ConvTranspose2d(nf, nc, 4, 2, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, input): |
|
d1 = self.upc1(input.view(-1, self.dim, 1, 1)) |
|
d2 = self.upc2(d1) |
|
d3 = self.upc3(d2) |
|
d4 = self.upc4(d3) |
|
output = self.upc5(d4) |
|
output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3]) |
|
return output |
|
|
|
|
|
class upconv(nn.Module): |
|
def __init__(self, nc_in, nc_out): |
|
super().__init__() |
|
self.conv = nn.Conv2d(nc_in, nc_out, 3, 1, 1) |
|
self.norm = nn.BatchNorm2d(nc_out) |
|
|
|
def forward(self, input): |
|
out = F.interpolate(input, scale_factor=2, mode='bilinear', align_corners=False) |
|
return F.relu(self.norm(self.conv(out))) |
|
|
|
class decoder_conv(nn.Module): |
|
def __init__(self, dim, nc=1): |
|
super(decoder_conv, self).__init__() |
|
self.dim = dim |
|
nf = 64 |
|
|
|
self.main = nn.Sequential( |
|
nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0), |
|
nn.BatchNorm2d(nf * 8), |
|
nn.ReLU(), |
|
upconv(nf * 8, nf * 4), |
|
upconv(nf * 4, nf * 2), |
|
upconv(nf * 2, nf * 2), |
|
upconv(nf * 2, nf), |
|
nn.Conv2d(nf, nc, 1, 1, 0), |
|
nn.Sigmoid() |
|
) |
|
|
|
|
|
def forward(self, input): |
|
output = self.main(input.view(-1, self.dim, 1, 1)) |
|
output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3]) |
|
return output |