|
|
|
"""Contains the implementation of generator described in StyleGAN2. |
|
|
|
Compared to that of StyleGAN, the generator in StyleGAN2 mainly introduces style |
|
demodulation, adds skip connections, increases model size, and disables |
|
progressive growth. This script ONLY supports config F in the original paper. |
|
|
|
Paper: https://arxiv.org/pdf/1912.04958.pdf |
|
|
|
Official TensorFlow implementation: https://github.com/NVlabs/stylegan2 |
|
""" |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from third_party.stylegan2_official_ops import fma |
|
from third_party.stylegan2_official_ops import bias_act |
|
from third_party.stylegan2_official_ops import upfirdn2d |
|
from third_party.stylegan2_official_ops import conv2d_gradfix |
|
from .utils.ops import all_gather |
|
|
|
__all__ = ['StyleGAN2Generator'] |
|
|
|
|
|
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] |
|
|
|
|
|
_ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin'] |
|
|
|
|
|
|
|
class StyleGAN2Generator(nn.Module): |
|
"""Defines the generator network in StyleGAN2. |
|
|
|
NOTE: The synthesized images are with `RGB` channel order and pixel range |
|
[-1, 1]. |
|
|
|
Settings for the mapping network: |
|
|
|
(1) z_dim: Dimension of the input latent space, Z. (default: 512) |
|
(2) w_dim: Dimension of the output latent space, W. (default: 512) |
|
(3) repeat_w: Repeat w-code for different layers. (default: True) |
|
(4) normalize_z: Whether to normalize the z-code. (default: True) |
|
(5) mapping_layers: Number of layers of the mapping network. (default: 8) |
|
(6) mapping_fmaps: Number of hidden channels of the mapping network. |
|
(default: 512) |
|
(7) mapping_use_wscale: Whether to use weight scaling for the mapping |
|
network. (default: True) |
|
(8) mapping_wscale_gain: The factor to control weight scaling for the |
|
mapping network (default: 1.0) |
|
(9) mapping_lr_mul: Learning rate multiplier for the mapping network. |
|
(default: 0.01) |
|
|
|
Settings for conditional generation: |
|
|
|
(1) label_dim: Dimension of the additional label for conditional generation. |
|
In one-hot conditioning case, it is equal to the number of classes. If |
|
set to 0, conditioning training will be disabled. (default: 0) |
|
(2) embedding_dim: Dimension of the embedding space, if needed. |
|
(default: 512) |
|
(3) embedding_bias: Whether to add bias to embedding learning. |
|
(default: True) |
|
(4) embedding_use_wscale: Whether to use weight scaling for embedding |
|
learning. (default: True) |
|
(5) embedding_wscale_gain: The factor to control weight scaling for |
|
embedding. (default: 1.0) |
|
(6) embedding_lr_mul: Learning rate multiplier for the embedding learning. |
|
(default: 1.0) |
|
(7) normalize_embedding: Whether to normalize the embedding. (default: True) |
|
(8) normalize_embedding_latent: Whether to normalize the embedding together |
|
with the latent. (default: False) |
|
|
|
Settings for the synthesis network: |
|
|
|
(1) resolution: The resolution of the output image. (default: -1) |
|
(2) init_res: The initial resolution to start with convolution. (default: 4) |
|
(3) image_channels: Number of channels of the output image. (default: 3) |
|
(4) final_tanh: Whether to use `tanh` to control the final pixel range. |
|
(default: False) |
|
(5) const_input: Whether to use a constant in the first convolutional layer. |
|
(default: True) |
|
(6) architecture: Type of architecture. Support `origin`, `skip`, and |
|
`resnet`. (default: `skip`) |
|
(7) demodulate: Whether to perform style demodulation. (default: True) |
|
(8) use_wscale: Whether to use weight scaling. (default: True) |
|
(9) wscale_gain: The factor to control weight scaling. (default: 1.0) |
|
(10) lr_mul: Learning rate multiplier for the synthesis network. |
|
(default: 1.0) |
|
(11) noise_type: Type of noise added to the convolutional results at each |
|
layer. (default: `spatial`) |
|
(12) fmaps_base: Factor to control number of feature maps for each layer. |
|
(default: 32 << 10) |
|
(13) fmaps_max: Maximum number of feature maps in each layer. (default: 512) |
|
(14) filter_kernel: Kernel used for filtering (e.g., downsampling). |
|
(default: (1, 3, 3, 1)) |
|
(15) conv_clamp: A threshold to clamp the output of convolution layers to |
|
avoid overflow under FP16 training. (default: None) |
|
(16) eps: A small value to avoid divide overflow. (default: 1e-8) |
|
|
|
Runtime settings: |
|
|
|
(1) w_moving_decay: Decay factor for updating `w_avg`, which is used for |
|
training only. Set `None` to disable. (default: None) |
|
(2) sync_w_avg: Synchronizing the stats of `w_avg` across replicas. If set |
|
as `True`, the stats will be more accurate, yet the speed maybe a little |
|
bit slower. (default: False) |
|
(3) style_mixing_prob: Probability to perform style mixing as a training |
|
regularization. Set `None` to disable. (default: None) |
|
(4) trunc_psi: Truncation psi, set `None` to disable. (default: None) |
|
(5) trunc_layers: Number of layers to perform truncation. (default: None) |
|
(6) noise_mode: Mode of the layer-wise noise. Support `none`, `random`, |
|
`const`. (default: `const`) |
|
(7) fused_modulate: Whether to fuse `style_modulate` and `conv2d` together. |
|
(default: False) |
|
(8) fp16_res: Layers at resolution higher than (or equal to) this field will |
|
use `float16` precision for computation. This is merely used for |
|
acceleration. If set as `None`, all layers will use `float32` by |
|
default. (default: None) |
|
(9) impl: Implementation mode of some particular ops, e.g., `filtering`, |
|
`bias_act`, etc. `cuda` means using the official CUDA implementation |
|
from StyleGAN2, while `ref` means using the native PyTorch ops. |
|
(default: `cuda`) |
|
""" |
|
|
|
def __init__(self, |
|
|
|
z_dim=512, |
|
w_dim=512, |
|
repeat_w=True, |
|
normalize_z=True, |
|
mapping_layers=8, |
|
mapping_fmaps=512, |
|
mapping_use_wscale=True, |
|
mapping_wscale_gain=1.0, |
|
mapping_lr_mul=0.01, |
|
|
|
label_dim=0, |
|
embedding_dim=512, |
|
embedding_bias=True, |
|
embedding_use_wscale=True, |
|
embedding_wscale_gian=1.0, |
|
embedding_lr_mul=1.0, |
|
normalize_embedding=True, |
|
normalize_embedding_latent=False, |
|
|
|
resolution=-1, |
|
init_res=4, |
|
image_channels=3, |
|
final_tanh=False, |
|
const_input=True, |
|
architecture='skip', |
|
demodulate=True, |
|
use_wscale=True, |
|
wscale_gain=1.0, |
|
lr_mul=1.0, |
|
noise_type='spatial', |
|
fmaps_base=32 << 10, |
|
fmaps_max=512, |
|
filter_kernel=(1, 3, 3, 1), |
|
conv_clamp=None, |
|
eps=1e-8): |
|
"""Initializes with basic settings. |
|
|
|
Raises: |
|
ValueError: If the `resolution` is not supported, or `architecture` |
|
is not supported. |
|
""" |
|
super().__init__() |
|
|
|
if resolution not in _RESOLUTIONS_ALLOWED: |
|
raise ValueError(f'Invalid resolution: `{resolution}`!\n' |
|
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') |
|
architecture = architecture.lower() |
|
if architecture not in _ARCHITECTURES_ALLOWED: |
|
raise ValueError(f'Invalid architecture: `{architecture}`!\n' |
|
f'Architectures allowed: ' |
|
f'{_ARCHITECTURES_ALLOWED}.') |
|
|
|
self.z_dim = z_dim |
|
self.w_dim = w_dim |
|
self.repeat_w = repeat_w |
|
self.normalize_z = normalize_z |
|
self.mapping_layers = mapping_layers |
|
self.mapping_fmaps = mapping_fmaps |
|
self.mapping_use_wscale = mapping_use_wscale |
|
self.mapping_wscale_gain = mapping_wscale_gain |
|
self.mapping_lr_mul = mapping_lr_mul |
|
|
|
self.label_dim = label_dim |
|
self.embedding_dim = embedding_dim |
|
self.embedding_bias = embedding_bias |
|
self.embedding_use_wscale = embedding_use_wscale |
|
self.embedding_wscale_gain = embedding_wscale_gian |
|
self.embedding_lr_mul = embedding_lr_mul |
|
self.normalize_embedding = normalize_embedding |
|
self.normalize_embedding_latent = normalize_embedding_latent |
|
|
|
self.resolution = resolution |
|
self.init_res = init_res |
|
self.image_channels = image_channels |
|
self.final_tanh = final_tanh |
|
self.const_input = const_input |
|
self.architecture = architecture |
|
self.demodulate = demodulate |
|
self.use_wscale = use_wscale |
|
self.wscale_gain = wscale_gain |
|
self.lr_mul = lr_mul |
|
self.noise_type = noise_type.lower() |
|
self.fmaps_base = fmaps_base |
|
self.fmaps_max = fmaps_max |
|
self.filter_kernel = filter_kernel |
|
self.conv_clamp = conv_clamp |
|
self.eps = eps |
|
|
|
|
|
self.latent_dim = (z_dim,) |
|
|
|
|
|
self.num_layers = int(np.log2(resolution // init_res * 2)) * 2 |
|
|
|
self.mapping = MappingNetwork( |
|
input_dim=z_dim, |
|
output_dim=w_dim, |
|
num_outputs=self.num_layers, |
|
repeat_output=repeat_w, |
|
normalize_input=normalize_z, |
|
num_layers=mapping_layers, |
|
hidden_dim=mapping_fmaps, |
|
use_wscale=mapping_use_wscale, |
|
wscale_gain=mapping_wscale_gain, |
|
lr_mul=mapping_lr_mul, |
|
label_dim=label_dim, |
|
embedding_dim=embedding_dim, |
|
embedding_bias=embedding_bias, |
|
embedding_use_wscale=embedding_use_wscale, |
|
embedding_wscale_gian=embedding_wscale_gian, |
|
embedding_lr_mul=embedding_lr_mul, |
|
normalize_embedding=normalize_embedding, |
|
normalize_embedding_latent=normalize_embedding_latent, |
|
eps=eps) |
|
|
|
|
|
if self.repeat_w: |
|
self.register_buffer('w_avg', torch.zeros(w_dim)) |
|
else: |
|
self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) |
|
|
|
self.synthesis = SynthesisNetwork(resolution=resolution, |
|
init_res=init_res, |
|
w_dim=w_dim, |
|
image_channels=image_channels, |
|
final_tanh=final_tanh, |
|
const_input=const_input, |
|
architecture=architecture, |
|
demodulate=demodulate, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
noise_type=noise_type, |
|
fmaps_base=fmaps_base, |
|
filter_kernel=filter_kernel, |
|
fmaps_max=fmaps_max, |
|
conv_clamp=conv_clamp, |
|
eps=eps) |
|
|
|
self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'} |
|
for key, val in self.mapping.pth_to_tf_var_mapping.items(): |
|
self.pth_to_tf_var_mapping[f'mapping.{key}'] = val |
|
for key, val in self.synthesis.pth_to_tf_var_mapping.items(): |
|
self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val |
|
|
|
def set_space_of_latent(self, space_of_latent): |
|
"""Sets the space to which the latent code belong. |
|
|
|
See `SynthesisNetwork` for more details. |
|
""" |
|
self.synthesis.set_space_of_latent(space_of_latent) |
|
|
|
def forward(self, |
|
z, |
|
label=None, |
|
w_moving_decay=None, |
|
sync_w_avg=False, |
|
style_mixing_prob=None, |
|
trunc_psi=None, |
|
trunc_layers=None, |
|
noise_mode='const', |
|
fused_modulate=False, |
|
fp16_res=None, |
|
impl='cuda'): |
|
"""Connects mapping network and synthesis network. |
|
|
|
This forward function will also update the average `w_code`, perform |
|
style mixing as a training regularizer, and do truncation trick, which |
|
is specially designed for inference. |
|
|
|
Concretely, the truncation trick acts as follows: |
|
|
|
For layers in range [0, truncation_layers), the truncated w-code is |
|
computed as |
|
|
|
w_new = w_avg + (w - w_avg) * trunc_psi |
|
|
|
To disable truncation, please set |
|
|
|
(1) trunc_psi = 1.0 (None) OR |
|
(2) trunc_layers = 0 (None) |
|
""" |
|
|
|
mapping_results = self.mapping(z, label, impl=impl) |
|
|
|
w = mapping_results['w'] |
|
if self.training and w_moving_decay is not None: |
|
if sync_w_avg: |
|
batch_w_avg = all_gather(w.detach()).mean(dim=0) |
|
else: |
|
batch_w_avg = w.detach().mean(dim=0) |
|
self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) |
|
|
|
wp = mapping_results.pop('wp') |
|
if self.training and style_mixing_prob is not None: |
|
if np.random.uniform() < style_mixing_prob: |
|
new_z = torch.randn_like(z) |
|
new_wp = self.mapping(new_z, label, impl=impl)['wp'] |
|
mixing_cutoff = np.random.randint(1, self.num_layers) |
|
wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:] |
|
|
|
if not self.training: |
|
trunc_psi = 1.0 if trunc_psi is None else trunc_psi |
|
trunc_layers = 0 if trunc_layers is None else trunc_layers |
|
if trunc_psi < 1.0 and trunc_layers > 0: |
|
w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] |
|
wp[:, :trunc_layers] = w_avg.lerp( |
|
wp[:, :trunc_layers], trunc_psi) |
|
|
|
synthesis_results = self.synthesis(wp, |
|
noise_mode=noise_mode, |
|
fused_modulate=fused_modulate, |
|
impl=impl, |
|
fp16_res=fp16_res) |
|
|
|
return {**mapping_results, **synthesis_results} |
|
|
|
|
|
class MappingNetwork(nn.Module): |
|
"""Implements the latent space mapping network. |
|
|
|
Basically, this network executes several dense layers in sequence, and the |
|
label embedding if needed. |
|
""" |
|
|
|
def __init__(self, |
|
input_dim, |
|
output_dim, |
|
num_outputs, |
|
repeat_output, |
|
normalize_input, |
|
num_layers, |
|
hidden_dim, |
|
use_wscale, |
|
wscale_gain, |
|
lr_mul, |
|
label_dim, |
|
embedding_dim, |
|
embedding_bias, |
|
embedding_use_wscale, |
|
embedding_wscale_gian, |
|
embedding_lr_mul, |
|
normalize_embedding, |
|
normalize_embedding_latent, |
|
eps): |
|
super().__init__() |
|
|
|
self.input_dim = input_dim |
|
self.output_dim = output_dim |
|
self.num_outputs = num_outputs |
|
self.repeat_output = repeat_output |
|
self.normalize_input = normalize_input |
|
self.num_layers = num_layers |
|
self.hidden_dim = hidden_dim |
|
self.use_wscale = use_wscale |
|
self.wscale_gain = wscale_gain |
|
self.lr_mul = lr_mul |
|
self.label_dim = label_dim |
|
self.embedding_dim = embedding_dim |
|
self.embedding_bias = embedding_bias |
|
self.embedding_use_wscale = embedding_use_wscale |
|
self.embedding_wscale_gian = embedding_wscale_gian |
|
self.embedding_lr_mul = embedding_lr_mul |
|
self.normalize_embedding = normalize_embedding |
|
self.normalize_embedding_latent = normalize_embedding_latent |
|
self.eps = eps |
|
|
|
self.pth_to_tf_var_mapping = {} |
|
|
|
self.norm = PixelNormLayer(dim=1, eps=eps) |
|
|
|
if self.label_dim > 0: |
|
input_dim = input_dim + embedding_dim |
|
self.embedding = DenseLayer(in_channels=label_dim, |
|
out_channels=embedding_dim, |
|
add_bias=embedding_bias, |
|
init_bias=0.0, |
|
use_wscale=embedding_use_wscale, |
|
wscale_gain=embedding_wscale_gian, |
|
lr_mul=embedding_lr_mul, |
|
activation_type='linear') |
|
self.pth_to_tf_var_mapping['embedding.weight'] = 'LabelEmbed/weight' |
|
if self.embedding_bias: |
|
self.pth_to_tf_var_mapping['embedding.bias'] = 'LabelEmbed/bias' |
|
|
|
if num_outputs is not None and not repeat_output: |
|
output_dim = output_dim * num_outputs |
|
for i in range(num_layers): |
|
in_channels = (input_dim if i == 0 else hidden_dim) |
|
out_channels = (output_dim if i == (num_layers - 1) else hidden_dim) |
|
self.add_module(f'dense{i}', |
|
DenseLayer(in_channels=in_channels, |
|
out_channels=out_channels, |
|
add_bias=True, |
|
init_bias=0.0, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
activation_type='lrelu')) |
|
self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight' |
|
self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias' |
|
|
|
def forward(self, z, label=None, impl='cuda'): |
|
if z.ndim != 2 or z.shape[1] != self.input_dim: |
|
raise ValueError(f'Input latent code should be with shape ' |
|
f'[batch_size, input_dim], where ' |
|
f'`input_dim` equals to {self.input_dim}!\n' |
|
f'But `{z.shape}` is received!') |
|
if self.normalize_input: |
|
z = self.norm(z) |
|
|
|
if self.label_dim > 0: |
|
if label is None: |
|
raise ValueError(f'Model requires an additional label ' |
|
f'(with dimension {self.label_dim}) as input, ' |
|
f'but no label is received!') |
|
if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim): |
|
raise ValueError(f'Input label should be with shape ' |
|
f'[batch_size, label_dim], where ' |
|
f'`batch_size` equals to that of ' |
|
f'latent codes ({z.shape[0]}) and ' |
|
f'`label_dim` equals to {self.label_dim}!\n' |
|
f'But `{label.shape}` is received!') |
|
label = label.to(dtype=torch.float32) |
|
embedding = self.embedding(label, impl=impl) |
|
if self.normalize_embedding: |
|
embedding = self.norm(embedding) |
|
w = torch.cat((z, embedding), dim=1) |
|
else: |
|
w = z |
|
|
|
if self.label_dim > 0 and self.normalize_embedding_latent: |
|
w = self.norm(w) |
|
|
|
for i in range(self.num_layers): |
|
w = getattr(self, f'dense{i}')(w, impl=impl) |
|
|
|
wp = None |
|
if self.num_outputs is not None: |
|
if self.repeat_output: |
|
wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1)) |
|
else: |
|
wp = w.reshape(-1, self.num_outputs, self.output_dim) |
|
|
|
results = { |
|
'z': z, |
|
'label': label, |
|
'w': w, |
|
'wp': wp, |
|
} |
|
if self.label_dim > 0: |
|
results['embedding'] = embedding |
|
return results |
|
|
|
|
|
class SynthesisNetwork(nn.Module): |
|
"""Implements the image synthesis network. |
|
|
|
Basically, this network executes several convolutional layers in sequence. |
|
""" |
|
|
|
def __init__(self, |
|
resolution, |
|
init_res, |
|
w_dim, |
|
image_channels, |
|
final_tanh, |
|
const_input, |
|
architecture, |
|
demodulate, |
|
use_wscale, |
|
wscale_gain, |
|
lr_mul, |
|
noise_type, |
|
fmaps_base, |
|
fmaps_max, |
|
filter_kernel, |
|
conv_clamp, |
|
eps): |
|
super().__init__() |
|
|
|
self.init_res = init_res |
|
self.init_res_log2 = int(np.log2(init_res)) |
|
self.resolution = resolution |
|
self.final_res_log2 = int(np.log2(resolution)) |
|
self.w_dim = w_dim |
|
self.image_channels = image_channels |
|
self.final_tanh = final_tanh |
|
self.const_input = const_input |
|
self.architecture = architecture.lower() |
|
self.demodulate = demodulate |
|
self.use_wscale = use_wscale |
|
self.wscale_gain = wscale_gain |
|
self.lr_mul = lr_mul |
|
self.noise_type = noise_type.lower() |
|
self.fmaps_base = fmaps_base |
|
self.fmaps_max = fmaps_max |
|
self.filter_kernel = filter_kernel |
|
self.conv_clamp = conv_clamp |
|
self.eps = eps |
|
|
|
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2 |
|
|
|
self.pth_to_tf_var_mapping = {} |
|
|
|
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): |
|
res = 2 ** res_log2 |
|
in_channels = self.get_nf(res // 2) |
|
out_channels = self.get_nf(res) |
|
block_idx = res_log2 - self.init_res_log2 |
|
|
|
|
|
if res == init_res: |
|
if self.const_input: |
|
self.add_module('early_layer', |
|
InputLayer(init_res=res, |
|
channels=out_channels)) |
|
self.pth_to_tf_var_mapping['early_layer.const'] = ( |
|
f'{res}x{res}/Const/const') |
|
else: |
|
channels = out_channels * res * res |
|
self.add_module('early_layer', |
|
DenseLayer(in_channels=w_dim, |
|
out_channels=channels, |
|
add_bias=True, |
|
init_bias=0.0, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
activation_type='lrelu')) |
|
self.pth_to_tf_var_mapping['early_layer.weight'] = ( |
|
f'{res}x{res}/Dense/weight') |
|
self.pth_to_tf_var_mapping['early_layer.bias'] = ( |
|
f'{res}x{res}/Dense/bias') |
|
else: |
|
|
|
|
|
if self.architecture == 'resnet': |
|
layer_name = f'residual{block_idx}' |
|
self.add_module(layer_name, |
|
ConvLayer(in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=1, |
|
add_bias=False, |
|
scale_factor=2, |
|
filter_kernel=filter_kernel, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
activation_type='linear', |
|
conv_clamp=None)) |
|
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
|
f'{res}x{res}/Skip/weight') |
|
|
|
|
|
layer_name = f'layer{2 * block_idx - 1}' |
|
self.add_module(layer_name, |
|
ModulateConvLayer(in_channels=in_channels, |
|
out_channels=out_channels, |
|
resolution=res, |
|
w_dim=w_dim, |
|
kernel_size=3, |
|
add_bias=True, |
|
scale_factor=2, |
|
filter_kernel=filter_kernel, |
|
demodulate=demodulate, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
noise_type=noise_type, |
|
activation_type='lrelu', |
|
conv_clamp=conv_clamp, |
|
eps=eps)) |
|
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
|
f'{res}x{res}/Conv0_up/weight') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( |
|
f'{res}x{res}/Conv0_up/bias') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = ( |
|
f'{res}x{res}/Conv0_up/mod_weight') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = ( |
|
f'{res}x{res}/Conv0_up/mod_bias') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = ( |
|
f'{res}x{res}/Conv0_up/noise_strength') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = ( |
|
f'noise{2 * block_idx - 1}') |
|
|
|
|
|
layer_name = f'layer{2 * block_idx}' |
|
self.add_module(layer_name, |
|
ModulateConvLayer(in_channels=out_channels, |
|
out_channels=out_channels, |
|
resolution=res, |
|
w_dim=w_dim, |
|
kernel_size=3, |
|
add_bias=True, |
|
scale_factor=1, |
|
filter_kernel=None, |
|
demodulate=demodulate, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
noise_type=noise_type, |
|
activation_type='lrelu', |
|
conv_clamp=conv_clamp, |
|
eps=eps)) |
|
tf_layer_name = 'Conv' if res == self.init_res else 'Conv1' |
|
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
|
f'{res}x{res}/{tf_layer_name}/weight') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( |
|
f'{res}x{res}/{tf_layer_name}/bias') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = ( |
|
f'{res}x{res}/{tf_layer_name}/mod_weight') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = ( |
|
f'{res}x{res}/{tf_layer_name}/mod_bias') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = ( |
|
f'{res}x{res}/{tf_layer_name}/noise_strength') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = ( |
|
f'noise{2 * block_idx}') |
|
|
|
|
|
if res_log2 == self.final_res_log2 or self.architecture == 'skip': |
|
layer_name = f'output{block_idx}' |
|
self.add_module(layer_name, |
|
ModulateConvLayer(in_channels=out_channels, |
|
out_channels=image_channels, |
|
resolution=res, |
|
w_dim=w_dim, |
|
kernel_size=1, |
|
add_bias=True, |
|
scale_factor=1, |
|
filter_kernel=None, |
|
demodulate=False, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
noise_type='none', |
|
activation_type='linear', |
|
conv_clamp=conv_clamp, |
|
eps=eps)) |
|
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
|
f'{res}x{res}/ToRGB/weight') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( |
|
f'{res}x{res}/ToRGB/bias') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = ( |
|
f'{res}x{res}/ToRGB/mod_weight') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = ( |
|
f'{res}x{res}/ToRGB/mod_bias') |
|
|
|
|
|
if self.architecture == 'skip': |
|
self.register_buffer( |
|
'filter', upfirdn2d.setup_filter(filter_kernel)) |
|
|
|
def get_nf(self, res): |
|
"""Gets number of feature maps according to the given resolution.""" |
|
return min(self.fmaps_base // res, self.fmaps_max) |
|
|
|
def set_space_of_latent(self, space_of_latent): |
|
"""Sets the space to which the latent code belong. |
|
|
|
This function is particularly used for choosing how to inject the latent |
|
code into the convolutional layers. The original generator will take a |
|
W-Space code and apply it for style modulation after an affine |
|
transformation. But, sometimes, it may need to directly feed an already |
|
affine-transformed code into the convolutional layer, e.g., when |
|
training an encoder for GAN inversion. We term the transformed space as |
|
Style Space (or Y-Space). This function is designed to tell the |
|
convolutional layers how to use the input code. |
|
|
|
Args: |
|
space_of_latent: The space to which the latent code belong. Case |
|
insensitive. Support `W` and `Y`. |
|
""" |
|
space_of_latent = space_of_latent.upper() |
|
for module in self.modules(): |
|
if isinstance(module, ModulateConvLayer): |
|
setattr(module, 'space_of_latent', space_of_latent) |
|
|
|
def forward(self, |
|
wp, |
|
noise_mode='const', |
|
fused_modulate=False, |
|
fp16_res=None, |
|
impl='cuda'): |
|
results = {'wp': wp} |
|
|
|
if self.const_input: |
|
x = self.early_layer(wp[:, 0]) |
|
else: |
|
x = self.early_layer(wp[:, 0], impl=impl) |
|
|
|
|
|
if fp16_res is not None and self.init_res >= fp16_res: |
|
x = x.to(torch.float16) |
|
|
|
if self.architecture == 'origin': |
|
for layer_idx in range(self.num_layers - 1): |
|
layer = getattr(self, f'layer{layer_idx}') |
|
x, style = layer(x, |
|
wp[:, layer_idx], |
|
noise_mode=noise_mode, |
|
fused_modulate=fused_modulate, |
|
impl=impl) |
|
results[f'style{layer_idx}'] = style |
|
|
|
|
|
if layer_idx % 2 == 0 and layer_idx != self.num_layers - 2: |
|
res = self.init_res * (2 ** (layer_idx // 2)) |
|
if fp16_res is not None and res * 2 >= fp16_res: |
|
x = x.to(torch.float16) |
|
else: |
|
x = x.to(torch.float32) |
|
output_layer = getattr(self, f'output{layer_idx // 2}') |
|
image, style = output_layer(x, |
|
wp[:, layer_idx + 1], |
|
fused_modulate=fused_modulate, |
|
impl=impl) |
|
image = image.to(torch.float32) |
|
results[f'output_style{layer_idx // 2}'] = style |
|
|
|
elif self.architecture == 'skip': |
|
for layer_idx in range(self.num_layers - 1): |
|
layer = getattr(self, f'layer{layer_idx}') |
|
x, style = layer(x, |
|
wp[:, layer_idx], |
|
noise_mode=noise_mode, |
|
fused_modulate=fused_modulate, |
|
impl=impl) |
|
results[f'style{layer_idx}'] = style |
|
if layer_idx % 2 == 0: |
|
output_layer = getattr(self, f'output{layer_idx // 2}') |
|
y, style = output_layer(x, |
|
wp[:, layer_idx + 1], |
|
fused_modulate=fused_modulate, |
|
impl=impl) |
|
results[f'output_style{layer_idx // 2}'] = style |
|
if layer_idx == 0: |
|
image = y.to(torch.float32) |
|
else: |
|
image = y.to(torch.float32) + upfirdn2d.upsample2d( |
|
image, self.filter, impl=impl) |
|
|
|
|
|
if layer_idx != self.num_layers - 2: |
|
res = self.init_res * (2 ** (layer_idx // 2)) |
|
if fp16_res is not None and res * 2 >= fp16_res: |
|
x = x.to(torch.float16) |
|
else: |
|
x = x.to(torch.float32) |
|
|
|
elif self.architecture == 'resnet': |
|
x, style = self.layer0(x, |
|
wp[:, 0], |
|
noise_mode=noise_mode, |
|
fused_modulate=fused_modulate, |
|
impl=impl) |
|
results['style0'] = style |
|
for layer_idx in range(1, self.num_layers - 1, 2): |
|
|
|
if layer_idx % 2 == 1: |
|
res = self.init_res * (2 ** (layer_idx // 2)) |
|
if fp16_res is not None and res * 2 >= fp16_res: |
|
x = x.to(torch.float16) |
|
else: |
|
x = x.to(torch.float32) |
|
|
|
skip_layer = getattr(self, f'residual{layer_idx // 2 + 1}') |
|
residual = skip_layer(x, runtime_gain=np.sqrt(0.5), impl=impl) |
|
layer = getattr(self, f'layer{layer_idx}') |
|
x, style = layer(x, |
|
wp[:, layer_idx], |
|
noise_mode=noise_mode, |
|
fused_modulate=fused_modulate, |
|
impl=impl) |
|
results[f'style{layer_idx}'] = style |
|
layer = getattr(self, f'layer{layer_idx + 1}') |
|
x, style = layer(x, |
|
wp[:, layer_idx + 1], |
|
runtime_gain=np.sqrt(0.5), |
|
noise_mode=noise_mode, |
|
fused_modulate=fused_modulate, |
|
impl=impl) |
|
results[f'style{layer_idx + 1}'] = style |
|
x = x + residual |
|
output_layer = getattr(self, f'output{layer_idx // 2 + 1}') |
|
image, style = output_layer(x, |
|
wp[:, layer_idx + 2], |
|
fused_modulate=fused_modulate, |
|
impl=impl) |
|
image = image.to(torch.float32) |
|
results[f'output_style{layer_idx // 2}'] = style |
|
|
|
if self.final_tanh: |
|
image = torch.tanh(image) |
|
results['image'] = image |
|
return results |
|
|
|
|
|
class PixelNormLayer(nn.Module): |
|
"""Implements pixel-wise feature vector normalization layer.""" |
|
|
|
def __init__(self, dim, eps): |
|
super().__init__() |
|
self.dim = dim |
|
self.eps = eps |
|
|
|
def extra_repr(self): |
|
return f'dim={self.dim}, epsilon={self.eps}' |
|
|
|
def forward(self, x): |
|
scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() |
|
return x * scale |
|
|
|
|
|
class InputLayer(nn.Module): |
|
"""Implements the input layer to start convolution with. |
|
|
|
Basically, this block starts from a const input, which is with shape |
|
`(channels, init_res, init_res)`. |
|
""" |
|
|
|
def __init__(self, init_res, channels): |
|
super().__init__() |
|
self.const = nn.Parameter(torch.randn(1, channels, init_res, init_res)) |
|
|
|
def forward(self, w): |
|
x = self.const.repeat(w.shape[0], 1, 1, 1) |
|
return x |
|
|
|
|
|
class ConvLayer(nn.Module): |
|
"""Implements the convolutional layer. |
|
|
|
If upsampling is needed (i.e., `scale_factor = 2`), the feature map will |
|
be filtered with `filter_kernel` after convolution. This layer will only be |
|
used for skip connection in `resnet` architecture. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
add_bias, |
|
scale_factor, |
|
filter_kernel, |
|
use_wscale, |
|
wscale_gain, |
|
lr_mul, |
|
activation_type, |
|
conv_clamp): |
|
"""Initializes with layer settings. |
|
|
|
Args: |
|
in_channels: Number of channels of the input tensor. |
|
out_channels: Number of channels of the output tensor. |
|
kernel_size: Size of the convolutional kernels. |
|
add_bias: Whether to add bias onto the convolutional result. |
|
scale_factor: Scale factor for upsampling. |
|
filter_kernel: Kernel used for filtering. |
|
use_wscale: Whether to use weight scaling. |
|
wscale_gain: Gain factor for weight scaling. |
|
lr_mul: Learning multiplier for both weight and bias. |
|
activation_type: Type of activation. |
|
conv_clamp: A threshold to clamp the output of convolution layers to |
|
avoid overflow under FP16 training. |
|
""" |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.kernel_size = kernel_size |
|
self.add_bias = add_bias |
|
self.scale_factor = scale_factor |
|
self.filter_kernel = filter_kernel |
|
self.use_wscale = use_wscale |
|
self.wscale_gain = wscale_gain |
|
self.lr_mul = lr_mul |
|
self.activation_type = activation_type |
|
self.conv_clamp = conv_clamp |
|
|
|
weight_shape = (out_channels, in_channels, kernel_size, kernel_size) |
|
fan_in = kernel_size * kernel_size * in_channels |
|
wscale = wscale_gain / np.sqrt(fan_in) |
|
if use_wscale: |
|
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) |
|
self.wscale = wscale * lr_mul |
|
else: |
|
self.weight = nn.Parameter( |
|
torch.randn(*weight_shape) * wscale / lr_mul) |
|
self.wscale = lr_mul |
|
|
|
if add_bias: |
|
self.bias = nn.Parameter(torch.zeros(out_channels)) |
|
self.bscale = lr_mul |
|
else: |
|
self.bias = None |
|
self.act_gain = bias_act.activation_funcs[activation_type].def_gain |
|
|
|
if scale_factor > 1: |
|
assert filter_kernel is not None |
|
self.register_buffer( |
|
'filter', upfirdn2d.setup_filter(filter_kernel)) |
|
fh, fw = self.filter.shape |
|
self.filter_padding = ( |
|
kernel_size // 2 + (fw + scale_factor - 1) // 2, |
|
kernel_size // 2 + (fw - scale_factor) // 2, |
|
kernel_size // 2 + (fh + scale_factor - 1) // 2, |
|
kernel_size // 2 + (fh - scale_factor) // 2) |
|
|
|
def extra_repr(self): |
|
return (f'in_ch={self.in_channels}, ' |
|
f'out_ch={self.out_channels}, ' |
|
f'ksize={self.kernel_size}, ' |
|
f'wscale_gain={self.wscale_gain:.3f}, ' |
|
f'bias={self.add_bias}, ' |
|
f'lr_mul={self.lr_mul:.3f}, ' |
|
f'upsample={self.scale_factor}, ' |
|
f'upsample_filter={self.filter_kernel}, ' |
|
f'act={self.activation_type}, ' |
|
f'clamp={self.conv_clamp}') |
|
|
|
def forward(self, x, runtime_gain=1.0, impl='cuda'): |
|
dtype = x.dtype |
|
|
|
weight = self.weight |
|
if self.wscale != 1.0: |
|
weight = weight * self.wscale |
|
bias = None |
|
if self.bias is not None: |
|
bias = self.bias.to(dtype) |
|
if self.bscale != 1.0: |
|
bias = bias * self.bscale |
|
|
|
if self.scale_factor == 1: |
|
padding = self.kernel_size // 2 |
|
x = conv2d_gradfix.conv2d( |
|
x, weight.to(dtype), stride=1, padding=padding, impl=impl) |
|
else: |
|
up = self.scale_factor |
|
f = self.filter |
|
|
|
if self.kernel_size == 1: |
|
padding = self.filter_padding |
|
x = conv2d_gradfix.conv2d( |
|
x, weight.to(dtype), stride=1, padding=0, impl=impl) |
|
x = upfirdn2d.upfirdn2d( |
|
x, f, up=up, padding=padding, gain=up ** 2, impl=impl) |
|
|
|
else: |
|
|
|
|
|
px0, px1, py0, py1 = self.filter_padding |
|
kh, kw = weight.shape[2:] |
|
px0 = px0 - (kw - 1) |
|
px1 = px1 - (kw - up) |
|
py0 = py0 - (kh - 1) |
|
py1 = py1 - (kh - up) |
|
pxt = max(min(-px0, -px1), 0) |
|
pyt = max(min(-py0, -py1), 0) |
|
weight = weight.transpose(0, 1) |
|
padding = (pyt, pxt) |
|
x = conv2d_gradfix.conv_transpose2d( |
|
x, weight.to(dtype), stride=up, padding=padding, impl=impl) |
|
padding = (px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt) |
|
x = upfirdn2d.upfirdn2d( |
|
x, f, up=1, padding=padding, gain=up ** 2, impl=impl) |
|
|
|
act_gain = self.act_gain * runtime_gain |
|
act_clamp = None |
|
if self.conv_clamp is not None: |
|
act_clamp = self.conv_clamp * runtime_gain |
|
x = bias_act.bias_act(x, bias, |
|
act=self.activation_type, |
|
gain=act_gain, |
|
clamp=act_clamp, |
|
impl=impl) |
|
|
|
assert x.dtype == dtype |
|
return x |
|
|
|
|
|
class ModulateConvLayer(nn.Module): |
|
"""Implements the convolutional layer with style modulation.""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
resolution, |
|
w_dim, |
|
kernel_size, |
|
add_bias, |
|
scale_factor, |
|
filter_kernel, |
|
demodulate, |
|
use_wscale, |
|
wscale_gain, |
|
lr_mul, |
|
noise_type, |
|
activation_type, |
|
conv_clamp, |
|
eps): |
|
"""Initializes with layer settings. |
|
|
|
Args: |
|
in_channels: Number of channels of the input tensor. |
|
out_channels: Number of channels of the output tensor. |
|
resolution: Resolution of the output tensor. |
|
w_dim: Dimension of W space for style modulation. |
|
kernel_size: Size of the convolutional kernels. |
|
add_bias: Whether to add bias onto the convolutional result. |
|
scale_factor: Scale factor for upsampling. |
|
filter_kernel: Kernel used for filtering. |
|
demodulate: Whether to perform style demodulation. |
|
use_wscale: Whether to use weight scaling. |
|
wscale_gain: Gain factor for weight scaling. |
|
lr_mul: Learning multiplier for both weight and bias. |
|
noise_type: Type of noise added to the feature map after the |
|
convolution (if needed). Support `none`, `spatial` and |
|
`channel`. |
|
activation_type: Type of activation. |
|
conv_clamp: A threshold to clamp the output of convolution layers to |
|
avoid overflow under FP16 training. |
|
eps: A small value to avoid divide overflow. |
|
""" |
|
super().__init__() |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.resolution = resolution |
|
self.w_dim = w_dim |
|
self.kernel_size = kernel_size |
|
self.add_bias = add_bias |
|
self.scale_factor = scale_factor |
|
self.filter_kernel = filter_kernel |
|
self.demodulate = demodulate |
|
self.use_wscale = use_wscale |
|
self.wscale_gain = wscale_gain |
|
self.lr_mul = lr_mul |
|
self.noise_type = noise_type.lower() |
|
self.activation_type = activation_type |
|
self.conv_clamp = conv_clamp |
|
self.eps = eps |
|
|
|
self.space_of_latent = 'W' |
|
|
|
|
|
weight_shape = (out_channels, in_channels, kernel_size, kernel_size) |
|
fan_in = kernel_size * kernel_size * in_channels |
|
wscale = wscale_gain / np.sqrt(fan_in) |
|
if use_wscale: |
|
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) |
|
self.wscale = wscale * lr_mul |
|
else: |
|
self.weight = nn.Parameter( |
|
torch.randn(*weight_shape) * wscale / lr_mul) |
|
self.wscale = lr_mul |
|
|
|
|
|
if add_bias: |
|
self.bias = nn.Parameter(torch.zeros(out_channels)) |
|
self.bscale = lr_mul |
|
else: |
|
self.bias = None |
|
self.act_gain = bias_act.activation_funcs[activation_type].def_gain |
|
|
|
|
|
self.style = DenseLayer(in_channels=w_dim, |
|
out_channels=in_channels, |
|
add_bias=True, |
|
init_bias=1.0, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
activation_type='linear') |
|
|
|
|
|
if self.noise_type != 'none': |
|
self.noise_strength = nn.Parameter(torch.zeros(())) |
|
if self.noise_type == 'spatial': |
|
self.register_buffer( |
|
'noise', torch.randn(1, 1, resolution, resolution)) |
|
elif self.noise_type == 'channel': |
|
self.register_buffer( |
|
'noise', torch.randn(1, out_channels, 1, 1)) |
|
else: |
|
raise NotImplementedError(f'Not implemented noise type: ' |
|
f'`{self.noise_type}`!') |
|
|
|
if scale_factor > 1: |
|
assert filter_kernel is not None |
|
self.register_buffer( |
|
'filter', upfirdn2d.setup_filter(filter_kernel)) |
|
fh, fw = self.filter.shape |
|
self.filter_padding = ( |
|
kernel_size // 2 + (fw + scale_factor - 1) // 2, |
|
kernel_size // 2 + (fw - scale_factor) // 2, |
|
kernel_size // 2 + (fh + scale_factor - 1) // 2, |
|
kernel_size // 2 + (fh - scale_factor) // 2) |
|
|
|
def extra_repr(self): |
|
return (f'in_ch={self.in_channels}, ' |
|
f'out_ch={self.out_channels}, ' |
|
f'ksize={self.kernel_size}, ' |
|
f'wscale_gain={self.wscale_gain:.3f}, ' |
|
f'bias={self.add_bias}, ' |
|
f'lr_mul={self.lr_mul:.3f}, ' |
|
f'upsample={self.scale_factor}, ' |
|
f'upsample_filter={self.filter_kernel}, ' |
|
f'demodulate={self.demodulate}, ' |
|
f'noise_type={self.noise_type}, ' |
|
f'act={self.activation_type}, ' |
|
f'clamp={self.conv_clamp}') |
|
|
|
def forward_style(self, w, impl='cuda'): |
|
"""Gets style code from the given input. |
|
|
|
More specifically, if the input is from W-Space, it will be projected by |
|
an affine transformation. If it is from the Style Space (Y-Space), no |
|
operation is required. |
|
|
|
NOTE: For codes from Y-Space, we use slicing to make sure the dimension |
|
is correct, in case that the code is padded before fed into this layer. |
|
""" |
|
space_of_latent = self.space_of_latent.upper() |
|
if space_of_latent == 'W': |
|
if w.ndim != 2 or w.shape[1] != self.w_dim: |
|
raise ValueError(f'The input tensor should be with shape ' |
|
f'[batch_size, w_dim], where ' |
|
f'`w_dim` equals to {self.w_dim}!\n' |
|
f'But `{w.shape}` is received!') |
|
style = self.style(w, impl=impl) |
|
elif space_of_latent == 'Y': |
|
if w.ndim != 2 or w.shape[1] < self.in_channels: |
|
raise ValueError(f'The input tensor should be with shape ' |
|
f'[batch_size, y_dim], where ' |
|
f'`y_dim` equals to {self.in_channels}!\n' |
|
f'But `{w.shape}` is received!') |
|
style = w[:, :self.in_channels] |
|
else: |
|
raise NotImplementedError(f'Not implemented `space_of_latent`: ' |
|
f'`{space_of_latent}`!') |
|
return style |
|
|
|
def forward(self, |
|
x, |
|
w, |
|
runtime_gain=1.0, |
|
noise_mode='const', |
|
fused_modulate=False, |
|
impl='cuda'): |
|
dtype = x.dtype |
|
N, C, H, W = x.shape |
|
|
|
fused_modulate = (fused_modulate and |
|
not self.training and |
|
(dtype == torch.float32 or N == 1)) |
|
|
|
weight = self.weight |
|
out_ch, in_ch, kh, kw = weight.shape |
|
assert in_ch == C |
|
|
|
|
|
style = self.forward_style(w, impl=impl) |
|
if not self.demodulate: |
|
_style = style * self.wscale |
|
else: |
|
_style = style |
|
|
|
|
|
noise = None |
|
noise_mode = noise_mode.lower() |
|
if self.noise_type != 'none' and noise_mode != 'none': |
|
if noise_mode == 'random': |
|
noise = torch.randn((N, *self.noise.shape[1:]), device=x.device) |
|
elif noise_mode == 'const': |
|
noise = self.noise |
|
else: |
|
raise ValueError(f'Unknown noise mode `{noise_mode}`!') |
|
noise = (noise * self.noise_strength).to(dtype) |
|
|
|
|
|
if dtype == torch.float16 and self.demodulate: |
|
weight_max = weight.norm(float('inf'), dim=(1, 2, 3), keepdim=True) |
|
weight = weight * (self.wscale / weight_max) |
|
style_max = _style.norm(float('inf'), dim=1, keepdim=True) |
|
_style = _style / style_max |
|
|
|
if self.demodulate or fused_modulate: |
|
_weight = weight.unsqueeze(0) |
|
_weight = _weight * _style.reshape(N, 1, in_ch, 1, 1) |
|
if self.demodulate: |
|
decoef = (_weight.square().sum(dim=(2, 3, 4)) + self.eps).rsqrt() |
|
if self.demodulate and fused_modulate: |
|
_weight = _weight * decoef.reshape(N, out_ch, 1, 1, 1) |
|
|
|
if not fused_modulate: |
|
x = x * _style.to(dtype).reshape(N, in_ch, 1, 1) |
|
w = weight.to(dtype) |
|
groups = 1 |
|
else: |
|
x = x.reshape(1, N * in_ch, H, W) |
|
w = _weight.reshape(N * out_ch, in_ch, kh, kw).to(dtype) |
|
groups = N |
|
|
|
if self.scale_factor == 1: |
|
up = 1 |
|
padding = self.kernel_size // 2 |
|
x = conv2d_gradfix.conv2d( |
|
x, w, stride=1, padding=padding, groups=groups, impl=impl) |
|
else: |
|
up = self.scale_factor |
|
f = self.filter |
|
|
|
if self.kernel_size == 1: |
|
padding = self.filter_padding |
|
x = conv2d_gradfix.conv2d( |
|
x, w, stride=1, padding=0, groups=groups, impl=impl) |
|
x = upfirdn2d.upfirdn2d( |
|
x, f, up=up, padding=padding, gain=up ** 2, impl=impl) |
|
|
|
else: |
|
|
|
|
|
px0, px1, py0, py1 = self.filter_padding |
|
px0 = px0 - (kw - 1) |
|
px1 = px1 - (kw - up) |
|
py0 = py0 - (kh - 1) |
|
py1 = py1 - (kh - up) |
|
pxt = max(min(-px0, -px1), 0) |
|
pyt = max(min(-py0, -py1), 0) |
|
if groups == 1: |
|
w = w.transpose(0, 1) |
|
else: |
|
w = w.reshape(N, out_ch, in_ch, kh, kw) |
|
w = w.transpose(1, 2) |
|
w = w.reshape(N * in_ch, out_ch, kh, kw) |
|
padding = (pyt, pxt) |
|
x = conv2d_gradfix.conv_transpose2d( |
|
x, w, stride=up, padding=padding, groups=groups, impl=impl) |
|
padding = (px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt) |
|
x = upfirdn2d.upfirdn2d( |
|
x, f, up=1, padding=padding, gain=up ** 2, impl=impl) |
|
|
|
if not fused_modulate: |
|
if self.demodulate: |
|
decoef = decoef.to(dtype).reshape(N, out_ch, 1, 1) |
|
if self.demodulate and noise is not None: |
|
x = fma.fma(x, decoef, noise, impl=impl) |
|
else: |
|
if self.demodulate: |
|
x = x * decoef |
|
if noise is not None: |
|
x = x + noise |
|
else: |
|
x = x.reshape(N, out_ch, H * up, W * up) |
|
if noise is not None: |
|
x = x + noise |
|
|
|
bias = None |
|
if self.bias is not None: |
|
bias = self.bias.to(dtype) |
|
if self.bscale != 1.0: |
|
bias = bias * self.bscale |
|
|
|
if self.activation_type == 'linear': |
|
x = bias_act.bias_act( |
|
x, bias, act='linear', clamp=self.conv_clamp, impl=impl) |
|
else: |
|
act_gain = self.act_gain * runtime_gain |
|
act_clamp = None |
|
if self.conv_clamp is not None: |
|
act_clamp = self.conv_clamp * runtime_gain |
|
x = bias_act.bias_act(x, bias, |
|
act=self.activation_type, |
|
gain=act_gain, |
|
clamp=act_clamp, |
|
impl=impl) |
|
|
|
assert x.dtype == dtype |
|
assert style.dtype == torch.float32 |
|
return x, style |
|
|
|
|
|
class DenseLayer(nn.Module): |
|
"""Implements the dense layer.""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
add_bias, |
|
init_bias, |
|
use_wscale, |
|
wscale_gain, |
|
lr_mul, |
|
activation_type): |
|
"""Initializes with layer settings. |
|
|
|
Args: |
|
in_channels: Number of channels of the input tensor. |
|
out_channels: Number of channels of the output tensor. |
|
add_bias: Whether to add bias onto the fully-connected result. |
|
init_bias: The initial bias value before training. |
|
use_wscale: Whether to use weight scaling. |
|
wscale_gain: Gain factor for weight scaling. |
|
lr_mul: Learning multiplier for both weight and bias. |
|
activation_type: Type of activation. |
|
""" |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.add_bias = add_bias |
|
self.init_bias = init_bias |
|
self.use_wscale = use_wscale |
|
self.wscale_gain = wscale_gain |
|
self.lr_mul = lr_mul |
|
self.activation_type = activation_type |
|
|
|
weight_shape = (out_channels, in_channels) |
|
wscale = wscale_gain / np.sqrt(in_channels) |
|
if use_wscale: |
|
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) |
|
self.wscale = wscale * lr_mul |
|
else: |
|
self.weight = nn.Parameter( |
|
torch.randn(*weight_shape) * wscale / lr_mul) |
|
self.wscale = lr_mul |
|
|
|
if add_bias: |
|
init_bias = np.float32(init_bias) / lr_mul |
|
self.bias = nn.Parameter(torch.full([out_channels], init_bias)) |
|
self.bscale = lr_mul |
|
else: |
|
self.bias = None |
|
|
|
def extra_repr(self): |
|
return (f'in_ch={self.in_channels}, ' |
|
f'out_ch={self.out_channels}, ' |
|
f'wscale_gain={self.wscale_gain:.3f}, ' |
|
f'bias={self.add_bias}, ' |
|
f'init_bias={self.init_bias}, ' |
|
f'lr_mul={self.lr_mul:.3f}, ' |
|
f'act={self.activation_type}') |
|
|
|
def forward(self, x, impl='cuda'): |
|
dtype = x.dtype |
|
|
|
if x.ndim != 2: |
|
x = x.flatten(start_dim=1) |
|
|
|
weight = self.weight.to(dtype) * self.wscale |
|
bias = None |
|
if self.bias is not None: |
|
bias = self.bias.to(dtype) |
|
if self.bscale != 1.0: |
|
bias = bias * self.bscale |
|
|
|
|
|
if self.activation_type == 'linear' and bias is not None: |
|
x = torch.addmm(bias.unsqueeze(0), x, weight.t()) |
|
else: |
|
x = x.matmul(weight.t()) |
|
x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl) |
|
|
|
assert x.dtype == dtype |
|
return x |
|
|
|
|
|
|