interfacegan_pp / models /pggan_generator_model.py
ybelkada's picture
commit files
4d6b877
raw
history blame
12.9 kB
# python3.7
"""Contains the implementation of generator described in ProgressiveGAN.
Different from the official tensorflow model in folder `pggan_tf_official`, this
is a simple pytorch version which only contains the generator part. This class
is specially used for inference.
For more details, please check the original paper:
https://arxiv.org/pdf/1710.10196.pdf
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['PGGANGeneratorModel']
# Defines a dictionary, which maps the target resolution of the final generated
# image to numbers of filters used in each convolutional layer in sequence.
_RESOLUTIONS_TO_CHANNELS = {
8: [512, 512, 512],
16: [512, 512, 512, 512],
32: [512, 512, 512, 512, 512],
64: [512, 512, 512, 512, 512, 256],
128: [512, 512, 512, 512, 512, 256, 128],
256: [512, 512, 512, 512, 512, 256, 128, 64],
512: [512, 512, 512, 512, 512, 256, 128, 64, 32],
1024: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16],
}
# Variable mapping from pytorch model to official tensorflow model.
_PGGAN_PTH_VARS_TO_TF_VARS = {
'lod': 'lod', # []
'layer0.conv.weight': '4x4/Dense/weight', # [512, 512, 4, 4]
'layer0.wscale.bias': '4x4/Dense/bias', # [512]
'layer1.conv.weight': '4x4/Conv/weight', # [512, 512, 3, 3]
'layer1.wscale.bias': '4x4/Conv/bias', # [512]
'layer2.conv.weight': '8x8/Conv0/weight', # [512, 512, 3, 3]
'layer2.wscale.bias': '8x8/Conv0/bias', # [512]
'layer3.conv.weight': '8x8/Conv1/weight', # [512, 512, 3, 3]
'layer3.wscale.bias': '8x8/Conv1/bias', # [512]
'layer4.conv.weight': '16x16/Conv0/weight', # [512, 512, 3, 3]
'layer4.wscale.bias': '16x16/Conv0/bias', # [512]
'layer5.conv.weight': '16x16/Conv1/weight', # [512, 512, 3, 3]
'layer5.wscale.bias': '16x16/Conv1/bias', # [512]
'layer6.conv.weight': '32x32/Conv0/weight', # [512, 512, 3, 3]
'layer6.wscale.bias': '32x32/Conv0/bias', # [512]
'layer7.conv.weight': '32x32/Conv1/weight', # [512, 512, 3, 3]
'layer7.wscale.bias': '32x32/Conv1/bias', # [512]
'layer8.conv.weight': '64x64/Conv0/weight', # [256, 512, 3, 3]
'layer8.wscale.bias': '64x64/Conv0/bias', # [256]
'layer9.conv.weight': '64x64/Conv1/weight', # [256, 256, 3, 3]
'layer9.wscale.bias': '64x64/Conv1/bias', # [256]
'layer10.conv.weight': '128x128/Conv0/weight', # [128, 256, 3, 3]
'layer10.wscale.bias': '128x128/Conv0/bias', # [128]
'layer11.conv.weight': '128x128/Conv1/weight', # [128, 128, 3, 3]
'layer11.wscale.bias': '128x128/Conv1/bias', # [128]
'layer12.conv.weight': '256x256/Conv0/weight', # [64, 128, 3, 3]
'layer12.wscale.bias': '256x256/Conv0/bias', # [64]
'layer13.conv.weight': '256x256/Conv1/weight', # [64, 64, 3, 3]
'layer13.wscale.bias': '256x256/Conv1/bias', # [64]
'layer14.conv.weight': '512x512/Conv0/weight', # [32, 64, 3, 3]
'layer14.wscale.bias': '512x512/Conv0/bias', # [32]
'layer15.conv.weight': '512x512/Conv1/weight', # [32, 32, 3, 3]
'layer15.wscale.bias': '512x512/Conv1/bias', # [32]
'layer16.conv.weight': '1024x1024/Conv0/weight', # [16, 32, 3, 3]
'layer16.wscale.bias': '1024x1024/Conv0/bias', # [16]
'layer17.conv.weight': '1024x1024/Conv1/weight', # [16, 16, 3, 3]
'layer17.wscale.bias': '1024x1024/Conv1/bias', # [16]
'output0.conv.weight': 'ToRGB_lod8/weight', # [3, 512, 1, 1]
'output0.wscale.bias': 'ToRGB_lod8/bias', # [3]
'output1.conv.weight': 'ToRGB_lod7/weight', # [3, 512, 1, 1]
'output1.wscale.bias': 'ToRGB_lod7/bias', # [3]
'output2.conv.weight': 'ToRGB_lod6/weight', # [3, 512, 1, 1]
'output2.wscale.bias': 'ToRGB_lod6/bias', # [3]
'output3.conv.weight': 'ToRGB_lod5/weight', # [3, 512, 1, 1]
'output3.wscale.bias': 'ToRGB_lod5/bias', # [3]
'output4.conv.weight': 'ToRGB_lod4/weight', # [3, 256, 1, 1]
'output4.wscale.bias': 'ToRGB_lod4/bias', # [3]
'output5.conv.weight': 'ToRGB_lod3/weight', # [3, 128, 1, 1]
'output5.wscale.bias': 'ToRGB_lod3/bias', # [3]
'output6.conv.weight': 'ToRGB_lod2/weight', # [3, 64, 1, 1]
'output6.wscale.bias': 'ToRGB_lod2/bias', # [3]
'output7.conv.weight': 'ToRGB_lod1/weight', # [3, 32, 1, 1]
'output7.wscale.bias': 'ToRGB_lod1/bias', # [3]
'output8.conv.weight': 'ToRGB_lod0/weight', # [3, 16, 1, 1]
'output8.wscale.bias': 'ToRGB_lod0/bias', # [3]
}
class PGGANGeneratorModel(nn.Module):
"""Defines the generator module in ProgressiveGAN.
Note that the generated images are with RGB color channels with range [-1, 1].
"""
def __init__(self,
resolution=1024,
fused_scale=False,
output_channels=3):
"""Initializes the generator with basic settings.
Args:
resolution: The resolution of the final output image. (default: 1024)
fused_scale: Whether to fused `upsample` and `conv2d` together, resulting
in `conv2_transpose`. (default: False)
output_channels: Number of channels of the output image. (default: 3)
Raises:
ValueError: If the input `resolution` is not supported.
"""
super().__init__()
try:
self.channels = _RESOLUTIONS_TO_CHANNELS[resolution]
except KeyError:
raise ValueError(f'Invalid resolution: {resolution}!\n'
f'Resolutions allowed: '
f'{list(_RESOLUTIONS_TO_CHANNELS)}.')
assert len(self.channels) == int(np.log2(resolution))
self.resolution = resolution
self.fused_scale = fused_scale
self.output_channels = output_channels
for block_idx in range(1, len(self.channels)):
if block_idx == 1:
self.add_module(
f'layer{2 * block_idx - 2}',
ConvBlock(in_channels=self.channels[block_idx - 1],
out_channels=self.channels[block_idx],
kernel_size=4,
padding=3))
else:
self.add_module(
f'layer{2 * block_idx - 2}',
ConvBlock(in_channels=self.channels[block_idx - 1],
out_channels=self.channels[block_idx],
upsample=True,
fused_scale=self.fused_scale))
self.add_module(
f'layer{2 * block_idx - 1}',
ConvBlock(in_channels=self.channels[block_idx],
out_channels=self.channels[block_idx]))
self.add_module(
f'output{block_idx - 1}',
ConvBlock(in_channels=self.channels[block_idx],
out_channels=self.output_channels,
kernel_size=1,
padding=0,
wscale_gain=1.0,
activation_type='linear'))
self.upsample = ResolutionScalingLayer()
self.lod = nn.Parameter(torch.zeros(()))
self.pth_to_tf_var_mapping = {}
for pth_var_name, tf_var_name in _PGGAN_PTH_VARS_TO_TF_VARS.items():
if self.fused_scale and 'Conv0' in tf_var_name:
pth_var_name = pth_var_name.replace('conv.weight', 'weight')
tf_var_name = tf_var_name.replace('Conv0', 'Conv0_up')
self.pth_to_tf_var_mapping[pth_var_name] = tf_var_name
def forward(self, x):
if len(x.shape) != 2:
raise ValueError(f'The input tensor should be with shape [batch_size, '
f'noise_dim], but {x.shape} received!')
x = x.view(x.shape[0], x.shape[1], 1, 1)
lod = self.lod.cpu().tolist()
for block_idx in range(1, len(self.channels)):
if block_idx + lod < len(self.channels):
x = self.__getattr__(f'layer{2 * block_idx - 2}')(x)
x = self.__getattr__(f'layer{2 * block_idx - 1}')(x)
image = self.__getattr__(f'output{block_idx - 1}')(x)
else:
image = self.upsample(image)
return image
class PixelNormLayer(nn.Module):
"""Implements pixel-wise feature vector normalization layer."""
def __init__(self, epsilon=1e-8):
super().__init__()
self.epsilon = epsilon
def forward(self, x):
return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)
class ResolutionScalingLayer(nn.Module):
"""Implements the resolution scaling layer.
Basically, this layer can be used to upsample or downsample feature maps from
spatial domain with nearest neighbor interpolation.
"""
def __init__(self, scale_factor=2):
super().__init__()
self.scale_factor = scale_factor
def forward(self, x):
return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
class WScaleLayer(nn.Module):
"""Implements the layer to scale weight variable and add bias.
Note that, the weight variable is trained in `nn.Conv2d` layer, and only
scaled with a constant number, which is not trainable, in this layer. However,
the bias variable is trainable in this layer.
"""
def __init__(self, in_channels, out_channels, kernel_size, gain=np.sqrt(2.0)):
super().__init__()
fan_in = in_channels * kernel_size * kernel_size
self.scale = gain / np.sqrt(fan_in)
self.bias = nn.Parameter(torch.zeros(out_channels))
def forward(self, x):
return x * self.scale + self.bias.view(1, -1, 1, 1)
class ConvBlock(nn.Module):
"""Implements the convolutional block used in ProgressiveGAN.
Basically, this block executes pixel-wise normalization layer, upsampling
layer (if needed), convolutional layer, weight-scale layer, and activation
layer in sequence.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
add_bias=False,
upsample=False,
fused_scale=False,
wscale_gain=np.sqrt(2.0),
activation_type='lrelu'):
"""Initializes the class with block settings.
Args:
in_channels: Number of channels of the input tensor fed into this block.
out_channels: Number of channels (kernels) of the output tensor.
kernel_size: Size of the convolutional kernel.
stride: Stride parameter for convolution operation.
padding: Padding parameter for convolution operation.
dilation: Dilation rate for convolution operation.
add_bias: Whether to add bias onto the convolutional result.
upsample: Whether to upsample the input tensor before convolution.
fused_scale: Whether to fused `upsample` and `conv2d` together, resulting
in `conv2_transpose`.
wscale_gain: The gain factor for `wscale` layer.
wscale_lr_multiplier: The learning rate multiplier factor for `wscale`
layer.
activation_type: Type of activation function. Support `linear`, `lrelu`
and `tanh`.
Raises:
NotImplementedError: If the input `activation_type` is not supported.
"""
super().__init__()
self.pixel_norm = PixelNormLayer()
if upsample and not fused_scale:
self.upsample = ResolutionScalingLayer()
else:
self.upsample = nn.Identity()
if upsample and fused_scale:
self.weight = nn.Parameter(
torch.randn(kernel_size, kernel_size, in_channels, out_channels))
fan_in = in_channels * kernel_size * kernel_size
self.scale = wscale_gain / np.sqrt(fan_in)
else:
self.conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=1,
bias=add_bias)
self.wscale = WScaleLayer(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
gain=wscale_gain)
if activation_type == 'linear':
self.activate = nn.Identity()
elif activation_type == 'lrelu':
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
elif activation_type == 'tanh':
self.activate = nn.Hardtanh()
else:
raise NotImplementedError(f'Not implemented activation function: '
f'{activation_type}!')
def forward(self, x):
x = self.pixel_norm(x)
x = self.upsample(x)
if hasattr(self, 'conv'):
x = self.conv(x)
else:
kernel = self.weight * self.scale
kernel = F.pad(kernel, (0, 0, 0, 0, 1, 1, 1, 1), 'constant', 0.0)
kernel = (kernel[1:, 1:] + kernel[:-1, 1:] +
kernel[1:, :-1] + kernel[:-1, :-1])
kernel = kernel.permute(2, 3, 0, 1)
x = F.conv_transpose2d(x, kernel, stride=2, padding=1)
x = x / self.scale
x = self.wscale(x)
x = self.activate(x)
return x