BerfScene / models /bev3d_generator.py
3v324v23's picture
init
2f85de4
raw
history blame
11.2 kB
# python3.8
"""Contains the implementation of generator described in BEV3D."""
import torch
import torch.nn as nn
from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone
from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
from models.utils.eg3d_superres import SuperresolutionHybrid2X
from models.utils.eg3d_superres import SuperresolutionHybrid4X
from models.utils.eg3d_superres import SuperresolutionHybrid4X_conststyle
from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
from models.rendering.renderer import Renderer
from models.rendering.feature_extractor import FeatureExtractor
from models.utils.spade import SPADEGenerator
class BEV3DGenerator(nn.Module):
def __init__(
self,
z_dim,
semantic_nc,
ngf,
bev_grid_size,
aspect_ratio,
num_upsampling_layers,
not_use_vae,
norm_G,
img_resolution,
interpolate_sr,
segmask=False,
dim_seq='16,8,4,2,1',
xyz_pe=False,
hidden_dim=64,
additional_layer_num=0,
sr_num_fp16_res=0, # Number of fp16 layers of SR Network.
rendering_kwargs={}, # Arguments for rendering.
sr_kwargs={}, # Arguments for SuperResolution Network.
):
super().__init__()
self.z_dim = z_dim
self.interpolate_sr = interpolate_sr
self.segmask = segmask
# Set up the overall renderer.
self.renderer = Renderer()
# Set up the feature extractor.
self.feature_extractor = FeatureExtractor(ref_mode='bev_plane_clevr', xyz_pe=xyz_pe)
# Set up the reference representation generator.
self.backbone = SPADEGenerator(z_dim=z_dim, semantic_nc=semantic_nc, ngf=ngf, dim_seq=dim_seq, bev_grid_size=bev_grid_size,
aspect_ratio=aspect_ratio, num_upsampling_layers=num_upsampling_layers,
not_use_vae=not_use_vae, norm_G=norm_G)
print('backbone SPADEGenerator set up!')
# Set up the post module in the feature extractor.
self.post_module = None
# Set up the post neural renderer.
self.post_neural_renderer = None
sr_kwargs_total = dict(
channels=32,
img_resolution=img_resolution,
sr_num_fp16_res=sr_num_fp16_res,
sr_antialias=rendering_kwargs['sr_antialias'],)
sr_kwargs_total.update(**sr_kwargs)
if img_resolution == 128:
self.post_neural_renderer = SuperresolutionHybrid2X(
**sr_kwargs_total)
elif img_resolution == 256:
self.post_neural_renderer = SuperresolutionHybrid4X_conststyle(
**sr_kwargs_total)
elif img_resolution == 512:
self.post_neural_renderer = SuperresolutionHybrid8XDC(
**sr_kwargs_total)
else:
raise TypeError(f'Unsupported image resolution: {img_resolution}!')
# Set up the fully-connected layer head.
self.fc_head = OSGDecoder(
128 if xyz_pe else 64 , {
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
'decoder_output_dim': 32
},
hidden_dim=hidden_dim,
additional_layer_num=additional_layer_num
)
# Set up some rendering related arguments.
self.neural_rendering_resolution = rendering_kwargs.get(
'resolution', 64)
self.rendering_kwargs = rendering_kwargs
def synthesis(self,
z,
c,
seg,
neural_rendering_resolution=None,
update_emas=False,
**synthesis_kwargs):
cam2world_matrix = c[:, :16].view(-1, 4, 4)
if self.rendering_kwargs.get('random_pose', False):
cam2world_matrix = None
if neural_rendering_resolution is None:
neural_rendering_resolution = self.neural_rendering_resolution
else:
self.neural_rendering_resolution = neural_rendering_resolution
xy_planes = self.backbone(z=z, input=seg)
if self.segmask:
xy_planes = xy_planes * seg[:, 0, ...][:, None, ...]
# import pdb;pdb.set_trace()
wp = z # in our case, we do not use wp.
rendering_result = self.renderer(
wp=wp,
feature_extractor=self.feature_extractor,
rendering_options=self.rendering_kwargs,
cam2world_matrix=cam2world_matrix,
position_encoder=None,
ref_representation=xy_planes,
post_module=self.post_module,
fc_head=self.fc_head)
feature_samples = rendering_result['composite_rgb']
depth_samples = rendering_result['composite_depth']
# Reshape to keep consistent with 'raw' neural-rendered image.
N = wp.shape[0]
H = W = self.neural_rendering_resolution
feature_image = feature_samples.permute(0, 2, 1).reshape(
N, feature_samples.shape[-1], H, W).contiguous()
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
# Run the post neural renderer to get final image.
# Here, the post neural renderer is a super-resolution network.
rgb_image = feature_image[:, :3]
if self.interpolate_sr:
sr_image = torch.nn.functional.interpolate(rgb_image, size=(256, 256), mode='bilinear', align_corners=False)
else:
sr_image = self.post_neural_renderer(
rgb_image,
feature_image,
# wp,
noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
**{
k: synthesis_kwargs[k]
for k in synthesis_kwargs.keys() if k != 'noise_mode'
})
return {
'image': sr_image,
'image_raw': rgb_image,
'image_depth': depth_image
}
def sample(self,
coordinates,
directions,
z,
c,
seg,
truncation_psi=1,
truncation_cutoff=None,
update_emas=False,
**synthesis_kwargs):
# Compute RGB features, density for arbitrary 3D coordinates.
# Mostly used for extracting shapes.
cam2world_matrix = c[:, :16].view(-1, 4, 4)
xy_planes = self.backbone(z=z, input=seg)
wp = z
result = self.renderer.get_sigma_rgb(
wp=wp,
points=coordinates,
feature_extractor=self.feature_extractor,
fc_head=self.fc_head,
rendering_options=self.rendering_kwargs,
ref_representation=xy_planes,
post_module=self.post_module,
ray_dirs=directions,
cam_matrix=cam2world_matrix)
return result
def sample_mixed(self,
coordinates,
directions,
z, c, seg,
truncation_psi=1,
truncation_cutoff=None,
update_emas=False,
**synthesis_kwargs):
# Same as function `self.sample()`, but expects latent vectors 'wp'
# instead of Gaussian noise 'z'.
cam2world_matrix = c[:, :16].view(-1, 4, 4)
xy_planes = self.backbone(z=z, input=seg)
wp = z
result = self.renderer.get_sigma_rgb(
wp=wp,
points=coordinates,
feature_extractor=self.feature_extractor,
fc_head=self.fc_head,
rendering_options=self.rendering_kwargs,
ref_representation=xy_planes,
post_module=self.post_module,
ray_dirs=directions,
cam_matrix=cam2world_matrix)
return result
def forward(self,
z,
c,
seg,
c_swapped=None, # `c_swapped` is swapped pose conditioning.
style_mixing_prob=0,
truncation_psi=1,
truncation_cutoff=None,
neural_rendering_resolution=None,
update_emas=False,
sample_mixed=False,
coordinates=None,
**synthesis_kwargs):
# Render a batch of generated images.
c_wp = c.clone()
if c_swapped is not None:
c_wp = c_swapped.clone()
if not sample_mixed:
gen_output = self.synthesis(
z,
c,
seg,
update_emas=update_emas,
neural_rendering_resolution=neural_rendering_resolution,
**synthesis_kwargs)
return {
'wp': z,
'gen_output': gen_output,
}
else:
# Only for density regularization in training process.
assert coordinates is not None
sample_sigma = self.sample_mixed(coordinates,
torch.randn_like(coordinates),
z, c, seg,
update_emas=False)['sigma']
return {
'wp': z,
'sample_sigma': sample_sigma
}
class OSGDecoder(nn.Module):
"""Defines fully-connected layer head in EG3D."""
def __init__(self, n_features, options, hidden_dim=64, additional_layer_num=0):
super().__init__()
self.hidden_dim = hidden_dim
lst = []
lst.append(FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']))
lst.append(nn.Softplus())
for i in range(additional_layer_num):
lst.append(FullyConnectedLayer(self.hidden_dim, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']))
lst.append(nn.Softplus())
lst.append(FullyConnectedLayer(self.hidden_dim, 1+options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul']))
self.net = nn.Sequential(*lst)
# self.net = nn.Sequential(
# FullyConnectedLayer(n_features,
# self.hidden_dim,
# lr_multiplier=options['decoder_lr_mul']),
# nn.Softplus(),
# FullyConnectedLayer(self.hidden_dim,
# 1 + options['decoder_output_dim'],
# lr_multiplier=options['decoder_lr_mul']))
def forward(self, point_features, wp=None, dirs=None):
# Aggregate features
# point_features.shape: [N, R, K, C].
# Average across 'X, Y, Z' planes.
N, R, K, C = point_features.shape
x = point_features.reshape(-1, point_features.shape[-1])
x = self.net(x)
x = x.view(N, -1, x.shape[-1])
# Uses sigmoid clamping from MipNeRF
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
sigma = x[..., 0:1]
return {'rgb': rgb, 'sigma': sigma}