Spaces:
Runtime error
Runtime error
# python3.7 | |
"""Contains the implementation of generator described in ProgressiveGAN. | |
Different from the official tensorflow model in folder `pggan_tf_official`, this | |
is a simple pytorch version which only contains the generator part. This class | |
is specially used for inference. | |
For more details, please check the original paper: | |
https://arxiv.org/pdf/1710.10196.pdf | |
""" | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
__all__ = ['PGGANGeneratorModel'] | |
# Defines a dictionary, which maps the target resolution of the final generated | |
# image to numbers of filters used in each convolutional layer in sequence. | |
_RESOLUTIONS_TO_CHANNELS = { | |
8: [512, 512, 512], | |
16: [512, 512, 512, 512], | |
32: [512, 512, 512, 512, 512], | |
64: [512, 512, 512, 512, 512, 256], | |
128: [512, 512, 512, 512, 512, 256, 128], | |
256: [512, 512, 512, 512, 512, 256, 128, 64], | |
512: [512, 512, 512, 512, 512, 256, 128, 64, 32], | |
1024: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16], | |
} | |
# Variable mapping from pytorch model to official tensorflow model. | |
_PGGAN_PTH_VARS_TO_TF_VARS = { | |
'lod': 'lod', # [] | |
'layer0.conv.weight': '4x4/Dense/weight', # [512, 512, 4, 4] | |
'layer0.wscale.bias': '4x4/Dense/bias', # [512] | |
'layer1.conv.weight': '4x4/Conv/weight', # [512, 512, 3, 3] | |
'layer1.wscale.bias': '4x4/Conv/bias', # [512] | |
'layer2.conv.weight': '8x8/Conv0/weight', # [512, 512, 3, 3] | |
'layer2.wscale.bias': '8x8/Conv0/bias', # [512] | |
'layer3.conv.weight': '8x8/Conv1/weight', # [512, 512, 3, 3] | |
'layer3.wscale.bias': '8x8/Conv1/bias', # [512] | |
'layer4.conv.weight': '16x16/Conv0/weight', # [512, 512, 3, 3] | |
'layer4.wscale.bias': '16x16/Conv0/bias', # [512] | |
'layer5.conv.weight': '16x16/Conv1/weight', # [512, 512, 3, 3] | |
'layer5.wscale.bias': '16x16/Conv1/bias', # [512] | |
'layer6.conv.weight': '32x32/Conv0/weight', # [512, 512, 3, 3] | |
'layer6.wscale.bias': '32x32/Conv0/bias', # [512] | |
'layer7.conv.weight': '32x32/Conv1/weight', # [512, 512, 3, 3] | |
'layer7.wscale.bias': '32x32/Conv1/bias', # [512] | |
'layer8.conv.weight': '64x64/Conv0/weight', # [256, 512, 3, 3] | |
'layer8.wscale.bias': '64x64/Conv0/bias', # [256] | |
'layer9.conv.weight': '64x64/Conv1/weight', # [256, 256, 3, 3] | |
'layer9.wscale.bias': '64x64/Conv1/bias', # [256] | |
'layer10.conv.weight': '128x128/Conv0/weight', # [128, 256, 3, 3] | |
'layer10.wscale.bias': '128x128/Conv0/bias', # [128] | |
'layer11.conv.weight': '128x128/Conv1/weight', # [128, 128, 3, 3] | |
'layer11.wscale.bias': '128x128/Conv1/bias', # [128] | |
'layer12.conv.weight': '256x256/Conv0/weight', # [64, 128, 3, 3] | |
'layer12.wscale.bias': '256x256/Conv0/bias', # [64] | |
'layer13.conv.weight': '256x256/Conv1/weight', # [64, 64, 3, 3] | |
'layer13.wscale.bias': '256x256/Conv1/bias', # [64] | |
'layer14.conv.weight': '512x512/Conv0/weight', # [32, 64, 3, 3] | |
'layer14.wscale.bias': '512x512/Conv0/bias', # [32] | |
'layer15.conv.weight': '512x512/Conv1/weight', # [32, 32, 3, 3] | |
'layer15.wscale.bias': '512x512/Conv1/bias', # [32] | |
'layer16.conv.weight': '1024x1024/Conv0/weight', # [16, 32, 3, 3] | |
'layer16.wscale.bias': '1024x1024/Conv0/bias', # [16] | |
'layer17.conv.weight': '1024x1024/Conv1/weight', # [16, 16, 3, 3] | |
'layer17.wscale.bias': '1024x1024/Conv1/bias', # [16] | |
'output0.conv.weight': 'ToRGB_lod8/weight', # [3, 512, 1, 1] | |
'output0.wscale.bias': 'ToRGB_lod8/bias', # [3] | |
'output1.conv.weight': 'ToRGB_lod7/weight', # [3, 512, 1, 1] | |
'output1.wscale.bias': 'ToRGB_lod7/bias', # [3] | |
'output2.conv.weight': 'ToRGB_lod6/weight', # [3, 512, 1, 1] | |
'output2.wscale.bias': 'ToRGB_lod6/bias', # [3] | |
'output3.conv.weight': 'ToRGB_lod5/weight', # [3, 512, 1, 1] | |
'output3.wscale.bias': 'ToRGB_lod5/bias', # [3] | |
'output4.conv.weight': 'ToRGB_lod4/weight', # [3, 256, 1, 1] | |
'output4.wscale.bias': 'ToRGB_lod4/bias', # [3] | |
'output5.conv.weight': 'ToRGB_lod3/weight', # [3, 128, 1, 1] | |
'output5.wscale.bias': 'ToRGB_lod3/bias', # [3] | |
'output6.conv.weight': 'ToRGB_lod2/weight', # [3, 64, 1, 1] | |
'output6.wscale.bias': 'ToRGB_lod2/bias', # [3] | |
'output7.conv.weight': 'ToRGB_lod1/weight', # [3, 32, 1, 1] | |
'output7.wscale.bias': 'ToRGB_lod1/bias', # [3] | |
'output8.conv.weight': 'ToRGB_lod0/weight', # [3, 16, 1, 1] | |
'output8.wscale.bias': 'ToRGB_lod0/bias', # [3] | |
} | |
class PGGANGeneratorModel(nn.Module): | |
"""Defines the generator module in ProgressiveGAN. | |
Note that the generated images are with RGB color channels with range [-1, 1]. | |
""" | |
def __init__(self, | |
resolution=1024, | |
fused_scale=False, | |
output_channels=3): | |
"""Initializes the generator with basic settings. | |
Args: | |
resolution: The resolution of the final output image. (default: 1024) | |
fused_scale: Whether to fused `upsample` and `conv2d` together, resulting | |
in `conv2_transpose`. (default: False) | |
output_channels: Number of channels of the output image. (default: 3) | |
Raises: | |
ValueError: If the input `resolution` is not supported. | |
""" | |
super().__init__() | |
try: | |
self.channels = _RESOLUTIONS_TO_CHANNELS[resolution] | |
except KeyError: | |
raise ValueError(f'Invalid resolution: {resolution}!\n' | |
f'Resolutions allowed: ' | |
f'{list(_RESOLUTIONS_TO_CHANNELS)}.') | |
assert len(self.channels) == int(np.log2(resolution)) | |
self.resolution = resolution | |
self.fused_scale = fused_scale | |
self.output_channels = output_channels | |
for block_idx in range(1, len(self.channels)): | |
if block_idx == 1: | |
self.add_module( | |
f'layer{2 * block_idx - 2}', | |
ConvBlock(in_channels=self.channels[block_idx - 1], | |
out_channels=self.channels[block_idx], | |
kernel_size=4, | |
padding=3)) | |
else: | |
self.add_module( | |
f'layer{2 * block_idx - 2}', | |
ConvBlock(in_channels=self.channels[block_idx - 1], | |
out_channels=self.channels[block_idx], | |
upsample=True, | |
fused_scale=self.fused_scale)) | |
self.add_module( | |
f'layer{2 * block_idx - 1}', | |
ConvBlock(in_channels=self.channels[block_idx], | |
out_channels=self.channels[block_idx])) | |
self.add_module( | |
f'output{block_idx - 1}', | |
ConvBlock(in_channels=self.channels[block_idx], | |
out_channels=self.output_channels, | |
kernel_size=1, | |
padding=0, | |
wscale_gain=1.0, | |
activation_type='linear')) | |
self.upsample = ResolutionScalingLayer() | |
self.lod = nn.Parameter(torch.zeros(())) | |
self.pth_to_tf_var_mapping = {} | |
for pth_var_name, tf_var_name in _PGGAN_PTH_VARS_TO_TF_VARS.items(): | |
if self.fused_scale and 'Conv0' in tf_var_name: | |
pth_var_name = pth_var_name.replace('conv.weight', 'weight') | |
tf_var_name = tf_var_name.replace('Conv0', 'Conv0_up') | |
self.pth_to_tf_var_mapping[pth_var_name] = tf_var_name | |
def forward(self, x): | |
if len(x.shape) != 2: | |
raise ValueError(f'The input tensor should be with shape [batch_size, ' | |
f'noise_dim], but {x.shape} received!') | |
x = x.view(x.shape[0], x.shape[1], 1, 1) | |
lod = self.lod.cpu().tolist() | |
for block_idx in range(1, len(self.channels)): | |
if block_idx + lod < len(self.channels): | |
x = self.__getattr__(f'layer{2 * block_idx - 2}')(x) | |
x = self.__getattr__(f'layer{2 * block_idx - 1}')(x) | |
image = self.__getattr__(f'output{block_idx - 1}')(x) | |
else: | |
image = self.upsample(image) | |
return image | |
class PixelNormLayer(nn.Module): | |
"""Implements pixel-wise feature vector normalization layer.""" | |
def __init__(self, epsilon=1e-8): | |
super().__init__() | |
self.epsilon = epsilon | |
def forward(self, x): | |
return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon) | |
class ResolutionScalingLayer(nn.Module): | |
"""Implements the resolution scaling layer. | |
Basically, this layer can be used to upsample or downsample feature maps from | |
spatial domain with nearest neighbor interpolation. | |
""" | |
def __init__(self, scale_factor=2): | |
super().__init__() | |
self.scale_factor = scale_factor | |
def forward(self, x): | |
return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') | |
class WScaleLayer(nn.Module): | |
"""Implements the layer to scale weight variable and add bias. | |
Note that, the weight variable is trained in `nn.Conv2d` layer, and only | |
scaled with a constant number, which is not trainable, in this layer. However, | |
the bias variable is trainable in this layer. | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size, gain=np.sqrt(2.0)): | |
super().__init__() | |
fan_in = in_channels * kernel_size * kernel_size | |
self.scale = gain / np.sqrt(fan_in) | |
self.bias = nn.Parameter(torch.zeros(out_channels)) | |
def forward(self, x): | |
return x * self.scale + self.bias.view(1, -1, 1, 1) | |
class ConvBlock(nn.Module): | |
"""Implements the convolutional block used in ProgressiveGAN. | |
Basically, this block executes pixel-wise normalization layer, upsampling | |
layer (if needed), convolutional layer, weight-scale layer, and activation | |
layer in sequence. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
dilation=1, | |
add_bias=False, | |
upsample=False, | |
fused_scale=False, | |
wscale_gain=np.sqrt(2.0), | |
activation_type='lrelu'): | |
"""Initializes the class with block settings. | |
Args: | |
in_channels: Number of channels of the input tensor fed into this block. | |
out_channels: Number of channels (kernels) of the output tensor. | |
kernel_size: Size of the convolutional kernel. | |
stride: Stride parameter for convolution operation. | |
padding: Padding parameter for convolution operation. | |
dilation: Dilation rate for convolution operation. | |
add_bias: Whether to add bias onto the convolutional result. | |
upsample: Whether to upsample the input tensor before convolution. | |
fused_scale: Whether to fused `upsample` and `conv2d` together, resulting | |
in `conv2_transpose`. | |
wscale_gain: The gain factor for `wscale` layer. | |
wscale_lr_multiplier: The learning rate multiplier factor for `wscale` | |
layer. | |
activation_type: Type of activation function. Support `linear`, `lrelu` | |
and `tanh`. | |
Raises: | |
NotImplementedError: If the input `activation_type` is not supported. | |
""" | |
super().__init__() | |
self.pixel_norm = PixelNormLayer() | |
if upsample and not fused_scale: | |
self.upsample = ResolutionScalingLayer() | |
else: | |
self.upsample = nn.Identity() | |
if upsample and fused_scale: | |
self.weight = nn.Parameter( | |
torch.randn(kernel_size, kernel_size, in_channels, out_channels)) | |
fan_in = in_channels * kernel_size * kernel_size | |
self.scale = wscale_gain / np.sqrt(fan_in) | |
else: | |
self.conv = nn.Conv2d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=1, | |
bias=add_bias) | |
self.wscale = WScaleLayer(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
gain=wscale_gain) | |
if activation_type == 'linear': | |
self.activate = nn.Identity() | |
elif activation_type == 'lrelu': | |
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
elif activation_type == 'tanh': | |
self.activate = nn.Hardtanh() | |
else: | |
raise NotImplementedError(f'Not implemented activation function: ' | |
f'{activation_type}!') | |
def forward(self, x): | |
x = self.pixel_norm(x) | |
x = self.upsample(x) | |
if hasattr(self, 'conv'): | |
x = self.conv(x) | |
else: | |
kernel = self.weight * self.scale | |
kernel = F.pad(kernel, (0, 0, 0, 0, 1, 1, 1, 1), 'constant', 0.0) | |
kernel = (kernel[1:, 1:] + kernel[:-1, 1:] + | |
kernel[1:, :-1] + kernel[:-1, :-1]) | |
kernel = kernel.permute(2, 3, 0, 1) | |
x = F.conv_transpose2d(x, kernel, stride=2, padding=1) | |
x = x / self.scale | |
x = self.wscale(x) | |
x = self.activate(x) | |
return x | |