import torch import numpy as np from torch_utils.ops import bias_act from torch_utils import misc def normalize_2nd_moment(x, dim=1, eps=1e-8): return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() class FullyConnectedLayer_normal(torch.nn.Module): def __init__(self, in_features, # Number of input features. out_features, # Number of output features. bias = True, # Apply additive bias before the activation function? bias_init = 0, # Initial value for the additive bias. ): super().__init__() self.fc = torch.nn.Linear(in_features, out_features, bias=bias) if bias: with torch.no_grad(): self.fc.bias.fill_(bias_init) def forward(self, x): output = self.fc(x) return output class MappingNetwork_normal(torch.nn.Module): def __init__(self, in_features, # Number of input features. int_dim, num_layers = 8, # Number of mapping layers. mapping_normalization = False #2nd normalization ): super().__init__() layers = [torch.nn.Linear(in_features, int_dim), torch.nn.LeakyReLU(0.2)] for i in range(1, num_layers): layers.append(torch.nn.Linear(int_dim, int_dim)) layers.append(torch.nn.LeakyReLU(0.2)) self.net = torch.nn.Sequential(*layers) self.normalization = mapping_normalization def forward(self, x): if self.normalization: x = normalize_2nd_moment(x) output = self.net(x) return output class DecodingNetwork(torch.nn.Module): def __init__(self, in_features, # Number of input features. out_dim, num_layers = 8, # Number of mapping layers. ): super().__init__() layers = [] for i in range(num_layers-1): layers.append(torch.nn.Linear(in_features, in_features)) layers.append(torch.nn.ReLU()) layers.append(torch.nn.Linear(in_features, out_dim)) self.net = torch.nn.Sequential(*layers) def forward(self, x): x = torch.nn.functional.normalize(x, dim=1) output = self.net(x) return output class FullyConnectedLayer(torch.nn.Module): def __init__(self, in_features, # Number of input features. out_features, # Number of output features. bias = True, # Apply additive bias before the activation function? activation = 'linear', # Activation function: 'relu', 'lrelu', etc. lr_multiplier = 1, # Learning rate multiplier. bias_init = 0, # Initial value for the additive bias. ): super().__init__() self.activation = activation self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None self.weight_gain = lr_multiplier / np.sqrt(in_features) self.bias_gain = lr_multiplier def forward(self, x): w = self.weight.to(x.dtype) * self.weight_gain b = self.bias if b is not None: b = b.to(x.dtype) if self.bias_gain != 1: b = b * self.bias_gain if self.activation == 'linear' and b is not None: x = torch.addmm(b.unsqueeze(0), x, w.t()) else: x = x.matmul(w.t()) x = bias_act.bias_act(x, b, act=self.activation) return x class MappingNetwork(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality, 0 = no latent. c_dim, # Conditioning label (C) dimensionality, 0 = no label. w_dim, # Intermediate latent (W) dimensionality. num_ws, # Number of intermediate latents to output, None = do not broadcast. num_layers = 8, # Number of mapping layers. embed_features = None, # Label embedding dimensionality, None = same as w_dim. layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track. normalization = None # Normalization input using normalize_2nd_moment ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.num_ws = num_ws self.num_layers = num_layers self.w_avg_beta = w_avg_beta self.normalization = normalization if embed_features is None: embed_features = w_dim if c_dim == 0: embed_features = 0 if layer_features is None: layer_features = w_dim features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] if c_dim > 0: self.embed = FullyConnectedLayer(c_dim, embed_features) for idx in range(num_layers): in_features = features_list[idx] out_features = features_list[idx + 1] layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) setattr(self, f'fc{idx}', layer) if num_ws is not None and w_avg_beta is not None: self.register_buffer('w_avg', torch.zeros([w_dim])) def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): # Embed, normalize, and concat inputs. x = None with torch.autograd.profiler.record_function('input'): if self.z_dim > 0: misc.assert_shape(z, [None, self.z_dim]) if self.normalization: x = normalize_2nd_moment(z.to(torch.float32)) else: x = z x = z.to(torch.float32) if self.c_dim > 0: raise ValueError("This implementation does not need class index") misc.assert_shape(c, [None, self.c_dim]) y = normalize_2nd_moment(self.embed(c.to(torch.float32))) y = self.embed(c.to(torch.float32)) x = torch.cat([x, y], dim=1) if x is not None else y # Main layers. for idx in range(self.num_layers): layer = getattr(self, f'fc{idx}') x = layer(x) # Update moving average of W. if self.w_avg_beta is not None and self.training and not skip_w_avg_update: with torch.autograd.profiler.record_function('update_w_avg'): self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) # Broadcast. if self.num_ws is not None: with torch.autograd.profiler.record_function('broadcast'): x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) # Apply truncation. if truncation_psi != 1: with torch.autograd.profiler.record_function('truncate'): assert self.w_avg_beta is not None if self.num_ws is None or truncation_cutoff is None: x = self.w_avg.lerp(x, truncation_psi) else: x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) return x