Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import functools | |
import re | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from imaginaire.layers import Conv2dBlock, LinearBlock | |
from imaginaire.model_utils.gancraft.layers import AffineMod, ModLinear | |
import imaginaire.model_utils.gancraft.mc_utils as mc_utils | |
import imaginaire.model_utils.gancraft.voxlib as voxlib | |
from imaginaire.utils.distributed import master_only_print as print | |
class RenderMLP(nn.Module): | |
r""" MLP with affine modulation.""" | |
def __init__(self, in_channels, style_dim, viewdir_dim, mask_dim=680, | |
out_channels_s=1, out_channels_c=3, hidden_channels=256, | |
use_seg=True): | |
super(RenderMLP, self).__init__() | |
self.use_seg = use_seg | |
if self.use_seg: | |
self.fc_m_a = nn.Linear(mask_dim, hidden_channels, bias=False) | |
self.fc_viewdir = None | |
if viewdir_dim > 0: | |
self.fc_viewdir = nn.Linear(viewdir_dim, hidden_channels, bias=False) | |
self.fc_1 = nn.Linear(in_channels, hidden_channels) | |
self.fc_2 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) | |
self.fc_3 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) | |
self.fc_4 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) | |
self.fc_sigma = nn.Linear(hidden_channels, out_channels_s) | |
if viewdir_dim > 0: | |
self.fc_5 = nn.Linear(hidden_channels, hidden_channels, bias=False) | |
self.mod_5 = AffineMod(hidden_channels, style_dim, mod_bias=True) | |
else: | |
self.fc_5 = ModLinear(hidden_channels, hidden_channels, style_dim, | |
bias=False, mod_bias=True, output_mode=True) | |
self.fc_6 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) | |
self.fc_out_c = nn.Linear(hidden_channels, out_channels_c) | |
self.act = nn.LeakyReLU(negative_slope=0.2) | |
def forward(self, x, raydir, z, m): | |
r""" Forward network | |
Args: | |
x (N x H x W x M x in_channels tensor): Projected features. | |
raydir (N x H x W x 1 x viewdir_dim tensor): Ray directions. | |
z (N x style_dim tensor): Style codes. | |
m (N x H x W x M x mask_dim tensor): One-hot segmentation maps. | |
""" | |
b, h, w, n, _ = x.size() | |
z = z[:, None, None, None, :] | |
f = self.fc_1(x) | |
if self.use_seg: | |
f = f + self.fc_m_a(m) | |
# Common MLP | |
f = self.act(f) | |
f = self.act(self.fc_2(f, z)) | |
f = self.act(self.fc_3(f, z)) | |
f = self.act(self.fc_4(f, z)) | |
# Sigma MLP | |
sigma = self.fc_sigma(f) | |
# Color MLP | |
if self.fc_viewdir is not None: | |
f = self.fc_5(f) | |
f = f + self.fc_viewdir(raydir) | |
f = self.act(self.mod_5(f, z)) | |
else: | |
f = self.act(self.fc_5(f, z)) | |
f = self.act(self.fc_6(f, z)) | |
c = self.fc_out_c(f) | |
return sigma, c | |
class StyleMLP(nn.Module): | |
r"""MLP converting style code to intermediate style representation.""" | |
def __init__(self, style_dim, out_dim, hidden_channels=256, leaky_relu=True, num_layers=5, normalize_input=True, | |
output_act=True): | |
super(StyleMLP, self).__init__() | |
self.normalize_input = normalize_input | |
self.output_act = output_act | |
fc_layers = [] | |
fc_layers.append(nn.Linear(style_dim, hidden_channels, bias=True)) | |
for i in range(num_layers-1): | |
fc_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=True)) | |
self.fc_layers = nn.ModuleList(fc_layers) | |
self.fc_out = nn.Linear(hidden_channels, out_dim, bias=True) | |
if leaky_relu: | |
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
else: | |
self.act = functools.partial(F.relu, inplace=True) | |
def forward(self, z): | |
r""" Forward network | |
Args: | |
z (N x style_dim tensor): Style codes. | |
""" | |
if self.normalize_input: | |
z = F.normalize(z, p=2, dim=-1) | |
for fc_layer in self.fc_layers: | |
z = self.act(fc_layer(z)) | |
z = self.fc_out(z) | |
if self.output_act: | |
z = self.act(z) | |
return z | |
class SKYMLP(nn.Module): | |
r"""MLP converting ray directions to sky features.""" | |
def __init__(self, in_channels, style_dim, out_channels_c=3, | |
hidden_channels=256, leaky_relu=True): | |
super(SKYMLP, self).__init__() | |
self.fc_z_a = nn.Linear(style_dim, hidden_channels, bias=False) | |
self.fc1 = nn.Linear(in_channels, hidden_channels) | |
self.fc2 = nn.Linear(hidden_channels, hidden_channels) | |
self.fc3 = nn.Linear(hidden_channels, hidden_channels) | |
self.fc4 = nn.Linear(hidden_channels, hidden_channels) | |
self.fc5 = nn.Linear(hidden_channels, hidden_channels) | |
self.fc_out_c = nn.Linear(hidden_channels, out_channels_c) | |
if leaky_relu: | |
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
else: | |
self.act = functools.partial(F.relu, inplace=True) | |
def forward(self, x, z): | |
r"""Forward network | |
Args: | |
x (... x in_channels tensor): Ray direction embeddings. | |
z (... x style_dim tensor): Style codes. | |
""" | |
z = self.fc_z_a(z) | |
while z.dim() < x.dim(): | |
z = z.unsqueeze(1) | |
y = self.act(self.fc1(x) + z) | |
y = self.act(self.fc2(y)) | |
y = self.act(self.fc3(y)) | |
y = self.act(self.fc4(y)) | |
y = self.act(self.fc5(y)) | |
c = self.fc_out_c(y) | |
return c | |
class RenderCNN(nn.Module): | |
r"""CNN converting intermediate feature map to final image.""" | |
def __init__(self, in_channels, style_dim, hidden_channels=256, | |
leaky_relu=True): | |
super(RenderCNN, self).__init__() | |
self.fc_z_cond = nn.Linear(style_dim, 2 * 2 * hidden_channels) | |
self.conv1 = nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0) | |
self.conv2a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) | |
self.conv2b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) | |
self.conv3a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) | |
self.conv3b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) | |
self.conv4a = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) | |
self.conv4b = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) | |
self.conv4 = nn.Conv2d(hidden_channels, 3, 1, stride=1, padding=0) | |
if leaky_relu: | |
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
else: | |
self.act = functools.partial(F.relu, inplace=True) | |
def modulate(self, x, w, b): | |
w = w[..., None, None] | |
b = b[..., None, None] | |
return x * (w+1) + b | |
def forward(self, x, z): | |
r"""Forward network. | |
Args: | |
x (N x in_channels x H x W tensor): Intermediate feature map | |
z (N x style_dim tensor): Style codes. | |
""" | |
z = self.fc_z_cond(z) | |
adapt = torch.chunk(z, 2 * 2, dim=-1) | |
y = self.act(self.conv1(x)) | |
y = y + self.conv2b(self.act(self.conv2a(y))) | |
y = self.act(self.modulate(y, adapt[0], adapt[1])) | |
y = y + self.conv3b(self.act(self.conv3a(y))) | |
y = self.act(self.modulate(y, adapt[2], adapt[3])) | |
y = y + self.conv4b(self.act(self.conv4a(y))) | |
y = self.act(y) | |
y = self.conv4(y) | |
return y | |
class StyleEncoder(nn.Module): | |
r"""Style Encoder constructor. | |
Args: | |
style_enc_cfg (obj): Style encoder definition file. | |
""" | |
def __init__(self, style_enc_cfg): | |
super(StyleEncoder, self).__init__() | |
input_image_channels = style_enc_cfg.input_image_channels | |
num_filters = style_enc_cfg.num_filters | |
kernel_size = style_enc_cfg.kernel_size | |
padding = int(np.ceil((kernel_size - 1.0) / 2)) | |
style_dims = style_enc_cfg.style_dims | |
weight_norm_type = style_enc_cfg.weight_norm_type | |
self.no_vae = getattr(style_enc_cfg, 'no_vae', False) | |
activation_norm_type = 'none' | |
nonlinearity = 'leakyrelu' | |
base_conv2d_block = \ | |
functools.partial(Conv2dBlock, | |
kernel_size=kernel_size, | |
stride=2, | |
padding=padding, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
# inplace_nonlinearity=True, | |
nonlinearity=nonlinearity) | |
self.layer1 = base_conv2d_block(input_image_channels, num_filters) | |
self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2) | |
self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4) | |
self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8) | |
self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8) | |
self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8) | |
self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims) | |
if not self.no_vae: | |
self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims) | |
def forward(self, input_x): | |
r"""SPADE Style Encoder forward. | |
Args: | |
input_x (N x 3 x H x W tensor): input images. | |
Returns: | |
mu (N x C tensor): Mean vectors. | |
logvar (N x C tensor): Log-variance vectors. | |
z (N x C tensor): Style code vectors. | |
""" | |
if input_x.size(2) != 256 or input_x.size(3) != 256: | |
input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear') | |
x = self.layer1(input_x) | |
x = self.layer2(x) | |
x = self.layer3(x) | |
x = self.layer4(x) | |
x = self.layer5(x) | |
x = self.layer6(x) | |
x = x.view(x.size(0), -1) | |
mu = self.fc_mu(x) | |
if not self.no_vae: | |
logvar = self.fc_var(x) | |
std = torch.exp(0.5 * logvar) | |
eps = torch.randn_like(std) | |
z = eps.mul(std) + mu | |
else: | |
z = mu | |
logvar = torch.zeros_like(mu) | |
return mu, logvar, z | |
class Base3DGenerator(nn.Module): | |
r"""Minecraft 3D generator constructor. | |
Args: | |
gen_cfg (obj): Generator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file. | |
""" | |
def __init__(self, gen_cfg, data_cfg): | |
super(Base3DGenerator, self).__init__() | |
print('Base3DGenerator initialization.') | |
# ---------------------- Main Network ------------------------ | |
# Exclude some of the features from positional encoding | |
self.pe_no_pe_feat_dim = getattr(gen_cfg, 'pe_no_pe_feat_dim', 0) | |
# blk_feat passes through PE | |
input_dim = (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim)*(gen_cfg.pe_lvl_feat*2) + self.pe_no_pe_feat_dim | |
if (gen_cfg.pe_incl_orig_feat): | |
input_dim += (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim) | |
print('[Base3DGenerator] Expected input dimensions: ', input_dim) | |
self.input_dim = input_dim | |
self.mlp_model_kwargs = gen_cfg.mlp_model_kwargs | |
self.pe_lvl_localcoords = getattr(gen_cfg, 'pe_lvl_localcoords', 0) | |
if self.pe_lvl_localcoords > 0: | |
self.mlp_model_kwargs['poscode_dim'] = self.pe_lvl_localcoords * 2 * 3 | |
# Set pe_lvl_raydir=0 and pe_incl_orig_raydir=False to disable view direction input | |
input_dim_viewdir = 3*(gen_cfg.pe_lvl_raydir*2) | |
if (gen_cfg.pe_incl_orig_raydir): | |
input_dim_viewdir += 3 | |
print('[Base3DGenerator] Expected viewdir input dimensions: ', input_dim_viewdir) | |
self.input_dim_viewdir = input_dim_viewdir | |
self.pe_params = [gen_cfg.pe_lvl_feat, gen_cfg.pe_incl_orig_feat, | |
gen_cfg.pe_lvl_raydir, gen_cfg.pe_incl_orig_raydir] | |
# Style input dimension | |
style_dims = gen_cfg.style_dims | |
self.style_dims = style_dims | |
interm_style_dims = getattr(gen_cfg, 'interm_style_dims', style_dims) | |
self.interm_style_dims = interm_style_dims | |
# ---------------------- Style MLP -------------------------- | |
self.style_net = globals()[gen_cfg.stylenet_model]( | |
style_dims, interm_style_dims, **gen_cfg.stylenet_model_kwargs) | |
# number of output channels for MLP (before blending) | |
final_feat_dim = getattr(gen_cfg, 'final_feat_dim', 16) | |
self.final_feat_dim = final_feat_dim | |
# ----------------------- Sky Network ------------------------- | |
sky_input_dim_base = 3 | |
# Dedicated sky network input dimensions | |
sky_input_dim = sky_input_dim_base*(gen_cfg.pe_lvl_raydir_sky*2) | |
if (gen_cfg.pe_incl_orig_raydir_sky): | |
sky_input_dim += sky_input_dim_base | |
print('[Base3DGenerator] Expected sky input dimensions: ', sky_input_dim) | |
self.pe_params_sky = [gen_cfg.pe_lvl_raydir_sky, gen_cfg.pe_incl_orig_raydir_sky] | |
self.sky_net = SKYMLP(sky_input_dim, style_dim=interm_style_dims, out_channels_c=final_feat_dim) | |
# ----------------------- Style Encoder ------------------------- | |
style_enc_cfg = getattr(gen_cfg, 'style_enc', None) | |
setattr(style_enc_cfg, 'input_image_channels', 3) | |
setattr(style_enc_cfg, 'style_dims', gen_cfg.style_dims) | |
self.style_encoder = StyleEncoder(style_enc_cfg) | |
# ---------------------- Ray Caster ------------------------- | |
self.num_blocks_early_stop = gen_cfg.num_blocks_early_stop | |
self.num_samples = gen_cfg.num_samples | |
self.sample_depth = gen_cfg.sample_depth | |
self.coarse_deterministic_sampling = getattr(gen_cfg, 'coarse_deterministic_sampling', True) | |
self.sample_use_box_boundaries = getattr(gen_cfg, 'sample_use_box_boundaries', True) | |
# ---------------------- Blender ------------------------- | |
self.raw_noise_std = getattr(gen_cfg, 'raw_noise_std', 0.0) | |
self.dists_scale = getattr(gen_cfg, 'dists_scale', 0.25) | |
self.clip_feat_map = getattr(gen_cfg, 'clip_feat_map', True) | |
self.keep_sky_out = getattr(gen_cfg, 'keep_sky_out', False) | |
self.keep_sky_out_avgpool = getattr(gen_cfg, 'keep_sky_out_avgpool', False) | |
keep_sky_out_learnbg = getattr(gen_cfg, 'keep_sky_out_learnbg', False) | |
self.sky_global_avgpool = getattr(gen_cfg, 'sky_global_avgpool', False) | |
if self.keep_sky_out: | |
self.sky_replace_color = None | |
if keep_sky_out_learnbg: | |
sky_replace_color = torch.zeros([final_feat_dim]) | |
sky_replace_color.requires_grad = True | |
self.sky_replace_color = torch.nn.Parameter(sky_replace_color) | |
# ---------------------- render_cnn ------------------------- | |
self.denoiser = RenderCNN(final_feat_dim, style_dim=interm_style_dims) | |
self.pad = gen_cfg.pad | |
def get_param_groups(self, cfg_opt): | |
print('[Generator] get_param_groups') | |
if hasattr(cfg_opt, 'ignore_parameters'): | |
print('[Generator::get_param_groups] [x]: ignored.') | |
optimize_parameters = [] | |
for k, x in self.named_parameters(): | |
match = False | |
for m in cfg_opt.ignore_parameters: | |
if re.match(m, k) is not None: | |
match = True | |
print(' [x]', k) | |
break | |
if match is False: | |
print(' [v]', k) | |
optimize_parameters.append(x) | |
else: | |
optimize_parameters = self.parameters() | |
param_groups = [] | |
param_groups.append({'params': optimize_parameters}) | |
if hasattr(cfg_opt, 'param_groups'): | |
optimized_param_names = [] | |
all_param_names = [k for k, v in self.named_parameters()] | |
param_groups = [] | |
for k, v in cfg_opt.param_groups.items(): | |
print('[Generator::get_param_groups] Adding param group from config:', k, v) | |
params = getattr(self, k) | |
named_parameters = [k] | |
if issubclass(type(params), nn.Module): | |
named_parameters = [k+'.'+pname for pname, _ in params.named_parameters()] | |
params = params.parameters() | |
param_groups.append({'params': params, **v}) | |
optimized_param_names.extend(named_parameters) | |
print('[Generator::get_param_groups] UNOPTIMIZED PARAMETERS:\n ', | |
set(all_param_names) - set(optimized_param_names)) | |
return param_groups | |
def _forward_perpix_sub(self, blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot=None): | |
r"""Forwarding the MLP. | |
Args: | |
blk_feats (K x C1 tensor): Sparse block features. | |
worldcoord2 (N x H x W x L x 3 tensor): 3D world coordinates of sampled points. | |
raydirs_in (N x H x W x 1 x C2 tensor or None): ray direction embeddings. | |
z (N x C3 tensor): Intermediate style vectors. | |
mc_masks_onehot (N x H x W x L x C4): One-hot segmentation maps. | |
Returns: | |
net_out_s (N x H x W x L x 1 tensor): Opacities. | |
net_out_c (N x H x W x L x C5 tensor): Color embeddings. | |
""" | |
proj_feature = voxlib.sparse_trilinear_interp_worldcoord( | |
blk_feats, self.voxel.corner_t, worldcoord2, ign_zero=True) | |
render_net_extra_kwargs = {} | |
if self.pe_lvl_localcoords > 0: | |
local_coords = torch.remainder(worldcoord2, 1.0) * 2.0 | |
# Scale to [0, 2], as the positional encoding function doesn't have internal x2 | |
local_coords[torch.isnan(local_coords)] = 0.0 | |
local_coords = local_coords.contiguous() | |
poscode = voxlib.positional_encoding(local_coords, self.pe_lvl_localcoords, -1, False) | |
render_net_extra_kwargs['poscode'] = poscode | |
if self.pe_params[0] == 0 and self.pe_params[1] is True: # no PE shortcut, saves ~400MB | |
feature_in = proj_feature | |
else: | |
if self.pe_no_pe_feat_dim > 0: | |
feature_in = voxlib.positional_encoding( | |
proj_feature[..., :-self.pe_no_pe_feat_dim].contiguous(), self.pe_params[0], -1, self.pe_params[1]) | |
feature_in = torch.cat([feature_in, proj_feature[..., -self.pe_no_pe_feat_dim:]], dim=-1) | |
else: | |
feature_in = voxlib.positional_encoding( | |
proj_feature.contiguous(), self.pe_params[0], -1, self.pe_params[1]) | |
net_out_s, net_out_c = self.render_net(feature_in, raydirs_in, z, mc_masks_onehot, **render_net_extra_kwargs) | |
if self.raw_noise_std > 0.: | |
noise = torch.randn_like(net_out_s) * self.raw_noise_std | |
net_out_s = net_out_s + noise | |
return net_out_s, net_out_c | |
def _forward_perpix(self, blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z): | |
r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features | |
Args: | |
blk_feats (K x C1 tensor): Sparse block features. | |
voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels | |
depth2 (N x 2 x H x W x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection. | |
raydirs (N x H x W x 1 x 3 tensor): The direction of each ray. | |
cam_ori_t (N x 3 tensor): Camera origins. | |
z (N x C3 tensor): Intermediate style vectors. | |
""" | |
# Generate sky_mask; PE transform on ray direction. | |
with torch.no_grad(): | |
raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous() | |
if self.pe_params[2] == 0 and self.pe_params[3] is True: | |
raydirs_in = raydirs_in | |
elif self.pe_params[2] == 0 and self.pe_params[3] is False: # Not using raydir at all | |
raydirs_in = None | |
else: | |
raydirs_in = voxlib.positional_encoding(raydirs_in, self.pe_params[2], -1, self.pe_params[3]) | |
# sky_mask: when True, ray finally hits sky | |
sky_mask = voxel_id[:, :, :, [-1], :] == 0 | |
# sky_only_mask: when True, ray hits nothing but sky | |
sky_only_mask = voxel_id[:, :, :, [0], :] == 0 | |
with torch.no_grad(): | |
# Random sample points along the ray | |
num_samples = self.num_samples + 1 | |
if self.sample_use_box_boundaries: | |
num_samples = self.num_samples - self.num_blocks_early_stop | |
# 10 samples per ray + 4 intersections - 2 | |
rand_depth, new_dists, new_idx = mc_utils.sample_depth_batched( | |
depth2, num_samples, deterministic=self.coarse_deterministic_sampling, | |
use_box_boundaries=self.sample_use_box_boundaries, sample_depth=self.sample_depth) | |
worldcoord2 = raydirs * rand_depth + cam_ori_t[:, None, None, None, :] | |
# Generate per-sample segmentation label | |
voxel_id_reduced = self.label_trans.mc2reduced(voxel_id, ign2dirt=True) | |
mc_masks = torch.gather(voxel_id_reduced, -2, new_idx) # B 256 256 N 1 | |
mc_masks = mc_masks.long() | |
mc_masks_onehot = torch.zeros([mc_masks.size(0), mc_masks.size(1), mc_masks.size( | |
2), mc_masks.size(3), self.num_reduced_labels], dtype=torch.float, device=voxel_id.device) | |
# mc_masks_onehot: [B H W Nlayer 680] | |
mc_masks_onehot.scatter_(-1, mc_masks, 1.0) | |
net_out_s, net_out_c = self._forward_perpix_sub(blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot) | |
# Handle sky | |
sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous() | |
sky_raydirs_in = voxlib.positional_encoding(sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1]) | |
skynet_out_c = self.sky_net(sky_raydirs_in, z) | |
# Blending | |
weights = mc_utils.volum_rendering_relu(net_out_s, new_dists * self.dists_scale, dim=-2) | |
# If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero. | |
weights = weights * torch.logical_not(sky_only_mask).float() | |
total_weights_raw = torch.sum(weights, dim=-2, keepdim=True) # 256 256 1 1 | |
total_weights = total_weights_raw | |
is_gnd = worldcoord2[..., [0]] <= 1.0 # Y X Z, [256, 256, 4, 3], nan < 1.0 == False | |
is_gnd = is_gnd.any(dim=-2, keepdim=True) | |
nosky_mask = torch.logical_or(torch.logical_not(sky_mask), is_gnd) | |
nosky_mask = nosky_mask.float() | |
# Avoid sky leakage | |
sky_weight = 1.0-total_weights | |
if self.keep_sky_out: | |
# keep_sky_out_avgpool overrides sky_replace_color | |
if self.sky_replace_color is None or self.keep_sky_out_avgpool: | |
if self.keep_sky_out_avgpool: | |
if hasattr(self, 'sky_avg'): | |
sky_avg = self.sky_avg | |
else: | |
if self.sky_global_avgpool: | |
sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True) | |
else: | |
skynet_out_c_nchw = skynet_out_c.permute(0, 4, 1, 2, 3).squeeze(-1) | |
sky_avg = F.avg_pool2d(skynet_out_c_nchw, 31, stride=1, padding=15, count_include_pad=False) | |
sky_avg = sky_avg.permute(0, 2, 3, 1).unsqueeze(-2) | |
# print(sky_avg.shape) | |
skynet_out_c = skynet_out_c * (1.0-nosky_mask) + sky_avg*(nosky_mask) | |
else: | |
sky_weight = sky_weight * (1.0-nosky_mask) | |
else: | |
skynet_out_c = skynet_out_c * (1.0-nosky_mask) + self.sky_replace_color*(nosky_mask) | |
if self.clip_feat_map is True: # intermediate feature before blending & CNN | |
rgbs = torch.clamp(net_out_c, -1, 1) + 1 | |
rgbs_sky = torch.clamp(skynet_out_c, -1, 1) + 1 | |
net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ | |
rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 | |
net_out = net_out.squeeze(-2) | |
net_out = net_out - 1 | |
elif self.clip_feat_map is False: | |
rgbs = net_out_c | |
rgbs_sky = skynet_out_c | |
net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ | |
rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 | |
net_out = net_out.squeeze(-2) | |
elif self.clip_feat_map == 'tanh': | |
rgbs = torch.tanh(net_out_c) | |
rgbs_sky = torch.tanh(skynet_out_c) | |
net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ | |
rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 | |
net_out = net_out.squeeze(-2) | |
else: | |
raise NotImplementedError | |
return net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \ | |
nosky_mask, sky_mask, sky_only_mask, new_idx | |
def _forward_global(self, net_out, z): | |
r"""Forward the CNN | |
Args: | |
net_out (N x C5 x H x W tensor): Intermediate feature maps. | |
z (N x C3 tensor): Intermediate style vectors. | |
Returns: | |
fake_images (N x 3 x H x W tensor): Output image. | |
fake_images_raw (N x 3 x H x W tensor): Output image before TanH. | |
""" | |
fake_images = net_out.permute(0, 3, 1, 2) | |
fake_images_raw = self.denoiser(fake_images, z) | |
fake_images = torch.tanh(fake_images_raw) | |
return fake_images, fake_images_raw | |