"""Contains the implementation of generator described in VolumeGAN. |
Paper: https://arxiv.org/pdf/2112.10759.pdf |
""" |
import numpy as np |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
from einops import rearrange |
from .stylegan2_generator import MappingNetwork |
from .stylegan2_generator import ModulateConvLayer |
from .stylegan2_generator import ConvLayer |
from .stylegan2_generator import DenseLayer |
from third_party.stylegan2_official_ops import upfirdn2d |
from .rendering import Renderer |
from .rendering import FeatureExtractor |
from .utils.ops import all_gather |
class VolumeGANGenerator(nn.Module): |
"""Defines the generator network in VoumeGAN.""" |
def __init__( |
self, |
z_dim=512, |
w_dim=512, |
repeat_w=True, |
normalize_z=True, |
mapping_layers=8, |
mapping_fmaps=512, |
mapping_use_wscale=True, |
mapping_wscale_gain=1.0, |
mapping_lr_mul=0.01, |
label_dim=0, |
embedding_dim=512, |
embedding_bias=True, |
embedding_use_wscale=True, |
embedding_wscale_gian=1.0, |
embedding_lr_mul=1.0, |
normalize_embedding=True, |
normalize_embedding_latent=False, |
resolution=-1, |
nerf_res=32, |
image_channels=3, |
final_tanh=False, |
demodulate=True, |
use_wscale=True, |
wscale_gain=1.0, |
lr_mul=1.0, |
noise_type='spatial', |
fmaps_base=32 << 10, |
fmaps_max=512, |
filter_kernel=(1, 3, 3, 1), |
conv_clamp=None, |
eps=1e-8, |
rgb_init_res_out=True, |
fv_cfg=dict(feat_res=32, |
init_res=4, |
base_channels=256, |
output_channels=32, |
w_dim=512), |
embed_cfg=dict(input_dim=3, max_freq_log2=10 - 1, N_freqs=10), |
fg_cfg=dict(num_layers=4, hidden_dim=256, activation_type='lrelu'), |
bg_cfg=None, |
out_dim=512, |
rendering_kwargs={}): |
super().__init__() |
self.z_dim = z_dim |
self.w_dim = w_dim |
self.repeat_w = repeat_w |
self.normalize_z = normalize_z |
self.mapping_layers = mapping_layers |
self.mapping_fmaps = mapping_fmaps |
self.mapping_use_wscale = mapping_use_wscale |
self.mapping_wscale_gain = mapping_wscale_gain |
self.mapping_lr_mul = mapping_lr_mul |
self.latent_dim = (z_dim,) |
self.label_size = label_dim |
self.label_dim = label_dim |
self.embedding_dim = embedding_dim |
self.embedding_bias = embedding_bias |
self.embedding_use_wscale = embedding_use_wscale |
self.embedding_wscale_gain = embedding_wscale_gian |
self.embedding_lr_mul = embedding_lr_mul |
self.normalize_embedding = normalize_embedding |
self.normalize_embedding_latent = normalize_embedding_latent |
self.resolution = resolution |
self.nerf_res = nerf_res |
self.image_channels = image_channels |
self.final_tanh = final_tanh |
self.demodulate = demodulate |
self.use_wscale = use_wscale |
self.wscale_gain = wscale_gain |
self.lr_mul = lr_mul |
self.noise_type = noise_type.lower() |
self.fmaps_base = fmaps_base |
self.fmaps_max = fmaps_max |
self.filter_kernel = filter_kernel |
self.conv_clamp = conv_clamp |
self.eps = eps |
self.num_nerf_layers = fg_cfg['num_layers'] |
self.num_cnn_layers = int(np.log2(resolution // nerf_res * 2)) * 2 |
self.num_layers = self.num_nerf_layers + self.num_cnn_layers |
if self.repeat_w: |
self.register_buffer('w_avg', torch.zeros(w_dim)) |
else: |
self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) |
self.mapping = MappingNetwork( |
input_dim=z_dim, |
output_dim=w_dim, |
num_outputs=self.num_layers, |
repeat_output=repeat_w, |
normalize_input=normalize_z, |
num_layers=mapping_layers, |
hidden_dim=mapping_fmaps, |
use_wscale=mapping_use_wscale, |
wscale_gain=mapping_wscale_gain, |
lr_mul=mapping_lr_mul, |
label_dim=label_dim, |
embedding_dim=embedding_dim, |
embedding_bias=embedding_bias, |
embedding_use_wscale=embedding_use_wscale, |
embedding_wscale_gian=embedding_wscale_gian, |
embedding_lr_mul=embedding_lr_mul, |
normalize_embedding=normalize_embedding, |
normalize_embedding_latent=normalize_embedding_latent, |
eps=eps) |
self.renderer = Renderer() |
self.ref_representation_generator = FeatureVolume(**fv_cfg) |
self.position_encoder = PositionEncoder(**embed_cfg) |
self.feature_extractor = FeatureExtractor(ref_mode='feature_volume') |
self.post_module = NeRFMLPNetwork(input_dim=self.position_encoder.out_dim + |
fv_cfg['output_channels'], |
fg_cfg=fg_cfg, |
bg_cfg=bg_cfg) |
self.fc_head = FCHead(fg_cfg=fg_cfg, bg_cfg=bg_cfg, out_dim=out_dim) |
self.post_neural_renderer = PostNeuralRendererNetwork( |
resolution=resolution, |
init_res=nerf_res, |
w_dim=w_dim, |
image_channels=image_channels, |
final_tanh=final_tanh, |
demodulate=demodulate, |
use_wscale=use_wscale, |
wscale_gain=wscale_gain, |
lr_mul=lr_mul, |
noise_type=noise_type, |
fmaps_base=fmaps_base, |
filter_kernel=filter_kernel, |
fmaps_max=fmaps_max, |
conv_clamp=conv_clamp, |
eps=eps, |
rgb_init_res_out=rgb_init_res_out) |
self.rendering_kwargs = rendering_kwargs |
self.cur_to_official_part_mapping = { |
'w_avg': 'w_avg', |
'mapping': 'mapping', |
'ref_representation_generator': 'nerfmlp.fv', |
'post_module.fg_mlp': 'nerfmlp.fg_mlps', |
'fc_head.fg_sigma_head': 'nerfmlp.fg_density', |
'fc_head.fg_rgb_head': 'nerfmlp.fg_color', |
'post_neural_renderer': 'synthesis' |
} |
if self.rendering_kwargs.get('debug_mode', False): |
self.set_weights_from_official( |
rendering_kwargs.get('cur_state', None), |
rendering_kwargs.get('official_state', None)) |
def get_cur_to_official_full_mapping(self, keys_cur): |
cur_to_official_full_mapping = {} |
for key, val in self.cur_to_official_part_mapping.items(): |
for key_cur_full in keys_cur: |
if key in key_cur_full: |
sub_key = key_cur_full.replace(key, '') |
cur_to_official_full_mapping[key + sub_key] = val + sub_key |
return cur_to_official_full_mapping |
def set_weights_from_official(self, cur_state, official_state): |
keys_cur = cur_state['models']['generator_smooth'].keys() |
self.cur_to_official_full_mapping = ( |
self.get_cur_to_official_full_mapping(keys_cur)) |
for name, param in self.named_parameters(): |
param.data = (official_state['models']['generator_smooth'][ |
self.cur_to_official_full_mapping[name]]) |
def forward( |
self, |
z, |
label=None, |
lod=None, |
w_moving_decay=None, |
sync_w_avg=False, |
style_mixing_prob=None, |
trunc_psi=None, |
trunc_layers=None, |
noise_mode='const', |
fused_modulate=False, |
impl='cuda', |
fp16_res=None, |
): |
mapping_results = self.mapping(z, label, impl=impl) |
w = mapping_results['w'] |
lod = self.post_neural_renderer.lod.item() if lod is None else lod |
if self.training and w_moving_decay is not None: |
if sync_w_avg: |
batch_w_avg = all_gather(w.detach()).mean(dim=0) |
else: |
batch_w_avg = w.detach().mean(dim=0) |
self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) |
wp = mapping_results['wp'] |
if self.training and style_mixing_prob is not None: |
if np.random.uniform() < style_mixing_prob: |
new_z = torch.randn_like(z) |
new_wp = self.mapping(new_z, label, impl=impl)['wp'] |
current_layers = self.num_layers |
if current_layers > self.num_nerf_layers: |
mixing_cutoff = np.random.randint(self.num_nerf_layers, |
current_layers) |
wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:] |
if not self.training: |
trunc_psi = 1.0 if trunc_psi is None else trunc_psi |
trunc_layers = 0 if trunc_layers is None else trunc_layers |
if trunc_psi < 1.0 and trunc_layers > 0: |
w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] |
wp[:, :trunc_layers] = w_avg.lerp( |
wp[:, :trunc_layers], trunc_psi) |
nerf_w = wp[:,:self.num_nerf_layers] |
cnn_w = wp[:,self.num_nerf_layers:] |
feature_volume = self.ref_representation_generator(nerf_w) |
rendering_results = self.renderer( |
wp=nerf_w, |
feature_extractor=self.feature_extractor, |
rendering_options=self.rendering_kwargs, |
position_encoder=self.position_encoder, |
ref_representation=feature_volume, |
post_module=self.post_module, |
fc_head=self.fc_head) |
feature2d = rendering_results['composite_rgb'] |
feature2d = feature2d.reshape(feature2d.shape[0], self.nerf_res, |
self.nerf_res, -1).permute(0, 3, 1, 2) |
final_results = self.post_neural_renderer( |
feature2d, |
cnn_w, |
lod=None, |
noise_mode=noise_mode, |
fused_modulate=fused_modulate, |
impl=impl, |
fp16_res=fp16_res) |
return {**mapping_results, **final_results} |
class PositionEncoder(nn.Module): |
"""Implements the class for positional encoding.""" |
def __init__(self, |
input_dim, |
max_freq_log2, |
N_freqs, |
log_sampling=True, |
include_input=True, |
periodic_fns=(torch.sin, torch.cos)): |
"""Initializes with basic settings. |
Args: |
input_dim: Dimension of input to be embedded. |
max_freq_log2: `log2` of max freq; min freq is 1 by default. |
N_freqs: Number of frequency bands. |
log_sampling: If True, frequency bands are linerly sampled in |
log-space. |
include_input: If True, raw input is included in the embedding. |
Defaults to True. |
periodic_fns: Periodic functions used to embed input. |
Defaults to (torch.sin, torch.cos). |
""" |
super().__init__() |
self.input_dim = input_dim |
self.include_input = include_input |
self.periodic_fns = periodic_fns |
self.out_dim = 0 |
if self.include_input: |
self.out_dim += self.input_dim |
self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns) |
if log_sampling: |
self.freq_bands = 2.**torch.linspace(0., max_freq_log2, N_freqs) |
else: |
self.freq_bands = torch.linspace(2.**0., 2.**max_freq_log2, |
N_freqs) |
self.freq_bands = self.freq_bands.numpy().tolist() |
def forward(self, input): |
assert (input.shape[-1] == self.input_dim) |
out = [] |
if self.include_input: |
out.append(input) |
for i in range(len(self.freq_bands)): |
freq = self.freq_bands[i] |
for p_fn in self.periodic_fns: |
out.append(p_fn(input * freq)) |
out = torch.cat(out, dim=-1) |
assert (out.shape[-1] == self.out_dim) |
return out |
class FeatureVolume(nn.Module): |
"""Defines feature volume in VolumeGAN.""" |
def __init__(self, |
feat_res=32, |
init_res=4, |
base_channels=256, |
output_channels=32, |
w_dim=512, |
**kwargs): |
super().__init__() |
self.num_stages = int(np.log2(feat_res // init_res)) + 1 |
self.const = nn.Parameter( |
torch.ones(1, base_channels, init_res, init_res, init_res)) |
inplanes = base_channels |
outplanes = base_channels |
self.stage_channels = [] |
for i in range(self.num_stages): |
conv = nn.Conv3d(inplanes, |
outplanes, |
kernel_size=(3, 3, 3), |
padding=(1, 1, 1)) |
self.stage_channels.append(outplanes) |
self.add_module(f'layer{i}', conv) |
instance_norm = InstanceNormLayer(num_features=outplanes, |
affine=False) |
self.add_module(f'instance_norm{i}', instance_norm) |
inplanes = outplanes |
outplanes = max(outplanes // 2, output_channels) |
if i == self.num_stages - 1: |
outplanes = output_channels |
self.mapping_network = nn.Linear(w_dim, sum(self.stage_channels) * 2) |
self.mapping_network.apply(kaiming_leaky_init) |
with torch.no_grad(): |
self.mapping_network.weight *= 0.25 |
self.upsample = UpsamplingLayer() |
self.lrelu = nn.LeakyReLU(negative_slope=0.2) |
def forward(self, w, **kwargs): |
if w.ndim == 3: |
_w = w[:, 0] |
else: |
_w = w |
scale_shifts = self.mapping_network(_w) |
scales = scale_shifts[..., :scale_shifts.shape[-1] // 2] |
shifts = scale_shifts[..., scale_shifts.shape[-1] // 2:] |
x = self.const.repeat(w.shape[0], 1, 1, 1, 1) |
for idx in range(self.num_stages): |
if idx != 0: |
x = self.upsample(x) |
conv_layer = self.__getattr__(f'layer{idx}') |
x = conv_layer(x) |
instance_norm = self.__getattr__(f'instance_norm{idx}') |
scale = scales[:, |
sum(self.stage_channels[:idx] |
):sum(self.stage_channels[:idx + 1])] |
shift = shifts[:, |
sum(self.stage_channels[:idx] |
):sum(self.stage_channels[:idx + 1])] |
scale = scale.view(scale.shape + (1, 1, 1)) |
shift = shift.view(shift.shape + (1, 1, 1)) |
x = instance_norm(x, weight=scale, bias=shift) |
x = self.lrelu(x) |
return x |
def kaiming_leaky_init(m): |
classname = m.__class__.__name__ |
if classname.find('Linear') != -1: |
torch.nn.init.kaiming_normal_(m.weight, |
a=0.2, |
mode='fan_in', |
nonlinearity='leaky_relu') |
class InstanceNormLayer(nn.Module): |
"""Implements instance normalization layer.""" |
def __init__(self, num_features, epsilon=1e-8, affine=False): |
super().__init__() |
self.eps = epsilon |
self.affine = affine |
if self.affine: |
self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1, 1)) |
self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1, 1)) |
self.weight.data.uniform_() |
self.bias.data.zero_() |
def forward(self, x, weight=None, bias=None): |
x = x - torch.mean(x, dim=[2, 3, 4], keepdim=True) |
norm = torch.sqrt( |
torch.mean(x**2, dim=[2, 3, 4], keepdim=True) + self.eps) |
x = x / norm |
isnot_input_none = weight is not None and bias is not None |
assert (isnot_input_none and not self.affine) or (not isnot_input_none |
and self.affine) |
if self.affine: |
x = x * self.weight + self.bias |
else: |
x = x * weight + bias |
return x |
class UpsamplingLayer(nn.Module): |
def __init__(self, scale_factor=2): |
super().__init__() |
self.scale_factor = scale_factor |
def forward(self, x): |
if self.scale_factor <= 1: |
return x |
return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') |
class NeRFMLPNetwork(nn.Module): |
"""Defines class of MLP Network described in VolumeGAN. |
Basically, this class takes in latent codes and point coodinates as input, |
and outputs features of each point, which is followed by two fully-connected |
layer heads. |
""" |
def __init__(self, input_dim, fg_cfg, bg_cfg=None): |
super().__init__() |
self.fg_mlp = self.build_mlp(input_dim=input_dim, **fg_cfg) |
def build_mlp(self, input_dim, num_layers, hidden_dim, activation_type, |
**kwargs): |
"""Implements function to build the `MLP`. |
Note that here the `MLP` network is consists of a series of |
`ModulateConvLayer` with `kernel_size=1` to simulate fully-connected |
layer. Typically, the input's shape of convolutional layers is |
`[N, C, H, W]`. And the input's shape is `[N, C, R*K, 1]` here, which |
aims to keep consistent with `MLP`. |
""" |
default_conv_cfg = dict(resolution=32, |
w_dim=512, |
kernel_size=1, |
add_bias=True, |
scale_factor=1, |
filter_kernel=None, |
demodulate=True, |
use_wscale=True, |
wscale_gain=1, |
lr_mul=1, |
noise_type='none', |
conv_clamp=None, |
eps=1e-8) |
mlp_list = nn.ModuleList() |
in_ch = input_dim |
out_ch = hidden_dim |
for _ in range(num_layers): |
mlp = ModulateConvLayer(in_channels=in_ch, |
out_channels=out_ch, |
activation_type=activation_type, |
**default_conv_cfg) |
mlp_list.append(mlp) |
in_ch = out_ch |
out_ch = hidden_dim |
return mlp_list |
def forward(self, |
pre_point_features, |
wp, |
points_encoding=None, |
fused_modulate=False, |
impl='cuda'): |
N, C, R_K, _ = points_encoding.shape |
x = torch.cat([pre_point_features, points_encoding], dim=1) |
for idx, mlp in enumerate(self.fg_mlp): |
if wp.ndim == 3: |
_w = wp[:, idx] |
else: |
_w = wp |
x, _ = mlp(x, _w, fused_modulate=fused_modulate, impl=impl) |
return x |
class FCHead(nn.Module): |
"""Defines fully-connected layer head in VolumeGAN to decode `feature` into |
`sigma` and `rgb`.""" |
def __init__(self, fg_cfg, bg_cfg=None, out_dim=512): |
super().__init__() |
self.fg_sigma_head = DenseLayer(in_channels=fg_cfg['hidden_dim'], |
out_channels=1, |
add_bias=True, |
init_bias=0.0, |
use_wscale=True, |
wscale_gain=1, |
lr_mul=1, |
activation_type='linear') |
self.fg_rgb_head = DenseLayer(in_channels=fg_cfg['hidden_dim'], |
out_channels=out_dim, |
add_bias=True, |
init_bias=0.0, |
use_wscale=True, |
wscale_gain=1, |
lr_mul=1, |
activation_type='linear') |
def forward(self, post_point_features, wp=None, dirs=None): |
post_point_features = rearrange( |
post_point_features, 'N C (R_K) 1 -> (N R_K) C').contiguous() |
fg_sigma = self.fg_sigma_head(post_point_features) |
fg_rgb = self.fg_rgb_head(post_point_features) |
results = {'sigma': fg_sigma, 'rgb': fg_rgb} |
return results |
class PostNeuralRendererNetwork(nn.Module): |
"""Implements the neural renderer in VolumeGAN to render high-resolution |
images. |
Basically, this network executes several convolutional layers in sequence. |
""" |
def __init__( |
self, |
resolution, |
init_res, |
w_dim, |
image_channels, |
final_tanh, |
demodulate, |
use_wscale, |
wscale_gain, |
lr_mul, |
noise_type, |
fmaps_base, |
fmaps_max, |
filter_kernel, |
conv_clamp, |
eps, |
rgb_init_res_out=False, |
): |
super().__init__() |
self.init_res = init_res |
self.init_res_log2 = int(np.log2(init_res)) |
self.resolution = resolution |
self.final_res_log2 = int(np.log2(resolution)) |
self.w_dim = w_dim |
self.image_channels = image_channels |
self.final_tanh = final_tanh |
self.demodulate = demodulate |
self.use_wscale = use_wscale |
self.wscale_gain = wscale_gain |
self.lr_mul = lr_mul |
self.noise_type = noise_type.lower() |
self.fmaps_base = fmaps_base |
self.fmaps_max = fmaps_max |
self.filter_kernel = filter_kernel |
self.conv_clamp = conv_clamp |
self.eps = eps |
self.rgb_init_res_out = rgb_init_res_out |
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2 |
self.register_buffer('lod', torch.zeros(())) |
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): |
res = 2**res_log2 |
in_channels = self.get_nf(res // 2) |
out_channels = self.get_nf(res) |
block_idx = res_log2 - self.init_res_log2 |
if res > init_res: |
layer_name = f'layer{2 * block_idx - 1}' |
self.add_module( |
layer_name, |
ModulateConvLayer(in_channels=in_channels, |
out_channels=out_channels, |
resolution=res, |
w_dim=w_dim, |
kernel_size=1, |
add_bias=True, |
scale_factor=2, |
filter_kernel=filter_kernel, |
demodulate=demodulate, |
use_wscale=use_wscale, |
wscale_gain=wscale_gain, |
lr_mul=lr_mul, |
noise_type=noise_type, |
activation_type='lrelu', |
conv_clamp=conv_clamp, |
eps=eps)) |
if block_idx == 0: |
if self.rgb_init_res_out: |
self.rgb_init_res = ConvLayer( |
in_channels=out_channels, |
out_channels=image_channels, |
kernel_size=1, |
add_bias=True, |
scale_factor=1, |
filter_kernel=None, |
use_wscale=use_wscale, |
wscale_gain=wscale_gain, |
lr_mul=lr_mul, |
activation_type='linear', |
conv_clamp=conv_clamp, |
) |
continue |
layer_name = f'layer{2 * block_idx}' |
self.add_module( |
layer_name, |
ModulateConvLayer(in_channels=out_channels, |
out_channels=out_channels, |
resolution=res, |
w_dim=w_dim, |
kernel_size=1, |
add_bias=True, |
scale_factor=1, |
filter_kernel=None, |
demodulate=demodulate, |
use_wscale=use_wscale, |
wscale_gain=wscale_gain, |
lr_mul=lr_mul, |
noise_type=noise_type, |
activation_type='lrelu', |
conv_clamp=conv_clamp, |
eps=eps)) |
layer_name = f'output{block_idx}' |
self.add_module( |
layer_name, |
ModulateConvLayer(in_channels=out_channels, |
out_channels=image_channels, |
resolution=res, |
w_dim=w_dim, |
kernel_size=1, |
add_bias=True, |
scale_factor=1, |
filter_kernel=None, |
demodulate=False, |
use_wscale=use_wscale, |
wscale_gain=wscale_gain, |
lr_mul=lr_mul, |
noise_type='none', |
activation_type='linear', |
conv_clamp=conv_clamp, |
eps=eps)) |
self.register_buffer('filter', upfirdn2d.setup_filter(filter_kernel)) |
def get_nf(self, res): |
"""Gets number of feature maps according to current resolution.""" |
return min(self.fmaps_base // res, self.fmaps_max) |
def set_space_of_latent(self, space_of_latent): |
"""Sets the space to which the latent code belong. |
Args: |
space_of_latent: The space to which the latent code belong. Case |
insensitive. Support `W` and `Y`. |
""" |
space_of_latent = space_of_latent.upper() |
for module in self.modules(): |
if isinstance(module, ModulateConvLayer): |
setattr(module, 'space_of_latent', space_of_latent) |
def forward(self, |
x, |
wp, |
lod=None, |
noise_mode='const', |
fused_modulate=False, |
impl='cuda', |
fp16_res=None, |
nerf_out=False): |
lod = self.lod.item() if lod is None else lod |
results = {} |
if fp16_res is not None and self.init_res >= fp16_res: |
x = x.to(torch.float16) |
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): |
cur_lod = self.final_res_log2 - res_log2 |
block_idx = res_log2 - self.init_res_log2 |
layer_idxs = [2 * block_idx - 1, 2 * |
block_idx] if block_idx > 0 else [ |
2 * block_idx, |
] |
if lod < cur_lod + 1: |
for layer_idx in layer_idxs: |
if layer_idx == 0: |
if self.rgb_init_res_out: |
cur_image = self.rgb_init_res(x, |
runtime_gain=1, |
impl=impl) |
else: |
cur_image = x[:, :3] |
continue |
layer = getattr(self, f'layer{layer_idx}') |
x, style = layer( |
x, |
wp[:, layer_idx], |
noise_mode=noise_mode, |
fused_modulate=fused_modulate, |
impl=impl, |
) |
results[f'style{layer_idx}'] = style |
if layer_idx % 2 == 0: |
output_layer = getattr(self, f'output{layer_idx // 2}') |
y, style = output_layer( |
x, |
wp[:, layer_idx + 1], |
fused_modulate=fused_modulate, |
impl=impl, |
) |
results[f'output_style{layer_idx // 2}'] = style |
if layer_idx == 0: |
cur_image = y.to(torch.float32) |
else: |
if not nerf_out: |
cur_image = y.to( |
torch.float32) + upfirdn2d.upsample2d( |
cur_image, self.filter, impl=impl) |
else: |
cur_image = y.to(torch.float32) + cur_image |
if layer_idx != self.num_layers - 2: |
res = self.init_res * (2**(layer_idx // 2)) |
if fp16_res is not None and res * 2 >= fp16_res: |
x = x.to(torch.float16) |
else: |
x = x.to(torch.float32) |
if cur_lod - 1 < lod <= cur_lod: |
image = cur_image |
elif cur_lod < lod < cur_lod + 1: |
alpha = np.ceil(lod) - lod |
image = F.interpolate(image, scale_factor=2, mode='nearest') |
image = cur_image * alpha + image * (1 - alpha) |
elif lod >= cur_lod + 1: |
image = F.interpolate(image, scale_factor=2, mode='nearest') |
if self.final_tanh: |
image = torch.tanh(image) |
results['image'] = image |
return results |