|
|
|
"""Contains the implementation of generator described in VolumeGAN. |
|
|
|
Paper: https://arxiv.org/pdf/2112.10759.pdf |
|
""" |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
|
|
from .stylegan2_generator import MappingNetwork |
|
from .stylegan2_generator import ModulateConvLayer |
|
from .stylegan2_generator import ConvLayer |
|
from .stylegan2_generator import DenseLayer |
|
from third_party.stylegan2_official_ops import upfirdn2d |
|
from .rendering import Renderer |
|
from .rendering import FeatureExtractor |
|
from .utils.ops import all_gather |
|
|
|
|
|
class VolumeGANGenerator(nn.Module): |
|
"""Defines the generator network in VoumeGAN.""" |
|
|
|
def __init__( |
|
self, |
|
|
|
z_dim=512, |
|
w_dim=512, |
|
repeat_w=True, |
|
normalize_z=True, |
|
mapping_layers=8, |
|
mapping_fmaps=512, |
|
mapping_use_wscale=True, |
|
mapping_wscale_gain=1.0, |
|
mapping_lr_mul=0.01, |
|
|
|
label_dim=0, |
|
embedding_dim=512, |
|
embedding_bias=True, |
|
embedding_use_wscale=True, |
|
embedding_wscale_gian=1.0, |
|
embedding_lr_mul=1.0, |
|
normalize_embedding=True, |
|
normalize_embedding_latent=False, |
|
|
|
resolution=-1, |
|
nerf_res=32, |
|
image_channels=3, |
|
final_tanh=False, |
|
demodulate=True, |
|
use_wscale=True, |
|
wscale_gain=1.0, |
|
lr_mul=1.0, |
|
noise_type='spatial', |
|
fmaps_base=32 << 10, |
|
fmaps_max=512, |
|
filter_kernel=(1, 3, 3, 1), |
|
conv_clamp=None, |
|
eps=1e-8, |
|
rgb_init_res_out=True, |
|
|
|
fv_cfg=dict(feat_res=32, |
|
init_res=4, |
|
base_channels=256, |
|
output_channels=32, |
|
w_dim=512), |
|
|
|
embed_cfg=dict(input_dim=3, max_freq_log2=10 - 1, N_freqs=10), |
|
|
|
fg_cfg=dict(num_layers=4, hidden_dim=256, activation_type='lrelu'), |
|
bg_cfg=None, |
|
out_dim=512, |
|
|
|
rendering_kwargs={}): |
|
|
|
super().__init__() |
|
|
|
self.z_dim = z_dim |
|
self.w_dim = w_dim |
|
self.repeat_w = repeat_w |
|
self.normalize_z = normalize_z |
|
self.mapping_layers = mapping_layers |
|
self.mapping_fmaps = mapping_fmaps |
|
self.mapping_use_wscale = mapping_use_wscale |
|
self.mapping_wscale_gain = mapping_wscale_gain |
|
self.mapping_lr_mul = mapping_lr_mul |
|
|
|
self.latent_dim = (z_dim,) |
|
self.label_size = label_dim |
|
self.label_dim = label_dim |
|
self.embedding_dim = embedding_dim |
|
self.embedding_bias = embedding_bias |
|
self.embedding_use_wscale = embedding_use_wscale |
|
self.embedding_wscale_gain = embedding_wscale_gian |
|
self.embedding_lr_mul = embedding_lr_mul |
|
self.normalize_embedding = normalize_embedding |
|
self.normalize_embedding_latent = normalize_embedding_latent |
|
|
|
self.resolution = resolution |
|
self.nerf_res = nerf_res |
|
self.image_channels = image_channels |
|
self.final_tanh = final_tanh |
|
self.demodulate = demodulate |
|
self.use_wscale = use_wscale |
|
self.wscale_gain = wscale_gain |
|
self.lr_mul = lr_mul |
|
self.noise_type = noise_type.lower() |
|
self.fmaps_base = fmaps_base |
|
self.fmaps_max = fmaps_max |
|
self.filter_kernel = filter_kernel |
|
self.conv_clamp = conv_clamp |
|
self.eps = eps |
|
|
|
self.num_nerf_layers = fg_cfg['num_layers'] |
|
self.num_cnn_layers = int(np.log2(resolution // nerf_res * 2)) * 2 |
|
self.num_layers = self.num_nerf_layers + self.num_cnn_layers |
|
|
|
|
|
if self.repeat_w: |
|
self.register_buffer('w_avg', torch.zeros(w_dim)) |
|
else: |
|
self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) |
|
|
|
|
|
self.mapping = MappingNetwork( |
|
input_dim=z_dim, |
|
output_dim=w_dim, |
|
num_outputs=self.num_layers, |
|
repeat_output=repeat_w, |
|
normalize_input=normalize_z, |
|
num_layers=mapping_layers, |
|
hidden_dim=mapping_fmaps, |
|
use_wscale=mapping_use_wscale, |
|
wscale_gain=mapping_wscale_gain, |
|
lr_mul=mapping_lr_mul, |
|
label_dim=label_dim, |
|
embedding_dim=embedding_dim, |
|
embedding_bias=embedding_bias, |
|
embedding_use_wscale=embedding_use_wscale, |
|
embedding_wscale_gian=embedding_wscale_gian, |
|
embedding_lr_mul=embedding_lr_mul, |
|
normalize_embedding=normalize_embedding, |
|
normalize_embedding_latent=normalize_embedding_latent, |
|
eps=eps) |
|
|
|
|
|
self.renderer = Renderer() |
|
|
|
|
|
self.ref_representation_generator = FeatureVolume(**fv_cfg) |
|
|
|
|
|
self.position_encoder = PositionEncoder(**embed_cfg) |
|
|
|
|
|
self.feature_extractor = FeatureExtractor(ref_mode='feature_volume') |
|
|
|
|
|
self.post_module = NeRFMLPNetwork(input_dim=self.position_encoder.out_dim + |
|
fv_cfg['output_channels'], |
|
fg_cfg=fg_cfg, |
|
bg_cfg=bg_cfg) |
|
|
|
|
|
self.fc_head = FCHead(fg_cfg=fg_cfg, bg_cfg=bg_cfg, out_dim=out_dim) |
|
|
|
|
|
self.post_neural_renderer = PostNeuralRendererNetwork( |
|
resolution=resolution, |
|
init_res=nerf_res, |
|
w_dim=w_dim, |
|
image_channels=image_channels, |
|
final_tanh=final_tanh, |
|
demodulate=demodulate, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
noise_type=noise_type, |
|
fmaps_base=fmaps_base, |
|
filter_kernel=filter_kernel, |
|
fmaps_max=fmaps_max, |
|
conv_clamp=conv_clamp, |
|
eps=eps, |
|
rgb_init_res_out=rgb_init_res_out) |
|
|
|
|
|
self.rendering_kwargs = rendering_kwargs |
|
|
|
|
|
|
|
self.cur_to_official_part_mapping = { |
|
'w_avg': 'w_avg', |
|
'mapping': 'mapping', |
|
'ref_representation_generator': 'nerfmlp.fv', |
|
'post_module.fg_mlp': 'nerfmlp.fg_mlps', |
|
'fc_head.fg_sigma_head': 'nerfmlp.fg_density', |
|
'fc_head.fg_rgb_head': 'nerfmlp.fg_color', |
|
'post_neural_renderer': 'synthesis' |
|
} |
|
|
|
|
|
if self.rendering_kwargs.get('debug_mode', False): |
|
self.set_weights_from_official( |
|
rendering_kwargs.get('cur_state', None), |
|
rendering_kwargs.get('official_state', None)) |
|
|
|
def get_cur_to_official_full_mapping(self, keys_cur): |
|
cur_to_official_full_mapping = {} |
|
for key, val in self.cur_to_official_part_mapping.items(): |
|
for key_cur_full in keys_cur: |
|
if key in key_cur_full: |
|
sub_key = key_cur_full.replace(key, '') |
|
cur_to_official_full_mapping[key + sub_key] = val + sub_key |
|
return cur_to_official_full_mapping |
|
|
|
def set_weights_from_official(self, cur_state, official_state): |
|
keys_cur = cur_state['models']['generator_smooth'].keys() |
|
self.cur_to_official_full_mapping = ( |
|
self.get_cur_to_official_full_mapping(keys_cur)) |
|
for name, param in self.named_parameters(): |
|
param.data = (official_state['models']['generator_smooth'][ |
|
self.cur_to_official_full_mapping[name]]) |
|
|
|
def forward( |
|
self, |
|
z, |
|
label=None, |
|
lod=None, |
|
w_moving_decay=None, |
|
sync_w_avg=False, |
|
style_mixing_prob=None, |
|
trunc_psi=None, |
|
trunc_layers=None, |
|
noise_mode='const', |
|
fused_modulate=False, |
|
impl='cuda', |
|
fp16_res=None, |
|
): |
|
mapping_results = self.mapping(z, label, impl=impl) |
|
w = mapping_results['w'] |
|
lod = self.post_neural_renderer.lod.item() if lod is None else lod |
|
|
|
if self.training and w_moving_decay is not None: |
|
if sync_w_avg: |
|
batch_w_avg = all_gather(w.detach()).mean(dim=0) |
|
else: |
|
batch_w_avg = w.detach().mean(dim=0) |
|
self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) |
|
|
|
wp = mapping_results['wp'] |
|
|
|
if self.training and style_mixing_prob is not None: |
|
if np.random.uniform() < style_mixing_prob: |
|
new_z = torch.randn_like(z) |
|
new_wp = self.mapping(new_z, label, impl=impl)['wp'] |
|
current_layers = self.num_layers |
|
if current_layers > self.num_nerf_layers: |
|
mixing_cutoff = np.random.randint(self.num_nerf_layers, |
|
current_layers) |
|
wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:] |
|
|
|
if not self.training: |
|
trunc_psi = 1.0 if trunc_psi is None else trunc_psi |
|
trunc_layers = 0 if trunc_layers is None else trunc_layers |
|
if trunc_psi < 1.0 and trunc_layers > 0: |
|
w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] |
|
wp[:, :trunc_layers] = w_avg.lerp( |
|
wp[:, :trunc_layers], trunc_psi) |
|
|
|
nerf_w = wp[:,:self.num_nerf_layers] |
|
cnn_w = wp[:,self.num_nerf_layers:] |
|
|
|
feature_volume = self.ref_representation_generator(nerf_w) |
|
|
|
rendering_results = self.renderer( |
|
wp=nerf_w, |
|
feature_extractor=self.feature_extractor, |
|
rendering_options=self.rendering_kwargs, |
|
position_encoder=self.position_encoder, |
|
ref_representation=feature_volume, |
|
post_module=self.post_module, |
|
fc_head=self.fc_head) |
|
|
|
feature2d = rendering_results['composite_rgb'] |
|
feature2d = feature2d.reshape(feature2d.shape[0], self.nerf_res, |
|
self.nerf_res, -1).permute(0, 3, 1, 2) |
|
|
|
final_results = self.post_neural_renderer( |
|
feature2d, |
|
cnn_w, |
|
lod=None, |
|
noise_mode=noise_mode, |
|
fused_modulate=fused_modulate, |
|
impl=impl, |
|
fp16_res=fp16_res) |
|
|
|
return {**mapping_results, **final_results} |
|
|
|
|
|
class PositionEncoder(nn.Module): |
|
"""Implements the class for positional encoding.""" |
|
|
|
def __init__(self, |
|
input_dim, |
|
max_freq_log2, |
|
N_freqs, |
|
log_sampling=True, |
|
include_input=True, |
|
periodic_fns=(torch.sin, torch.cos)): |
|
"""Initializes with basic settings. |
|
|
|
Args: |
|
input_dim: Dimension of input to be embedded. |
|
max_freq_log2: `log2` of max freq; min freq is 1 by default. |
|
N_freqs: Number of frequency bands. |
|
log_sampling: If True, frequency bands are linerly sampled in |
|
log-space. |
|
include_input: If True, raw input is included in the embedding. |
|
Defaults to True. |
|
periodic_fns: Periodic functions used to embed input. |
|
Defaults to (torch.sin, torch.cos). |
|
""" |
|
super().__init__() |
|
|
|
self.input_dim = input_dim |
|
self.include_input = include_input |
|
self.periodic_fns = periodic_fns |
|
|
|
self.out_dim = 0 |
|
if self.include_input: |
|
self.out_dim += self.input_dim |
|
|
|
self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns) |
|
|
|
if log_sampling: |
|
self.freq_bands = 2.**torch.linspace(0., max_freq_log2, N_freqs) |
|
else: |
|
self.freq_bands = torch.linspace(2.**0., 2.**max_freq_log2, |
|
N_freqs) |
|
|
|
self.freq_bands = self.freq_bands.numpy().tolist() |
|
|
|
def forward(self, input): |
|
assert (input.shape[-1] == self.input_dim) |
|
|
|
out = [] |
|
if self.include_input: |
|
out.append(input) |
|
|
|
for i in range(len(self.freq_bands)): |
|
freq = self.freq_bands[i] |
|
for p_fn in self.periodic_fns: |
|
out.append(p_fn(input * freq)) |
|
out = torch.cat(out, dim=-1) |
|
|
|
assert (out.shape[-1] == self.out_dim) |
|
|
|
return out |
|
|
|
|
|
class FeatureVolume(nn.Module): |
|
"""Defines feature volume in VolumeGAN.""" |
|
|
|
def __init__(self, |
|
feat_res=32, |
|
init_res=4, |
|
base_channels=256, |
|
output_channels=32, |
|
w_dim=512, |
|
**kwargs): |
|
super().__init__() |
|
self.num_stages = int(np.log2(feat_res // init_res)) + 1 |
|
|
|
self.const = nn.Parameter( |
|
torch.ones(1, base_channels, init_res, init_res, init_res)) |
|
inplanes = base_channels |
|
outplanes = base_channels |
|
|
|
self.stage_channels = [] |
|
for i in range(self.num_stages): |
|
conv = nn.Conv3d(inplanes, |
|
outplanes, |
|
kernel_size=(3, 3, 3), |
|
padding=(1, 1, 1)) |
|
self.stage_channels.append(outplanes) |
|
self.add_module(f'layer{i}', conv) |
|
instance_norm = InstanceNormLayer(num_features=outplanes, |
|
affine=False) |
|
|
|
self.add_module(f'instance_norm{i}', instance_norm) |
|
inplanes = outplanes |
|
outplanes = max(outplanes // 2, output_channels) |
|
if i == self.num_stages - 1: |
|
outplanes = output_channels |
|
|
|
self.mapping_network = nn.Linear(w_dim, sum(self.stage_channels) * 2) |
|
self.mapping_network.apply(kaiming_leaky_init) |
|
with torch.no_grad(): |
|
self.mapping_network.weight *= 0.25 |
|
self.upsample = UpsamplingLayer() |
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2) |
|
|
|
def forward(self, w, **kwargs): |
|
if w.ndim == 3: |
|
_w = w[:, 0] |
|
else: |
|
_w = w |
|
scale_shifts = self.mapping_network(_w) |
|
scales = scale_shifts[..., :scale_shifts.shape[-1] // 2] |
|
shifts = scale_shifts[..., scale_shifts.shape[-1] // 2:] |
|
|
|
x = self.const.repeat(w.shape[0], 1, 1, 1, 1) |
|
for idx in range(self.num_stages): |
|
if idx != 0: |
|
x = self.upsample(x) |
|
conv_layer = self.__getattr__(f'layer{idx}') |
|
x = conv_layer(x) |
|
instance_norm = self.__getattr__(f'instance_norm{idx}') |
|
scale = scales[:, |
|
sum(self.stage_channels[:idx] |
|
):sum(self.stage_channels[:idx + 1])] |
|
shift = shifts[:, |
|
sum(self.stage_channels[:idx] |
|
):sum(self.stage_channels[:idx + 1])] |
|
scale = scale.view(scale.shape + (1, 1, 1)) |
|
shift = shift.view(shift.shape + (1, 1, 1)) |
|
x = instance_norm(x, weight=scale, bias=shift) |
|
x = self.lrelu(x) |
|
|
|
return x |
|
|
|
|
|
def kaiming_leaky_init(m): |
|
classname = m.__class__.__name__ |
|
if classname.find('Linear') != -1: |
|
torch.nn.init.kaiming_normal_(m.weight, |
|
a=0.2, |
|
mode='fan_in', |
|
nonlinearity='leaky_relu') |
|
|
|
|
|
class InstanceNormLayer(nn.Module): |
|
"""Implements instance normalization layer.""" |
|
|
|
def __init__(self, num_features, epsilon=1e-8, affine=False): |
|
super().__init__() |
|
self.eps = epsilon |
|
self.affine = affine |
|
if self.affine: |
|
self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1, 1)) |
|
self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1, 1)) |
|
self.weight.data.uniform_() |
|
self.bias.data.zero_() |
|
|
|
def forward(self, x, weight=None, bias=None): |
|
x = x - torch.mean(x, dim=[2, 3, 4], keepdim=True) |
|
norm = torch.sqrt( |
|
torch.mean(x**2, dim=[2, 3, 4], keepdim=True) + self.eps) |
|
x = x / norm |
|
isnot_input_none = weight is not None and bias is not None |
|
assert (isnot_input_none and not self.affine) or (not isnot_input_none |
|
and self.affine) |
|
if self.affine: |
|
x = x * self.weight + self.bias |
|
else: |
|
x = x * weight + bias |
|
return x |
|
|
|
|
|
class UpsamplingLayer(nn.Module): |
|
|
|
def __init__(self, scale_factor=2): |
|
super().__init__() |
|
self.scale_factor = scale_factor |
|
|
|
def forward(self, x): |
|
if self.scale_factor <= 1: |
|
return x |
|
return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') |
|
|
|
|
|
class NeRFMLPNetwork(nn.Module): |
|
"""Defines class of MLP Network described in VolumeGAN. |
|
|
|
Basically, this class takes in latent codes and point coodinates as input, |
|
and outputs features of each point, which is followed by two fully-connected |
|
layer heads. |
|
""" |
|
|
|
def __init__(self, input_dim, fg_cfg, bg_cfg=None): |
|
super().__init__() |
|
self.fg_mlp = self.build_mlp(input_dim=input_dim, **fg_cfg) |
|
|
|
def build_mlp(self, input_dim, num_layers, hidden_dim, activation_type, |
|
**kwargs): |
|
"""Implements function to build the `MLP`. |
|
|
|
Note that here the `MLP` network is consists of a series of |
|
`ModulateConvLayer` with `kernel_size=1` to simulate fully-connected |
|
layer. Typically, the input's shape of convolutional layers is |
|
`[N, C, H, W]`. And the input's shape is `[N, C, R*K, 1]` here, which |
|
aims to keep consistent with `MLP`. |
|
""" |
|
default_conv_cfg = dict(resolution=32, |
|
w_dim=512, |
|
kernel_size=1, |
|
add_bias=True, |
|
scale_factor=1, |
|
filter_kernel=None, |
|
demodulate=True, |
|
use_wscale=True, |
|
wscale_gain=1, |
|
lr_mul=1, |
|
noise_type='none', |
|
conv_clamp=None, |
|
eps=1e-8) |
|
mlp_list = nn.ModuleList() |
|
in_ch = input_dim |
|
out_ch = hidden_dim |
|
for _ in range(num_layers): |
|
mlp = ModulateConvLayer(in_channels=in_ch, |
|
out_channels=out_ch, |
|
activation_type=activation_type, |
|
**default_conv_cfg) |
|
mlp_list.append(mlp) |
|
in_ch = out_ch |
|
out_ch = hidden_dim |
|
|
|
return mlp_list |
|
|
|
def forward(self, |
|
pre_point_features, |
|
wp, |
|
points_encoding=None, |
|
fused_modulate=False, |
|
impl='cuda'): |
|
N, C, R_K, _ = points_encoding.shape |
|
x = torch.cat([pre_point_features, points_encoding], dim=1) |
|
|
|
for idx, mlp in enumerate(self.fg_mlp): |
|
if wp.ndim == 3: |
|
_w = wp[:, idx] |
|
else: |
|
_w = wp |
|
x, _ = mlp(x, _w, fused_modulate=fused_modulate, impl=impl) |
|
|
|
return x |
|
|
|
|
|
class FCHead(nn.Module): |
|
"""Defines fully-connected layer head in VolumeGAN to decode `feature` into |
|
`sigma` and `rgb`.""" |
|
|
|
def __init__(self, fg_cfg, bg_cfg=None, out_dim=512): |
|
super().__init__() |
|
self.fg_sigma_head = DenseLayer(in_channels=fg_cfg['hidden_dim'], |
|
out_channels=1, |
|
add_bias=True, |
|
init_bias=0.0, |
|
use_wscale=True, |
|
wscale_gain=1, |
|
lr_mul=1, |
|
activation_type='linear') |
|
self.fg_rgb_head = DenseLayer(in_channels=fg_cfg['hidden_dim'], |
|
out_channels=out_dim, |
|
add_bias=True, |
|
init_bias=0.0, |
|
use_wscale=True, |
|
wscale_gain=1, |
|
lr_mul=1, |
|
activation_type='linear') |
|
|
|
def forward(self, post_point_features, wp=None, dirs=None): |
|
post_point_features = rearrange( |
|
post_point_features, 'N C (R_K) 1 -> (N R_K) C').contiguous() |
|
fg_sigma = self.fg_sigma_head(post_point_features) |
|
fg_rgb = self.fg_rgb_head(post_point_features) |
|
|
|
results = {'sigma': fg_sigma, 'rgb': fg_rgb} |
|
|
|
return results |
|
|
|
|
|
class PostNeuralRendererNetwork(nn.Module): |
|
"""Implements the neural renderer in VolumeGAN to render high-resolution |
|
images. |
|
|
|
Basically, this network executes several convolutional layers in sequence. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
resolution, |
|
init_res, |
|
w_dim, |
|
image_channels, |
|
final_tanh, |
|
demodulate, |
|
use_wscale, |
|
wscale_gain, |
|
lr_mul, |
|
noise_type, |
|
fmaps_base, |
|
fmaps_max, |
|
filter_kernel, |
|
conv_clamp, |
|
eps, |
|
rgb_init_res_out=False, |
|
): |
|
super().__init__() |
|
|
|
self.init_res = init_res |
|
self.init_res_log2 = int(np.log2(init_res)) |
|
self.resolution = resolution |
|
self.final_res_log2 = int(np.log2(resolution)) |
|
self.w_dim = w_dim |
|
self.image_channels = image_channels |
|
self.final_tanh = final_tanh |
|
self.demodulate = demodulate |
|
self.use_wscale = use_wscale |
|
self.wscale_gain = wscale_gain |
|
self.lr_mul = lr_mul |
|
self.noise_type = noise_type.lower() |
|
self.fmaps_base = fmaps_base |
|
self.fmaps_max = fmaps_max |
|
self.filter_kernel = filter_kernel |
|
self.conv_clamp = conv_clamp |
|
self.eps = eps |
|
self.rgb_init_res_out = rgb_init_res_out |
|
|
|
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2 |
|
|
|
self.register_buffer('lod', torch.zeros(())) |
|
|
|
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): |
|
res = 2**res_log2 |
|
in_channels = self.get_nf(res // 2) |
|
out_channels = self.get_nf(res) |
|
block_idx = res_log2 - self.init_res_log2 |
|
|
|
|
|
if res > init_res: |
|
layer_name = f'layer{2 * block_idx - 1}' |
|
self.add_module( |
|
layer_name, |
|
ModulateConvLayer(in_channels=in_channels, |
|
out_channels=out_channels, |
|
resolution=res, |
|
w_dim=w_dim, |
|
kernel_size=1, |
|
add_bias=True, |
|
scale_factor=2, |
|
filter_kernel=filter_kernel, |
|
demodulate=demodulate, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
noise_type=noise_type, |
|
activation_type='lrelu', |
|
conv_clamp=conv_clamp, |
|
eps=eps)) |
|
if block_idx == 0: |
|
if self.rgb_init_res_out: |
|
self.rgb_init_res = ConvLayer( |
|
in_channels=out_channels, |
|
out_channels=image_channels, |
|
kernel_size=1, |
|
add_bias=True, |
|
scale_factor=1, |
|
filter_kernel=None, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
activation_type='linear', |
|
conv_clamp=conv_clamp, |
|
) |
|
continue |
|
|
|
layer_name = f'layer{2 * block_idx}' |
|
self.add_module( |
|
layer_name, |
|
ModulateConvLayer(in_channels=out_channels, |
|
out_channels=out_channels, |
|
resolution=res, |
|
w_dim=w_dim, |
|
kernel_size=1, |
|
add_bias=True, |
|
scale_factor=1, |
|
filter_kernel=None, |
|
demodulate=demodulate, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
noise_type=noise_type, |
|
activation_type='lrelu', |
|
conv_clamp=conv_clamp, |
|
eps=eps)) |
|
|
|
|
|
layer_name = f'output{block_idx}' |
|
self.add_module( |
|
layer_name, |
|
ModulateConvLayer(in_channels=out_channels, |
|
out_channels=image_channels, |
|
resolution=res, |
|
w_dim=w_dim, |
|
kernel_size=1, |
|
add_bias=True, |
|
scale_factor=1, |
|
filter_kernel=None, |
|
demodulate=False, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
noise_type='none', |
|
activation_type='linear', |
|
conv_clamp=conv_clamp, |
|
eps=eps)) |
|
|
|
|
|
self.register_buffer('filter', upfirdn2d.setup_filter(filter_kernel)) |
|
|
|
def get_nf(self, res): |
|
"""Gets number of feature maps according to current resolution.""" |
|
return min(self.fmaps_base // res, self.fmaps_max) |
|
|
|
def set_space_of_latent(self, space_of_latent): |
|
"""Sets the space to which the latent code belong. |
|
|
|
Args: |
|
space_of_latent: The space to which the latent code belong. Case |
|
insensitive. Support `W` and `Y`. |
|
""" |
|
space_of_latent = space_of_latent.upper() |
|
for module in self.modules(): |
|
if isinstance(module, ModulateConvLayer): |
|
setattr(module, 'space_of_latent', space_of_latent) |
|
|
|
def forward(self, |
|
x, |
|
wp, |
|
lod=None, |
|
noise_mode='const', |
|
fused_modulate=False, |
|
impl='cuda', |
|
fp16_res=None, |
|
nerf_out=False): |
|
lod = self.lod.item() if lod is None else lod |
|
|
|
results = {} |
|
|
|
|
|
if fp16_res is not None and self.init_res >= fp16_res: |
|
x = x.to(torch.float16) |
|
|
|
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): |
|
cur_lod = self.final_res_log2 - res_log2 |
|
block_idx = res_log2 - self.init_res_log2 |
|
|
|
layer_idxs = [2 * block_idx - 1, 2 * |
|
block_idx] if block_idx > 0 else [ |
|
2 * block_idx, |
|
] |
|
|
|
if lod < cur_lod + 1: |
|
for layer_idx in layer_idxs: |
|
if layer_idx == 0: |
|
|
|
if self.rgb_init_res_out: |
|
cur_image = self.rgb_init_res(x, |
|
runtime_gain=1, |
|
impl=impl) |
|
else: |
|
cur_image = x[:, :3] |
|
continue |
|
layer = getattr(self, f'layer{layer_idx}') |
|
x, style = layer( |
|
x, |
|
wp[:, layer_idx], |
|
noise_mode=noise_mode, |
|
fused_modulate=fused_modulate, |
|
impl=impl, |
|
) |
|
results[f'style{layer_idx}'] = style |
|
if layer_idx % 2 == 0: |
|
output_layer = getattr(self, f'output{layer_idx // 2}') |
|
y, style = output_layer( |
|
x, |
|
wp[:, layer_idx + 1], |
|
fused_modulate=fused_modulate, |
|
impl=impl, |
|
) |
|
results[f'output_style{layer_idx // 2}'] = style |
|
if layer_idx == 0: |
|
cur_image = y.to(torch.float32) |
|
else: |
|
if not nerf_out: |
|
cur_image = y.to( |
|
torch.float32) + upfirdn2d.upsample2d( |
|
cur_image, self.filter, impl=impl) |
|
else: |
|
cur_image = y.to(torch.float32) + cur_image |
|
|
|
|
|
if layer_idx != self.num_layers - 2: |
|
res = self.init_res * (2**(layer_idx // 2)) |
|
if fp16_res is not None and res * 2 >= fp16_res: |
|
x = x.to(torch.float16) |
|
else: |
|
x = x.to(torch.float32) |
|
|
|
|
|
if cur_lod - 1 < lod <= cur_lod: |
|
image = cur_image |
|
elif cur_lod < lod < cur_lod + 1: |
|
alpha = np.ceil(lod) - lod |
|
image = F.interpolate(image, scale_factor=2, mode='nearest') |
|
image = cur_image * alpha + image * (1 - alpha) |
|
elif lod >= cur_lod + 1: |
|
image = F.interpolate(image, scale_factor=2, mode='nearest') |
|
|
|
if self.final_tanh: |
|
image = torch.tanh(image) |
|
results['image'] = image |
|
|
|
return results |
|
|