# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import numpy as np import torch from torch_utils import misc from torch_utils import persistence from training.models import * #---------------------------------------------------------------------------- @persistence.persistent_class class MappingNetwork(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality, 0 = no latent. c_dim, # Conditioning label (C) dimensionality, 0 = no label. w_dim, # Intermediate latent (W) dimensionality. num_ws, # Number of intermediate latents to output, None = do not broadcast. num_layers = 8, # Number of mapping layers. embed_features = None, # Label embedding dimensionality, None = same as w_dim. layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.num_ws = num_ws self.num_layers = num_layers self.w_avg_beta = w_avg_beta if embed_features is None: embed_features = w_dim if c_dim == 0: embed_features = 0 if layer_features is None: layer_features = w_dim features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] if c_dim > 0: self.embed = FullyConnectedLayer(c_dim, embed_features) for idx in range(num_layers): in_features = features_list[idx] out_features = features_list[idx + 1] layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) setattr(self, f'fc{idx}', layer) if num_ws is not None and w_avg_beta is not None: self.register_buffer('w_avg', torch.zeros([w_dim])) def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): # Embed, normalize, and concat inputs. x = None with torch.autograd.profiler.record_function('input'): if self.z_dim > 0: misc.assert_shape(z, [None, self.z_dim]) x = normalize_2nd_moment(z.to(torch.float32)) if self.c_dim > 0: misc.assert_shape(c, [None, self.c_dim]) y = normalize_2nd_moment(self.embed(c.to(torch.float32))) x = torch.cat([x, y], dim=1) if x is not None else y # Main layers. for idx in range(self.num_layers): layer = getattr(self, f'fc{idx}') x = layer(x) # Update moving average of W. if self.w_avg_beta is not None and self.training and not skip_w_avg_update: with torch.autograd.profiler.record_function('update_w_avg'): self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) # Broadcast. if self.num_ws is not None: with torch.autograd.profiler.record_function('broadcast'): x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) # Apply truncation. if truncation_psi != 1: with torch.autograd.profiler.record_function('truncate'): assert self.w_avg_beta is not None if self.num_ws is None or truncation_cutoff is None: x = self.w_avg.lerp(x, truncation_psi) else: x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) return x #---------------------------------------------------------------------------- @persistence.persistent_class class EncoderNetwork(torch.nn.Module): def __init__(self, c_dim, # Conditioning label (C) dimensionality. z_dim, # Input latent (Z) dimensionality. img_resolution, # Input resolution. img_channels, # Number of input color channels. architecture = 'orig', # Architecture: 'orig', 'skip', 'resnet'. channel_base = 16384, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 0, # Use FP16 for the N highest resolutions. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. block_kwargs = {}, # Arguments for DiscriminatorBlock. mapping_kwargs = {}, # Arguments for MappingNetwork. epilogue_kwargs = {}, # Arguments for EncoderEpilogue. ): super().__init__() self.c_dim = c_dim self.z_dim = z_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) if cmap_dim is None: cmap_dim = channels_dict[4] if c_dim == 0: cmap_dim = 0 common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) cur_layer_idx = 0 for res in self.block_resolutions: in_channels = channels_dict[res] if res < img_resolution else 0 tmp_channels = channels_dict[res] out_channels = channels_dict[res // 2] use_fp16 = (res >= fp16_resolution) block = EncoderBlock(in_channels, tmp_channels, out_channels, resolution=res, first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) setattr(self, f'b{res}', block) cur_layer_idx += block.num_layers if c_dim > 0: self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) self.b4 = EncoderEpilogue(channels_dict[4], cmap_dim=cmap_dim, z_dim=z_dim * 2, resolution=4, **epilogue_kwargs, **common_kwargs) def forward(self, img, c, **block_kwargs): x = None feats = {} for res in self.block_resolutions: block = getattr(self, f'b{res}') x, img, feat = block(x, img, **block_kwargs) feats[res] = feat cmap = None if self.c_dim > 0: cmap = self.mapping(None, c) x, const_e = self.b4(x, cmap) feats[4] = const_e B, _ = x.shape z = torch.randn((B, self.z_dim), requires_grad=False, dtype=x.dtype, device=x.device) ## Noise for Co-Modulation return x, z, feats ## 1/2, 1/4, 1/8, 1/16, 1/32, 1/64 #---------------------------------------------------------------------------- @persistence.persistent_class class SynthesisNetwork(torch.nn.Module): def __init__(self, w_dim, # Intermediate latent (W) dimensionality. z_dim, # Output Latent (Z) dimensionality. img_resolution, # Output image resolution. img_channels, # Number of color channels. channel_base = 16384, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 0, # Use FP16 for the N highest resolutions. **block_kwargs, # Arguments for SynthesisBlock. ): assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 super().__init__() self.w_dim = w_dim self.img_resolution = img_resolution self.img_resolution_log2 = int( np.log2(img_resolution)) self.img_channels = img_channels self.block_resolutions = [2 ** i for i in range(3, self.img_resolution_log2 + 1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) self.foreword = SynthesisForeword(img_channels=img_channels, in_channels=min(channel_base // 4, channel_max), z_dim=z_dim*2, resolution=4) self.num_ws = self.img_resolution_log2 * 2 - 2 for res in self.block_resolutions: if res // 2 in channels_dict.keys(): in_channels = channels_dict[res // 2] if res > 4 else 0 else: in_channels = min(channel_base // (res // 2) , channel_max) out_channels = channels_dict[res] use_fp16 = (res >= fp16_resolution) is_last = (res == self.img_resolution) block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) setattr(self, f'b{res}', block) def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs): img = None x, img = self.foreword(x_global, ws, feats, img) for res in self.block_resolutions: block = getattr(self, f'b{res}') mod_vector0 = [] mod_vector0.append(ws[:, int(np.log2(res))*2-5]) mod_vector0.append(x_global.clone()) mod_vector0 = torch.cat(mod_vector0, dim = 1) mod_vector1 = [] mod_vector1.append(ws[:, int(np.log2(res))*2-4]) mod_vector1.append(x_global.clone()) mod_vector1 = torch.cat(mod_vector1, dim = 1) mod_vector_rgb = [] mod_vector_rgb.append(ws[:, int(np.log2(res))*2-3]) mod_vector_rgb.append(x_global.clone()) mod_vector_rgb = torch.cat(mod_vector_rgb, dim = 1) # ic(x.shape) x, img = block(x, mask, feats, img, (mod_vector0, mod_vector1, mod_vector_rgb), fname=fname, **block_kwargs) # ic(x.shape) # ic('--------') return img #---------------------------------------------------------------------------- @persistence.persistent_class class Generator(torch.nn.Module): def __init__(self, z_dim, # Input latent (Z) dimensionality. c_dim, # Conditioning label (C) dimensionality. w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output resolution. img_channels, # Number of output color channels. encoder_kwargs = {}, # Arguments for EncoderNetwork. mapping_kwargs = {}, # Arguments for MappingNetwork. synthesis_kwargs = {}, # Arguments for SynthesisNetwork. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.img_resolution = img_resolution self.img_channels = img_channels self.encoder = EncoderNetwork(c_dim=c_dim, z_dim=z_dim, img_resolution=img_resolution, img_channels=img_channels, **encoder_kwargs) self.synthesis = SynthesisNetwork(z_dim=z_dim, w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) self.num_ws = self.synthesis.num_ws self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) def forward(self, img, c, fname=None, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs): mask = img[:, -1].unsqueeze(1) x_global, z, feats = self.encoder(img, c) ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs) # exit() return img #---------------------------------------------------------------------------- @persistence.persistent_class class Discriminator(torch.nn.Module): def __init__(self, c_dim, # Conditioning label (C) dimensionality. img_resolution, # Input resolution. img_channels, # Number of input color channels. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. channel_base = 16384, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 0, # Use FP16 for the N highest resolutions. conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. block_kwargs = {}, # Arguments for DiscriminatorBlock. mapping_kwargs = {}, # Arguments for MappingNetwork. epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. ): super().__init__() self.c_dim = c_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) if cmap_dim is None: cmap_dim = channels_dict[4] if c_dim == 0: cmap_dim = 0 common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) cur_layer_idx = 0 for res in self.block_resolutions: in_channels = channels_dict[res] if res < img_resolution else 0 tmp_channels = channels_dict[res] out_channels = channels_dict[res // 2] use_fp16 = (res >= fp16_resolution) block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) setattr(self, f'b{res}', block) cur_layer_idx += block.num_layers if c_dim > 0: self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) def forward(self, img, c, **block_kwargs): x = None for res in self.block_resolutions: block = getattr(self, f'b{res}') x, img = block(x, img, **block_kwargs) cmap = None if self.c_dim > 0: cmap = self.mapping(None, c) x = self.b4(x, img, cmap) return x #----------------------------------------------------------------------------