Spaces:
Runtime error
Runtime error
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
import numpy as np | |
import torch | |
from torch_utils import misc | |
from torch_utils import persistence | |
from training.models import * | |
#---------------------------------------------------------------------------- | |
class MappingNetwork(torch.nn.Module): | |
def __init__(self, | |
z_dim, # Input latent (Z) dimensionality, 0 = no latent. | |
c_dim, # Conditioning label (C) dimensionality, 0 = no label. | |
w_dim, # Intermediate latent (W) dimensionality. | |
num_ws, # Number of intermediate latents to output, None = do not broadcast. | |
num_layers = 8, # Number of mapping layers. | |
embed_features = None, # Label embedding dimensionality, None = same as w_dim. | |
layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. | |
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. | |
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. | |
w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track. | |
): | |
super().__init__() | |
self.z_dim = z_dim | |
self.c_dim = c_dim | |
self.w_dim = w_dim | |
self.num_ws = num_ws | |
self.num_layers = num_layers | |
self.w_avg_beta = w_avg_beta | |
if embed_features is None: | |
embed_features = w_dim | |
if c_dim == 0: | |
embed_features = 0 | |
if layer_features is None: | |
layer_features = w_dim | |
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] | |
if c_dim > 0: | |
self.embed = FullyConnectedLayer(c_dim, embed_features) | |
for idx in range(num_layers): | |
in_features = features_list[idx] | |
out_features = features_list[idx + 1] | |
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) | |
setattr(self, f'fc{idx}', layer) | |
if num_ws is not None and w_avg_beta is not None: | |
self.register_buffer('w_avg', torch.zeros([w_dim])) | |
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): | |
# Embed, normalize, and concat inputs. | |
x = None | |
with torch.autograd.profiler.record_function('input'): | |
if self.z_dim > 0: | |
misc.assert_shape(z, [None, self.z_dim]) | |
x = normalize_2nd_moment(z.to(torch.float32)) | |
if self.c_dim > 0: | |
misc.assert_shape(c, [None, self.c_dim]) | |
y = normalize_2nd_moment(self.embed(c.to(torch.float32))) | |
x = torch.cat([x, y], dim=1) if x is not None else y | |
# Main layers. | |
for idx in range(self.num_layers): | |
layer = getattr(self, f'fc{idx}') | |
x = layer(x) | |
# Update moving average of W. | |
if self.w_avg_beta is not None and self.training and not skip_w_avg_update: | |
with torch.autograd.profiler.record_function('update_w_avg'): | |
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) | |
# Broadcast. | |
if self.num_ws is not None: | |
with torch.autograd.profiler.record_function('broadcast'): | |
x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) | |
# Apply truncation. | |
if truncation_psi != 1: | |
with torch.autograd.profiler.record_function('truncate'): | |
assert self.w_avg_beta is not None | |
if self.num_ws is None or truncation_cutoff is None: | |
x = self.w_avg.lerp(x, truncation_psi) | |
else: | |
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) | |
return x | |
#---------------------------------------------------------------------------- | |
class EncoderNetwork(torch.nn.Module): | |
def __init__(self, | |
c_dim, # Conditioning label (C) dimensionality. | |
z_dim, # Input latent (Z) dimensionality. | |
img_resolution, # Input resolution. | |
img_channels, # Number of input color channels. | |
architecture = 'orig', # Architecture: 'orig', 'skip', 'resnet'. | |
channel_base = 16384, # Overall multiplier for the number of channels. | |
channel_max = 512, # Maximum number of channels in any layer. | |
num_fp16_res = 0, # Use FP16 for the N highest resolutions. | |
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. | |
block_kwargs = {}, # Arguments for DiscriminatorBlock. | |
mapping_kwargs = {}, # Arguments for MappingNetwork. | |
epilogue_kwargs = {}, # Arguments for EncoderEpilogue. | |
): | |
super().__init__() | |
self.c_dim = c_dim | |
self.z_dim = z_dim | |
self.img_resolution = img_resolution | |
self.img_resolution_log2 = int(np.log2(img_resolution)) | |
self.img_channels = img_channels | |
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] | |
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} | |
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | |
if cmap_dim is None: | |
cmap_dim = channels_dict[4] | |
if c_dim == 0: | |
cmap_dim = 0 | |
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) | |
cur_layer_idx = 0 | |
for res in self.block_resolutions: | |
in_channels = channels_dict[res] if res < img_resolution else 0 | |
tmp_channels = channels_dict[res] | |
out_channels = channels_dict[res // 2] | |
use_fp16 = (res >= fp16_resolution) | |
block = EncoderBlock(in_channels, tmp_channels, out_channels, resolution=res, | |
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) | |
setattr(self, f'b{res}', block) | |
cur_layer_idx += block.num_layers | |
if c_dim > 0: | |
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) | |
self.b4 = EncoderEpilogue(channels_dict[4], cmap_dim=cmap_dim, z_dim=z_dim * 2, resolution=4, **epilogue_kwargs, **common_kwargs) | |
def forward(self, img, c, **block_kwargs): | |
x = None | |
feats = {} | |
for res in self.block_resolutions: | |
block = getattr(self, f'b{res}') | |
x, img, feat = block(x, img, **block_kwargs) | |
feats[res] = feat | |
cmap = None | |
if self.c_dim > 0: | |
cmap = self.mapping(None, c) | |
x, const_e = self.b4(x, cmap) | |
feats[4] = const_e | |
B, _ = x.shape | |
z = torch.randn((B, self.z_dim), requires_grad=False, dtype=x.dtype, device=x.device) ## Noise for Co-Modulation | |
return x, z, feats ## 1/2, 1/4, 1/8, 1/16, 1/32, 1/64 | |
#---------------------------------------------------------------------------- | |
class SynthesisNetwork(torch.nn.Module): | |
def __init__(self, | |
w_dim, # Intermediate latent (W) dimensionality. | |
z_dim, # Output Latent (Z) dimensionality. | |
img_resolution, # Output image resolution. | |
img_channels, # Number of color channels. | |
channel_base = 16384, # Overall multiplier for the number of channels. | |
channel_max = 512, # Maximum number of channels in any layer. | |
num_fp16_res = 0, # Use FP16 for the N highest resolutions. | |
**block_kwargs, # Arguments for SynthesisBlock. | |
): | |
assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 | |
super().__init__() | |
self.w_dim = w_dim | |
self.img_resolution = img_resolution | |
self.img_resolution_log2 = int( np.log2(img_resolution)) | |
self.img_channels = img_channels | |
self.block_resolutions = [2 ** i for i in range(3, self.img_resolution_log2 + 1)] | |
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} | |
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | |
self.foreword = SynthesisForeword(img_channels=img_channels, in_channels=min(channel_base // 4, channel_max), z_dim=z_dim*2, resolution=4) | |
self.num_ws = self.img_resolution_log2 * 2 - 2 | |
for res in self.block_resolutions: | |
if res // 2 in channels_dict.keys(): | |
in_channels = channels_dict[res // 2] if res > 4 else 0 | |
else: | |
in_channels = min(channel_base // (res // 2) , channel_max) | |
out_channels = channels_dict[res] | |
use_fp16 = (res >= fp16_resolution) | |
is_last = (res == self.img_resolution) | |
block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, | |
img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) | |
setattr(self, f'b{res}', block) | |
def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs): | |
img = None | |
x, img = self.foreword(x_global, ws, feats, img) | |
for res in self.block_resolutions: | |
block = getattr(self, f'b{res}') | |
mod_vector0 = [] | |
mod_vector0.append(ws[:, int(np.log2(res))*2-5]) | |
mod_vector0.append(x_global.clone()) | |
mod_vector0 = torch.cat(mod_vector0, dim = 1) | |
mod_vector1 = [] | |
mod_vector1.append(ws[:, int(np.log2(res))*2-4]) | |
mod_vector1.append(x_global.clone()) | |
mod_vector1 = torch.cat(mod_vector1, dim = 1) | |
mod_vector_rgb = [] | |
mod_vector_rgb.append(ws[:, int(np.log2(res))*2-3]) | |
mod_vector_rgb.append(x_global.clone()) | |
mod_vector_rgb = torch.cat(mod_vector_rgb, dim = 1) | |
# ic(x.shape) | |
x, img = block(x, mask, feats, img, (mod_vector0, mod_vector1, mod_vector_rgb), fname=fname, **block_kwargs) | |
# ic(x.shape) | |
# ic('--------') | |
return img | |
#---------------------------------------------------------------------------- | |
class Generator(torch.nn.Module): | |
def __init__(self, | |
z_dim, # Input latent (Z) dimensionality. | |
c_dim, # Conditioning label (C) dimensionality. | |
w_dim, # Intermediate latent (W) dimensionality. | |
img_resolution, # Output resolution. | |
img_channels, # Number of output color channels. | |
encoder_kwargs = {}, # Arguments for EncoderNetwork. | |
mapping_kwargs = {}, # Arguments for MappingNetwork. | |
synthesis_kwargs = {}, # Arguments for SynthesisNetwork. | |
): | |
super().__init__() | |
self.z_dim = z_dim | |
self.c_dim = c_dim | |
self.w_dim = w_dim | |
self.img_resolution = img_resolution | |
self.img_channels = img_channels | |
self.encoder = EncoderNetwork(c_dim=c_dim, z_dim=z_dim, img_resolution=img_resolution, img_channels=img_channels, **encoder_kwargs) | |
self.synthesis = SynthesisNetwork(z_dim=z_dim, w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs) | |
self.num_ws = self.synthesis.num_ws | |
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) | |
def forward(self, img, c, fname=None, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs): | |
mask = img[:, -1].unsqueeze(1) | |
x_global, z, feats = self.encoder(img, c) | |
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) | |
img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs) | |
# exit() | |
return img | |
#---------------------------------------------------------------------------- | |
class Discriminator(torch.nn.Module): | |
def __init__(self, | |
c_dim, # Conditioning label (C) dimensionality. | |
img_resolution, # Input resolution. | |
img_channels, # Number of input color channels. | |
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. | |
channel_base = 16384, # Overall multiplier for the number of channels. | |
channel_max = 512, # Maximum number of channels in any layer. | |
num_fp16_res = 0, # Use FP16 for the N highest resolutions. | |
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. | |
block_kwargs = {}, # Arguments for DiscriminatorBlock. | |
mapping_kwargs = {}, # Arguments for MappingNetwork. | |
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. | |
): | |
super().__init__() | |
self.c_dim = c_dim | |
self.img_resolution = img_resolution | |
self.img_resolution_log2 = int(np.log2(img_resolution)) | |
self.img_channels = img_channels | |
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] | |
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} | |
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | |
if cmap_dim is None: | |
cmap_dim = channels_dict[4] | |
if c_dim == 0: | |
cmap_dim = 0 | |
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) | |
cur_layer_idx = 0 | |
for res in self.block_resolutions: | |
in_channels = channels_dict[res] if res < img_resolution else 0 | |
tmp_channels = channels_dict[res] | |
out_channels = channels_dict[res // 2] | |
use_fp16 = (res >= fp16_resolution) | |
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, | |
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) | |
setattr(self, f'b{res}', block) | |
cur_layer_idx += block.num_layers | |
if c_dim > 0: | |
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) | |
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) | |
def forward(self, img, c, **block_kwargs): | |
x = None | |
for res in self.block_resolutions: | |
block = getattr(self, f'b{res}') | |
x, img = block(x, img, **block_kwargs) | |
cmap = None | |
if self.c_dim > 0: | |
cmap = self.mapping(None, c) | |
x = self.b4(x, img, cmap) | |
return x | |
#---------------------------------------------------------------------------- | |