Spaces:
Running
Running
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torchvision | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
m.weight.data.normal_(0.0, 0.02) | |
elif classname.find("BatchNorm2d") != -1: | |
m.weight.data.normal_(1.0, 0.02) | |
m.bias.data.fill_(0) | |
class MultiscaleDiscriminator(nn.Module): | |
def __init__( | |
self, | |
input_nc, | |
ndf=64, | |
n_layers=3, | |
norm_layer=nn.BatchNorm2d, | |
use_sigmoid=False, | |
num_D=3, | |
getIntermFeat=False, | |
finetune=False, | |
): | |
super(MultiscaleDiscriminator, self).__init__() | |
self.num_D = num_D | |
self.n_layers = n_layers | |
self.getIntermFeat = getIntermFeat | |
for i in range(num_D): | |
netD = NLayerDiscriminator( | |
input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat | |
) | |
if getIntermFeat: | |
for j in range(n_layers + 2): | |
setattr( | |
self, | |
"scale" + str(i) + "_layer" + str(j), | |
getattr(netD, "model" + str(j)), | |
) | |
else: | |
setattr(self, "layer" + str(i), netD.model) | |
self.downsample = nn.AvgPool2d( | |
3, stride=2, padding=[1, 1], count_include_pad=False | |
) | |
weights_init(self) | |
if finetune: | |
self.requires_grad_(False) | |
for name, param in self.named_parameters(): | |
if 'layer0' in name: | |
param.requires_grad = True | |
def singleD_forward(self, model, input): | |
if self.getIntermFeat: | |
result = [input] | |
for i in range(len(model)): | |
result.append(model[i](result[-1])) | |
return result[1:] | |
else: | |
return [model(input)] | |
def forward(self, input): | |
num_D = self.num_D | |
result = [] | |
input_downsampled = input | |
for i in range(num_D): | |
if self.getIntermFeat: | |
model = [ | |
getattr(self, "scale" + str(num_D - 1 - i) + "_layer" + str(j)) | |
for j in range(self.n_layers + 2) | |
] | |
else: | |
model = getattr(self, "layer" + str(num_D - 1 - i)) | |
result.append(self.singleD_forward(model, input_downsampled)) | |
if i != (num_D - 1): | |
input_downsampled = self.downsample(input_downsampled) | |
return result | |
# Defines the PatchGAN discriminator with the specified arguments. | |
class NLayerDiscriminator(nn.Module): | |
def __init__( | |
self, | |
input_nc, | |
ndf=64, | |
n_layers=3, | |
norm_layer=nn.BatchNorm2d, | |
use_sigmoid=False, | |
getIntermFeat=False, | |
): | |
super(NLayerDiscriminator, self).__init__() | |
self.getIntermFeat = getIntermFeat | |
self.n_layers = n_layers | |
kw = 4 | |
padw = int(np.ceil((kw - 1.0) / 2)) | |
sequence = [ | |
[ | |
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), | |
nn.LeakyReLU(0.2, True), | |
] | |
] | |
nf = ndf | |
for n in range(1, n_layers): | |
nf_prev = nf | |
nf = min(nf * 2, 512) | |
sequence += [ | |
[ | |
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), | |
norm_layer(nf), | |
nn.LeakyReLU(0.2, True), | |
] | |
] | |
nf_prev = nf | |
nf = min(nf * 2, 512) | |
sequence += [ | |
[ | |
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), | |
norm_layer(nf), | |
nn.LeakyReLU(0.2, True), | |
] | |
] | |
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] | |
if use_sigmoid: | |
sequence += [[nn.Sigmoid()]] | |
if getIntermFeat: | |
for n in range(len(sequence)): | |
setattr(self, "model" + str(n), nn.Sequential(*sequence[n])) | |
else: | |
sequence_stream = [] | |
for n in range(len(sequence)): | |
sequence_stream += sequence[n] | |
self.model = nn.Sequential(*sequence_stream) | |
def forward(self, input): | |
if self.getIntermFeat: | |
res = [input] | |
for n in range(self.n_layers + 2): | |
model = getattr(self, "model" + str(n)) | |
res.append(model(res[-1])) | |
return res[1:] | |
else: | |
return self.model(input) | |