|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch_utils import persistence |
|
from training.networks_stylegan2 import Generator as StyleGAN2Backbone |
|
|
|
|
|
import dnnlib |
|
|
|
@persistence.persistent_class |
|
class TriPlaneGenerator(torch.nn.Module): |
|
def __init__(self, |
|
z_dim, |
|
c_dim, |
|
w_dim, |
|
|
|
|
|
|
|
mapping_kwargs = {}, |
|
|
|
|
|
**synthesis_kwargs, |
|
): |
|
super().__init__() |
|
self.z_dim=z_dim |
|
self.c_dim=c_dim |
|
self.w_dim=w_dim |
|
|
|
|
|
|
|
|
|
self.backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=256, img_channels=32*3, mapping_kwargs=mapping_kwargs, **synthesis_kwargs) |
|
|
|
self.decoder = OSGDecoder(32, {'decoder_output_dim': 0}) |
|
|
|
|
|
|
|
self._last_planes = None |
|
|
|
def mapping(self, z, c=None, truncation_psi=1, truncation_cutoff=None, update_emas=False): |
|
|
|
|
|
|
|
return self.backbone.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) |
|
|
|
def synthesis(self, ws, c=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_cached_backbone and self._last_planes is not None: |
|
planes = self._last_planes |
|
else: |
|
planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) |
|
if cache_backbone: |
|
self._last_planes = planes |
|
|
|
|
|
planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1]) |
|
return planes |
|
|
|
|
|
feature_samples, depth_samples, weights_samples = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs) |
|
|
|
|
|
H = W = self.neural_rendering_resolution |
|
feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous() |
|
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
|
|
|
|
|
rgb_image = feature_image[:, :3] |
|
sr_image = self.superresolution(rgb_image, feature_image, ws, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'}) |
|
|
|
return {'image': sr_image, 'image_raw': rgb_image, 'image_depth': depth_image} |
|
|
|
def sample(self, coordinates, directions, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): |
|
|
|
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) |
|
planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) |
|
planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1]) |
|
return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) |
|
|
|
def sample_mixed(self, coordinates, directions, ws, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): |
|
|
|
planes = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs) |
|
planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1]) |
|
return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) |
|
|
|
def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs): |
|
|
|
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) |
|
return self.synthesis(ws, c, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, **synthesis_kwargs) |
|
|
|
|
|
from .training.networks_stylegan2 import FullyConnectedLayer |
|
|
|
class OSGDecoder(torch.nn.Module): |
|
def __init__(self, n_features, options): |
|
super().__init__() |
|
self.hidden_dim = 64 |
|
|
|
self.net = torch.nn.Sequential( |
|
FullyConnectedLayer(n_features, self.hidden_dim), |
|
torch.nn.Softplus(), |
|
FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim']) |
|
) |
|
|
|
def forward(self, sampled_features, ray_directions=None): |
|
|
|
sampled_features = sampled_features.mean(1) |
|
x = sampled_features |
|
|
|
N, M, C = x.shape |
|
x = x.view(N*M, C) |
|
|
|
x = self.net(x) |
|
x = x.view(N, M, -1) |
|
return x |
|
rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 |
|
sigma = x[..., 0:1] |
|
return {'rgb': rgb, 'sigma': sigma} |
|
|