import numpy as np import torch import torch.nn as nn import torchvision def weights_init(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find("BatchNorm2d") != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) class MultiscaleDiscriminator(nn.Module): def __init__( self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, num_D=3, getIntermFeat=False, finetune=False, ): super(MultiscaleDiscriminator, self).__init__() self.num_D = num_D self.n_layers = n_layers self.getIntermFeat = getIntermFeat for i in range(num_D): netD = NLayerDiscriminator( input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat ) if getIntermFeat: for j in range(n_layers + 2): setattr( self, "scale" + str(i) + "_layer" + str(j), getattr(netD, "model" + str(j)), ) else: setattr(self, "layer" + str(i), netD.model) self.downsample = nn.AvgPool2d( 3, stride=2, padding=[1, 1], count_include_pad=False ) weights_init(self) if finetune: self.requires_grad_(False) for name, param in self.named_parameters(): if 'layer0' in name: param.requires_grad = True def singleD_forward(self, model, input): if self.getIntermFeat: result = [input] for i in range(len(model)): result.append(model[i](result[-1])) return result[1:] else: return [model(input)] def forward(self, input): num_D = self.num_D result = [] input_downsampled = input for i in range(num_D): if self.getIntermFeat: model = [ getattr(self, "scale" + str(num_D - 1 - i) + "_layer" + str(j)) for j in range(self.n_layers + 2) ] else: model = getattr(self, "layer" + str(num_D - 1 - i)) result.append(self.singleD_forward(model, input_downsampled)) if i != (num_D - 1): input_downsampled = self.downsample(input_downsampled) return result # Defines the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): def __init__( self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False, ): super(NLayerDiscriminator, self).__init__() self.getIntermFeat = getIntermFeat self.n_layers = n_layers kw = 4 padw = int(np.ceil((kw - 1.0) / 2)) sequence = [ [ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True), ] ] nf = ndf for n in range(1, n_layers): nf_prev = nf nf = min(nf * 2, 512) sequence += [ [ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), norm_layer(nf), nn.LeakyReLU(0.2, True), ] ] nf_prev = nf nf = min(nf * 2, 512) sequence += [ [ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), norm_layer(nf), nn.LeakyReLU(0.2, True), ] ] sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] if use_sigmoid: sequence += [[nn.Sigmoid()]] if getIntermFeat: for n in range(len(sequence)): setattr(self, "model" + str(n), nn.Sequential(*sequence[n])) else: sequence_stream = [] for n in range(len(sequence)): sequence_stream += sequence[n] self.model = nn.Sequential(*sequence_stream) def forward(self, input): if self.getIntermFeat: res = [input] for n in range(self.n_layers + 2): model = getattr(self, "model" + str(n)) res.append(model(res[-1])) return res[1:] else: return self.model(input)