|
|
|
"""Contains the implementation of generator described in PiGAN.""" |
|
|
|
import math |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.cuda.amp import autocast |
|
|
|
from .utils.ops import all_gather |
|
from .rendering.renderer import Renderer |
|
from .rendering.feature_extractor import FeatureExtractor |
|
|
|
__all__ = ['PiGANGenerator'] |
|
|
|
|
|
class PiGANGenerator(nn.Module): |
|
"""Defines the generator network in PiGAN.""" |
|
def __init__(self, |
|
|
|
z_dim=256, |
|
w_dim=256, |
|
repeat_w=False, |
|
normalize_z=False, |
|
mapping_layers=3, |
|
mapping_hidden_dim=256, |
|
|
|
label_dim=0, |
|
embedding_dim=512, |
|
normalize_embedding=True, |
|
normalize_embedding_latent=False, |
|
label_concat=True, |
|
|
|
resolution=-1, |
|
synthesis_input_dim=3, |
|
synthesis_output_dim=256, |
|
synthesis_layers=8, |
|
grid_scale=0.24, |
|
eps=1e-8, |
|
|
|
rendering_kwargs={}): |
|
"""Initializes with basic settings.""" |
|
super().__init__() |
|
|
|
self.z_dim = z_dim |
|
self.w_dim = w_dim |
|
self.repeat_w = repeat_w |
|
self.normalize_z = normalize_z |
|
self.mapping_layers = mapping_layers |
|
|
|
self.latent_dim = (z_dim,) |
|
self.label_dim = label_dim |
|
self.embedding_dim = embedding_dim |
|
self.normalize_embedding = normalize_embedding |
|
self.normalize_embedding_latent = normalize_embedding_latent |
|
|
|
self.resolution = resolution |
|
self.num_layers = synthesis_layers |
|
self.eps = eps |
|
|
|
if self.repeat_w: |
|
self.mapping_space_dim = self.w_dim |
|
else: |
|
self.mapping_space_dim = self.w_dim * (self.num_layers + 1) |
|
|
|
|
|
self.mapping = MappingNetwork( |
|
input_dim=z_dim, |
|
output_dim=w_dim, |
|
num_outputs=synthesis_layers + 1, |
|
repeat_output=repeat_w, |
|
normalize_input=normalize_z, |
|
num_layers=mapping_layers, |
|
hidden_dim=mapping_hidden_dim, |
|
label_dim=label_dim, |
|
embedding_dim=embedding_dim, |
|
normalize_embedding=normalize_embedding, |
|
normalize_embedding_latent=normalize_embedding_latent, |
|
eps=eps, |
|
label_concat=label_concat, |
|
lr=None) |
|
|
|
|
|
self.renderer = Renderer() |
|
|
|
|
|
self.ref_representation_generator = None |
|
|
|
|
|
self.feature_extractor = FeatureExtractor(ref_mode='none') |
|
|
|
|
|
self.post_module = MLPNetwork(w_dim=w_dim, |
|
in_channels=synthesis_input_dim, |
|
num_layers=synthesis_layers, |
|
out_channels=synthesis_output_dim, |
|
grid_scale=grid_scale) |
|
|
|
|
|
self.fc_head = FCHead(w_dim=w_dim, |
|
channels=synthesis_output_dim, |
|
mlp_length=self.post_module.mlp_length) |
|
|
|
|
|
self.post_neural_renderer = None |
|
|
|
|
|
if self.repeat_w: |
|
self.register_buffer('w_avg', torch.zeros(w_dim)) |
|
else: |
|
self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) |
|
|
|
|
|
self.rendering_kwargs = rendering_kwargs |
|
|
|
|
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
self.mapping.init_weights() |
|
self.post_module.init_weights() |
|
self.fc_head.init_weights() |
|
|
|
def forward(self, |
|
z, |
|
label=None, |
|
lod=None, |
|
w_moving_decay=None, |
|
sync_w_avg=False, |
|
style_mixing_prob=None, |
|
noise_std=None, |
|
trunc_psi=None, |
|
trunc_layers=None, |
|
enable_amp=False): |
|
if noise_std is not None: |
|
self.rendering_kwargs.update(noise_std=noise_std) |
|
|
|
lod = self.post_module.lod.cpu().tolist() if lod is None else lod |
|
|
|
mapping_results = self.mapping(z, label) |
|
w = mapping_results['w'] |
|
wp = mapping_results.pop('wp') |
|
|
|
if self.training and w_moving_decay is not None: |
|
if sync_w_avg: |
|
batch_w_avg = all_gather(w.detach()).mean(dim=0) |
|
else: |
|
batch_w_avg = w.detach().mean(dim=0) |
|
self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) |
|
|
|
|
|
if not self.training: |
|
trunc_psi = 1.0 if trunc_psi is None else trunc_psi |
|
trunc_layers = 0 if trunc_layers is None else trunc_layers |
|
if trunc_psi < 1.0 and trunc_layers > 0: |
|
w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] |
|
wp[:, :trunc_layers] = w_avg.lerp( |
|
wp[:, :trunc_layers], trunc_psi) |
|
|
|
with autocast(enabled=enable_amp): |
|
rendering_result = self.renderer( |
|
wp=wp, |
|
feature_extractor=self.feature_extractor, |
|
rendering_options=self.rendering_kwargs, |
|
position_encoder=None, |
|
ref_representation=None, |
|
post_module=self.post_module, |
|
post_module_kwargs=dict(lod=lod), |
|
fc_head=self.fc_head) |
|
|
|
image = rendering_result['composite_rgb'].reshape( |
|
z.shape[0], self.resolution, self.resolution, |
|
-1).permute(0, 3, 1, 2) |
|
|
|
camera = torch.cat([ |
|
rendering_result['camera_polar'], |
|
rendering_result['camera_azimuthal'] |
|
], -1) |
|
|
|
return { |
|
**mapping_results, |
|
'image': image, |
|
'camera': camera, |
|
'latent': z |
|
} |
|
|
|
|
|
class MappingNetwork(nn.Module): |
|
"""Implements the latent space mapping module. |
|
|
|
Basically, this module executes several dense layers in sequence, and the |
|
label embedding if needed. |
|
""" |
|
|
|
def __init__(self, |
|
input_dim, |
|
output_dim, |
|
num_outputs, |
|
repeat_output, |
|
normalize_input, |
|
num_layers, |
|
hidden_dim, |
|
label_dim, |
|
embedding_dim, |
|
normalize_embedding, |
|
normalize_embedding_latent, |
|
eps, |
|
label_concat, |
|
lr=None): |
|
super().__init__() |
|
|
|
self.input_dim = input_dim |
|
self.output_dim = output_dim |
|
self.num_outputs = num_outputs |
|
self.repeat_output = repeat_output |
|
self.normalize_input = normalize_input |
|
self.num_layers = num_layers |
|
|
|
|
|
|
|
|
|
self.label_dim = label_dim |
|
self.embedding_dim = embedding_dim |
|
self.normalize_embedding = normalize_embedding |
|
self.normalize_embedding_latent = normalize_embedding_latent |
|
self.eps = eps |
|
self.label_concat = label_concat |
|
|
|
self.norm = PixelNormLayer(dim=1, eps=eps) |
|
|
|
if num_outputs is not None and not repeat_output: |
|
output_dim = output_dim * num_outputs |
|
|
|
if self.label_dim > 0: |
|
if self.label_concat: |
|
input_dim = input_dim + embedding_dim |
|
self.embedding = EqualLinear(label_dim, |
|
embedding_dim, |
|
bias=True, |
|
bias_init=0, |
|
lr_mul=1) |
|
else: |
|
self.embedding = EqualLinear(label_dim, |
|
output_dim, |
|
bias=True, |
|
bias_init=0, |
|
lr_mul=1) |
|
|
|
network = [] |
|
for i in range(num_layers): |
|
in_channels = (input_dim if i == 0 else hidden_dim) |
|
out_channels = (output_dim if i == (num_layers - 1) else hidden_dim) |
|
network.append(nn.Linear(in_channels, out_channels)) |
|
network.append(nn.LeakyReLU(0.2, inplace=True)) |
|
self.network = nn.Sequential(*network) |
|
|
|
def init_weights(self): |
|
for module in self.network.modules(): |
|
if isinstance(module, nn.Linear): |
|
nn.init.kaiming_normal_(module.weight, |
|
a=0.2, |
|
mode='fan_in', |
|
nonlinearity='leaky_relu') |
|
|
|
def forward(self, z, label=None): |
|
if z.ndim != 2 or z.shape[1] != self.input_dim: |
|
raise ValueError(f'Input latent code should be with shape ' |
|
f'[batch_size, input_dim], where ' |
|
f'`input_dim` equals to {self.input_dim}!\n' |
|
f'But `{z.shape}` is received!') |
|
if self.normalize_input: |
|
z = self.norm(z) |
|
if self.label_dim > 0: |
|
if label is None: |
|
raise ValueError(f'Model requires an additional label ' |
|
f'(with dimension {self.label_dim}) as input, ' |
|
f'but no label is received!') |
|
if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim): |
|
raise ValueError(f'Input label should be with shape ' |
|
f'[batch_size, label_dim], where ' |
|
f'`batch_size` equals to that of ' |
|
f'latent codes ({z.shape[0]}) and ' |
|
f'`label_dim` equals to {self.label_dim}!\n' |
|
f'But `{label.shape}` is received!') |
|
label = label.to(dtype=torch.float32) |
|
|
|
embedding = self.embedding(label) |
|
if self.normalize_embedding and self.label_concat: |
|
embedding = self.norm(embedding) |
|
if self.label_concat: |
|
w = torch.cat((z, embedding), dim=1) |
|
else: |
|
w = z |
|
else: |
|
w = z |
|
|
|
if (self.label_dim > 0 and self.normalize_embedding_latent |
|
and self.label_concat): |
|
w = self.norm(w) |
|
|
|
for layer in self.network: |
|
w = layer(w) |
|
|
|
if self.label_dim > 0 and (not self.label_concat): |
|
w = w * embedding |
|
|
|
wp = None |
|
if self.num_outputs is not None: |
|
if self.repeat_output: |
|
wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1)) |
|
else: |
|
wp = w.reshape(-1, self.num_outputs, self.output_dim) |
|
|
|
results = { |
|
'z': z, |
|
'label': label, |
|
'w': w, |
|
'wp': wp, |
|
} |
|
if self.label_dim > 0: |
|
results['embedding'] = embedding |
|
return results |
|
|
|
|
|
class MLPNetwork(nn.Module): |
|
"""Defines MLP Network in Pi-GAN.""" |
|
def __init__(self, |
|
w_dim, |
|
in_channels, |
|
num_layers, |
|
out_channels, |
|
grid_scale=0.24): |
|
super().__init__() |
|
|
|
self.in_channels = in_channels |
|
self.w_dim = w_dim |
|
self.out_channels = out_channels |
|
|
|
self.register_buffer('lod', torch.zeros(())) |
|
|
|
self.grid_warper = UniformBoxWarp(grid_scale) |
|
|
|
network = [] |
|
for i in range(num_layers): |
|
in_channels = in_channels if i == 0 else out_channels |
|
out_channels = out_channels |
|
film = FiLMLayer(in_channels, out_channels, w_dim) |
|
network.append(film) |
|
self.mlp_network = nn.Sequential(*network) |
|
|
|
self.mlp_length = len(self.mlp_network) |
|
|
|
def init_weights(self): |
|
for module in self.modules(): |
|
if isinstance(module, FiLMLayer): |
|
module.init_weights() |
|
|
|
self.mlp_network[0].init_weights(first=True) |
|
|
|
def forward(self, pts, wp, lod=None): |
|
num_dims = pts.ndim |
|
assert num_dims in [3, 4, 5] |
|
if num_dims == 5: |
|
N, H, W, K, C = pts.shape |
|
pts = pts.reshape(N, H * W * K, C) |
|
elif num_dims == 4: |
|
N, R, K, C = pts.shape |
|
pts = pts.reshape(N, R * K, C) |
|
|
|
x = self.grid_warper(pts) |
|
|
|
for idx, layer in enumerate(self.mlp_network): |
|
x = layer(x, wp[:, idx]) |
|
|
|
return x |
|
|
|
|
|
class FCHead(nn.Module): |
|
"""Defines fully-connected layer head in Pi-GAN to decode `feature` into |
|
`sigma` and `rgb`.""" |
|
|
|
def __init__(self, w_dim, channels, mlp_length): |
|
super().__init__() |
|
|
|
self.w_dim = w_dim |
|
self.channels = channels |
|
self.mlp_length = mlp_length |
|
|
|
self.sigma_head = nn.Linear(channels, 1) |
|
self.rgb_film = FiLMLayer(channels + 3, channels, w_dim) |
|
self.rgb_head = nn.Linear(channels, 3) |
|
|
|
def init_weights(self,): |
|
self.sigma_head.apply(freq_init(25)) |
|
self.rgb_head.apply(freq_init(25)) |
|
|
|
self.rgb_film.init_weights() |
|
|
|
def forward(self, point_features, wp, dirs): |
|
sigma = self.sigma_head(point_features) |
|
|
|
dirs = torch.cat([point_features, dirs], dim=-1) |
|
rgb = self.rgb_film(dirs, wp[:, self.mlp_length]) |
|
rgb = self.rgb_head(rgb).sigmoid() |
|
|
|
results = {'sigma': sigma, 'rgb': rgb} |
|
|
|
return results |
|
|
|
|
|
class FiLMLayer(nn.Module): |
|
def __init__(self, input_dim, output_dim, w_dim, **kwargs): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.output_dim = output_dim |
|
self.w_dim = w_dim |
|
|
|
self.layer = nn.Linear(input_dim, output_dim) |
|
self.style = nn.Linear(w_dim, output_dim*2) |
|
|
|
def init_weights(self, first=False): |
|
|
|
if not first: |
|
self.layer.apply(freq_init(25)) |
|
else: |
|
self.layer.apply(first_film_init) |
|
|
|
nn.init.kaiming_normal_(self.style.weight, |
|
a=0.2, |
|
mode='fan_in', |
|
nonlinearity='leaky_relu') |
|
with torch.no_grad(): self.style.weight *= 0.25 |
|
|
|
def extra_repr(self): |
|
return (f'in_ch={self.input_dim}, ' |
|
f'out_ch={self.output_dim}, ' |
|
f'w_ch={self.w_dim}') |
|
|
|
def forward(self, x, wp): |
|
x = self.layer(x) |
|
style = self.style(wp) |
|
style_split = style.unsqueeze(1).chunk(2, dim=2) |
|
freq = style_split[0] |
|
|
|
freq = freq*15 + 30 |
|
phase_shift = style_split[1] |
|
return torch.sin(freq * x + phase_shift) |
|
|
|
class PixelNormLayer(nn.Module): |
|
"""Implements pixel-wise feature vector normalization layer.""" |
|
|
|
def __init__(self, dim, eps): |
|
super().__init__() |
|
self.dim = dim |
|
self.eps = eps |
|
|
|
def extra_repr(self): |
|
return f'dim={self.dim}, epsilon={self.eps}' |
|
|
|
def forward(self, x): |
|
scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() |
|
return x * scale |
|
|
|
class UniformBoxWarp(nn.Module): |
|
def __init__(self, sidelength): |
|
super().__init__() |
|
self.scale_factor = 2 / sidelength |
|
|
|
def forward(self, coordinates): |
|
return coordinates * self.scale_factor |
|
|
|
def first_film_init(m): |
|
with torch.no_grad(): |
|
if isinstance(m, nn.Linear): |
|
num_input = m.weight.size(-1) |
|
m.weight.uniform_(-1/num_input, 1/num_input) |
|
|
|
def freq_init(freq): |
|
def init(m): |
|
with torch.no_grad(): |
|
if isinstance(m, nn.Linear): |
|
num_input = m.weight.size(-1) |
|
m.weight.uniform_(-np.sqrt(6/num_input)/freq, |
|
np.sqrt(6/num_input)/freq) |
|
return init |
|
|
|
class EqualLinear(nn.Module): |
|
def __init__( |
|
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, |
|
): |
|
super().__init__() |
|
|
|
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) |
|
|
|
if bias: |
|
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) |
|
|
|
else: |
|
self.bias = None |
|
|
|
self.scale = (1 / math.sqrt(in_dim)) * lr_mul |
|
self.lr_mul = lr_mul |
|
|
|
def forward(self, input): |
|
out = F.linear( |
|
input, self.weight * self.scale, bias=self.bias * self.lr_mul |
|
) |
|
return out |
|
|
|
def __repr__(self): |
|
return ( |
|
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' |
|
) |