import torch.nn as nn import torch.nn.functional as F class dcgan_conv(nn.Module): def __init__(self, nin, nout): super(dcgan_conv, self).__init__() self.main = nn.Sequential(nn.Conv2d(nin, nout, 4, 2, 1), nn.BatchNorm2d(nout), nn.LeakyReLU(0.2, inplace=True)) def forward(self, input): return self.main(input) class dcgan_upconv(nn.Module): def __init__(self, nin, nout): super(dcgan_upconv, self).__init__() self.main = nn.Sequential(nn.ConvTranspose2d(nin, nout, 4, 2, 1), nn.BatchNorm2d(nout), nn.LeakyReLU(0.2, inplace=True)) def forward(self, input): return self.main(input) class encoder(nn.Module): def __init__(self, dim, nc=1): super(encoder, self).__init__() self.dim = dim nf = 64 self.c1 = dcgan_conv(nc, nf) self.c2 = dcgan_conv(nf, nf * 2) self.c3 = dcgan_conv(nf * 2, nf * 4) self.c4 = dcgan_conv(nf * 4, nf * 8) self.c5 = nn.Sequential(nn.Conv2d(nf * 8, dim, 4, 1, 0), nn.BatchNorm2d(dim), nn.Tanh()) def forward(self, input): h1 = self.c1(input) h2 = self.c2(h1) h3 = self.c3(h2) h4 = self.c4(h3) h5 = self.c5(h4) return h5.view(-1, self.dim), [h1, h2, h3, h4] class decoder_convT(nn.Module): def __init__(self, dim, nc=1): super(decoder_convT, self).__init__() self.dim = dim nf = 64 self.upc1 = nn.Sequential( nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0), nn.BatchNorm2d(nf * 8), nn.LeakyReLU(0.2, inplace=True) ) self.upc2 = dcgan_upconv(nf * 8, nf * 4) self.upc3 = dcgan_upconv(nf * 4, nf * 2) self.upc4 = dcgan_upconv(nf * 2, nf) self.upc5 = nn.Sequential( nn.ConvTranspose2d(nf, nc, 4, 2, 1), nn.Sigmoid() ) def forward(self, input): d1 = self.upc1(input.view(-1, self.dim, 1, 1)) d2 = self.upc2(d1) d3 = self.upc3(d2) d4 = self.upc4(d3) output = self.upc5(d4) output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3]) return output class decoder_woSkip(nn.Module): def __init__(self, dim, nc=1): super(decoder_woSkip, self).__init__() self.dim = dim nf = 64 self.upc1 = nn.Sequential( nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0), nn.BatchNorm2d(nf * 8), nn.LeakyReLU(0.2, inplace=True) ) self.upc2 = dcgan_upconv(nf * 8, nf * 4) self.upc3 = dcgan_upconv(nf * 4, nf * 2) self.upc4 = dcgan_upconv(nf * 2, nf) self.upc5 = nn.Sequential( nn.ConvTranspose2d(nf, nc, 4, 2, 1), nn.Sigmoid() ) def forward(self, input): d1 = self.upc1(input.view(-1, self.dim, 1, 1)) d2 = self.upc2(d1) d3 = self.upc3(d2) d4 = self.upc4(d3) output = self.upc5(d4) output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3]) return output class upconv(nn.Module): def __init__(self, nc_in, nc_out): super().__init__() self.conv = nn.Conv2d(nc_in, nc_out, 3, 1, 1) self.norm = nn.BatchNorm2d(nc_out) def forward(self, input): out = F.interpolate(input, scale_factor=2, mode='bilinear', align_corners=False) return F.relu(self.norm(self.conv(out))) class decoder_conv(nn.Module): def __init__(self, dim, nc=1): super(decoder_conv, self).__init__() self.dim = dim nf = 64 self.main = nn.Sequential( nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0), nn.BatchNorm2d(nf * 8), nn.ReLU(), upconv(nf * 8, nf * 4), upconv(nf * 4, nf * 2), upconv(nf * 2, nf * 2), upconv(nf * 2, nf), nn.Conv2d(nf, nc, 1, 1, 0), nn.Sigmoid() ) def forward(self, input): output = self.main(input.view(-1, self.dim, 1, 1)) output = output.view(input.shape[0], input.shape[1], output.shape[1], output.shape[2], output.shape[3]) return output