ReliableSwap_Demo / modules /layers /discriminator.py
gavinyuan
udpate: app.py import FSGenerator
a104d3f
raw
history blame
4.62 kB
import numpy as np
import torch
import torch.nn as nn
import torchvision
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class MultiscaleDiscriminator(nn.Module):
def __init__(
self,
input_nc,
ndf=64,
n_layers=3,
norm_layer=nn.BatchNorm2d,
use_sigmoid=False,
num_D=3,
getIntermFeat=False,
finetune=False,
):
super(MultiscaleDiscriminator, self).__init__()
self.num_D = num_D
self.n_layers = n_layers
self.getIntermFeat = getIntermFeat
for i in range(num_D):
netD = NLayerDiscriminator(
input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat
)
if getIntermFeat:
for j in range(n_layers + 2):
setattr(
self,
"scale" + str(i) + "_layer" + str(j),
getattr(netD, "model" + str(j)),
)
else:
setattr(self, "layer" + str(i), netD.model)
self.downsample = nn.AvgPool2d(
3, stride=2, padding=[1, 1], count_include_pad=False
)
weights_init(self)
if finetune:
self.requires_grad_(False)
for name, param in self.named_parameters():
if 'layer0' in name:
param.requires_grad = True
def singleD_forward(self, model, input):
if self.getIntermFeat:
result = [input]
for i in range(len(model)):
result.append(model[i](result[-1]))
return result[1:]
else:
return [model(input)]
def forward(self, input):
num_D = self.num_D
result = []
input_downsampled = input
for i in range(num_D):
if self.getIntermFeat:
model = [
getattr(self, "scale" + str(num_D - 1 - i) + "_layer" + str(j))
for j in range(self.n_layers + 2)
]
else:
model = getattr(self, "layer" + str(num_D - 1 - i))
result.append(self.singleD_forward(model, input_downsampled))
if i != (num_D - 1):
input_downsampled = self.downsample(input_downsampled)
return result
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(
self,
input_nc,
ndf=64,
n_layers=3,
norm_layer=nn.BatchNorm2d,
use_sigmoid=False,
getIntermFeat=False,
):
super(NLayerDiscriminator, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw - 1.0) / 2))
sequence = [
[
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True),
]
]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [
[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True),
]
]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [
[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True),
]
]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, "model" + str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers + 2):
model = getattr(self, "model" + str(n))
res.append(model(res[-1]))
return res[1:]
else:
return self.model(input)