interfacegan_pp / models /stylegan_generator_model.py
ybelkada's picture
commit files
4d6b877
raw
history blame
36 kB
# python3.7
"""Contains the implementation of generator described in StyleGAN.
Different from the official tensorflow model in folder `stylegan_tf_official`,
this is a simple pytorch version which only contains the generator part. This
class is specially used for inference.
For more details, please check the original paper:
https://arxiv.org/pdf/1812.04948.pdf
"""
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['StyleGANGeneratorModel']
# Defines a dictionary, which maps the target resolution of the final generated
# image to numbers of filters used in each convolutional layer in sequence.
_RESOLUTIONS_TO_CHANNELS = {
8: [512, 512, 512],
16: [512, 512, 512, 512],
32: [512, 512, 512, 512, 512],
64: [512, 512, 512, 512, 512, 256],
128: [512, 512, 512, 512, 512, 256, 128],
256: [512, 512, 512, 512, 512, 256, 128, 64],
512: [512, 512, 512, 512, 512, 256, 128, 64, 32],
1024: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16],
}
# pylint: disable=line-too-long
# Variable mapping from pytorch model to official tensorflow model.
_STYLEGAN_PTH_VARS_TO_TF_VARS = {
# Statistic information of disentangled latent feature, w.
'truncation.w_avg':'dlatent_avg', # [512]
# Noises.
'synthesis.layer0.epilogue.apply_noise.noise': 'noise0', # [1, 1, 4, 4]
'synthesis.layer1.epilogue.apply_noise.noise': 'noise1', # [1, 1, 4, 4]
'synthesis.layer2.epilogue.apply_noise.noise': 'noise2', # [1, 1, 8, 8]
'synthesis.layer3.epilogue.apply_noise.noise': 'noise3', # [1, 1, 8, 8]
'synthesis.layer4.epilogue.apply_noise.noise': 'noise4', # [1, 1, 16, 16]
'synthesis.layer5.epilogue.apply_noise.noise': 'noise5', # [1, 1, 16, 16]
'synthesis.layer6.epilogue.apply_noise.noise': 'noise6', # [1, 1, 32, 32]
'synthesis.layer7.epilogue.apply_noise.noise': 'noise7', # [1, 1, 32, 32]
'synthesis.layer8.epilogue.apply_noise.noise': 'noise8', # [1, 1, 64, 64]
'synthesis.layer9.epilogue.apply_noise.noise': 'noise9', # [1, 1, 64, 64]
'synthesis.layer10.epilogue.apply_noise.noise': 'noise10', # [1, 1, 128, 128]
'synthesis.layer11.epilogue.apply_noise.noise': 'noise11', # [1, 1, 128, 128]
'synthesis.layer12.epilogue.apply_noise.noise': 'noise12', # [1, 1, 256, 256]
'synthesis.layer13.epilogue.apply_noise.noise': 'noise13', # [1, 1, 256, 256]
'synthesis.layer14.epilogue.apply_noise.noise': 'noise14', # [1, 1, 512, 512]
'synthesis.layer15.epilogue.apply_noise.noise': 'noise15', # [1, 1, 512, 512]
'synthesis.layer16.epilogue.apply_noise.noise': 'noise16', # [1, 1, 1024, 1024]
'synthesis.layer17.epilogue.apply_noise.noise': 'noise17', # [1, 1, 1024, 1024]
# Mapping blocks.
'mapping.dense0.linear.weight': 'Dense0/weight', # [512, 512]
'mapping.dense0.wscale.bias': 'Dense0/bias', # [512]
'mapping.dense1.linear.weight': 'Dense1/weight', # [512, 512]
'mapping.dense1.wscale.bias': 'Dense1/bias', # [512]
'mapping.dense2.linear.weight': 'Dense2/weight', # [512, 512]
'mapping.dense2.wscale.bias': 'Dense2/bias', # [512]
'mapping.dense3.linear.weight': 'Dense3/weight', # [512, 512]
'mapping.dense3.wscale.bias': 'Dense3/bias', # [512]
'mapping.dense4.linear.weight': 'Dense4/weight', # [512, 512]
'mapping.dense4.wscale.bias': 'Dense4/bias', # [512]
'mapping.dense5.linear.weight': 'Dense5/weight', # [512, 512]
'mapping.dense5.wscale.bias': 'Dense5/bias', # [512]
'mapping.dense6.linear.weight': 'Dense6/weight', # [512, 512]
'mapping.dense6.wscale.bias': 'Dense6/bias', # [512]
'mapping.dense7.linear.weight': 'Dense7/weight', # [512, 512]
'mapping.dense7.wscale.bias': 'Dense7/bias', # [512]
# Synthesis blocks.
'synthesis.lod': 'lod', # []
'synthesis.layer0.first_layer': '4x4/Const/const', # [1, 512, 4, 4]
'synthesis.layer0.epilogue.apply_noise.weight': '4x4/Const/Noise/weight', # [512]
'synthesis.layer0.epilogue.bias': '4x4/Const/bias', # [512]
'synthesis.layer0.epilogue.style_mod.dense.linear.weight': '4x4/Const/StyleMod/weight', # [1024, 512]
'synthesis.layer0.epilogue.style_mod.dense.wscale.bias': '4x4/Const/StyleMod/bias', # [1024]
'synthesis.layer1.conv.weight': '4x4/Conv/weight', # [512, 512, 3, 3]
'synthesis.layer1.epilogue.apply_noise.weight': '4x4/Conv/Noise/weight', # [512]
'synthesis.layer1.epilogue.bias': '4x4/Conv/bias', # [512]
'synthesis.layer1.epilogue.style_mod.dense.linear.weight': '4x4/Conv/StyleMod/weight', # [1024, 512]
'synthesis.layer1.epilogue.style_mod.dense.wscale.bias': '4x4/Conv/StyleMod/bias', # [1024]
'synthesis.layer2.conv.weight': '8x8/Conv0_up/weight', # [512, 512, 3, 3]
'synthesis.layer2.epilogue.apply_noise.weight': '8x8/Conv0_up/Noise/weight', # [512]
'synthesis.layer2.epilogue.bias': '8x8/Conv0_up/bias', # [512]
'synthesis.layer2.epilogue.style_mod.dense.linear.weight': '8x8/Conv0_up/StyleMod/weight', # [1024, 512]
'synthesis.layer2.epilogue.style_mod.dense.wscale.bias': '8x8/Conv0_up/StyleMod/bias', # [1024]
'synthesis.layer3.conv.weight': '8x8/Conv1/weight', # [512, 512, 3, 3]
'synthesis.layer3.epilogue.apply_noise.weight': '8x8/Conv1/Noise/weight', # [512]
'synthesis.layer3.epilogue.bias': '8x8/Conv1/bias', # [512]
'synthesis.layer3.epilogue.style_mod.dense.linear.weight': '8x8/Conv1/StyleMod/weight', # [1024, 512]
'synthesis.layer3.epilogue.style_mod.dense.wscale.bias': '8x8/Conv1/StyleMod/bias', # [1024]
'synthesis.layer4.conv.weight': '16x16/Conv0_up/weight', # [512, 512, 3, 3]
'synthesis.layer4.epilogue.apply_noise.weight': '16x16/Conv0_up/Noise/weight', # [512]
'synthesis.layer4.epilogue.bias': '16x16/Conv0_up/bias', # [512]
'synthesis.layer4.epilogue.style_mod.dense.linear.weight': '16x16/Conv0_up/StyleMod/weight', # [1024, 512]
'synthesis.layer4.epilogue.style_mod.dense.wscale.bias': '16x16/Conv0_up/StyleMod/bias', # [1024]
'synthesis.layer5.conv.weight': '16x16/Conv1/weight', # [512, 512, 3, 3]
'synthesis.layer5.epilogue.apply_noise.weight': '16x16/Conv1/Noise/weight', # [512]
'synthesis.layer5.epilogue.bias': '16x16/Conv1/bias', # [512]
'synthesis.layer5.epilogue.style_mod.dense.linear.weight': '16x16/Conv1/StyleMod/weight', # [1024, 512]
'synthesis.layer5.epilogue.style_mod.dense.wscale.bias': '16x16/Conv1/StyleMod/bias', # [1024]
'synthesis.layer6.conv.weight': '32x32/Conv0_up/weight', # [512, 512, 3, 3]
'synthesis.layer6.epilogue.apply_noise.weight': '32x32/Conv0_up/Noise/weight', # [512]
'synthesis.layer6.epilogue.bias': '32x32/Conv0_up/bias', # [512]
'synthesis.layer6.epilogue.style_mod.dense.linear.weight': '32x32/Conv0_up/StyleMod/weight', # [1024, 512]
'synthesis.layer6.epilogue.style_mod.dense.wscale.bias': '32x32/Conv0_up/StyleMod/bias', # [1024]
'synthesis.layer7.conv.weight': '32x32/Conv1/weight', # [512, 512, 3, 3]
'synthesis.layer7.epilogue.apply_noise.weight': '32x32/Conv1/Noise/weight', # [512]
'synthesis.layer7.epilogue.bias': '32x32/Conv1/bias', # [512]
'synthesis.layer7.epilogue.style_mod.dense.linear.weight': '32x32/Conv1/StyleMod/weight', # [1024, 512]
'synthesis.layer7.epilogue.style_mod.dense.wscale.bias': '32x32/Conv1/StyleMod/bias', # [1024]
'synthesis.layer8.conv.weight': '64x64/Conv0_up/weight', # [256, 512, 3, 3]
'synthesis.layer8.epilogue.apply_noise.weight': '64x64/Conv0_up/Noise/weight', # [256]
'synthesis.layer8.epilogue.bias': '64x64/Conv0_up/bias', # [256]
'synthesis.layer8.epilogue.style_mod.dense.linear.weight': '64x64/Conv0_up/StyleMod/weight', # [512, 512]
'synthesis.layer8.epilogue.style_mod.dense.wscale.bias': '64x64/Conv0_up/StyleMod/bias', # [512]
'synthesis.layer9.conv.weight': '64x64/Conv1/weight', # [256, 256, 3, 3]
'synthesis.layer9.epilogue.apply_noise.weight': '64x64/Conv1/Noise/weight', # [256]
'synthesis.layer9.epilogue.bias': '64x64/Conv1/bias', # [256]
'synthesis.layer9.epilogue.style_mod.dense.linear.weight': '64x64/Conv1/StyleMod/weight', # [512, 512]
'synthesis.layer9.epilogue.style_mod.dense.wscale.bias': '64x64/Conv1/StyleMod/bias', # [512]
'synthesis.layer10.conv.weight': '128x128/Conv0_up/weight', # [128, 256, 3, 3]
'synthesis.layer10.epilogue.apply_noise.weight': '128x128/Conv0_up/Noise/weight', # [128]
'synthesis.layer10.epilogue.bias': '128x128/Conv0_up/bias', # [128]
'synthesis.layer10.epilogue.style_mod.dense.linear.weight': '128x128/Conv0_up/StyleMod/weight', # [256, 512]
'synthesis.layer10.epilogue.style_mod.dense.wscale.bias': '128x128/Conv0_up/StyleMod/bias', # [256]
'synthesis.layer11.conv.weight': '128x128/Conv1/weight', # [128, 128, 3, 3]
'synthesis.layer11.epilogue.apply_noise.weight': '128x128/Conv1/Noise/weight', # [128]
'synthesis.layer11.epilogue.bias': '128x128/Conv1/bias', # [128]
'synthesis.layer11.epilogue.style_mod.dense.linear.weight': '128x128/Conv1/StyleMod/weight', # [256, 512]
'synthesis.layer11.epilogue.style_mod.dense.wscale.bias': '128x128/Conv1/StyleMod/bias', # [256]
'synthesis.layer12.conv.weight': '256x256/Conv0_up/weight', # [64, 128, 3, 3]
'synthesis.layer12.epilogue.apply_noise.weight': '256x256/Conv0_up/Noise/weight', # [64]
'synthesis.layer12.epilogue.bias': '256x256/Conv0_up/bias', # [64]
'synthesis.layer12.epilogue.style_mod.dense.linear.weight': '256x256/Conv0_up/StyleMod/weight', # [128, 512]
'synthesis.layer12.epilogue.style_mod.dense.wscale.bias': '256x256/Conv0_up/StyleMod/bias', # [128]
'synthesis.layer13.conv.weight': '256x256/Conv1/weight', # [64, 64, 3, 3]
'synthesis.layer13.epilogue.apply_noise.weight': '256x256/Conv1/Noise/weight', # [64]
'synthesis.layer13.epilogue.bias': '256x256/Conv1/bias', # [64]
'synthesis.layer13.epilogue.style_mod.dense.linear.weight': '256x256/Conv1/StyleMod/weight', # [128, 512]
'synthesis.layer13.epilogue.style_mod.dense.wscale.bias': '256x256/Conv1/StyleMod/bias', # [128]
'synthesis.layer14.conv.weight': '512x512/Conv0_up/weight', # [32, 64, 3, 3]
'synthesis.layer14.epilogue.apply_noise.weight': '512x512/Conv0_up/Noise/weight', # [32]
'synthesis.layer14.epilogue.bias': '512x512/Conv0_up/bias', # [32]
'synthesis.layer14.epilogue.style_mod.dense.linear.weight': '512x512/Conv0_up/StyleMod/weight', # [64, 512]
'synthesis.layer14.epilogue.style_mod.dense.wscale.bias': '512x512/Conv0_up/StyleMod/bias', # [64]
'synthesis.layer15.conv.weight': '512x512/Conv1/weight', # [32, 32, 3, 3]
'synthesis.layer15.epilogue.apply_noise.weight': '512x512/Conv1/Noise/weight', # [32]
'synthesis.layer15.epilogue.bias': '512x512/Conv1/bias', # [32]
'synthesis.layer15.epilogue.style_mod.dense.linear.weight': '512x512/Conv1/StyleMod/weight', # [64, 512]
'synthesis.layer15.epilogue.style_mod.dense.wscale.bias': '512x512/Conv1/StyleMod/bias', # [64]
'synthesis.layer16.conv.weight': '1024x1024/Conv0_up/weight', # [16, 32, 3, 3]
'synthesis.layer16.epilogue.apply_noise.weight': '1024x1024/Conv0_up/Noise/weight', # [16]
'synthesis.layer16.epilogue.bias': '1024x1024/Conv0_up/bias', # [16]
'synthesis.layer16.epilogue.style_mod.dense.linear.weight': '1024x1024/Conv0_up/StyleMod/weight', # [32, 512]
'synthesis.layer16.epilogue.style_mod.dense.wscale.bias': '1024x1024/Conv0_up/StyleMod/bias', # [32]
'synthesis.layer17.conv.weight': '1024x1024/Conv1/weight', # [16, 16, 3, 3]
'synthesis.layer17.epilogue.apply_noise.weight': '1024x1024/Conv1/Noise/weight', # [16]
'synthesis.layer17.epilogue.bias': '1024x1024/Conv1/bias', # [16]
'synthesis.layer17.epilogue.style_mod.dense.linear.weight': '1024x1024/Conv1/StyleMod/weight', # [32, 512]
'synthesis.layer17.epilogue.style_mod.dense.wscale.bias': '1024x1024/Conv1/StyleMod/bias', # [32]
'synthesis.output0.conv.weight': 'ToRGB_lod8/weight', # [3, 512, 1, 1]
'synthesis.output0.bias': 'ToRGB_lod8/bias', # [3]
'synthesis.output1.conv.weight': 'ToRGB_lod7/weight', # [3, 512, 1, 1]
'synthesis.output1.bias': 'ToRGB_lod7/bias', # [3]
'synthesis.output2.conv.weight': 'ToRGB_lod6/weight', # [3, 512, 1, 1]
'synthesis.output2.bias': 'ToRGB_lod6/bias', # [3]
'synthesis.output3.conv.weight': 'ToRGB_lod5/weight', # [3, 512, 1, 1]
'synthesis.output3.bias': 'ToRGB_lod5/bias', # [3]
'synthesis.output4.conv.weight': 'ToRGB_lod4/weight', # [3, 256, 1, 1]
'synthesis.output4.bias': 'ToRGB_lod4/bias', # [3]
'synthesis.output5.conv.weight': 'ToRGB_lod3/weight', # [3, 128, 1, 1]
'synthesis.output5.bias': 'ToRGB_lod3/bias', # [3]
'synthesis.output6.conv.weight': 'ToRGB_lod2/weight', # [3, 64, 1, 1]
'synthesis.output6.bias': 'ToRGB_lod2/bias', # [3]
'synthesis.output7.conv.weight': 'ToRGB_lod1/weight', # [3, 32, 1, 1]
'synthesis.output7.bias': 'ToRGB_lod1/bias', # [3]
'synthesis.output8.conv.weight': 'ToRGB_lod0/weight', # [3, 16, 1, 1]
'synthesis.output8.bias': 'ToRGB_lod0/bias', # [3]
}
# pylint: enable=line-too-long
# Minimal resolution for `auto` fused-scale strategy.
_AUTO_FUSED_SCALE_MIN_RES = 128
class StyleGANGeneratorModel(nn.Module):
"""Defines the generator module in StyleGAN.
Note that the generated images are with RGB color channels.
"""
def __init__(self,
resolution=1024,
w_space_dim=512,
fused_scale='auto',
output_channels=3,
truncation_psi=0.7,
truncation_layers=8,
randomize_noise=False):
"""Initializes the generator with basic settings.
Args:
resolution: The resolution of the final output image. (default: 1024)
w_space_dim: The dimension of the disentangled latent vectors, w.
(default: 512)
fused_scale: If set as `True`, `conv2d_transpose` is used for upscaling.
If set as `False`, `upsample + conv2d` is used for upscaling. If set as
`auto`, `upsample + conv2d` is used for bottom layers until resolution
reaches 128. (default: `auto`)
output_channels: Number of channels of output image. (default: 3)
truncation_psi: Style strength multiplier for the truncation trick.
`None` or `1.0` indicates no truncation. (default: 0.7)
truncation_layers: Number of layers for which to apply the truncation
trick. `None` indicates no truncation. (default: 8)
randomize_noise: Whether to add random noise for each convolutional layer.
(default: False)
Raises:
ValueError: If the input `resolution` is not supported.
"""
super().__init__()
self.resolution = resolution
self.w_space_dim = w_space_dim
self.fused_scale = fused_scale
self.output_channels = output_channels
self.truncation_psi = truncation_psi
self.truncation_layers = truncation_layers
self.randomize_noise = randomize_noise
self.mapping = MappingModule(final_space_dim=self.w_space_dim)
self.truncation = TruncationModule(resolution=self.resolution,
w_space_dim=self.w_space_dim,
truncation_psi=self.truncation_psi,
truncation_layers=self.truncation_layers)
self.synthesis = SynthesisModule(resolution=self.resolution,
fused_scale=self.fused_scale,
output_channels=self.output_channels,
randomize_noise=self.randomize_noise)
self.pth_to_tf_var_mapping = {}
for pth_var_name, tf_var_name in _STYLEGAN_PTH_VARS_TO_TF_VARS.items():
if 'Conv0_up' in tf_var_name:
res = int(tf_var_name.split('x')[0])
if ((self.fused_scale is True) or
(self.fused_scale == 'auto' and res >= _AUTO_FUSED_SCALE_MIN_RES)):
pth_var_name = pth_var_name.replace('conv.weight', 'weight')
self.pth_to_tf_var_mapping[pth_var_name] = tf_var_name
def forward(self, z):
w = self.mapping(z)
w = self.truncation(w)
x = self.synthesis(w)
return x
class MappingModule(nn.Sequential):
"""Implements the latent space mapping module used in StyleGAN.
Basically, this module executes several dense layers in sequence.
"""
def __init__(self,
normalize_input=True,
input_space_dim=512,
hidden_space_dim=512,
final_space_dim=512,
num_layers=8):
sequence = OrderedDict()
def _add_layer(layer, name=None):
name = name or f'dense{len(sequence) + (not normalize_input) - 1}'
sequence[name] = layer
if normalize_input:
_add_layer(PixelNormLayer(), name='normalize')
for i in range(num_layers):
in_dim = input_space_dim if i == 0 else hidden_space_dim
out_dim = final_space_dim if i == (num_layers - 1) else hidden_space_dim
_add_layer(DenseBlock(in_dim, out_dim))
super().__init__(sequence)
def forward(self, x):
if len(x.shape) != 2:
raise ValueError(f'The input tensor should be with shape [batch_size, '
f'noise_dim], but {x.shape} received!')
return super().forward(x)
class TruncationModule(nn.Module):
"""Implements the truncation module used in StyleGAN."""
def __init__(self,
resolution=1024,
w_space_dim=512,
truncation_psi=0.7,
truncation_layers=8):
super().__init__()
self.num_layers = int(np.log2(resolution)) * 2 - 2
self.w_space_dim = w_space_dim
if truncation_psi is not None and truncation_layers is not None:
self.use_truncation = True
else:
self.use_truncation = False
truncation_psi = 1.0
truncation_layers = 0
self.register_buffer('w_avg', torch.zeros(w_space_dim))
layer_idx = np.arange(self.num_layers).reshape(1, self.num_layers, 1)
coefs = np.ones_like(layer_idx, dtype=np.float32)
coefs[layer_idx < truncation_layers] *= truncation_psi
self.register_buffer('truncation', torch.from_numpy(coefs))
def forward(self, w):
if len(w.shape) == 2:
w = w.view(-1, 1, self.w_space_dim).repeat(1, self.num_layers, 1)
if self.use_truncation:
w_avg = self.w_avg.view(1, 1, self.w_space_dim)
w = w_avg + (w - w_avg) * self.truncation
return w
class SynthesisModule(nn.Module):
"""Implements the image synthesis module used in StyleGAN.
Basically, this module executes several convolutional layers in sequence.
"""
def __init__(self,
resolution=1024,
fused_scale='auto',
output_channels=3,
randomize_noise=False):
super().__init__()
try:
self.channels = _RESOLUTIONS_TO_CHANNELS[resolution]
except KeyError:
raise ValueError(f'Invalid resolution: {resolution}!\n'
f'Resolutions allowed: '
f'{list(_RESOLUTIONS_TO_CHANNELS)}.')
assert len(self.channels) == int(np.log2(resolution))
for block_idx in range(1, len(self.channels)):
if block_idx == 1:
self.add_module(
f'layer{2 * block_idx - 2}',
FirstConvBlock(in_channels=self.channels[block_idx - 1],
randomize_noise=randomize_noise))
else:
self.add_module(
f'layer{2 * block_idx - 2}',
UpConvBlock(layer_idx=2 * block_idx - 2,
in_channels=self.channels[block_idx - 1],
out_channels=self.channels[block_idx],
randomize_noise=randomize_noise,
fused_scale=fused_scale))
self.add_module(
f'layer{2 * block_idx - 1}',
ConvBlock(layer_idx=2 * block_idx - 1,
in_channels=self.channels[block_idx],
out_channels=self.channels[block_idx],
randomize_noise=randomize_noise))
self.add_module(
f'output{block_idx - 1}',
LastConvBlock(in_channels=self.channels[block_idx],
out_channels=output_channels))
self.upsample = ResolutionScalingLayer()
self.lod = nn.Parameter(torch.zeros(()))
def forward(self, w):
lod = self.lod.cpu().tolist()
x = self.layer0(w[:, 0])
for block_idx in range(1, len(self.channels)):
if block_idx + lod < len(self.channels):
layer_idx = 2 * block_idx - 2
if layer_idx == 0:
x = self.__getattr__(f'layer{layer_idx}')(w[:, layer_idx])
else:
x = self.__getattr__(f'layer{layer_idx}')(x, w[:, layer_idx])
layer_idx = 2 * block_idx - 1
x = self.__getattr__(f'layer{layer_idx}')(x, w[:, layer_idx])
image = self.__getattr__(f'output{block_idx - 1}')(x)
else:
image = self.upsample(image)
return image
class PixelNormLayer(nn.Module):
"""Implements pixel-wise feature vector normalization layer."""
def __init__(self, epsilon=1e-8):
super().__init__()
self.epsilon = epsilon
def forward(self, x):
return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)
class InstanceNormLayer(nn.Module):
"""Implements instance normalization layer."""
def __init__(self, epsilon=1e-8):
super().__init__()
self.epsilon = epsilon
def forward(self, x):
if len(x.shape) != 4:
raise ValueError(f'The input tensor should be with shape [batch_size, '
f'num_channels, height, width], but {x.shape} received!')
x = x - torch.mean(x, dim=[2, 3], keepdim=True)
x = x / torch.sqrt(torch.mean(x**2, dim=[2, 3], keepdim=True) +
self.epsilon)
return x
class ResolutionScalingLayer(nn.Module):
"""Implements the resolution scaling layer.
Basically, this layer can be used to upsample or downsample feature maps from
spatial domain with nearest neighbor interpolation.
"""
def __init__(self, scale_factor=2):
super().__init__()
self.scale_factor = scale_factor
def forward(self, x):
return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
class BlurLayer(nn.Module):
"""Implements the blur layer used in StyleGAN."""
def __init__(self,
channels,
kernel=(1, 2, 1),
normalize=True,
flip=False):
super().__init__()
kernel = np.array(kernel, dtype=np.float32).reshape(1, 3)
kernel = kernel.T.dot(kernel)
if normalize:
kernel /= np.sum(kernel)
if flip:
kernel = kernel[::-1, ::-1]
kernel = kernel.reshape(3, 3, 1, 1)
kernel = np.tile(kernel, [1, 1, channels, 1])
kernel = np.transpose(kernel, [2, 3, 0, 1])
self.register_buffer('kernel', torch.from_numpy(kernel))
self.channels = channels
def forward(self, x):
return F.conv2d(x, self.kernel, stride=1, padding=1, groups=self.channels)
class NoiseApplyingLayer(nn.Module):
"""Implements the noise applying layer used in StyleGAN."""
def __init__(self, layer_idx, channels, randomize_noise=False):
super().__init__()
self.randomize_noise = randomize_noise
self.res = 2**(layer_idx // 2 + 2)
self.register_buffer('noise', torch.randn(1, 1, self.res, self.res))
self.weight = nn.Parameter(torch.zeros(channels))
def forward(self, x):
if len(x.shape) != 4:
raise ValueError(f'The input tensor should be with shape [batch_size, '
f'num_channels, height, width], but {x.shape} received!')
if self.randomize_noise:
noise = torch.randn(x.shape[0], 1, self.res, self.res).to(x)
else:
noise = self.noise
return x + noise * self.weight.view(1, -1, 1, 1)
class StyleModulationLayer(nn.Module):
"""Implements the style modulation layer used in StyleGAN."""
def __init__(self, channels, w_space_dim=512):
super().__init__()
self.channels = channels
self.dense = DenseBlock(in_features=w_space_dim,
out_features=channels*2,
wscale_gain=1.0,
wscale_lr_multiplier=1.0,
activation_type='linear')
def forward(self, x, w):
if len(w.shape) != 2:
raise ValueError(f'The input tensor should be with shape [batch_size, '
f'num_channels], but {x.shape} received!')
style = self.dense(w)
style = style.view(-1, 2, self.channels, 1, 1)
return x * (style[:, 0] + 1) + style[:, 1]
class WScaleLayer(nn.Module):
"""Implements the layer to scale weight variable and add bias.
Note that, the weight variable is trained in `nn.Conv2d` layer (or `nn.Linear`
layer), and only scaled with a constant number , which is not trainable, in
this layer. However, the bias variable is trainable in this layer.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
gain=np.sqrt(2.0),
lr_multiplier=1.0):
super().__init__()
fan_in = in_channels * kernel_size * kernel_size
self.scale = gain / np.sqrt(fan_in) * lr_multiplier
self.bias = nn.Parameter(torch.zeros(out_channels))
self.lr_multiplier = lr_multiplier
def forward(self, x):
if len(x.shape) == 4:
return x * self.scale + self.bias.view(1, -1, 1, 1) * self.lr_multiplier
if len(x.shape) == 2:
return x * self.scale + self.bias.view(1, -1) * self.lr_multiplier
raise ValueError(f'The input tensor should be with shape [batch_size, '
f'num_channels, height, width], or [batch_size, '
f'num_channels], but {x.shape} received!')
class EpilogueBlock(nn.Module):
"""Implements the epilogue block of each conv block."""
def __init__(self,
layer_idx,
channels,
randomize_noise=False,
normalization_fn='instance'):
super().__init__()
self.apply_noise = NoiseApplyingLayer(layer_idx, channels, randomize_noise)
self.bias = nn.Parameter(torch.zeros(channels))
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
if normalization_fn == 'pixel':
self.norm = PixelNormLayer()
elif normalization_fn == 'instance':
self.norm = InstanceNormLayer()
else:
raise NotImplementedError(f'Not implemented normalization function: '
f'{normalization_fn}!')
self.style_mod = StyleModulationLayer(channels)
def forward(self, x, w):
x = self.apply_noise(x)
x = x + self.bias.view(1, -1, 1, 1)
x = self.activate(x)
x = self.norm(x)
x = self.style_mod(x, w)
return x
class FirstConvBlock(nn.Module):
"""Implements the first convolutional block used in StyleGAN.
Basically, this block starts from a const input, which is `ones(512, 4, 4)`.
"""
def __init__(self, in_channels, randomize_noise=False):
super().__init__()
self.first_layer = nn.Parameter(torch.ones(1, in_channels, 4, 4))
self.epilogue = EpilogueBlock(layer_idx=0,
channels=in_channels,
randomize_noise=randomize_noise)
def forward(self, w):
x = self.first_layer.repeat(w.shape[0], 1, 1, 1)
x = self.epilogue(x, w)
return x
class UpConvBlock(nn.Module):
"""Implements the convolutional block used in StyleGAN.
Basically, this block is used as the first convolutional block for each
resolution, which will execute upsampling.
"""
def __init__(self,
layer_idx,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
add_bias=False,
fused_scale='auto',
wscale_gain=np.sqrt(2.0),
wscale_lr_multiplier=1.0,
randomize_noise=False):
"""Initializes the class with block settings.
Args:
in_channels: Number of channels of the input tensor fed into this block.
out_channels: Number of channels (kernels) of the output tensor.
kernel_size: Size of the convolutional kernel.
stride: Stride parameter for convolution operation.
padding: Padding parameter for convolution operation.
dilation: Dilation rate for convolution operation.
add_bias: Whether to add bias onto the convolutional result.
fused_scale: Whether to fuse `upsample` and `conv2d` together, resulting
in `conv2d_transpose`.
wscale_gain: The gain factor for `wscale` layer.
wscale_lr_multiplier: The learning rate multiplier factor for `wscale`
layer.
randomize_noise: Whether to add random noise.
Raises:
ValueError: If the block is not applied to the first block for a
particular resolution. Or `fused_scale` does not belong to [True, False,
`auto`].
"""
super().__init__()
if layer_idx % 2 == 1:
raise ValueError(f'This block is implemented as the first block of each '
f'resolution, but is applied to layer {layer_idx}!')
if fused_scale not in [True, False, 'auto']:
raise ValueError(f'`fused_scale` can only be [True, False, `auto`], '
f'but {fused_scale} received!')
cur_res = 2 ** (layer_idx // 2 + 2)
if fused_scale == 'auto':
self.fused_scale = (cur_res >= _AUTO_FUSED_SCALE_MIN_RES)
else:
self.fused_scale = fused_scale
if self.fused_scale:
self.weight = nn.Parameter(
torch.randn(kernel_size, kernel_size, in_channels, out_channels))
else:
self.upsample = ResolutionScalingLayer()
self.conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=1,
bias=add_bias)
fan_in = in_channels * kernel_size * kernel_size
self.scale = wscale_gain / np.sqrt(fan_in) * wscale_lr_multiplier
self.blur = BlurLayer(channels=out_channels)
self.epilogue = EpilogueBlock(layer_idx=layer_idx,
channels=out_channels,
randomize_noise=randomize_noise)
def forward(self, x, w):
if self.fused_scale:
kernel = self.weight * self.scale
kernel = F.pad(kernel, (0, 0, 0, 0, 1, 1, 1, 1), 'constant', 0.0)
kernel = (kernel[1:, 1:] + kernel[:-1, 1:] +
kernel[1:, :-1] + kernel[:-1, :-1])
kernel = kernel.permute(2, 3, 0, 1)
x = F.conv_transpose2d(x, kernel, stride=2, padding=1)
else:
x = self.upsample(x)
x = self.conv(x) * self.scale
x = self.blur(x)
x = self.epilogue(x, w)
return x
class ConvBlock(nn.Module):
"""Implements the convolutional block used in StyleGAN.
Basically, this block is used as the second convolutional block for each
resolution.
"""
def __init__(self,
layer_idx,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
add_bias=False,
wscale_gain=np.sqrt(2.0),
wscale_lr_multiplier=1.0,
randomize_noise=False):
"""Initializes the class with block settings.
Args:
in_channels: Number of channels of the input tensor fed into this block.
out_channels: Number of channels (kernels) of the output tensor.
kernel_size: Size of the convolutional kernel.
stride: Stride parameter for convolution operation.
padding: Padding parameter for convolution operation.
dilation: Dilation rate for convolution operation.
add_bias: Whether to add bias onto the convolutional result.
wscale_gain: The gain factor for `wscale` layer.
wscale_lr_multiplier: The learning rate multiplier factor for `wscale`
layer.
randomize_noise: Whether to add random noise.
Raises:
ValueError: If the block is not applied to the second block for a
particular resolution.
"""
super().__init__()
if layer_idx % 2 == 0:
raise ValueError(f'This block is implemented as the second block of each '
f'resolution, but is applied to layer {layer_idx}!')
self.conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=1,
bias=add_bias)
fan_in = in_channels * kernel_size * kernel_size
self.scale = wscale_gain / np.sqrt(fan_in) * wscale_lr_multiplier
self.epilogue = EpilogueBlock(layer_idx=layer_idx,
channels=out_channels,
randomize_noise=randomize_noise)
def forward(self, x, w):
x = self.conv(x) * self.scale
x = self.epilogue(x, w)
return x
class LastConvBlock(nn.Module):
"""Implements the last convolutional block used in StyleGAN.
Basically, this block converts the final feature map to RGB image.
"""
def __init__(self, in_channels, out_channels=3):
super().__init__()
self.conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
bias=False)
self.scale = 1 / np.sqrt(in_channels)
self.bias = nn.Parameter(torch.zeros(3))
def forward(self, x):
x = self.conv(x) * self.scale
x = x + self.bias.view(1, -1, 1, 1)
return x
class DenseBlock(nn.Module):
"""Implements the dense block used in StyleGAN.
Basically, this block executes fully-connected layer, weight-scale layer,
and activation layer in sequence.
"""
def __init__(self,
in_features,
out_features,
add_bias=False,
wscale_gain=np.sqrt(2.0),
wscale_lr_multiplier=0.01,
activation_type='lrelu'):
"""Initializes the class with block settings.
Args:
in_features: Number of channels of the input tensor fed into this block.
out_features: Number of channels of the output tensor.
add_bias: Whether to add bias onto the fully-connected result.
wscale_gain: The gain factor for `wscale` layer.
wscale_lr_multiplier: The learning rate multiplier factor for `wscale`
layer.
activation_type: Type of activation function. Support `linear` and
`lrelu`.
Raises:
NotImplementedError: If the input `activation_type` is not supported.
"""
super().__init__()
self.linear = nn.Linear(in_features=in_features,
out_features=out_features,
bias=add_bias)
self.wscale = WScaleLayer(in_channels=in_features,
out_channels=out_features,
kernel_size=1,
gain=wscale_gain,
lr_multiplier=wscale_lr_multiplier)
if activation_type == 'linear':
self.activate = nn.Identity()
elif activation_type == 'lrelu':
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
raise NotImplementedError(f'Not implemented activation function: '
f'{activation_type}!')
def forward(self, x):
x = self.linear(x)
x = self.wscale(x)
x = self.activate(x)
return x