Spaces:
Runtime error
Runtime error
# python3.7 | |
"""Contains the implementation of generator described in StyleGAN. | |
Different from the official tensorflow model in folder `stylegan_tf_official`, | |
this is a simple pytorch version which only contains the generator part. This | |
class is specially used for inference. | |
For more details, please check the original paper: | |
https://arxiv.org/pdf/1812.04948.pdf | |
""" | |
from collections import OrderedDict | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
__all__ = ['StyleGANGeneratorModel'] | |
# Defines a dictionary, which maps the target resolution of the final generated | |
# image to numbers of filters used in each convolutional layer in sequence. | |
_RESOLUTIONS_TO_CHANNELS = { | |
8: [512, 512, 512], | |
16: [512, 512, 512, 512], | |
32: [512, 512, 512, 512, 512], | |
64: [512, 512, 512, 512, 512, 256], | |
128: [512, 512, 512, 512, 512, 256, 128], | |
256: [512, 512, 512, 512, 512, 256, 128, 64], | |
512: [512, 512, 512, 512, 512, 256, 128, 64, 32], | |
1024: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16], | |
} | |
# pylint: disable=line-too-long | |
# Variable mapping from pytorch model to official tensorflow model. | |
_STYLEGAN_PTH_VARS_TO_TF_VARS = { | |
# Statistic information of disentangled latent feature, w. | |
'truncation.w_avg':'dlatent_avg', # [512] | |
# Noises. | |
'synthesis.layer0.epilogue.apply_noise.noise': 'noise0', # [1, 1, 4, 4] | |
'synthesis.layer1.epilogue.apply_noise.noise': 'noise1', # [1, 1, 4, 4] | |
'synthesis.layer2.epilogue.apply_noise.noise': 'noise2', # [1, 1, 8, 8] | |
'synthesis.layer3.epilogue.apply_noise.noise': 'noise3', # [1, 1, 8, 8] | |
'synthesis.layer4.epilogue.apply_noise.noise': 'noise4', # [1, 1, 16, 16] | |
'synthesis.layer5.epilogue.apply_noise.noise': 'noise5', # [1, 1, 16, 16] | |
'synthesis.layer6.epilogue.apply_noise.noise': 'noise6', # [1, 1, 32, 32] | |
'synthesis.layer7.epilogue.apply_noise.noise': 'noise7', # [1, 1, 32, 32] | |
'synthesis.layer8.epilogue.apply_noise.noise': 'noise8', # [1, 1, 64, 64] | |
'synthesis.layer9.epilogue.apply_noise.noise': 'noise9', # [1, 1, 64, 64] | |
'synthesis.layer10.epilogue.apply_noise.noise': 'noise10', # [1, 1, 128, 128] | |
'synthesis.layer11.epilogue.apply_noise.noise': 'noise11', # [1, 1, 128, 128] | |
'synthesis.layer12.epilogue.apply_noise.noise': 'noise12', # [1, 1, 256, 256] | |
'synthesis.layer13.epilogue.apply_noise.noise': 'noise13', # [1, 1, 256, 256] | |
'synthesis.layer14.epilogue.apply_noise.noise': 'noise14', # [1, 1, 512, 512] | |
'synthesis.layer15.epilogue.apply_noise.noise': 'noise15', # [1, 1, 512, 512] | |
'synthesis.layer16.epilogue.apply_noise.noise': 'noise16', # [1, 1, 1024, 1024] | |
'synthesis.layer17.epilogue.apply_noise.noise': 'noise17', # [1, 1, 1024, 1024] | |
# Mapping blocks. | |
'mapping.dense0.linear.weight': 'Dense0/weight', # [512, 512] | |
'mapping.dense0.wscale.bias': 'Dense0/bias', # [512] | |
'mapping.dense1.linear.weight': 'Dense1/weight', # [512, 512] | |
'mapping.dense1.wscale.bias': 'Dense1/bias', # [512] | |
'mapping.dense2.linear.weight': 'Dense2/weight', # [512, 512] | |
'mapping.dense2.wscale.bias': 'Dense2/bias', # [512] | |
'mapping.dense3.linear.weight': 'Dense3/weight', # [512, 512] | |
'mapping.dense3.wscale.bias': 'Dense3/bias', # [512] | |
'mapping.dense4.linear.weight': 'Dense4/weight', # [512, 512] | |
'mapping.dense4.wscale.bias': 'Dense4/bias', # [512] | |
'mapping.dense5.linear.weight': 'Dense5/weight', # [512, 512] | |
'mapping.dense5.wscale.bias': 'Dense5/bias', # [512] | |
'mapping.dense6.linear.weight': 'Dense6/weight', # [512, 512] | |
'mapping.dense6.wscale.bias': 'Dense6/bias', # [512] | |
'mapping.dense7.linear.weight': 'Dense7/weight', # [512, 512] | |
'mapping.dense7.wscale.bias': 'Dense7/bias', # [512] | |
# Synthesis blocks. | |
'synthesis.lod': 'lod', # [] | |
'synthesis.layer0.first_layer': '4x4/Const/const', # [1, 512, 4, 4] | |
'synthesis.layer0.epilogue.apply_noise.weight': '4x4/Const/Noise/weight', # [512] | |
'synthesis.layer0.epilogue.bias': '4x4/Const/bias', # [512] | |
'synthesis.layer0.epilogue.style_mod.dense.linear.weight': '4x4/Const/StyleMod/weight', # [1024, 512] | |
'synthesis.layer0.epilogue.style_mod.dense.wscale.bias': '4x4/Const/StyleMod/bias', # [1024] | |
'synthesis.layer1.conv.weight': '4x4/Conv/weight', # [512, 512, 3, 3] | |
'synthesis.layer1.epilogue.apply_noise.weight': '4x4/Conv/Noise/weight', # [512] | |
'synthesis.layer1.epilogue.bias': '4x4/Conv/bias', # [512] | |
'synthesis.layer1.epilogue.style_mod.dense.linear.weight': '4x4/Conv/StyleMod/weight', # [1024, 512] | |
'synthesis.layer1.epilogue.style_mod.dense.wscale.bias': '4x4/Conv/StyleMod/bias', # [1024] | |
'synthesis.layer2.conv.weight': '8x8/Conv0_up/weight', # [512, 512, 3, 3] | |
'synthesis.layer2.epilogue.apply_noise.weight': '8x8/Conv0_up/Noise/weight', # [512] | |
'synthesis.layer2.epilogue.bias': '8x8/Conv0_up/bias', # [512] | |
'synthesis.layer2.epilogue.style_mod.dense.linear.weight': '8x8/Conv0_up/StyleMod/weight', # [1024, 512] | |
'synthesis.layer2.epilogue.style_mod.dense.wscale.bias': '8x8/Conv0_up/StyleMod/bias', # [1024] | |
'synthesis.layer3.conv.weight': '8x8/Conv1/weight', # [512, 512, 3, 3] | |
'synthesis.layer3.epilogue.apply_noise.weight': '8x8/Conv1/Noise/weight', # [512] | |
'synthesis.layer3.epilogue.bias': '8x8/Conv1/bias', # [512] | |
'synthesis.layer3.epilogue.style_mod.dense.linear.weight': '8x8/Conv1/StyleMod/weight', # [1024, 512] | |
'synthesis.layer3.epilogue.style_mod.dense.wscale.bias': '8x8/Conv1/StyleMod/bias', # [1024] | |
'synthesis.layer4.conv.weight': '16x16/Conv0_up/weight', # [512, 512, 3, 3] | |
'synthesis.layer4.epilogue.apply_noise.weight': '16x16/Conv0_up/Noise/weight', # [512] | |
'synthesis.layer4.epilogue.bias': '16x16/Conv0_up/bias', # [512] | |
'synthesis.layer4.epilogue.style_mod.dense.linear.weight': '16x16/Conv0_up/StyleMod/weight', # [1024, 512] | |
'synthesis.layer4.epilogue.style_mod.dense.wscale.bias': '16x16/Conv0_up/StyleMod/bias', # [1024] | |
'synthesis.layer5.conv.weight': '16x16/Conv1/weight', # [512, 512, 3, 3] | |
'synthesis.layer5.epilogue.apply_noise.weight': '16x16/Conv1/Noise/weight', # [512] | |
'synthesis.layer5.epilogue.bias': '16x16/Conv1/bias', # [512] | |
'synthesis.layer5.epilogue.style_mod.dense.linear.weight': '16x16/Conv1/StyleMod/weight', # [1024, 512] | |
'synthesis.layer5.epilogue.style_mod.dense.wscale.bias': '16x16/Conv1/StyleMod/bias', # [1024] | |
'synthesis.layer6.conv.weight': '32x32/Conv0_up/weight', # [512, 512, 3, 3] | |
'synthesis.layer6.epilogue.apply_noise.weight': '32x32/Conv0_up/Noise/weight', # [512] | |
'synthesis.layer6.epilogue.bias': '32x32/Conv0_up/bias', # [512] | |
'synthesis.layer6.epilogue.style_mod.dense.linear.weight': '32x32/Conv0_up/StyleMod/weight', # [1024, 512] | |
'synthesis.layer6.epilogue.style_mod.dense.wscale.bias': '32x32/Conv0_up/StyleMod/bias', # [1024] | |
'synthesis.layer7.conv.weight': '32x32/Conv1/weight', # [512, 512, 3, 3] | |
'synthesis.layer7.epilogue.apply_noise.weight': '32x32/Conv1/Noise/weight', # [512] | |
'synthesis.layer7.epilogue.bias': '32x32/Conv1/bias', # [512] | |
'synthesis.layer7.epilogue.style_mod.dense.linear.weight': '32x32/Conv1/StyleMod/weight', # [1024, 512] | |
'synthesis.layer7.epilogue.style_mod.dense.wscale.bias': '32x32/Conv1/StyleMod/bias', # [1024] | |
'synthesis.layer8.conv.weight': '64x64/Conv0_up/weight', # [256, 512, 3, 3] | |
'synthesis.layer8.epilogue.apply_noise.weight': '64x64/Conv0_up/Noise/weight', # [256] | |
'synthesis.layer8.epilogue.bias': '64x64/Conv0_up/bias', # [256] | |
'synthesis.layer8.epilogue.style_mod.dense.linear.weight': '64x64/Conv0_up/StyleMod/weight', # [512, 512] | |
'synthesis.layer8.epilogue.style_mod.dense.wscale.bias': '64x64/Conv0_up/StyleMod/bias', # [512] | |
'synthesis.layer9.conv.weight': '64x64/Conv1/weight', # [256, 256, 3, 3] | |
'synthesis.layer9.epilogue.apply_noise.weight': '64x64/Conv1/Noise/weight', # [256] | |
'synthesis.layer9.epilogue.bias': '64x64/Conv1/bias', # [256] | |
'synthesis.layer9.epilogue.style_mod.dense.linear.weight': '64x64/Conv1/StyleMod/weight', # [512, 512] | |
'synthesis.layer9.epilogue.style_mod.dense.wscale.bias': '64x64/Conv1/StyleMod/bias', # [512] | |
'synthesis.layer10.conv.weight': '128x128/Conv0_up/weight', # [128, 256, 3, 3] | |
'synthesis.layer10.epilogue.apply_noise.weight': '128x128/Conv0_up/Noise/weight', # [128] | |
'synthesis.layer10.epilogue.bias': '128x128/Conv0_up/bias', # [128] | |
'synthesis.layer10.epilogue.style_mod.dense.linear.weight': '128x128/Conv0_up/StyleMod/weight', # [256, 512] | |
'synthesis.layer10.epilogue.style_mod.dense.wscale.bias': '128x128/Conv0_up/StyleMod/bias', # [256] | |
'synthesis.layer11.conv.weight': '128x128/Conv1/weight', # [128, 128, 3, 3] | |
'synthesis.layer11.epilogue.apply_noise.weight': '128x128/Conv1/Noise/weight', # [128] | |
'synthesis.layer11.epilogue.bias': '128x128/Conv1/bias', # [128] | |
'synthesis.layer11.epilogue.style_mod.dense.linear.weight': '128x128/Conv1/StyleMod/weight', # [256, 512] | |
'synthesis.layer11.epilogue.style_mod.dense.wscale.bias': '128x128/Conv1/StyleMod/bias', # [256] | |
'synthesis.layer12.conv.weight': '256x256/Conv0_up/weight', # [64, 128, 3, 3] | |
'synthesis.layer12.epilogue.apply_noise.weight': '256x256/Conv0_up/Noise/weight', # [64] | |
'synthesis.layer12.epilogue.bias': '256x256/Conv0_up/bias', # [64] | |
'synthesis.layer12.epilogue.style_mod.dense.linear.weight': '256x256/Conv0_up/StyleMod/weight', # [128, 512] | |
'synthesis.layer12.epilogue.style_mod.dense.wscale.bias': '256x256/Conv0_up/StyleMod/bias', # [128] | |
'synthesis.layer13.conv.weight': '256x256/Conv1/weight', # [64, 64, 3, 3] | |
'synthesis.layer13.epilogue.apply_noise.weight': '256x256/Conv1/Noise/weight', # [64] | |
'synthesis.layer13.epilogue.bias': '256x256/Conv1/bias', # [64] | |
'synthesis.layer13.epilogue.style_mod.dense.linear.weight': '256x256/Conv1/StyleMod/weight', # [128, 512] | |
'synthesis.layer13.epilogue.style_mod.dense.wscale.bias': '256x256/Conv1/StyleMod/bias', # [128] | |
'synthesis.layer14.conv.weight': '512x512/Conv0_up/weight', # [32, 64, 3, 3] | |
'synthesis.layer14.epilogue.apply_noise.weight': '512x512/Conv0_up/Noise/weight', # [32] | |
'synthesis.layer14.epilogue.bias': '512x512/Conv0_up/bias', # [32] | |
'synthesis.layer14.epilogue.style_mod.dense.linear.weight': '512x512/Conv0_up/StyleMod/weight', # [64, 512] | |
'synthesis.layer14.epilogue.style_mod.dense.wscale.bias': '512x512/Conv0_up/StyleMod/bias', # [64] | |
'synthesis.layer15.conv.weight': '512x512/Conv1/weight', # [32, 32, 3, 3] | |
'synthesis.layer15.epilogue.apply_noise.weight': '512x512/Conv1/Noise/weight', # [32] | |
'synthesis.layer15.epilogue.bias': '512x512/Conv1/bias', # [32] | |
'synthesis.layer15.epilogue.style_mod.dense.linear.weight': '512x512/Conv1/StyleMod/weight', # [64, 512] | |
'synthesis.layer15.epilogue.style_mod.dense.wscale.bias': '512x512/Conv1/StyleMod/bias', # [64] | |
'synthesis.layer16.conv.weight': '1024x1024/Conv0_up/weight', # [16, 32, 3, 3] | |
'synthesis.layer16.epilogue.apply_noise.weight': '1024x1024/Conv0_up/Noise/weight', # [16] | |
'synthesis.layer16.epilogue.bias': '1024x1024/Conv0_up/bias', # [16] | |
'synthesis.layer16.epilogue.style_mod.dense.linear.weight': '1024x1024/Conv0_up/StyleMod/weight', # [32, 512] | |
'synthesis.layer16.epilogue.style_mod.dense.wscale.bias': '1024x1024/Conv0_up/StyleMod/bias', # [32] | |
'synthesis.layer17.conv.weight': '1024x1024/Conv1/weight', # [16, 16, 3, 3] | |
'synthesis.layer17.epilogue.apply_noise.weight': '1024x1024/Conv1/Noise/weight', # [16] | |
'synthesis.layer17.epilogue.bias': '1024x1024/Conv1/bias', # [16] | |
'synthesis.layer17.epilogue.style_mod.dense.linear.weight': '1024x1024/Conv1/StyleMod/weight', # [32, 512] | |
'synthesis.layer17.epilogue.style_mod.dense.wscale.bias': '1024x1024/Conv1/StyleMod/bias', # [32] | |
'synthesis.output0.conv.weight': 'ToRGB_lod8/weight', # [3, 512, 1, 1] | |
'synthesis.output0.bias': 'ToRGB_lod8/bias', # [3] | |
'synthesis.output1.conv.weight': 'ToRGB_lod7/weight', # [3, 512, 1, 1] | |
'synthesis.output1.bias': 'ToRGB_lod7/bias', # [3] | |
'synthesis.output2.conv.weight': 'ToRGB_lod6/weight', # [3, 512, 1, 1] | |
'synthesis.output2.bias': 'ToRGB_lod6/bias', # [3] | |
'synthesis.output3.conv.weight': 'ToRGB_lod5/weight', # [3, 512, 1, 1] | |
'synthesis.output3.bias': 'ToRGB_lod5/bias', # [3] | |
'synthesis.output4.conv.weight': 'ToRGB_lod4/weight', # [3, 256, 1, 1] | |
'synthesis.output4.bias': 'ToRGB_lod4/bias', # [3] | |
'synthesis.output5.conv.weight': 'ToRGB_lod3/weight', # [3, 128, 1, 1] | |
'synthesis.output5.bias': 'ToRGB_lod3/bias', # [3] | |
'synthesis.output6.conv.weight': 'ToRGB_lod2/weight', # [3, 64, 1, 1] | |
'synthesis.output6.bias': 'ToRGB_lod2/bias', # [3] | |
'synthesis.output7.conv.weight': 'ToRGB_lod1/weight', # [3, 32, 1, 1] | |
'synthesis.output7.bias': 'ToRGB_lod1/bias', # [3] | |
'synthesis.output8.conv.weight': 'ToRGB_lod0/weight', # [3, 16, 1, 1] | |
'synthesis.output8.bias': 'ToRGB_lod0/bias', # [3] | |
} | |
# pylint: enable=line-too-long | |
# Minimal resolution for `auto` fused-scale strategy. | |
_AUTO_FUSED_SCALE_MIN_RES = 128 | |
class StyleGANGeneratorModel(nn.Module): | |
"""Defines the generator module in StyleGAN. | |
Note that the generated images are with RGB color channels. | |
""" | |
def __init__(self, | |
resolution=1024, | |
w_space_dim=512, | |
fused_scale='auto', | |
output_channels=3, | |
truncation_psi=0.7, | |
truncation_layers=8, | |
randomize_noise=False): | |
"""Initializes the generator with basic settings. | |
Args: | |
resolution: The resolution of the final output image. (default: 1024) | |
w_space_dim: The dimension of the disentangled latent vectors, w. | |
(default: 512) | |
fused_scale: If set as `True`, `conv2d_transpose` is used for upscaling. | |
If set as `False`, `upsample + conv2d` is used for upscaling. If set as | |
`auto`, `upsample + conv2d` is used for bottom layers until resolution | |
reaches 128. (default: `auto`) | |
output_channels: Number of channels of output image. (default: 3) | |
truncation_psi: Style strength multiplier for the truncation trick. | |
`None` or `1.0` indicates no truncation. (default: 0.7) | |
truncation_layers: Number of layers for which to apply the truncation | |
trick. `None` indicates no truncation. (default: 8) | |
randomize_noise: Whether to add random noise for each convolutional layer. | |
(default: False) | |
Raises: | |
ValueError: If the input `resolution` is not supported. | |
""" | |
super().__init__() | |
self.resolution = resolution | |
self.w_space_dim = w_space_dim | |
self.fused_scale = fused_scale | |
self.output_channels = output_channels | |
self.truncation_psi = truncation_psi | |
self.truncation_layers = truncation_layers | |
self.randomize_noise = randomize_noise | |
self.mapping = MappingModule(final_space_dim=self.w_space_dim) | |
self.truncation = TruncationModule(resolution=self.resolution, | |
w_space_dim=self.w_space_dim, | |
truncation_psi=self.truncation_psi, | |
truncation_layers=self.truncation_layers) | |
self.synthesis = SynthesisModule(resolution=self.resolution, | |
fused_scale=self.fused_scale, | |
output_channels=self.output_channels, | |
randomize_noise=self.randomize_noise) | |
self.pth_to_tf_var_mapping = {} | |
for pth_var_name, tf_var_name in _STYLEGAN_PTH_VARS_TO_TF_VARS.items(): | |
if 'Conv0_up' in tf_var_name: | |
res = int(tf_var_name.split('x')[0]) | |
if ((self.fused_scale is True) or | |
(self.fused_scale == 'auto' and res >= _AUTO_FUSED_SCALE_MIN_RES)): | |
pth_var_name = pth_var_name.replace('conv.weight', 'weight') | |
self.pth_to_tf_var_mapping[pth_var_name] = tf_var_name | |
def forward(self, z): | |
w = self.mapping(z) | |
w = self.truncation(w) | |
x = self.synthesis(w) | |
return x | |
class MappingModule(nn.Sequential): | |
"""Implements the latent space mapping module used in StyleGAN. | |
Basically, this module executes several dense layers in sequence. | |
""" | |
def __init__(self, | |
normalize_input=True, | |
input_space_dim=512, | |
hidden_space_dim=512, | |
final_space_dim=512, | |
num_layers=8): | |
sequence = OrderedDict() | |
def _add_layer(layer, name=None): | |
name = name or f'dense{len(sequence) + (not normalize_input) - 1}' | |
sequence[name] = layer | |
if normalize_input: | |
_add_layer(PixelNormLayer(), name='normalize') | |
for i in range(num_layers): | |
in_dim = input_space_dim if i == 0 else hidden_space_dim | |
out_dim = final_space_dim if i == (num_layers - 1) else hidden_space_dim | |
_add_layer(DenseBlock(in_dim, out_dim)) | |
super().__init__(sequence) | |
def forward(self, x): | |
if len(x.shape) != 2: | |
raise ValueError(f'The input tensor should be with shape [batch_size, ' | |
f'noise_dim], but {x.shape} received!') | |
return super().forward(x) | |
class TruncationModule(nn.Module): | |
"""Implements the truncation module used in StyleGAN.""" | |
def __init__(self, | |
resolution=1024, | |
w_space_dim=512, | |
truncation_psi=0.7, | |
truncation_layers=8): | |
super().__init__() | |
self.num_layers = int(np.log2(resolution)) * 2 - 2 | |
self.w_space_dim = w_space_dim | |
if truncation_psi is not None and truncation_layers is not None: | |
self.use_truncation = True | |
else: | |
self.use_truncation = False | |
truncation_psi = 1.0 | |
truncation_layers = 0 | |
self.register_buffer('w_avg', torch.zeros(w_space_dim)) | |
layer_idx = np.arange(self.num_layers).reshape(1, self.num_layers, 1) | |
coefs = np.ones_like(layer_idx, dtype=np.float32) | |
coefs[layer_idx < truncation_layers] *= truncation_psi | |
self.register_buffer('truncation', torch.from_numpy(coefs)) | |
def forward(self, w): | |
if len(w.shape) == 2: | |
w = w.view(-1, 1, self.w_space_dim).repeat(1, self.num_layers, 1) | |
if self.use_truncation: | |
w_avg = self.w_avg.view(1, 1, self.w_space_dim) | |
w = w_avg + (w - w_avg) * self.truncation | |
return w | |
class SynthesisModule(nn.Module): | |
"""Implements the image synthesis module used in StyleGAN. | |
Basically, this module executes several convolutional layers in sequence. | |
""" | |
def __init__(self, | |
resolution=1024, | |
fused_scale='auto', | |
output_channels=3, | |
randomize_noise=False): | |
super().__init__() | |
try: | |
self.channels = _RESOLUTIONS_TO_CHANNELS[resolution] | |
except KeyError: | |
raise ValueError(f'Invalid resolution: {resolution}!\n' | |
f'Resolutions allowed: ' | |
f'{list(_RESOLUTIONS_TO_CHANNELS)}.') | |
assert len(self.channels) == int(np.log2(resolution)) | |
for block_idx in range(1, len(self.channels)): | |
if block_idx == 1: | |
self.add_module( | |
f'layer{2 * block_idx - 2}', | |
FirstConvBlock(in_channels=self.channels[block_idx - 1], | |
randomize_noise=randomize_noise)) | |
else: | |
self.add_module( | |
f'layer{2 * block_idx - 2}', | |
UpConvBlock(layer_idx=2 * block_idx - 2, | |
in_channels=self.channels[block_idx - 1], | |
out_channels=self.channels[block_idx], | |
randomize_noise=randomize_noise, | |
fused_scale=fused_scale)) | |
self.add_module( | |
f'layer{2 * block_idx - 1}', | |
ConvBlock(layer_idx=2 * block_idx - 1, | |
in_channels=self.channels[block_idx], | |
out_channels=self.channels[block_idx], | |
randomize_noise=randomize_noise)) | |
self.add_module( | |
f'output{block_idx - 1}', | |
LastConvBlock(in_channels=self.channels[block_idx], | |
out_channels=output_channels)) | |
self.upsample = ResolutionScalingLayer() | |
self.lod = nn.Parameter(torch.zeros(())) | |
def forward(self, w): | |
lod = self.lod.cpu().tolist() | |
x = self.layer0(w[:, 0]) | |
for block_idx in range(1, len(self.channels)): | |
if block_idx + lod < len(self.channels): | |
layer_idx = 2 * block_idx - 2 | |
if layer_idx == 0: | |
x = self.__getattr__(f'layer{layer_idx}')(w[:, layer_idx]) | |
else: | |
x = self.__getattr__(f'layer{layer_idx}')(x, w[:, layer_idx]) | |
layer_idx = 2 * block_idx - 1 | |
x = self.__getattr__(f'layer{layer_idx}')(x, w[:, layer_idx]) | |
image = self.__getattr__(f'output{block_idx - 1}')(x) | |
else: | |
image = self.upsample(image) | |
return image | |
class PixelNormLayer(nn.Module): | |
"""Implements pixel-wise feature vector normalization layer.""" | |
def __init__(self, epsilon=1e-8): | |
super().__init__() | |
self.epsilon = epsilon | |
def forward(self, x): | |
return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon) | |
class InstanceNormLayer(nn.Module): | |
"""Implements instance normalization layer.""" | |
def __init__(self, epsilon=1e-8): | |
super().__init__() | |
self.epsilon = epsilon | |
def forward(self, x): | |
if len(x.shape) != 4: | |
raise ValueError(f'The input tensor should be with shape [batch_size, ' | |
f'num_channels, height, width], but {x.shape} received!') | |
x = x - torch.mean(x, dim=[2, 3], keepdim=True) | |
x = x / torch.sqrt(torch.mean(x**2, dim=[2, 3], keepdim=True) + | |
self.epsilon) | |
return x | |
class ResolutionScalingLayer(nn.Module): | |
"""Implements the resolution scaling layer. | |
Basically, this layer can be used to upsample or downsample feature maps from | |
spatial domain with nearest neighbor interpolation. | |
""" | |
def __init__(self, scale_factor=2): | |
super().__init__() | |
self.scale_factor = scale_factor | |
def forward(self, x): | |
return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') | |
class BlurLayer(nn.Module): | |
"""Implements the blur layer used in StyleGAN.""" | |
def __init__(self, | |
channels, | |
kernel=(1, 2, 1), | |
normalize=True, | |
flip=False): | |
super().__init__() | |
kernel = np.array(kernel, dtype=np.float32).reshape(1, 3) | |
kernel = kernel.T.dot(kernel) | |
if normalize: | |
kernel /= np.sum(kernel) | |
if flip: | |
kernel = kernel[::-1, ::-1] | |
kernel = kernel.reshape(3, 3, 1, 1) | |
kernel = np.tile(kernel, [1, 1, channels, 1]) | |
kernel = np.transpose(kernel, [2, 3, 0, 1]) | |
self.register_buffer('kernel', torch.from_numpy(kernel)) | |
self.channels = channels | |
def forward(self, x): | |
return F.conv2d(x, self.kernel, stride=1, padding=1, groups=self.channels) | |
class NoiseApplyingLayer(nn.Module): | |
"""Implements the noise applying layer used in StyleGAN.""" | |
def __init__(self, layer_idx, channels, randomize_noise=False): | |
super().__init__() | |
self.randomize_noise = randomize_noise | |
self.res = 2**(layer_idx // 2 + 2) | |
self.register_buffer('noise', torch.randn(1, 1, self.res, self.res)) | |
self.weight = nn.Parameter(torch.zeros(channels)) | |
def forward(self, x): | |
if len(x.shape) != 4: | |
raise ValueError(f'The input tensor should be with shape [batch_size, ' | |
f'num_channels, height, width], but {x.shape} received!') | |
if self.randomize_noise: | |
noise = torch.randn(x.shape[0], 1, self.res, self.res).to(x) | |
else: | |
noise = self.noise | |
return x + noise * self.weight.view(1, -1, 1, 1) | |
class StyleModulationLayer(nn.Module): | |
"""Implements the style modulation layer used in StyleGAN.""" | |
def __init__(self, channels, w_space_dim=512): | |
super().__init__() | |
self.channels = channels | |
self.dense = DenseBlock(in_features=w_space_dim, | |
out_features=channels*2, | |
wscale_gain=1.0, | |
wscale_lr_multiplier=1.0, | |
activation_type='linear') | |
def forward(self, x, w): | |
if len(w.shape) != 2: | |
raise ValueError(f'The input tensor should be with shape [batch_size, ' | |
f'num_channels], but {x.shape} received!') | |
style = self.dense(w) | |
style = style.view(-1, 2, self.channels, 1, 1) | |
return x * (style[:, 0] + 1) + style[:, 1] | |
class WScaleLayer(nn.Module): | |
"""Implements the layer to scale weight variable and add bias. | |
Note that, the weight variable is trained in `nn.Conv2d` layer (or `nn.Linear` | |
layer), and only scaled with a constant number , which is not trainable, in | |
this layer. However, the bias variable is trainable in this layer. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
gain=np.sqrt(2.0), | |
lr_multiplier=1.0): | |
super().__init__() | |
fan_in = in_channels * kernel_size * kernel_size | |
self.scale = gain / np.sqrt(fan_in) * lr_multiplier | |
self.bias = nn.Parameter(torch.zeros(out_channels)) | |
self.lr_multiplier = lr_multiplier | |
def forward(self, x): | |
if len(x.shape) == 4: | |
return x * self.scale + self.bias.view(1, -1, 1, 1) * self.lr_multiplier | |
if len(x.shape) == 2: | |
return x * self.scale + self.bias.view(1, -1) * self.lr_multiplier | |
raise ValueError(f'The input tensor should be with shape [batch_size, ' | |
f'num_channels, height, width], or [batch_size, ' | |
f'num_channels], but {x.shape} received!') | |
class EpilogueBlock(nn.Module): | |
"""Implements the epilogue block of each conv block.""" | |
def __init__(self, | |
layer_idx, | |
channels, | |
randomize_noise=False, | |
normalization_fn='instance'): | |
super().__init__() | |
self.apply_noise = NoiseApplyingLayer(layer_idx, channels, randomize_noise) | |
self.bias = nn.Parameter(torch.zeros(channels)) | |
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
if normalization_fn == 'pixel': | |
self.norm = PixelNormLayer() | |
elif normalization_fn == 'instance': | |
self.norm = InstanceNormLayer() | |
else: | |
raise NotImplementedError(f'Not implemented normalization function: ' | |
f'{normalization_fn}!') | |
self.style_mod = StyleModulationLayer(channels) | |
def forward(self, x, w): | |
x = self.apply_noise(x) | |
x = x + self.bias.view(1, -1, 1, 1) | |
x = self.activate(x) | |
x = self.norm(x) | |
x = self.style_mod(x, w) | |
return x | |
class FirstConvBlock(nn.Module): | |
"""Implements the first convolutional block used in StyleGAN. | |
Basically, this block starts from a const input, which is `ones(512, 4, 4)`. | |
""" | |
def __init__(self, in_channels, randomize_noise=False): | |
super().__init__() | |
self.first_layer = nn.Parameter(torch.ones(1, in_channels, 4, 4)) | |
self.epilogue = EpilogueBlock(layer_idx=0, | |
channels=in_channels, | |
randomize_noise=randomize_noise) | |
def forward(self, w): | |
x = self.first_layer.repeat(w.shape[0], 1, 1, 1) | |
x = self.epilogue(x, w) | |
return x | |
class UpConvBlock(nn.Module): | |
"""Implements the convolutional block used in StyleGAN. | |
Basically, this block is used as the first convolutional block for each | |
resolution, which will execute upsampling. | |
""" | |
def __init__(self, | |
layer_idx, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
dilation=1, | |
add_bias=False, | |
fused_scale='auto', | |
wscale_gain=np.sqrt(2.0), | |
wscale_lr_multiplier=1.0, | |
randomize_noise=False): | |
"""Initializes the class with block settings. | |
Args: | |
in_channels: Number of channels of the input tensor fed into this block. | |
out_channels: Number of channels (kernels) of the output tensor. | |
kernel_size: Size of the convolutional kernel. | |
stride: Stride parameter for convolution operation. | |
padding: Padding parameter for convolution operation. | |
dilation: Dilation rate for convolution operation. | |
add_bias: Whether to add bias onto the convolutional result. | |
fused_scale: Whether to fuse `upsample` and `conv2d` together, resulting | |
in `conv2d_transpose`. | |
wscale_gain: The gain factor for `wscale` layer. | |
wscale_lr_multiplier: The learning rate multiplier factor for `wscale` | |
layer. | |
randomize_noise: Whether to add random noise. | |
Raises: | |
ValueError: If the block is not applied to the first block for a | |
particular resolution. Or `fused_scale` does not belong to [True, False, | |
`auto`]. | |
""" | |
super().__init__() | |
if layer_idx % 2 == 1: | |
raise ValueError(f'This block is implemented as the first block of each ' | |
f'resolution, but is applied to layer {layer_idx}!') | |
if fused_scale not in [True, False, 'auto']: | |
raise ValueError(f'`fused_scale` can only be [True, False, `auto`], ' | |
f'but {fused_scale} received!') | |
cur_res = 2 ** (layer_idx // 2 + 2) | |
if fused_scale == 'auto': | |
self.fused_scale = (cur_res >= _AUTO_FUSED_SCALE_MIN_RES) | |
else: | |
self.fused_scale = fused_scale | |
if self.fused_scale: | |
self.weight = nn.Parameter( | |
torch.randn(kernel_size, kernel_size, in_channels, out_channels)) | |
else: | |
self.upsample = ResolutionScalingLayer() | |
self.conv = nn.Conv2d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=1, | |
bias=add_bias) | |
fan_in = in_channels * kernel_size * kernel_size | |
self.scale = wscale_gain / np.sqrt(fan_in) * wscale_lr_multiplier | |
self.blur = BlurLayer(channels=out_channels) | |
self.epilogue = EpilogueBlock(layer_idx=layer_idx, | |
channels=out_channels, | |
randomize_noise=randomize_noise) | |
def forward(self, x, w): | |
if self.fused_scale: | |
kernel = self.weight * self.scale | |
kernel = F.pad(kernel, (0, 0, 0, 0, 1, 1, 1, 1), 'constant', 0.0) | |
kernel = (kernel[1:, 1:] + kernel[:-1, 1:] + | |
kernel[1:, :-1] + kernel[:-1, :-1]) | |
kernel = kernel.permute(2, 3, 0, 1) | |
x = F.conv_transpose2d(x, kernel, stride=2, padding=1) | |
else: | |
x = self.upsample(x) | |
x = self.conv(x) * self.scale | |
x = self.blur(x) | |
x = self.epilogue(x, w) | |
return x | |
class ConvBlock(nn.Module): | |
"""Implements the convolutional block used in StyleGAN. | |
Basically, this block is used as the second convolutional block for each | |
resolution. | |
""" | |
def __init__(self, | |
layer_idx, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
dilation=1, | |
add_bias=False, | |
wscale_gain=np.sqrt(2.0), | |
wscale_lr_multiplier=1.0, | |
randomize_noise=False): | |
"""Initializes the class with block settings. | |
Args: | |
in_channels: Number of channels of the input tensor fed into this block. | |
out_channels: Number of channels (kernels) of the output tensor. | |
kernel_size: Size of the convolutional kernel. | |
stride: Stride parameter for convolution operation. | |
padding: Padding parameter for convolution operation. | |
dilation: Dilation rate for convolution operation. | |
add_bias: Whether to add bias onto the convolutional result. | |
wscale_gain: The gain factor for `wscale` layer. | |
wscale_lr_multiplier: The learning rate multiplier factor for `wscale` | |
layer. | |
randomize_noise: Whether to add random noise. | |
Raises: | |
ValueError: If the block is not applied to the second block for a | |
particular resolution. | |
""" | |
super().__init__() | |
if layer_idx % 2 == 0: | |
raise ValueError(f'This block is implemented as the second block of each ' | |
f'resolution, but is applied to layer {layer_idx}!') | |
self.conv = nn.Conv2d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=1, | |
bias=add_bias) | |
fan_in = in_channels * kernel_size * kernel_size | |
self.scale = wscale_gain / np.sqrt(fan_in) * wscale_lr_multiplier | |
self.epilogue = EpilogueBlock(layer_idx=layer_idx, | |
channels=out_channels, | |
randomize_noise=randomize_noise) | |
def forward(self, x, w): | |
x = self.conv(x) * self.scale | |
x = self.epilogue(x, w) | |
return x | |
class LastConvBlock(nn.Module): | |
"""Implements the last convolutional block used in StyleGAN. | |
Basically, this block converts the final feature map to RGB image. | |
""" | |
def __init__(self, in_channels, out_channels=3): | |
super().__init__() | |
self.conv = nn.Conv2d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
bias=False) | |
self.scale = 1 / np.sqrt(in_channels) | |
self.bias = nn.Parameter(torch.zeros(3)) | |
def forward(self, x): | |
x = self.conv(x) * self.scale | |
x = x + self.bias.view(1, -1, 1, 1) | |
return x | |
class DenseBlock(nn.Module): | |
"""Implements the dense block used in StyleGAN. | |
Basically, this block executes fully-connected layer, weight-scale layer, | |
and activation layer in sequence. | |
""" | |
def __init__(self, | |
in_features, | |
out_features, | |
add_bias=False, | |
wscale_gain=np.sqrt(2.0), | |
wscale_lr_multiplier=0.01, | |
activation_type='lrelu'): | |
"""Initializes the class with block settings. | |
Args: | |
in_features: Number of channels of the input tensor fed into this block. | |
out_features: Number of channels of the output tensor. | |
add_bias: Whether to add bias onto the fully-connected result. | |
wscale_gain: The gain factor for `wscale` layer. | |
wscale_lr_multiplier: The learning rate multiplier factor for `wscale` | |
layer. | |
activation_type: Type of activation function. Support `linear` and | |
`lrelu`. | |
Raises: | |
NotImplementedError: If the input `activation_type` is not supported. | |
""" | |
super().__init__() | |
self.linear = nn.Linear(in_features=in_features, | |
out_features=out_features, | |
bias=add_bias) | |
self.wscale = WScaleLayer(in_channels=in_features, | |
out_channels=out_features, | |
kernel_size=1, | |
gain=wscale_gain, | |
lr_multiplier=wscale_lr_multiplier) | |
if activation_type == 'linear': | |
self.activate = nn.Identity() | |
elif activation_type == 'lrelu': | |
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
else: | |
raise NotImplementedError(f'Not implemented activation function: ' | |
f'{activation_type}!') | |
def forward(self, x): | |
x = self.linear(x) | |
x = self.wscale(x) | |
x = self.activate(x) | |
return x | |