# python3.7 """Contains the implementation of generator described in ProgressiveGAN. Different from the official tensorflow model in folder `pggan_tf_official`, this is a simple pytorch version which only contains the generator part. This class is specially used for inference. For more details, please check the original paper: https://arxiv.org/pdf/1710.10196.pdf """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F __all__ = ['PGGANGeneratorModel'] # Defines a dictionary, which maps the target resolution of the final generated # image to numbers of filters used in each convolutional layer in sequence. _RESOLUTIONS_TO_CHANNELS = { 8: [512, 512, 512], 16: [512, 512, 512, 512], 32: [512, 512, 512, 512, 512], 64: [512, 512, 512, 512, 512, 256], 128: [512, 512, 512, 512, 512, 256, 128], 256: [512, 512, 512, 512, 512, 256, 128, 64], 512: [512, 512, 512, 512, 512, 256, 128, 64, 32], 1024: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16], } # Variable mapping from pytorch model to official tensorflow model. _PGGAN_PTH_VARS_TO_TF_VARS = { 'lod': 'lod', # [] 'layer0.conv.weight': '4x4/Dense/weight', # [512, 512, 4, 4] 'layer0.wscale.bias': '4x4/Dense/bias', # [512] 'layer1.conv.weight': '4x4/Conv/weight', # [512, 512, 3, 3] 'layer1.wscale.bias': '4x4/Conv/bias', # [512] 'layer2.conv.weight': '8x8/Conv0/weight', # [512, 512, 3, 3] 'layer2.wscale.bias': '8x8/Conv0/bias', # [512] 'layer3.conv.weight': '8x8/Conv1/weight', # [512, 512, 3, 3] 'layer3.wscale.bias': '8x8/Conv1/bias', # [512] 'layer4.conv.weight': '16x16/Conv0/weight', # [512, 512, 3, 3] 'layer4.wscale.bias': '16x16/Conv0/bias', # [512] 'layer5.conv.weight': '16x16/Conv1/weight', # [512, 512, 3, 3] 'layer5.wscale.bias': '16x16/Conv1/bias', # [512] 'layer6.conv.weight': '32x32/Conv0/weight', # [512, 512, 3, 3] 'layer6.wscale.bias': '32x32/Conv0/bias', # [512] 'layer7.conv.weight': '32x32/Conv1/weight', # [512, 512, 3, 3] 'layer7.wscale.bias': '32x32/Conv1/bias', # [512] 'layer8.conv.weight': '64x64/Conv0/weight', # [256, 512, 3, 3] 'layer8.wscale.bias': '64x64/Conv0/bias', # [256] 'layer9.conv.weight': '64x64/Conv1/weight', # [256, 256, 3, 3] 'layer9.wscale.bias': '64x64/Conv1/bias', # [256] 'layer10.conv.weight': '128x128/Conv0/weight', # [128, 256, 3, 3] 'layer10.wscale.bias': '128x128/Conv0/bias', # [128] 'layer11.conv.weight': '128x128/Conv1/weight', # [128, 128, 3, 3] 'layer11.wscale.bias': '128x128/Conv1/bias', # [128] 'layer12.conv.weight': '256x256/Conv0/weight', # [64, 128, 3, 3] 'layer12.wscale.bias': '256x256/Conv0/bias', # [64] 'layer13.conv.weight': '256x256/Conv1/weight', # [64, 64, 3, 3] 'layer13.wscale.bias': '256x256/Conv1/bias', # [64] 'layer14.conv.weight': '512x512/Conv0/weight', # [32, 64, 3, 3] 'layer14.wscale.bias': '512x512/Conv0/bias', # [32] 'layer15.conv.weight': '512x512/Conv1/weight', # [32, 32, 3, 3] 'layer15.wscale.bias': '512x512/Conv1/bias', # [32] 'layer16.conv.weight': '1024x1024/Conv0/weight', # [16, 32, 3, 3] 'layer16.wscale.bias': '1024x1024/Conv0/bias', # [16] 'layer17.conv.weight': '1024x1024/Conv1/weight', # [16, 16, 3, 3] 'layer17.wscale.bias': '1024x1024/Conv1/bias', # [16] 'output0.conv.weight': 'ToRGB_lod8/weight', # [3, 512, 1, 1] 'output0.wscale.bias': 'ToRGB_lod8/bias', # [3] 'output1.conv.weight': 'ToRGB_lod7/weight', # [3, 512, 1, 1] 'output1.wscale.bias': 'ToRGB_lod7/bias', # [3] 'output2.conv.weight': 'ToRGB_lod6/weight', # [3, 512, 1, 1] 'output2.wscale.bias': 'ToRGB_lod6/bias', # [3] 'output3.conv.weight': 'ToRGB_lod5/weight', # [3, 512, 1, 1] 'output3.wscale.bias': 'ToRGB_lod5/bias', # [3] 'output4.conv.weight': 'ToRGB_lod4/weight', # [3, 256, 1, 1] 'output4.wscale.bias': 'ToRGB_lod4/bias', # [3] 'output5.conv.weight': 'ToRGB_lod3/weight', # [3, 128, 1, 1] 'output5.wscale.bias': 'ToRGB_lod3/bias', # [3] 'output6.conv.weight': 'ToRGB_lod2/weight', # [3, 64, 1, 1] 'output6.wscale.bias': 'ToRGB_lod2/bias', # [3] 'output7.conv.weight': 'ToRGB_lod1/weight', # [3, 32, 1, 1] 'output7.wscale.bias': 'ToRGB_lod1/bias', # [3] 'output8.conv.weight': 'ToRGB_lod0/weight', # [3, 16, 1, 1] 'output8.wscale.bias': 'ToRGB_lod0/bias', # [3] } class PGGANGeneratorModel(nn.Module): """Defines the generator module in ProgressiveGAN. Note that the generated images are with RGB color channels with range [-1, 1]. """ def __init__(self, resolution=1024, fused_scale=False, output_channels=3): """Initializes the generator with basic settings. Args: resolution: The resolution of the final output image. (default: 1024) fused_scale: Whether to fused `upsample` and `conv2d` together, resulting in `conv2_transpose`. (default: False) output_channels: Number of channels of the output image. (default: 3) Raises: ValueError: If the input `resolution` is not supported. """ super().__init__() try: self.channels = _RESOLUTIONS_TO_CHANNELS[resolution] except KeyError: raise ValueError(f'Invalid resolution: {resolution}!\n' f'Resolutions allowed: ' f'{list(_RESOLUTIONS_TO_CHANNELS)}.') assert len(self.channels) == int(np.log2(resolution)) self.resolution = resolution self.fused_scale = fused_scale self.output_channels = output_channels for block_idx in range(1, len(self.channels)): if block_idx == 1: self.add_module( f'layer{2 * block_idx - 2}', ConvBlock(in_channels=self.channels[block_idx - 1], out_channels=self.channels[block_idx], kernel_size=4, padding=3)) else: self.add_module( f'layer{2 * block_idx - 2}', ConvBlock(in_channels=self.channels[block_idx - 1], out_channels=self.channels[block_idx], upsample=True, fused_scale=self.fused_scale)) self.add_module( f'layer{2 * block_idx - 1}', ConvBlock(in_channels=self.channels[block_idx], out_channels=self.channels[block_idx])) self.add_module( f'output{block_idx - 1}', ConvBlock(in_channels=self.channels[block_idx], out_channels=self.output_channels, kernel_size=1, padding=0, wscale_gain=1.0, activation_type='linear')) self.upsample = ResolutionScalingLayer() self.lod = nn.Parameter(torch.zeros(())) self.pth_to_tf_var_mapping = {} for pth_var_name, tf_var_name in _PGGAN_PTH_VARS_TO_TF_VARS.items(): if self.fused_scale and 'Conv0' in tf_var_name: pth_var_name = pth_var_name.replace('conv.weight', 'weight') tf_var_name = tf_var_name.replace('Conv0', 'Conv0_up') self.pth_to_tf_var_mapping[pth_var_name] = tf_var_name def forward(self, x): if len(x.shape) != 2: raise ValueError(f'The input tensor should be with shape [batch_size, ' f'noise_dim], but {x.shape} received!') x = x.view(x.shape[0], x.shape[1], 1, 1) lod = self.lod.cpu().tolist() for block_idx in range(1, len(self.channels)): if block_idx + lod < len(self.channels): x = self.__getattr__(f'layer{2 * block_idx - 2}')(x) x = self.__getattr__(f'layer{2 * block_idx - 1}')(x) image = self.__getattr__(f'output{block_idx - 1}')(x) else: image = self.upsample(image) return image class PixelNormLayer(nn.Module): """Implements pixel-wise feature vector normalization layer.""" def __init__(self, epsilon=1e-8): super().__init__() self.epsilon = epsilon def forward(self, x): return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon) class ResolutionScalingLayer(nn.Module): """Implements the resolution scaling layer. Basically, this layer can be used to upsample or downsample feature maps from spatial domain with nearest neighbor interpolation. """ def __init__(self, scale_factor=2): super().__init__() self.scale_factor = scale_factor def forward(self, x): return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') class WScaleLayer(nn.Module): """Implements the layer to scale weight variable and add bias. Note that, the weight variable is trained in `nn.Conv2d` layer, and only scaled with a constant number, which is not trainable, in this layer. However, the bias variable is trainable in this layer. """ def __init__(self, in_channels, out_channels, kernel_size, gain=np.sqrt(2.0)): super().__init__() fan_in = in_channels * kernel_size * kernel_size self.scale = gain / np.sqrt(fan_in) self.bias = nn.Parameter(torch.zeros(out_channels)) def forward(self, x): return x * self.scale + self.bias.view(1, -1, 1, 1) class ConvBlock(nn.Module): """Implements the convolutional block used in ProgressiveGAN. Basically, this block executes pixel-wise normalization layer, upsampling layer (if needed), convolutional layer, weight-scale layer, and activation layer in sequence. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, add_bias=False, upsample=False, fused_scale=False, wscale_gain=np.sqrt(2.0), activation_type='lrelu'): """Initializes the class with block settings. Args: in_channels: Number of channels of the input tensor fed into this block. out_channels: Number of channels (kernels) of the output tensor. kernel_size: Size of the convolutional kernel. stride: Stride parameter for convolution operation. padding: Padding parameter for convolution operation. dilation: Dilation rate for convolution operation. add_bias: Whether to add bias onto the convolutional result. upsample: Whether to upsample the input tensor before convolution. fused_scale: Whether to fused `upsample` and `conv2d` together, resulting in `conv2_transpose`. wscale_gain: The gain factor for `wscale` layer. wscale_lr_multiplier: The learning rate multiplier factor for `wscale` layer. activation_type: Type of activation function. Support `linear`, `lrelu` and `tanh`. Raises: NotImplementedError: If the input `activation_type` is not supported. """ super().__init__() self.pixel_norm = PixelNormLayer() if upsample and not fused_scale: self.upsample = ResolutionScalingLayer() else: self.upsample = nn.Identity() if upsample and fused_scale: self.weight = nn.Parameter( torch.randn(kernel_size, kernel_size, in_channels, out_channels)) fan_in = in_channels * kernel_size * kernel_size self.scale = wscale_gain / np.sqrt(fan_in) else: self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=1, bias=add_bias) self.wscale = WScaleLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, gain=wscale_gain) if activation_type == 'linear': self.activate = nn.Identity() elif activation_type == 'lrelu': self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) elif activation_type == 'tanh': self.activate = nn.Hardtanh() else: raise NotImplementedError(f'Not implemented activation function: ' f'{activation_type}!') def forward(self, x): x = self.pixel_norm(x) x = self.upsample(x) if hasattr(self, 'conv'): x = self.conv(x) else: kernel = self.weight * self.scale kernel = F.pad(kernel, (0, 0, 0, 0, 1, 1, 1, 1), 'constant', 0.0) kernel = (kernel[1:, 1:] + kernel[:-1, 1:] + kernel[1:, :-1] + kernel[:-1, :-1]) kernel = kernel.permute(2, 3, 0, 1) x = F.conv_transpose2d(x, kernel, stride=2, padding=1) x = x / self.scale x = self.wscale(x) x = self.activate(x) return x