# python3.7 """Contains the implementation of generator described in PiGAN.""" import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.cuda.amp import autocast from .utils.ops import all_gather from .rendering.renderer import Renderer from .rendering.feature_extractor import FeatureExtractor __all__ = ['PiGANGenerator'] class PiGANGenerator(nn.Module): """Defines the generator network in PiGAN.""" def __init__(self, # Settings for mapping network. z_dim=256, w_dim=256, repeat_w=False, normalize_z=False, mapping_layers=3, mapping_hidden_dim=256, # Settings for conditional generation. label_dim=0, embedding_dim=512, normalize_embedding=True, normalize_embedding_latent=False, label_concat=True, # Settings for synthesis network. resolution=-1, synthesis_input_dim=3, synthesis_output_dim=256, synthesis_layers=8, grid_scale=0.24, eps=1e-8, # Settings for rendering module. rendering_kwargs={}): """Initializes with basic settings.""" super().__init__() self.z_dim = z_dim self.w_dim = w_dim self.repeat_w = repeat_w self.normalize_z = normalize_z self.mapping_layers = mapping_layers self.latent_dim = (z_dim,) self.label_dim = label_dim self.embedding_dim = embedding_dim self.normalize_embedding = normalize_embedding self.normalize_embedding_latent = normalize_embedding_latent self.resolution = resolution self.num_layers = synthesis_layers self.eps = eps if self.repeat_w: self.mapping_space_dim = self.w_dim else: self.mapping_space_dim = self.w_dim * (self.num_layers + 1) # Mapping Network to tranform latent codes from Z-Space into W-Space. self.mapping = MappingNetwork( input_dim=z_dim, output_dim=w_dim, num_outputs=synthesis_layers + 1, repeat_output=repeat_w, normalize_input=normalize_z, num_layers=mapping_layers, hidden_dim=mapping_hidden_dim, label_dim=label_dim, embedding_dim=embedding_dim, normalize_embedding=normalize_embedding, normalize_embedding_latent=normalize_embedding_latent, eps=eps, label_concat=label_concat, lr=None) # Set up the overall renderer. self.renderer = Renderer() # Set up the reference representation generator. self.ref_representation_generator = None # Set up the feature extractor. self.feature_extractor = FeatureExtractor(ref_mode='none') # Set up the post module in the feature extractor. self.post_module = MLPNetwork(w_dim=w_dim, in_channels=synthesis_input_dim, num_layers=synthesis_layers, out_channels=synthesis_output_dim, grid_scale=grid_scale) # Set up the fully-connected layer head. self.fc_head = FCHead(w_dim=w_dim, channels=synthesis_output_dim, mlp_length=self.post_module.mlp_length) # Set up the post neural renderer. self.post_neural_renderer = None # This is used for truncation trick. if self.repeat_w: self.register_buffer('w_avg', torch.zeros(w_dim)) else: self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) # Set up some rendering related arguments. self.rendering_kwargs = rendering_kwargs # Initialize weights. self.init_weights() def init_weights(self): self.mapping.init_weights() self.post_module.init_weights() self.fc_head.init_weights() def forward(self, z, label=None, lod=None, w_moving_decay=None, sync_w_avg=False, style_mixing_prob=None, noise_std=None, trunc_psi=None, trunc_layers=None, enable_amp=False): if noise_std is not None: self.rendering_kwargs.update(noise_std=noise_std) lod = self.post_module.lod.cpu().tolist() if lod is None else lod mapping_results = self.mapping(z, label) w = mapping_results['w'] wp = mapping_results.pop('wp') if self.training and w_moving_decay is not None: if sync_w_avg: batch_w_avg = all_gather(w.detach()).mean(dim=0) else: batch_w_avg = w.detach().mean(dim=0) self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) # Truncation. if not self.training: trunc_psi = 1.0 if trunc_psi is None else trunc_psi trunc_layers = 0 if trunc_layers is None else trunc_layers if trunc_psi < 1.0 and trunc_layers > 0: w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] wp[:, :trunc_layers] = w_avg.lerp( wp[:, :trunc_layers], trunc_psi) with autocast(enabled=enable_amp): rendering_result = self.renderer( wp=wp, feature_extractor=self.feature_extractor, rendering_options=self.rendering_kwargs, position_encoder=None, ref_representation=None, post_module=self.post_module, post_module_kwargs=dict(lod=lod), fc_head=self.fc_head) image = rendering_result['composite_rgb'].reshape( z.shape[0], self.resolution, self.resolution, -1).permute(0, 3, 1, 2) camera = torch.cat([ rendering_result['camera_polar'], rendering_result['camera_azimuthal'] ], -1) return { **mapping_results, 'image': image, 'camera': camera, 'latent': z } class MappingNetwork(nn.Module): """Implements the latent space mapping module. Basically, this module executes several dense layers in sequence, and the label embedding if needed. """ def __init__(self, input_dim, output_dim, num_outputs, repeat_output, normalize_input, num_layers, hidden_dim, label_dim, embedding_dim, normalize_embedding, normalize_embedding_latent, eps, label_concat, lr=None): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.num_outputs = num_outputs self.repeat_output = repeat_output self.normalize_input = normalize_input self.num_layers = num_layers # self.out_channels = out_channels # TODO # self.lr_mul = lr_mul self.label_dim = label_dim self.embedding_dim = embedding_dim self.normalize_embedding = normalize_embedding self.normalize_embedding_latent = normalize_embedding_latent self.eps = eps self.label_concat = label_concat self.norm = PixelNormLayer(dim=1, eps=eps) if num_outputs is not None and not repeat_output: output_dim = output_dim * num_outputs if self.label_dim > 0: if self.label_concat: input_dim = input_dim + embedding_dim self.embedding = EqualLinear(label_dim, embedding_dim, bias=True, bias_init=0, lr_mul=1) else: self.embedding = EqualLinear(label_dim, output_dim, bias=True, bias_init=0, lr_mul=1) network = [] for i in range(num_layers): in_channels = (input_dim if i == 0 else hidden_dim) out_channels = (output_dim if i == (num_layers - 1) else hidden_dim) network.append(nn.Linear(in_channels, out_channels)) network.append(nn.LeakyReLU(0.2, inplace=True)) self.network = nn.Sequential(*network) def init_weights(self): for module in self.network.modules(): if isinstance(module, nn.Linear): nn.init.kaiming_normal_(module.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu') def forward(self, z, label=None): if z.ndim != 2 or z.shape[1] != self.input_dim: raise ValueError(f'Input latent code should be with shape ' f'[batch_size, input_dim], where ' f'`input_dim` equals to {self.input_dim}!\n' f'But `{z.shape}` is received!') if self.normalize_input: z = self.norm(z) if self.label_dim > 0: if label is None: raise ValueError(f'Model requires an additional label ' f'(with dimension {self.label_dim}) as input, ' f'but no label is received!') if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim): raise ValueError(f'Input label should be with shape ' f'[batch_size, label_dim], where ' f'`batch_size` equals to that of ' f'latent codes ({z.shape[0]}) and ' f'`label_dim` equals to {self.label_dim}!\n' f'But `{label.shape}` is received!') label = label.to(dtype=torch.float32) embedding = self.embedding(label) if self.normalize_embedding and self.label_concat: embedding = self.norm(embedding) if self.label_concat: w = torch.cat((z, embedding), dim=1) else: w = z else: w = z if (self.label_dim > 0 and self.normalize_embedding_latent and self.label_concat): w = self.norm(w) for layer in self.network: w = layer(w) if self.label_dim > 0 and (not self.label_concat): w = w * embedding wp = None if self.num_outputs is not None: if self.repeat_output: wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1)) else: wp = w.reshape(-1, self.num_outputs, self.output_dim) results = { 'z': z, 'label': label, 'w': w, 'wp': wp, } if self.label_dim > 0: results['embedding'] = embedding return results class MLPNetwork(nn.Module): """Defines MLP Network in Pi-GAN.""" def __init__(self, w_dim, in_channels, num_layers, out_channels, grid_scale=0.24): super().__init__() self.in_channels = in_channels self.w_dim = w_dim self.out_channels = out_channels self.register_buffer('lod', torch.zeros(())) self.grid_warper = UniformBoxWarp(grid_scale) network = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels out_channels = out_channels film = FiLMLayer(in_channels, out_channels, w_dim) network.append(film) self.mlp_network = nn.Sequential(*network) self.mlp_length = len(self.mlp_network) def init_weights(self): for module in self.modules(): if isinstance(module, FiLMLayer): module.init_weights() self.mlp_network[0].init_weights(first=True) def forward(self, pts, wp, lod=None): num_dims = pts.ndim assert num_dims in [3, 4, 5] if num_dims == 5: N, H, W, K, C = pts.shape pts = pts.reshape(N, H * W * K, C) elif num_dims == 4: N, R, K, C = pts.shape pts = pts.reshape(N, R * K, C) x = self.grid_warper(pts) for idx, layer in enumerate(self.mlp_network): x = layer(x, wp[:, idx]) return x class FCHead(nn.Module): """Defines fully-connected layer head in Pi-GAN to decode `feature` into `sigma` and `rgb`.""" def __init__(self, w_dim, channels, mlp_length): super().__init__() self.w_dim = w_dim self.channels = channels self.mlp_length = mlp_length self.sigma_head = nn.Linear(channels, 1) self.rgb_film = FiLMLayer(channels + 3, channels, w_dim) self.rgb_head = nn.Linear(channels, 3) def init_weights(self,): self.sigma_head.apply(freq_init(25)) self.rgb_head.apply(freq_init(25)) self.rgb_film.init_weights() def forward(self, point_features, wp, dirs): sigma = self.sigma_head(point_features) dirs = torch.cat([point_features, dirs], dim=-1) rgb = self.rgb_film(dirs, wp[:, self.mlp_length]) rgb = self.rgb_head(rgb).sigmoid() results = {'sigma': sigma, 'rgb': rgb} return results class FiLMLayer(nn.Module): def __init__(self, input_dim, output_dim, w_dim, **kwargs): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.w_dim = w_dim self.layer = nn.Linear(input_dim, output_dim) self.style = nn.Linear(w_dim, output_dim*2) def init_weights(self, first=False): # initial with 25 frequency if not first: self.layer.apply(freq_init(25)) else: self.layer.apply(first_film_init) # kaiming initial && scale 1/4 nn.init.kaiming_normal_(self.style.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu') with torch.no_grad(): self.style.weight *= 0.25 def extra_repr(self): return (f'in_ch={self.input_dim}, ' f'out_ch={self.output_dim}, ' f'w_ch={self.w_dim}') def forward(self, x, wp): x = self.layer(x) style = self.style(wp) style_split = style.unsqueeze(1).chunk(2, dim=2) freq = style_split[0] # Scale for sin activation freq = freq*15 + 30 phase_shift = style_split[1] return torch.sin(freq * x + phase_shift) class PixelNormLayer(nn.Module): """Implements pixel-wise feature vector normalization layer.""" def __init__(self, dim, eps): super().__init__() self.dim = dim self.eps = eps def extra_repr(self): return f'dim={self.dim}, epsilon={self.eps}' def forward(self, x): scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() return x * scale class UniformBoxWarp(nn.Module): def __init__(self, sidelength): super().__init__() self.scale_factor = 2 / sidelength def forward(self, coordinates): return coordinates * self.scale_factor def first_film_init(m): with torch.no_grad(): if isinstance(m, nn.Linear): num_input = m.weight.size(-1) m.weight.uniform_(-1/num_input, 1/num_input) def freq_init(freq): def init(m): with torch.no_grad(): if isinstance(m, nn.Linear): num_input = m.weight.size(-1) m.weight.uniform_(-np.sqrt(6/num_input)/freq, np.sqrt(6/num_input)/freq) return init class EqualLinear(nn.Module): def __init__( self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, ): super().__init__() self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) if bias: self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) else: self.bias = None self.scale = (1 / math.sqrt(in_dim)) * lr_mul self.lr_mul = lr_mul def forward(self, input): out = F.linear( input, self.weight * self.scale, bias=self.bias * self.lr_mul ) return out def __repr__(self): return ( f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' )