# python3.7 """Contains the implementation of generator described in StyleGAN. Different from the official tensorflow model in folder `stylegan_tf_official`, this is a simple pytorch version which only contains the generator part. This class is specially used for inference. For more details, please check the original paper: https://arxiv.org/pdf/1812.04948.pdf """ from collections import OrderedDict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F __all__ = ['StyleGANGeneratorModel'] # Defines a dictionary, which maps the target resolution of the final generated # image to numbers of filters used in each convolutional layer in sequence. _RESOLUTIONS_TO_CHANNELS = { 8: [512, 512, 512], 16: [512, 512, 512, 512], 32: [512, 512, 512, 512, 512], 64: [512, 512, 512, 512, 512, 256], 128: [512, 512, 512, 512, 512, 256, 128], 256: [512, 512, 512, 512, 512, 256, 128, 64], 512: [512, 512, 512, 512, 512, 256, 128, 64, 32], 1024: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16], } # pylint: disable=line-too-long # Variable mapping from pytorch model to official tensorflow model. _STYLEGAN_PTH_VARS_TO_TF_VARS = { # Statistic information of disentangled latent feature, w. 'truncation.w_avg':'dlatent_avg', # [512] # Noises. 'synthesis.layer0.epilogue.apply_noise.noise': 'noise0', # [1, 1, 4, 4] 'synthesis.layer1.epilogue.apply_noise.noise': 'noise1', # [1, 1, 4, 4] 'synthesis.layer2.epilogue.apply_noise.noise': 'noise2', # [1, 1, 8, 8] 'synthesis.layer3.epilogue.apply_noise.noise': 'noise3', # [1, 1, 8, 8] 'synthesis.layer4.epilogue.apply_noise.noise': 'noise4', # [1, 1, 16, 16] 'synthesis.layer5.epilogue.apply_noise.noise': 'noise5', # [1, 1, 16, 16] 'synthesis.layer6.epilogue.apply_noise.noise': 'noise6', # [1, 1, 32, 32] 'synthesis.layer7.epilogue.apply_noise.noise': 'noise7', # [1, 1, 32, 32] 'synthesis.layer8.epilogue.apply_noise.noise': 'noise8', # [1, 1, 64, 64] 'synthesis.layer9.epilogue.apply_noise.noise': 'noise9', # [1, 1, 64, 64] 'synthesis.layer10.epilogue.apply_noise.noise': 'noise10', # [1, 1, 128, 128] 'synthesis.layer11.epilogue.apply_noise.noise': 'noise11', # [1, 1, 128, 128] 'synthesis.layer12.epilogue.apply_noise.noise': 'noise12', # [1, 1, 256, 256] 'synthesis.layer13.epilogue.apply_noise.noise': 'noise13', # [1, 1, 256, 256] 'synthesis.layer14.epilogue.apply_noise.noise': 'noise14', # [1, 1, 512, 512] 'synthesis.layer15.epilogue.apply_noise.noise': 'noise15', # [1, 1, 512, 512] 'synthesis.layer16.epilogue.apply_noise.noise': 'noise16', # [1, 1, 1024, 1024] 'synthesis.layer17.epilogue.apply_noise.noise': 'noise17', # [1, 1, 1024, 1024] # Mapping blocks. 'mapping.dense0.linear.weight': 'Dense0/weight', # [512, 512] 'mapping.dense0.wscale.bias': 'Dense0/bias', # [512] 'mapping.dense1.linear.weight': 'Dense1/weight', # [512, 512] 'mapping.dense1.wscale.bias': 'Dense1/bias', # [512] 'mapping.dense2.linear.weight': 'Dense2/weight', # [512, 512] 'mapping.dense2.wscale.bias': 'Dense2/bias', # [512] 'mapping.dense3.linear.weight': 'Dense3/weight', # [512, 512] 'mapping.dense3.wscale.bias': 'Dense3/bias', # [512] 'mapping.dense4.linear.weight': 'Dense4/weight', # [512, 512] 'mapping.dense4.wscale.bias': 'Dense4/bias', # [512] 'mapping.dense5.linear.weight': 'Dense5/weight', # [512, 512] 'mapping.dense5.wscale.bias': 'Dense5/bias', # [512] 'mapping.dense6.linear.weight': 'Dense6/weight', # [512, 512] 'mapping.dense6.wscale.bias': 'Dense6/bias', # [512] 'mapping.dense7.linear.weight': 'Dense7/weight', # [512, 512] 'mapping.dense7.wscale.bias': 'Dense7/bias', # [512] # Synthesis blocks. 'synthesis.lod': 'lod', # [] 'synthesis.layer0.first_layer': '4x4/Const/const', # [1, 512, 4, 4] 'synthesis.layer0.epilogue.apply_noise.weight': '4x4/Const/Noise/weight', # [512] 'synthesis.layer0.epilogue.bias': '4x4/Const/bias', # [512] 'synthesis.layer0.epilogue.style_mod.dense.linear.weight': '4x4/Const/StyleMod/weight', # [1024, 512] 'synthesis.layer0.epilogue.style_mod.dense.wscale.bias': '4x4/Const/StyleMod/bias', # [1024] 'synthesis.layer1.conv.weight': '4x4/Conv/weight', # [512, 512, 3, 3] 'synthesis.layer1.epilogue.apply_noise.weight': '4x4/Conv/Noise/weight', # [512] 'synthesis.layer1.epilogue.bias': '4x4/Conv/bias', # [512] 'synthesis.layer1.epilogue.style_mod.dense.linear.weight': '4x4/Conv/StyleMod/weight', # [1024, 512] 'synthesis.layer1.epilogue.style_mod.dense.wscale.bias': '4x4/Conv/StyleMod/bias', # [1024] 'synthesis.layer2.conv.weight': '8x8/Conv0_up/weight', # [512, 512, 3, 3] 'synthesis.layer2.epilogue.apply_noise.weight': '8x8/Conv0_up/Noise/weight', # [512] 'synthesis.layer2.epilogue.bias': '8x8/Conv0_up/bias', # [512] 'synthesis.layer2.epilogue.style_mod.dense.linear.weight': '8x8/Conv0_up/StyleMod/weight', # [1024, 512] 'synthesis.layer2.epilogue.style_mod.dense.wscale.bias': '8x8/Conv0_up/StyleMod/bias', # [1024] 'synthesis.layer3.conv.weight': '8x8/Conv1/weight', # [512, 512, 3, 3] 'synthesis.layer3.epilogue.apply_noise.weight': '8x8/Conv1/Noise/weight', # [512] 'synthesis.layer3.epilogue.bias': '8x8/Conv1/bias', # [512] 'synthesis.layer3.epilogue.style_mod.dense.linear.weight': '8x8/Conv1/StyleMod/weight', # [1024, 512] 'synthesis.layer3.epilogue.style_mod.dense.wscale.bias': '8x8/Conv1/StyleMod/bias', # [1024] 'synthesis.layer4.conv.weight': '16x16/Conv0_up/weight', # [512, 512, 3, 3] 'synthesis.layer4.epilogue.apply_noise.weight': '16x16/Conv0_up/Noise/weight', # [512] 'synthesis.layer4.epilogue.bias': '16x16/Conv0_up/bias', # [512] 'synthesis.layer4.epilogue.style_mod.dense.linear.weight': '16x16/Conv0_up/StyleMod/weight', # [1024, 512] 'synthesis.layer4.epilogue.style_mod.dense.wscale.bias': '16x16/Conv0_up/StyleMod/bias', # [1024] 'synthesis.layer5.conv.weight': '16x16/Conv1/weight', # [512, 512, 3, 3] 'synthesis.layer5.epilogue.apply_noise.weight': '16x16/Conv1/Noise/weight', # [512] 'synthesis.layer5.epilogue.bias': '16x16/Conv1/bias', # [512] 'synthesis.layer5.epilogue.style_mod.dense.linear.weight': '16x16/Conv1/StyleMod/weight', # [1024, 512] 'synthesis.layer5.epilogue.style_mod.dense.wscale.bias': '16x16/Conv1/StyleMod/bias', # [1024] 'synthesis.layer6.conv.weight': '32x32/Conv0_up/weight', # [512, 512, 3, 3] 'synthesis.layer6.epilogue.apply_noise.weight': '32x32/Conv0_up/Noise/weight', # [512] 'synthesis.layer6.epilogue.bias': '32x32/Conv0_up/bias', # [512] 'synthesis.layer6.epilogue.style_mod.dense.linear.weight': '32x32/Conv0_up/StyleMod/weight', # [1024, 512] 'synthesis.layer6.epilogue.style_mod.dense.wscale.bias': '32x32/Conv0_up/StyleMod/bias', # [1024] 'synthesis.layer7.conv.weight': '32x32/Conv1/weight', # [512, 512, 3, 3] 'synthesis.layer7.epilogue.apply_noise.weight': '32x32/Conv1/Noise/weight', # [512] 'synthesis.layer7.epilogue.bias': '32x32/Conv1/bias', # [512] 'synthesis.layer7.epilogue.style_mod.dense.linear.weight': '32x32/Conv1/StyleMod/weight', # [1024, 512] 'synthesis.layer7.epilogue.style_mod.dense.wscale.bias': '32x32/Conv1/StyleMod/bias', # [1024] 'synthesis.layer8.conv.weight': '64x64/Conv0_up/weight', # [256, 512, 3, 3] 'synthesis.layer8.epilogue.apply_noise.weight': '64x64/Conv0_up/Noise/weight', # [256] 'synthesis.layer8.epilogue.bias': '64x64/Conv0_up/bias', # [256] 'synthesis.layer8.epilogue.style_mod.dense.linear.weight': '64x64/Conv0_up/StyleMod/weight', # [512, 512] 'synthesis.layer8.epilogue.style_mod.dense.wscale.bias': '64x64/Conv0_up/StyleMod/bias', # [512] 'synthesis.layer9.conv.weight': '64x64/Conv1/weight', # [256, 256, 3, 3] 'synthesis.layer9.epilogue.apply_noise.weight': '64x64/Conv1/Noise/weight', # [256] 'synthesis.layer9.epilogue.bias': '64x64/Conv1/bias', # [256] 'synthesis.layer9.epilogue.style_mod.dense.linear.weight': '64x64/Conv1/StyleMod/weight', # [512, 512] 'synthesis.layer9.epilogue.style_mod.dense.wscale.bias': '64x64/Conv1/StyleMod/bias', # [512] 'synthesis.layer10.conv.weight': '128x128/Conv0_up/weight', # [128, 256, 3, 3] 'synthesis.layer10.epilogue.apply_noise.weight': '128x128/Conv0_up/Noise/weight', # [128] 'synthesis.layer10.epilogue.bias': '128x128/Conv0_up/bias', # [128] 'synthesis.layer10.epilogue.style_mod.dense.linear.weight': '128x128/Conv0_up/StyleMod/weight', # [256, 512] 'synthesis.layer10.epilogue.style_mod.dense.wscale.bias': '128x128/Conv0_up/StyleMod/bias', # [256] 'synthesis.layer11.conv.weight': '128x128/Conv1/weight', # [128, 128, 3, 3] 'synthesis.layer11.epilogue.apply_noise.weight': '128x128/Conv1/Noise/weight', # [128] 'synthesis.layer11.epilogue.bias': '128x128/Conv1/bias', # [128] 'synthesis.layer11.epilogue.style_mod.dense.linear.weight': '128x128/Conv1/StyleMod/weight', # [256, 512] 'synthesis.layer11.epilogue.style_mod.dense.wscale.bias': '128x128/Conv1/StyleMod/bias', # [256] 'synthesis.layer12.conv.weight': '256x256/Conv0_up/weight', # [64, 128, 3, 3] 'synthesis.layer12.epilogue.apply_noise.weight': '256x256/Conv0_up/Noise/weight', # [64] 'synthesis.layer12.epilogue.bias': '256x256/Conv0_up/bias', # [64] 'synthesis.layer12.epilogue.style_mod.dense.linear.weight': '256x256/Conv0_up/StyleMod/weight', # [128, 512] 'synthesis.layer12.epilogue.style_mod.dense.wscale.bias': '256x256/Conv0_up/StyleMod/bias', # [128] 'synthesis.layer13.conv.weight': '256x256/Conv1/weight', # [64, 64, 3, 3] 'synthesis.layer13.epilogue.apply_noise.weight': '256x256/Conv1/Noise/weight', # [64] 'synthesis.layer13.epilogue.bias': '256x256/Conv1/bias', # [64] 'synthesis.layer13.epilogue.style_mod.dense.linear.weight': '256x256/Conv1/StyleMod/weight', # [128, 512] 'synthesis.layer13.epilogue.style_mod.dense.wscale.bias': '256x256/Conv1/StyleMod/bias', # [128] 'synthesis.layer14.conv.weight': '512x512/Conv0_up/weight', # [32, 64, 3, 3] 'synthesis.layer14.epilogue.apply_noise.weight': '512x512/Conv0_up/Noise/weight', # [32] 'synthesis.layer14.epilogue.bias': '512x512/Conv0_up/bias', # [32] 'synthesis.layer14.epilogue.style_mod.dense.linear.weight': '512x512/Conv0_up/StyleMod/weight', # [64, 512] 'synthesis.layer14.epilogue.style_mod.dense.wscale.bias': '512x512/Conv0_up/StyleMod/bias', # [64] 'synthesis.layer15.conv.weight': '512x512/Conv1/weight', # [32, 32, 3, 3] 'synthesis.layer15.epilogue.apply_noise.weight': '512x512/Conv1/Noise/weight', # [32] 'synthesis.layer15.epilogue.bias': '512x512/Conv1/bias', # [32] 'synthesis.layer15.epilogue.style_mod.dense.linear.weight': '512x512/Conv1/StyleMod/weight', # [64, 512] 'synthesis.layer15.epilogue.style_mod.dense.wscale.bias': '512x512/Conv1/StyleMod/bias', # [64] 'synthesis.layer16.conv.weight': '1024x1024/Conv0_up/weight', # [16, 32, 3, 3] 'synthesis.layer16.epilogue.apply_noise.weight': '1024x1024/Conv0_up/Noise/weight', # [16] 'synthesis.layer16.epilogue.bias': '1024x1024/Conv0_up/bias', # [16] 'synthesis.layer16.epilogue.style_mod.dense.linear.weight': '1024x1024/Conv0_up/StyleMod/weight', # [32, 512] 'synthesis.layer16.epilogue.style_mod.dense.wscale.bias': '1024x1024/Conv0_up/StyleMod/bias', # [32] 'synthesis.layer17.conv.weight': '1024x1024/Conv1/weight', # [16, 16, 3, 3] 'synthesis.layer17.epilogue.apply_noise.weight': '1024x1024/Conv1/Noise/weight', # [16] 'synthesis.layer17.epilogue.bias': '1024x1024/Conv1/bias', # [16] 'synthesis.layer17.epilogue.style_mod.dense.linear.weight': '1024x1024/Conv1/StyleMod/weight', # [32, 512] 'synthesis.layer17.epilogue.style_mod.dense.wscale.bias': '1024x1024/Conv1/StyleMod/bias', # [32] 'synthesis.output0.conv.weight': 'ToRGB_lod8/weight', # [3, 512, 1, 1] 'synthesis.output0.bias': 'ToRGB_lod8/bias', # [3] 'synthesis.output1.conv.weight': 'ToRGB_lod7/weight', # [3, 512, 1, 1] 'synthesis.output1.bias': 'ToRGB_lod7/bias', # [3] 'synthesis.output2.conv.weight': 'ToRGB_lod6/weight', # [3, 512, 1, 1] 'synthesis.output2.bias': 'ToRGB_lod6/bias', # [3] 'synthesis.output3.conv.weight': 'ToRGB_lod5/weight', # [3, 512, 1, 1] 'synthesis.output3.bias': 'ToRGB_lod5/bias', # [3] 'synthesis.output4.conv.weight': 'ToRGB_lod4/weight', # [3, 256, 1, 1] 'synthesis.output4.bias': 'ToRGB_lod4/bias', # [3] 'synthesis.output5.conv.weight': 'ToRGB_lod3/weight', # [3, 128, 1, 1] 'synthesis.output5.bias': 'ToRGB_lod3/bias', # [3] 'synthesis.output6.conv.weight': 'ToRGB_lod2/weight', # [3, 64, 1, 1] 'synthesis.output6.bias': 'ToRGB_lod2/bias', # [3] 'synthesis.output7.conv.weight': 'ToRGB_lod1/weight', # [3, 32, 1, 1] 'synthesis.output7.bias': 'ToRGB_lod1/bias', # [3] 'synthesis.output8.conv.weight': 'ToRGB_lod0/weight', # [3, 16, 1, 1] 'synthesis.output8.bias': 'ToRGB_lod0/bias', # [3] } # pylint: enable=line-too-long # Minimal resolution for `auto` fused-scale strategy. _AUTO_FUSED_SCALE_MIN_RES = 128 class StyleGANGeneratorModel(nn.Module): """Defines the generator module in StyleGAN. Note that the generated images are with RGB color channels. """ def __init__(self, resolution=1024, w_space_dim=512, fused_scale='auto', output_channels=3, truncation_psi=0.7, truncation_layers=8, randomize_noise=False): """Initializes the generator with basic settings. Args: resolution: The resolution of the final output image. (default: 1024) w_space_dim: The dimension of the disentangled latent vectors, w. (default: 512) fused_scale: If set as `True`, `conv2d_transpose` is used for upscaling. If set as `False`, `upsample + conv2d` is used for upscaling. If set as `auto`, `upsample + conv2d` is used for bottom layers until resolution reaches 128. (default: `auto`) output_channels: Number of channels of output image. (default: 3) truncation_psi: Style strength multiplier for the truncation trick. `None` or `1.0` indicates no truncation. (default: 0.7) truncation_layers: Number of layers for which to apply the truncation trick. `None` indicates no truncation. (default: 8) randomize_noise: Whether to add random noise for each convolutional layer. (default: False) Raises: ValueError: If the input `resolution` is not supported. """ super().__init__() self.resolution = resolution self.w_space_dim = w_space_dim self.fused_scale = fused_scale self.output_channels = output_channels self.truncation_psi = truncation_psi self.truncation_layers = truncation_layers self.randomize_noise = randomize_noise self.mapping = MappingModule(final_space_dim=self.w_space_dim) self.truncation = TruncationModule(resolution=self.resolution, w_space_dim=self.w_space_dim, truncation_psi=self.truncation_psi, truncation_layers=self.truncation_layers) self.synthesis = SynthesisModule(resolution=self.resolution, fused_scale=self.fused_scale, output_channels=self.output_channels, randomize_noise=self.randomize_noise) self.pth_to_tf_var_mapping = {} for pth_var_name, tf_var_name in _STYLEGAN_PTH_VARS_TO_TF_VARS.items(): if 'Conv0_up' in tf_var_name: res = int(tf_var_name.split('x')[0]) if ((self.fused_scale is True) or (self.fused_scale == 'auto' and res >= _AUTO_FUSED_SCALE_MIN_RES)): pth_var_name = pth_var_name.replace('conv.weight', 'weight') self.pth_to_tf_var_mapping[pth_var_name] = tf_var_name def forward(self, z): w = self.mapping(z) w = self.truncation(w) x = self.synthesis(w) return x class MappingModule(nn.Sequential): """Implements the latent space mapping module used in StyleGAN. Basically, this module executes several dense layers in sequence. """ def __init__(self, normalize_input=True, input_space_dim=512, hidden_space_dim=512, final_space_dim=512, num_layers=8): sequence = OrderedDict() def _add_layer(layer, name=None): name = name or f'dense{len(sequence) + (not normalize_input) - 1}' sequence[name] = layer if normalize_input: _add_layer(PixelNormLayer(), name='normalize') for i in range(num_layers): in_dim = input_space_dim if i == 0 else hidden_space_dim out_dim = final_space_dim if i == (num_layers - 1) else hidden_space_dim _add_layer(DenseBlock(in_dim, out_dim)) super().__init__(sequence) def forward(self, x): if len(x.shape) != 2: raise ValueError(f'The input tensor should be with shape [batch_size, ' f'noise_dim], but {x.shape} received!') return super().forward(x) class TruncationModule(nn.Module): """Implements the truncation module used in StyleGAN.""" def __init__(self, resolution=1024, w_space_dim=512, truncation_psi=0.7, truncation_layers=8): super().__init__() self.num_layers = int(np.log2(resolution)) * 2 - 2 self.w_space_dim = w_space_dim if truncation_psi is not None and truncation_layers is not None: self.use_truncation = True else: self.use_truncation = False truncation_psi = 1.0 truncation_layers = 0 self.register_buffer('w_avg', torch.zeros(w_space_dim)) layer_idx = np.arange(self.num_layers).reshape(1, self.num_layers, 1) coefs = np.ones_like(layer_idx, dtype=np.float32) coefs[layer_idx < truncation_layers] *= truncation_psi self.register_buffer('truncation', torch.from_numpy(coefs)) def forward(self, w): if len(w.shape) == 2: w = w.view(-1, 1, self.w_space_dim).repeat(1, self.num_layers, 1) if self.use_truncation: w_avg = self.w_avg.view(1, 1, self.w_space_dim) w = w_avg + (w - w_avg) * self.truncation return w class SynthesisModule(nn.Module): """Implements the image synthesis module used in StyleGAN. Basically, this module executes several convolutional layers in sequence. """ def __init__(self, resolution=1024, fused_scale='auto', output_channels=3, randomize_noise=False): super().__init__() try: self.channels = _RESOLUTIONS_TO_CHANNELS[resolution] except KeyError: raise ValueError(f'Invalid resolution: {resolution}!\n' f'Resolutions allowed: ' f'{list(_RESOLUTIONS_TO_CHANNELS)}.') assert len(self.channels) == int(np.log2(resolution)) for block_idx in range(1, len(self.channels)): if block_idx == 1: self.add_module( f'layer{2 * block_idx - 2}', FirstConvBlock(in_channels=self.channels[block_idx - 1], randomize_noise=randomize_noise)) else: self.add_module( f'layer{2 * block_idx - 2}', UpConvBlock(layer_idx=2 * block_idx - 2, in_channels=self.channels[block_idx - 1], out_channels=self.channels[block_idx], randomize_noise=randomize_noise, fused_scale=fused_scale)) self.add_module( f'layer{2 * block_idx - 1}', ConvBlock(layer_idx=2 * block_idx - 1, in_channels=self.channels[block_idx], out_channels=self.channels[block_idx], randomize_noise=randomize_noise)) self.add_module( f'output{block_idx - 1}', LastConvBlock(in_channels=self.channels[block_idx], out_channels=output_channels)) self.upsample = ResolutionScalingLayer() self.lod = nn.Parameter(torch.zeros(())) def forward(self, w): lod = self.lod.cpu().tolist() x = self.layer0(w[:, 0]) for block_idx in range(1, len(self.channels)): if block_idx + lod < len(self.channels): layer_idx = 2 * block_idx - 2 if layer_idx == 0: x = self.__getattr__(f'layer{layer_idx}')(w[:, layer_idx]) else: x = self.__getattr__(f'layer{layer_idx}')(x, w[:, layer_idx]) layer_idx = 2 * block_idx - 1 x = self.__getattr__(f'layer{layer_idx}')(x, w[:, layer_idx]) image = self.__getattr__(f'output{block_idx - 1}')(x) else: image = self.upsample(image) return image class PixelNormLayer(nn.Module): """Implements pixel-wise feature vector normalization layer.""" def __init__(self, epsilon=1e-8): super().__init__() self.epsilon = epsilon def forward(self, x): return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon) class InstanceNormLayer(nn.Module): """Implements instance normalization layer.""" def __init__(self, epsilon=1e-8): super().__init__() self.epsilon = epsilon def forward(self, x): if len(x.shape) != 4: raise ValueError(f'The input tensor should be with shape [batch_size, ' f'num_channels, height, width], but {x.shape} received!') x = x - torch.mean(x, dim=[2, 3], keepdim=True) x = x / torch.sqrt(torch.mean(x**2, dim=[2, 3], keepdim=True) + self.epsilon) return x class ResolutionScalingLayer(nn.Module): """Implements the resolution scaling layer. Basically, this layer can be used to upsample or downsample feature maps from spatial domain with nearest neighbor interpolation. """ def __init__(self, scale_factor=2): super().__init__() self.scale_factor = scale_factor def forward(self, x): return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') class BlurLayer(nn.Module): """Implements the blur layer used in StyleGAN.""" def __init__(self, channels, kernel=(1, 2, 1), normalize=True, flip=False): super().__init__() kernel = np.array(kernel, dtype=np.float32).reshape(1, 3) kernel = kernel.T.dot(kernel) if normalize: kernel /= np.sum(kernel) if flip: kernel = kernel[::-1, ::-1] kernel = kernel.reshape(3, 3, 1, 1) kernel = np.tile(kernel, [1, 1, channels, 1]) kernel = np.transpose(kernel, [2, 3, 0, 1]) self.register_buffer('kernel', torch.from_numpy(kernel)) self.channels = channels def forward(self, x): return F.conv2d(x, self.kernel, stride=1, padding=1, groups=self.channels) class NoiseApplyingLayer(nn.Module): """Implements the noise applying layer used in StyleGAN.""" def __init__(self, layer_idx, channels, randomize_noise=False): super().__init__() self.randomize_noise = randomize_noise self.res = 2**(layer_idx // 2 + 2) self.register_buffer('noise', torch.randn(1, 1, self.res, self.res)) self.weight = nn.Parameter(torch.zeros(channels)) def forward(self, x): if len(x.shape) != 4: raise ValueError(f'The input tensor should be with shape [batch_size, ' f'num_channels, height, width], but {x.shape} received!') if self.randomize_noise: noise = torch.randn(x.shape[0], 1, self.res, self.res).to(x) else: noise = self.noise return x + noise * self.weight.view(1, -1, 1, 1) class StyleModulationLayer(nn.Module): """Implements the style modulation layer used in StyleGAN.""" def __init__(self, channels, w_space_dim=512): super().__init__() self.channels = channels self.dense = DenseBlock(in_features=w_space_dim, out_features=channels*2, wscale_gain=1.0, wscale_lr_multiplier=1.0, activation_type='linear') def forward(self, x, w): if len(w.shape) != 2: raise ValueError(f'The input tensor should be with shape [batch_size, ' f'num_channels], but {x.shape} received!') style = self.dense(w) style = style.view(-1, 2, self.channels, 1, 1) return x * (style[:, 0] + 1) + style[:, 1] class WScaleLayer(nn.Module): """Implements the layer to scale weight variable and add bias. Note that, the weight variable is trained in `nn.Conv2d` layer (or `nn.Linear` layer), and only scaled with a constant number , which is not trainable, in this layer. However, the bias variable is trainable in this layer. """ def __init__(self, in_channels, out_channels, kernel_size, gain=np.sqrt(2.0), lr_multiplier=1.0): super().__init__() fan_in = in_channels * kernel_size * kernel_size self.scale = gain / np.sqrt(fan_in) * lr_multiplier self.bias = nn.Parameter(torch.zeros(out_channels)) self.lr_multiplier = lr_multiplier def forward(self, x): if len(x.shape) == 4: return x * self.scale + self.bias.view(1, -1, 1, 1) * self.lr_multiplier if len(x.shape) == 2: return x * self.scale + self.bias.view(1, -1) * self.lr_multiplier raise ValueError(f'The input tensor should be with shape [batch_size, ' f'num_channels, height, width], or [batch_size, ' f'num_channels], but {x.shape} received!') class EpilogueBlock(nn.Module): """Implements the epilogue block of each conv block.""" def __init__(self, layer_idx, channels, randomize_noise=False, normalization_fn='instance'): super().__init__() self.apply_noise = NoiseApplyingLayer(layer_idx, channels, randomize_noise) self.bias = nn.Parameter(torch.zeros(channels)) self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) if normalization_fn == 'pixel': self.norm = PixelNormLayer() elif normalization_fn == 'instance': self.norm = InstanceNormLayer() else: raise NotImplementedError(f'Not implemented normalization function: ' f'{normalization_fn}!') self.style_mod = StyleModulationLayer(channels) def forward(self, x, w): x = self.apply_noise(x) x = x + self.bias.view(1, -1, 1, 1) x = self.activate(x) x = self.norm(x) x = self.style_mod(x, w) return x class FirstConvBlock(nn.Module): """Implements the first convolutional block used in StyleGAN. Basically, this block starts from a const input, which is `ones(512, 4, 4)`. """ def __init__(self, in_channels, randomize_noise=False): super().__init__() self.first_layer = nn.Parameter(torch.ones(1, in_channels, 4, 4)) self.epilogue = EpilogueBlock(layer_idx=0, channels=in_channels, randomize_noise=randomize_noise) def forward(self, w): x = self.first_layer.repeat(w.shape[0], 1, 1, 1) x = self.epilogue(x, w) return x class UpConvBlock(nn.Module): """Implements the convolutional block used in StyleGAN. Basically, this block is used as the first convolutional block for each resolution, which will execute upsampling. """ def __init__(self, layer_idx, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, add_bias=False, fused_scale='auto', wscale_gain=np.sqrt(2.0), wscale_lr_multiplier=1.0, randomize_noise=False): """Initializes the class with block settings. Args: in_channels: Number of channels of the input tensor fed into this block. out_channels: Number of channels (kernels) of the output tensor. kernel_size: Size of the convolutional kernel. stride: Stride parameter for convolution operation. padding: Padding parameter for convolution operation. dilation: Dilation rate for convolution operation. add_bias: Whether to add bias onto the convolutional result. fused_scale: Whether to fuse `upsample` and `conv2d` together, resulting in `conv2d_transpose`. wscale_gain: The gain factor for `wscale` layer. wscale_lr_multiplier: The learning rate multiplier factor for `wscale` layer. randomize_noise: Whether to add random noise. Raises: ValueError: If the block is not applied to the first block for a particular resolution. Or `fused_scale` does not belong to [True, False, `auto`]. """ super().__init__() if layer_idx % 2 == 1: raise ValueError(f'This block is implemented as the first block of each ' f'resolution, but is applied to layer {layer_idx}!') if fused_scale not in [True, False, 'auto']: raise ValueError(f'`fused_scale` can only be [True, False, `auto`], ' f'but {fused_scale} received!') cur_res = 2 ** (layer_idx // 2 + 2) if fused_scale == 'auto': self.fused_scale = (cur_res >= _AUTO_FUSED_SCALE_MIN_RES) else: self.fused_scale = fused_scale if self.fused_scale: self.weight = nn.Parameter( torch.randn(kernel_size, kernel_size, in_channels, out_channels)) else: self.upsample = ResolutionScalingLayer() self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=1, bias=add_bias) fan_in = in_channels * kernel_size * kernel_size self.scale = wscale_gain / np.sqrt(fan_in) * wscale_lr_multiplier self.blur = BlurLayer(channels=out_channels) self.epilogue = EpilogueBlock(layer_idx=layer_idx, channels=out_channels, randomize_noise=randomize_noise) def forward(self, x, w): if self.fused_scale: kernel = self.weight * self.scale kernel = F.pad(kernel, (0, 0, 0, 0, 1, 1, 1, 1), 'constant', 0.0) kernel = (kernel[1:, 1:] + kernel[:-1, 1:] + kernel[1:, :-1] + kernel[:-1, :-1]) kernel = kernel.permute(2, 3, 0, 1) x = F.conv_transpose2d(x, kernel, stride=2, padding=1) else: x = self.upsample(x) x = self.conv(x) * self.scale x = self.blur(x) x = self.epilogue(x, w) return x class ConvBlock(nn.Module): """Implements the convolutional block used in StyleGAN. Basically, this block is used as the second convolutional block for each resolution. """ def __init__(self, layer_idx, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, add_bias=False, wscale_gain=np.sqrt(2.0), wscale_lr_multiplier=1.0, randomize_noise=False): """Initializes the class with block settings. Args: in_channels: Number of channels of the input tensor fed into this block. out_channels: Number of channels (kernels) of the output tensor. kernel_size: Size of the convolutional kernel. stride: Stride parameter for convolution operation. padding: Padding parameter for convolution operation. dilation: Dilation rate for convolution operation. add_bias: Whether to add bias onto the convolutional result. wscale_gain: The gain factor for `wscale` layer. wscale_lr_multiplier: The learning rate multiplier factor for `wscale` layer. randomize_noise: Whether to add random noise. Raises: ValueError: If the block is not applied to the second block for a particular resolution. """ super().__init__() if layer_idx % 2 == 0: raise ValueError(f'This block is implemented as the second block of each ' f'resolution, but is applied to layer {layer_idx}!') self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=1, bias=add_bias) fan_in = in_channels * kernel_size * kernel_size self.scale = wscale_gain / np.sqrt(fan_in) * wscale_lr_multiplier self.epilogue = EpilogueBlock(layer_idx=layer_idx, channels=out_channels, randomize_noise=randomize_noise) def forward(self, x, w): x = self.conv(x) * self.scale x = self.epilogue(x, w) return x class LastConvBlock(nn.Module): """Implements the last convolutional block used in StyleGAN. Basically, this block converts the final feature map to RGB image. """ def __init__(self, in_channels, out_channels=3): super().__init__() self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False) self.scale = 1 / np.sqrt(in_channels) self.bias = nn.Parameter(torch.zeros(3)) def forward(self, x): x = self.conv(x) * self.scale x = x + self.bias.view(1, -1, 1, 1) return x class DenseBlock(nn.Module): """Implements the dense block used in StyleGAN. Basically, this block executes fully-connected layer, weight-scale layer, and activation layer in sequence. """ def __init__(self, in_features, out_features, add_bias=False, wscale_gain=np.sqrt(2.0), wscale_lr_multiplier=0.01, activation_type='lrelu'): """Initializes the class with block settings. Args: in_features: Number of channels of the input tensor fed into this block. out_features: Number of channels of the output tensor. add_bias: Whether to add bias onto the fully-connected result. wscale_gain: The gain factor for `wscale` layer. wscale_lr_multiplier: The learning rate multiplier factor for `wscale` layer. activation_type: Type of activation function. Support `linear` and `lrelu`. Raises: NotImplementedError: If the input `activation_type` is not supported. """ super().__init__() self.linear = nn.Linear(in_features=in_features, out_features=out_features, bias=add_bias) self.wscale = WScaleLayer(in_channels=in_features, out_channels=out_features, kernel_size=1, gain=wscale_gain, lr_multiplier=wscale_lr_multiplier) if activation_type == 'linear': self.activate = nn.Identity() elif activation_type == 'lrelu': self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) else: raise NotImplementedError(f'Not implemented activation function: ' f'{activation_type}!') def forward(self, x): x = self.linear(x) x = self.wscale(x) x = self.activate(x) return x