"""File for all blocks which are parts of decoders """ import torch import torch.nn as nn import torch.nn.functional as F import climategan.strings as strings from climategan.norms import SPADE, AdaptiveInstanceNorm2d, LayerNorm, SpectralNorm class InterpolateNearest2d(nn.Module): """ Custom implementation of nn.Upsample because pytorch/xla does not yet support scale_factor and needs to be provided with the output_size """ def __init__(self, scale_factor=2): """ Create an InterpolateNearest2d module Args: scale_factor (int, optional): Output size multiplier. Defaults to 2. """ super().__init__() self.scale_factor = scale_factor def forward(self, x): """ Interpolate x in "nearest" mode on its last 2 dimensions Args: x (torch.Tensor): input to interpolate Returns: torch.Tensor: upsampled tensor with shape (...x.shape, x.shape[-2] * scale_factor, x.shape[-1] * scale_factor) """ return F.interpolate( x, size=(x.shape[-2] * self.scale_factor, x.shape[-1] * self.scale_factor), mode="nearest", ) # ----------------------------------------- # ----- Generic Convolutional Block ----- # ----------------------------------------- class Conv2dBlock(nn.Module): def __init__( self, input_dim, output_dim, kernel_size, stride=1, padding=0, dilation=1, norm="none", activation="relu", pad_type="zero", bias=True, ): super().__init__() self.use_bias = bias # initialize padding if pad_type == "reflect": self.pad = nn.ReflectionPad2d(padding) elif pad_type == "replicate": self.pad = nn.ReplicationPad2d(padding) elif pad_type == "zero": self.pad = nn.ZeroPad2d(padding) else: assert 0, "Unsupported padding type: {}".format(pad_type) # initialize normalization use_spectral_norm = False if norm.startswith("spectral_"): norm = norm.replace("spectral_", "") use_spectral_norm = True norm_dim = output_dim if norm == "batch": self.norm = nn.BatchNorm2d(norm_dim) elif norm == "instance": # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) self.norm = nn.InstanceNorm2d(norm_dim) elif norm == "layer": self.norm = LayerNorm(norm_dim) elif norm == "adain": self.norm = AdaptiveInstanceNorm2d(norm_dim) elif norm == "spectral" or norm.startswith("spectral_"): self.norm = None # dealt with later in the code elif norm == "none": self.norm = None else: raise ValueError("Unsupported normalization: {}".format(norm)) # initialize activation if activation == "relu": self.activation = nn.ReLU(inplace=False) elif activation == "lrelu": self.activation = nn.LeakyReLU(0.2, inplace=False) elif activation == "prelu": self.activation = nn.PReLU() elif activation == "selu": self.activation = nn.SELU(inplace=False) elif activation == "tanh": self.activation = nn.Tanh() elif activation == "sigmoid": self.activation = nn.Sigmoid() elif activation == "none": self.activation = None else: raise ValueError("Unsupported activation: {}".format(activation)) # initialize convolution if norm == "spectral" or use_spectral_norm: self.conv = SpectralNorm( nn.Conv2d( input_dim, output_dim, kernel_size, stride, dilation=dilation, bias=self.use_bias, ) ) else: self.conv = nn.Conv2d( input_dim, output_dim, kernel_size, stride, dilation=dilation, bias=self.use_bias if norm != "batch" else False, ) def forward(self, x): x = self.conv(self.pad(x)) if self.norm is not None: x = self.norm(x) if self.activation is not None: x = self.activation(x) return x def __str__(self): return strings.conv2dblock(self) # ----------------------------- # ----- Residual Blocks ----- # ----------------------------- class ResBlocks(nn.Module): """ From https://github.com/NVlabs/MUNIT/blob/master/networks.py """ def __init__(self, num_blocks, dim, norm="in", activation="relu", pad_type="zero"): super().__init__() self.model = nn.Sequential( *[ ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type) for _ in range(num_blocks) ] ) def forward(self, x): return self.model(x) def __str__(self): return strings.resblocks(self) class ResBlock(nn.Module): def __init__(self, dim, norm="in", activation="relu", pad_type="zero"): super().__init__() self.dim = dim self.norm = norm self.activation = activation model = [] model += [ Conv2dBlock( dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type ) ] model += [ Conv2dBlock( dim, dim, 3, 1, 1, norm=norm, activation="none", pad_type=pad_type ) ] self.model = nn.Sequential(*model) def forward(self, x): residual = x out = self.model(x) out += residual return out def __str__(self): return strings.resblock(self) # -------------------------- # ----- Base Decoder ----- # -------------------------- class BaseDecoder(nn.Module): def __init__( self, n_upsample=4, n_res=4, input_dim=2048, proj_dim=64, output_dim=3, norm="batch", activ="relu", pad_type="zero", output_activ="tanh", low_level_feats_dim=-1, use_dada=False, ): super().__init__() self.low_level_feats_dim = low_level_feats_dim self.use_dada = use_dada self.model = [] if proj_dim != -1: self.proj_conv = Conv2dBlock( input_dim, proj_dim, 1, 1, 0, norm=norm, activation=activ ) else: self.proj_conv = None proj_dim = input_dim if low_level_feats_dim > 0: self.low_level_conv = Conv2dBlock( input_dim=low_level_feats_dim, output_dim=proj_dim, kernel_size=3, stride=1, padding=1, pad_type=pad_type, norm=norm, activation=activ, ) self.merge_feats_conv = Conv2dBlock( input_dim=2 * proj_dim, output_dim=proj_dim, kernel_size=1, stride=1, padding=0, pad_type=pad_type, norm=norm, activation=activ, ) else: self.low_level_conv = None self.model += [ResBlocks(n_res, proj_dim, norm, activ, pad_type=pad_type)] dim = proj_dim # upsampling blocks for i in range(n_upsample): self.model += [ InterpolateNearest2d(scale_factor=2), Conv2dBlock( input_dim=dim, output_dim=dim // 2, kernel_size=3, stride=1, padding=1, pad_type=pad_type, norm=norm, activation=activ, ), ] dim //= 2 # use reflection padding in the last conv layer self.model += [ Conv2dBlock( input_dim=dim, output_dim=output_dim, kernel_size=3, stride=1, padding=1, pad_type=pad_type, norm="none", activation=output_activ, ) ] self.model = nn.Sequential(*self.model) def forward(self, z, cond=None, z_depth=None): low_level_feat = None if isinstance(z, (list, tuple)): if self.low_level_conv is None: z = z[0] else: z, low_level_feat = z low_level_feat = self.low_level_conv(low_level_feat) low_level_feat = F.interpolate( low_level_feat, size=z.shape[-2:], mode="bilinear" ) if z_depth is not None and self.use_dada: z = z * z_depth if self.proj_conv is not None: z = self.proj_conv(z) if low_level_feat is not None: z = self.merge_feats_conv(torch.cat([low_level_feat, z], dim=1)) return self.model(z) def __str__(self): return strings.basedecoder(self) # -------------------------- # ----- SPADE Blocks ----- # -------------------------- # https://github.com/NVlabs/SPADE/blob/0ff661e70131c9b85091d11a66e019c0f2062d4c # /models/networks/generator.py # 0ff661e on 13 Apr 2019 class SPADEResnetBlock(nn.Module): def __init__( self, fin, fout, cond_nc, spade_use_spectral_norm, spade_param_free_norm, spade_kernel_size, last_activation=None, ): super().__init__() # Attributes self.fin = fin self.fout = fout self.use_spectral_norm = spade_use_spectral_norm self.param_free_norm = spade_param_free_norm self.kernel_size = spade_kernel_size self.learned_shortcut = fin != fout self.last_activation = last_activation fmiddle = min(fin, fout) # create conv layers self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) if self.learned_shortcut: self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) # apply spectral norm if specified if spade_use_spectral_norm: self.conv_0 = SpectralNorm(self.conv_0) self.conv_1 = SpectralNorm(self.conv_1) if self.learned_shortcut: self.conv_s = SpectralNorm(self.conv_s) self.norm_0 = SPADE(spade_param_free_norm, spade_kernel_size, fin, cond_nc) self.norm_1 = SPADE(spade_param_free_norm, spade_kernel_size, fmiddle, cond_nc) if self.learned_shortcut: self.norm_s = SPADE(spade_param_free_norm, spade_kernel_size, fin, cond_nc) # note the resnet block with SPADE also takes in |seg|, # the semantic segmentation map as input def forward(self, x, seg): x_s = self.shortcut(x, seg) dx = self.conv_0(self.activation(self.norm_0(x, seg))) dx = self.conv_1(self.activation(self.norm_1(dx, seg))) out = x_s + dx if self.last_activation == "lrelu": return self.activation(out) elif self.last_activation is None: return out else: raise NotImplementedError( "The type of activation is not supported: {}".format( self.last_activation ) ) def shortcut(self, x, seg): if self.learned_shortcut: x_s = self.conv_s(self.norm_s(x, seg)) else: x_s = x return x_s def activation(self, x): return F.leaky_relu(x, 2e-1) def __str__(self): return strings.spaderesblock(self)