"""Discriminator architecture for ClimateGAN's GAN components (a and t) """ import functools import torch import torch.nn as nn from climategan.blocks import SpectralNorm from climategan.tutils import init_weights # from torch.optim import lr_scheduler # mainly from https://github.com/sangwoomo/instagan/blob/master/models/networks.py def create_discriminator(opts, device, no_init=False, verbose=0): disc = OmniDiscriminator(opts) if no_init: return disc for task, model in disc.items(): if isinstance(model, nn.ModuleDict): for domain, domain_model in model.items(): init_weights( domain_model, init_type=opts.dis[task].init_type, init_gain=opts.dis[task].init_gain, verbose=verbose, caller=f"create_discriminator {task} {domain}", ) else: init_weights( model, init_type=opts.dis[task].init_type, init_gain=opts.dis[task].init_gain, verbose=verbose, caller=f"create_discriminator {task}", ) return disc.to(device) def define_D( input_nc, ndf, n_layers=3, norm="batch", use_sigmoid=False, get_intermediate_features=False, num_D=1, ): norm_layer = get_norm_layer(norm_type=norm) net = MultiscaleDiscriminator( input_nc, ndf, n_layers=n_layers, norm_layer=norm_layer, use_sigmoid=use_sigmoid, get_intermediate_features=get_intermediate_features, num_D=num_D, ) return net def get_norm_layer(norm_type="instance"): if not norm_type: print("norm_type is {}, defaulting to instance") norm_type = "instance" if norm_type == "batch": norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == "instance": norm_layer = functools.partial( nn.InstanceNorm2d, affine=False, track_running_stats=False ) elif norm_type == "none": norm_layer = None else: raise NotImplementedError("normalization layer [%s] is not found" % norm_type) return norm_layer # Defines the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): def __init__( self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, get_intermediate_features=True, ): super(NLayerDiscriminator, self).__init__() if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d self.get_intermediate_features = get_intermediate_features kw = 4 padw = 1 sequence = [ [ # Use spectral normalization SpectralNorm( nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw) ), nn.LeakyReLU(0.2, True), ] ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ [ # Use spectral normalization SpectralNorm( # TODO replace with Conv2dBlock nn.Conv2d( ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias, ) ), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True), ] ] nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence += [ [ # Use spectral normalization SpectralNorm( nn.Conv2d( ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias, ) ), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True), ] ] # Use spectral normalization sequence += [ [ SpectralNorm( nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) ) ] ] if use_sigmoid: sequence += [[nn.Sigmoid()]] # We divide the layers into groups to extract intermediate layer outputs for n in range(len(sequence)): self.add_module("model" + str(n), nn.Sequential(*sequence[n])) # self.model = nn.Sequential(*sequence) def forward(self, input): results = [input] for submodel in self.children(): intermediate_output = submodel(results[-1]) results.append(intermediate_output) get_intermediate_features = self.get_intermediate_features if get_intermediate_features: return results[1:] else: return results[-1] # def forward(self, input): # return self.model(input) # Source: https://github.com/NVIDIA/pix2pixHD class MultiscaleDiscriminator(nn.Module): def __init__( self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, get_intermediate_features=True, num_D=3, ): super(MultiscaleDiscriminator, self).__init__() # self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, # use_sigmoid=False, num_D=3, getIntermFeat=False self.n_layers = n_layers self.ndf = ndf self.norm_layer = norm_layer self.use_sigmoid = use_sigmoid self.get_intermediate_features = get_intermediate_features self.num_D = num_D for i in range(self.num_D): netD = NLayerDiscriminator( input_nc=input_nc, ndf=self.ndf, n_layers=self.n_layers, norm_layer=self.norm_layer, use_sigmoid=self.use_sigmoid, get_intermediate_features=self.get_intermediate_features, ) self.add_module("discriminator_%d" % i, netD) self.downsample = nn.AvgPool2d( 3, stride=2, padding=[1, 1], count_include_pad=False ) def forward(self, input): result = [] get_intermediate_features = self.get_intermediate_features for name, D in self.named_children(): if "discriminator" not in name: continue out = D(input) if not get_intermediate_features: out = [out] result.append(out) input = self.downsample(input) return result class OmniDiscriminator(nn.ModuleDict): def __init__(self, opts): super().__init__() if "p" in opts.tasks: if opts.dis.p.use_local_discriminator: self["p"] = nn.ModuleDict( { "global": define_D( input_nc=3, ndf=opts.dis.p.ndf, n_layers=opts.dis.p.n_layers, norm=opts.dis.p.norm, use_sigmoid=opts.dis.p.use_sigmoid, get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501 num_D=opts.dis.p.num_D, ), "local": define_D( input_nc=3, ndf=opts.dis.p.ndf, n_layers=opts.dis.p.n_layers, norm=opts.dis.p.norm, use_sigmoid=opts.dis.p.use_sigmoid, get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501 num_D=opts.dis.p.num_D, ), } ) else: self["p"] = define_D( input_nc=4, # image + mask ndf=opts.dis.p.ndf, n_layers=opts.dis.p.n_layers, norm=opts.dis.p.norm, use_sigmoid=opts.dis.p.use_sigmoid, get_intermediate_features=opts.dis.p.get_intermediate_features, num_D=opts.dis.p.num_D, ) if "m" in opts.tasks: if opts.gen.m.use_advent: if opts.dis.m.architecture == "base": if opts.dis.m.gan_type == "WGAN_norm": self["m"] = nn.ModuleDict( { "Advent": get_fc_discriminator( num_classes=2, use_norm=True ) } ) else: self["m"] = nn.ModuleDict( { "Advent": get_fc_discriminator( num_classes=2, use_norm=False ) } ) elif opts.dis.m.architecture == "OmniDiscriminator": self["m"] = nn.ModuleDict( { "Advent": define_D( input_nc=2, ndf=opts.dis.m.ndf, n_layers=opts.dis.m.n_layers, norm=opts.dis.m.norm, use_sigmoid=opts.dis.m.use_sigmoid, get_intermediate_features=opts.dis.m.get_intermediate_features, # noqa: E501 num_D=opts.dis.m.num_D, ) } ) else: raise Exception("This Discriminator is currently not supported!") if "s" in opts.tasks: if opts.gen.s.use_advent: if opts.dis.s.gan_type == "WGAN_norm": self["s"] = nn.ModuleDict( {"Advent": get_fc_discriminator(num_classes=11, use_norm=True)} ) else: self["s"] = nn.ModuleDict( {"Advent": get_fc_discriminator(num_classes=11, use_norm=False)} ) def get_fc_discriminator(num_classes=2, ndf=64, use_norm=False): if use_norm: return torch.nn.Sequential( SpectralNorm( torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1) ), torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), SpectralNorm( torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1) ), torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), SpectralNorm( torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1) ), torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), SpectralNorm( torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1) ), torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), SpectralNorm( torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1) ), ) else: return torch.nn.Sequential( torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1), torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1), torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1), torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1), torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1), )