Spaces:
Build error
Build error
# python3.7 | |
"""Contains the implementation of discriminator described in StyleGAN. | |
Paper: https://arxiv.org/pdf/1812.04948.pdf | |
Official TensorFlow implementation: https://github.com/NVlabs/stylegan | |
""" | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
__all__ = ['StyleGANDiscriminator'] | |
# Resolutions allowed. | |
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] | |
# Initial resolution. | |
_INIT_RES = 4 | |
# Fused-scale options allowed. | |
_FUSED_SCALE_ALLOWED = [True, False, 'auto'] | |
# Minimal resolution for `auto` fused-scale strategy. | |
_AUTO_FUSED_SCALE_MIN_RES = 128 | |
# Default gain factor for weight scaling. | |
_WSCALE_GAIN = np.sqrt(2.0) | |
class StyleGANDiscriminator(nn.Module): | |
"""Defines the discriminator network in StyleGAN. | |
NOTE: The discriminator takes images with `RGB` channel order and pixel | |
range [-1, 1] as inputs. | |
Settings for the network: | |
(1) resolution: The resolution of the input image. | |
(2) image_channels: Number of channels of the input image. (default: 3) | |
(3) label_size: Size of the additional label for conditional generation. | |
(default: 0) | |
(4) fused_scale: Whether to fused `conv2d` and `downsample` together, | |
resulting in `conv2d` with strides. (default: `auto`) | |
(5) use_wscale: Whether to use weight scaling. (default: True) | |
(6) minibatch_std_group_size: Group size for the minibatch standard | |
deviation layer. 0 means disable. (default: 4) | |
(7) minibatch_std_channels: Number of new channels after the minibatch | |
standard deviation layer. (default: 1) | |
(8) fmaps_base: Factor to control number of feature maps for each layer. | |
(default: 16 << 10) | |
(9) fmaps_max: Maximum number of feature maps in each layer. (default: 512) | |
""" | |
def __init__(self, | |
resolution, | |
image_channels=3, | |
label_size=0, | |
fused_scale='auto', | |
use_wscale=True, | |
minibatch_std_group_size=4, | |
minibatch_std_channels=1, | |
fmaps_base=16 << 10, | |
fmaps_max=512): | |
"""Initializes with basic settings. | |
Raises: | |
ValueError: If the `resolution` is not supported, or `fused_scale` | |
is not supported. | |
""" | |
super().__init__() | |
if resolution not in _RESOLUTIONS_ALLOWED: | |
raise ValueError(f'Invalid resolution: `{resolution}`!\n' | |
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') | |
if fused_scale not in _FUSED_SCALE_ALLOWED: | |
raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n' | |
f'Options allowed: {_FUSED_SCALE_ALLOWED}.') | |
self.init_res = _INIT_RES | |
self.init_res_log2 = int(np.log2(self.init_res)) | |
self.resolution = resolution | |
self.final_res_log2 = int(np.log2(self.resolution)) | |
self.image_channels = image_channels | |
self.label_size = label_size | |
self.fused_scale = fused_scale | |
self.use_wscale = use_wscale | |
self.minibatch_std_group_size = minibatch_std_group_size | |
self.minibatch_std_channels = minibatch_std_channels | |
self.fmaps_base = fmaps_base | |
self.fmaps_max = fmaps_max | |
# Level of detail (used for progressive training). | |
self.register_buffer('lod', torch.zeros(())) | |
self.pth_to_tf_var_mapping = {'lod': 'lod'} | |
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): | |
res = 2 ** res_log2 | |
block_idx = self.final_res_log2 - res_log2 | |
# Input convolution layer for each resolution. | |
self.add_module( | |
f'input{block_idx}', | |
ConvBlock(in_channels=self.image_channels, | |
out_channels=self.get_nf(res), | |
kernel_size=1, | |
padding=0, | |
use_wscale=self.use_wscale)) | |
self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = ( | |
f'FromRGB_lod{block_idx}/weight') | |
self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = ( | |
f'FromRGB_lod{block_idx}/bias') | |
# Convolution block for each resolution (except the last one). | |
if res != self.init_res: | |
if self.fused_scale == 'auto': | |
fused_scale = (res >= _AUTO_FUSED_SCALE_MIN_RES) | |
else: | |
fused_scale = self.fused_scale | |
self.add_module( | |
f'layer{2 * block_idx}', | |
ConvBlock(in_channels=self.get_nf(res), | |
out_channels=self.get_nf(res), | |
use_wscale=self.use_wscale)) | |
tf_layer0_name = 'Conv0' | |
self.add_module( | |
f'layer{2 * block_idx + 1}', | |
ConvBlock(in_channels=self.get_nf(res), | |
out_channels=self.get_nf(res // 2), | |
downsample=True, | |
fused_scale=fused_scale, | |
use_wscale=self.use_wscale)) | |
tf_layer1_name = 'Conv1_down' | |
# Convolution block for last resolution. | |
else: | |
self.add_module( | |
f'layer{2 * block_idx}', | |
ConvBlock(in_channels=self.get_nf(res), | |
out_channels=self.get_nf(res), | |
use_wscale=self.use_wscale, | |
minibatch_std_group_size=minibatch_std_group_size, | |
minibatch_std_channels=minibatch_std_channels)) | |
tf_layer0_name = 'Conv' | |
self.add_module( | |
f'layer{2 * block_idx + 1}', | |
DenseBlock(in_channels=self.get_nf(res) * res * res, | |
out_channels=self.get_nf(res // 2), | |
use_wscale=self.use_wscale)) | |
tf_layer1_name = 'Dense0' | |
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = ( | |
f'{res}x{res}/{tf_layer0_name}/weight') | |
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = ( | |
f'{res}x{res}/{tf_layer0_name}/bias') | |
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = ( | |
f'{res}x{res}/{tf_layer1_name}/weight') | |
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = ( | |
f'{res}x{res}/{tf_layer1_name}/bias') | |
# Final dense block. | |
self.add_module( | |
f'layer{2 * block_idx + 2}', | |
DenseBlock(in_channels=self.get_nf(res // 2), | |
out_channels=max(self.label_size, 1), | |
use_wscale=self.use_wscale, | |
wscale_gain=1.0, | |
activation_type='linear')) | |
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.weight'] = ( | |
f'{res}x{res}/Dense1/weight') | |
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.bias'] = ( | |
f'{res}x{res}/Dense1/bias') | |
self.downsample = DownsamplingLayer() | |
def get_nf(self, res): | |
"""Gets number of feature maps according to current resolution.""" | |
return min(self.fmaps_base // res, self.fmaps_max) | |
def forward(self, image, label=None, lod=None, **_unused_kwargs): | |
expected_shape = (self.image_channels, self.resolution, self.resolution) | |
if image.ndim != 4 or image.shape[1:] != expected_shape: | |
raise ValueError(f'The input tensor should be with shape ' | |
f'[batch_size, channel, height, width], where ' | |
f'`channel` equals to {self.image_channels}, ' | |
f'`height`, `width` equal to {self.resolution}!\n' | |
f'But `{image.shape}` is received!') | |
lod = self.lod.cpu().tolist() if lod is None else lod | |
if lod + self.init_res_log2 > self.final_res_log2: | |
raise ValueError(f'Maximum level-of-detail (lod) is ' | |
f'{self.final_res_log2 - self.init_res_log2}, ' | |
f'but `{lod}` is received!') | |
if self.label_size: | |
if label is None: | |
raise ValueError(f'Model requires an additional label ' | |
f'(with size {self.label_size}) as input, ' | |
f'but no label is received!') | |
batch_size = image.shape[0] | |
if label.ndim != 2 or label.shape != (batch_size, self.label_size): | |
raise ValueError(f'Input label should be with shape ' | |
f'[batch_size, label_size], where ' | |
f'`batch_size` equals to that of ' | |
f'images ({image.shape[0]}) and ' | |
f'`label_size` equals to {self.label_size}!\n' | |
f'But `{label.shape}` is received!') | |
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): | |
block_idx = current_lod = self.final_res_log2 - res_log2 | |
if current_lod <= lod < current_lod + 1: | |
x = self.__getattr__(f'input{block_idx}')(image) | |
elif current_lod - 1 < lod < current_lod: | |
alpha = lod - np.floor(lod) | |
x = (self.__getattr__(f'input{block_idx}')(image) * alpha + | |
x * (1 - alpha)) | |
if lod < current_lod + 1: | |
x = self.__getattr__(f'layer{2 * block_idx}')(x) | |
x = self.__getattr__(f'layer{2 * block_idx + 1}')(x) | |
if lod > current_lod: | |
image = self.downsample(image) | |
x = self.__getattr__(f'layer{2 * block_idx + 2}')(x) | |
if self.label_size: | |
x = torch.sum(x * label, dim=1, keepdim=True) | |
return x | |
class MiniBatchSTDLayer(nn.Module): | |
"""Implements the minibatch standard deviation layer.""" | |
def __init__(self, group_size=4, new_channels=1, epsilon=1e-8): | |
super().__init__() | |
self.group_size = group_size | |
self.new_channels = new_channels | |
self.epsilon = epsilon | |
def forward(self, x): | |
if self.group_size <= 1: | |
return x | |
ng = min(self.group_size, x.shape[0]) | |
nc = self.new_channels | |
temp_c = x.shape[1] // nc # [NCHW] | |
y = x.view(ng, -1, nc, temp_c, x.shape[2], x.shape[3]) # [GMncHW] | |
y = y - torch.mean(y, dim=0, keepdim=True) # [GMncHW] | |
y = torch.mean(y ** 2, dim=0) # [MncHW] | |
y = torch.sqrt(y + self.epsilon) # [MncHW] | |
y = torch.mean(y, dim=[2, 3, 4], keepdim=True) # [Mn111] | |
y = torch.mean(y, dim=2) # [Mn11] | |
y = y.repeat(ng, 1, x.shape[2], x.shape[3]) # [NnHW] | |
return torch.cat([x, y], dim=1) | |
class DownsamplingLayer(nn.Module): | |
"""Implements the downsampling layer. | |
Basically, this layer can be used to downsample feature maps with average | |
pooling. | |
""" | |
def __init__(self, scale_factor=2): | |
super().__init__() | |
self.scale_factor = scale_factor | |
def forward(self, x): | |
if self.scale_factor <= 1: | |
return x | |
return F.avg_pool2d(x, | |
kernel_size=self.scale_factor, | |
stride=self.scale_factor, | |
padding=0) | |
class Blur(torch.autograd.Function): | |
"""Defines blur operation with customized gradient computation.""" | |
def forward(ctx, x, kernel): | |
ctx.save_for_backward(kernel) | |
y = F.conv2d(input=x, | |
weight=kernel, | |
bias=None, | |
stride=1, | |
padding=1, | |
groups=x.shape[1]) | |
return y | |
def backward(ctx, dy): | |
kernel, = ctx.saved_tensors | |
dx = BlurBackPropagation.apply(dy, kernel) | |
return dx, None, None | |
class BlurBackPropagation(torch.autograd.Function): | |
"""Defines the back propagation of blur operation. | |
NOTE: This is used to speed up the backward of gradient penalty. | |
""" | |
def forward(ctx, dy, kernel): | |
ctx.save_for_backward(kernel) | |
dx = F.conv2d(input=dy, | |
weight=kernel.flip((2, 3)), | |
bias=None, | |
stride=1, | |
padding=1, | |
groups=dy.shape[1]) | |
return dx | |
def backward(ctx, ddx): | |
kernel, = ctx.saved_tensors | |
ddy = F.conv2d(input=ddx, | |
weight=kernel, | |
bias=None, | |
stride=1, | |
padding=1, | |
groups=ddx.shape[1]) | |
return ddy, None, None | |
class BlurLayer(nn.Module): | |
"""Implements the blur layer.""" | |
def __init__(self, | |
channels, | |
kernel=(1, 2, 1), | |
normalize=True): | |
super().__init__() | |
kernel = np.array(kernel, dtype=np.float32).reshape(1, -1) | |
kernel = kernel.T.dot(kernel) | |
if normalize: | |
kernel = kernel / np.sum(kernel) | |
kernel = kernel[np.newaxis, np.newaxis] | |
kernel = np.tile(kernel, [channels, 1, 1, 1]) | |
self.register_buffer('kernel', torch.from_numpy(kernel)) | |
def forward(self, x): | |
return Blur.apply(x, self.kernel) | |
class ConvBlock(nn.Module): | |
"""Implements the convolutional block. | |
Basically, this block executes minibatch standard deviation layer (if | |
needed), convolutional layer, activation layer, and downsampling layer ( | |
if needed) in sequence. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
add_bias=True, | |
downsample=False, | |
fused_scale=False, | |
use_wscale=True, | |
wscale_gain=_WSCALE_GAIN, | |
lr_mul=1.0, | |
activation_type='lrelu', | |
minibatch_std_group_size=0, | |
minibatch_std_channels=1): | |
"""Initializes with block settings. | |
Args: | |
in_channels: Number of channels of the input tensor. | |
out_channels: Number of channels of the output tensor. | |
kernel_size: Size of the convolutional kernels. (default: 3) | |
stride: Stride parameter for convolution operation. (default: 1) | |
padding: Padding parameter for convolution operation. (default: 1) | |
add_bias: Whether to add bias onto the convolutional result. | |
(default: True) | |
downsample: Whether to downsample the result after convolution. | |
(default: False) | |
fused_scale: Whether to fused `conv2d` and `downsample` together, | |
resulting in `conv2d` with strides. (default: False) | |
use_wscale: Whether to use weight scaling. (default: True) | |
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN) | |
lr_mul: Learning multiplier for both weight and bias. (default: 1.0) | |
activation_type: Type of activation. Support `linear` and `lrelu`. | |
(default: `lrelu`) | |
minibatch_std_group_size: Group size for the minibatch standard | |
deviation layer. 0 means disable. (default: 0) | |
minibatch_std_channels: Number of new channels after the minibatch | |
standard deviation layer. (default: 1) | |
Raises: | |
NotImplementedError: If the `activation_type` is not supported. | |
""" | |
super().__init__() | |
if minibatch_std_group_size > 1: | |
in_channels = in_channels + minibatch_std_channels | |
self.mbstd = MiniBatchSTDLayer(group_size=minibatch_std_group_size, | |
new_channels=minibatch_std_channels) | |
else: | |
self.mbstd = nn.Identity() | |
if downsample: | |
self.blur = BlurLayer(channels=in_channels) | |
else: | |
self.blur = nn.Identity() | |
if downsample and not fused_scale: | |
self.downsample = DownsamplingLayer() | |
else: | |
self.downsample = nn.Identity() | |
if downsample and fused_scale: | |
self.use_stride = True | |
self.stride = 2 | |
self.padding = 1 | |
else: | |
self.use_stride = False | |
self.stride = stride | |
self.padding = padding | |
weight_shape = (out_channels, in_channels, kernel_size, kernel_size) | |
fan_in = kernel_size * kernel_size * in_channels | |
wscale = wscale_gain / np.sqrt(fan_in) | |
if use_wscale: | |
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) | |
self.wscale = wscale * lr_mul | |
else: | |
self.weight = nn.Parameter( | |
torch.randn(*weight_shape) * wscale / lr_mul) | |
self.wscale = lr_mul | |
if add_bias: | |
self.bias = nn.Parameter(torch.zeros(out_channels)) | |
self.bscale = lr_mul | |
else: | |
self.bias = None | |
if activation_type == 'linear': | |
self.activate = nn.Identity() | |
elif activation_type == 'lrelu': | |
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
else: | |
raise NotImplementedError(f'Not implemented activation function: ' | |
f'`{activation_type}`!') | |
def forward(self, x): | |
x = self.mbstd(x) | |
x = self.blur(x) | |
weight = self.weight * self.wscale | |
bias = self.bias * self.bscale if self.bias is not None else None | |
if self.use_stride: | |
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0) | |
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] + | |
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25 | |
x = F.conv2d(x, | |
weight=weight, | |
bias=bias, | |
stride=self.stride, | |
padding=self.padding) | |
x = self.downsample(x) | |
x = self.activate(x) | |
return x | |
class DenseBlock(nn.Module): | |
"""Implements the dense block. | |
Basically, this block executes fully-connected layer and activation layer. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
add_bias=True, | |
use_wscale=True, | |
wscale_gain=_WSCALE_GAIN, | |
lr_mul=1.0, | |
activation_type='lrelu'): | |
"""Initializes with block settings. | |
Args: | |
in_channels: Number of channels of the input tensor. | |
out_channels: Number of channels of the output tensor. | |
add_bias: Whether to add bias onto the fully-connected result. | |
(default: True) | |
use_wscale: Whether to use weight scaling. (default: True) | |
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN) | |
lr_mul: Learning multiplier for both weight and bias. (default: 1.0) | |
activation_type: Type of activation. Support `linear` and `lrelu`. | |
(default: `lrelu`) | |
Raises: | |
NotImplementedError: If the `activation_type` is not supported. | |
""" | |
super().__init__() | |
weight_shape = (out_channels, in_channels) | |
wscale = wscale_gain / np.sqrt(in_channels) | |
if use_wscale: | |
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) | |
self.wscale = wscale * lr_mul | |
else: | |
self.weight = nn.Parameter( | |
torch.randn(*weight_shape) * wscale / lr_mul) | |
self.wscale = lr_mul | |
if add_bias: | |
self.bias = nn.Parameter(torch.zeros(out_channels)) | |
self.bscale = lr_mul | |
else: | |
self.bias = None | |
if activation_type == 'linear': | |
self.activate = nn.Identity() | |
elif activation_type == 'lrelu': | |
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
else: | |
raise NotImplementedError(f'Not implemented activation function: ' | |
f'`{activation_type}`!') | |
def forward(self, x): | |
if x.ndim != 2: | |
x = x.view(x.shape[0], -1) | |
bias = self.bias * self.bscale if self.bias is not None else None | |
x = F.linear(x, weight=self.weight * self.wscale, bias=bias) | |
x = self.activate(x) | |
return x | |