Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import functools | |
import math | |
import types | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import Upsample as NearestUpsample | |
from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock | |
from imaginaire.utils.data import (get_crop_h_w, | |
get_paired_input_image_channel_number, | |
get_paired_input_label_channel_number) | |
from imaginaire.utils.distributed import master_only_print as print | |
class Generator(nn.Module): | |
r"""SPADE generator constructor. | |
Args: | |
gen_cfg (obj): Generator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file. | |
""" | |
def __init__(self, gen_cfg, data_cfg): | |
super(Generator, self).__init__() | |
print('SPADE generator initialization.') | |
# We assume the first datum is the ground truth image. | |
image_channels = getattr(gen_cfg, 'image_channels', None) | |
if image_channels is None: | |
image_channels = get_paired_input_image_channel_number(data_cfg) | |
num_labels = getattr(gen_cfg, 'num_labels', None) | |
if num_labels is None: | |
# Calculate number of channels in the input label when not specified. | |
num_labels = get_paired_input_label_channel_number(data_cfg) | |
crop_h, crop_w = get_crop_h_w(data_cfg.train.augmentations) | |
# Build the generator | |
out_image_small_side_size = crop_w if crop_w < crop_h else crop_h | |
num_filters = getattr(gen_cfg, 'num_filters', 128) | |
kernel_size = getattr(gen_cfg, 'kernel_size', 3) | |
weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral') | |
cond_dims = 0 | |
# Check whether we use the style code. | |
style_dims = getattr(gen_cfg, 'style_dims', None) | |
self.style_dims = style_dims | |
if style_dims is not None: | |
print('\tStyle code dimensions: %d' % style_dims) | |
cond_dims += style_dims | |
self.use_style = True | |
else: | |
self.use_style = False | |
# Check whether we use the attribute code. | |
if hasattr(gen_cfg, 'attribute_dims'): | |
self.use_attribute = True | |
self.attribute_dims = gen_cfg.attribute_dims | |
cond_dims += gen_cfg.attribute_dims | |
else: | |
self.use_attribute = False | |
if not self.use_style and not self.use_attribute: | |
self.use_style_encoder = False | |
else: | |
self.use_style_encoder = True | |
print('\tBase filter number: %d' % num_filters) | |
print('\tConvolution kernel size: %d' % kernel_size) | |
print('\tWeight norm type: %s' % weight_norm_type) | |
skip_activation_norm = \ | |
getattr(gen_cfg, 'skip_activation_norm', True) | |
activation_norm_params = getattr(gen_cfg, 'activation_norm_params', None) | |
if activation_norm_params is None: | |
activation_norm_params = types.SimpleNamespace() | |
if not hasattr(activation_norm_params, 'num_filters'): | |
setattr(activation_norm_params, 'num_filters', 128) | |
if not hasattr(activation_norm_params, 'kernel_size'): | |
setattr(activation_norm_params, 'kernel_size', 3) | |
if not hasattr(activation_norm_params, 'activation_norm_type'): | |
setattr(activation_norm_params, 'activation_norm_type', 'sync_batch') | |
if not hasattr(activation_norm_params, 'separate_projection'): | |
setattr(activation_norm_params, 'separate_projection', False) | |
if not hasattr(activation_norm_params, 'activation_norm_params'): | |
activation_norm_params.activation_norm_params = types.SimpleNamespace() | |
activation_norm_params.activation_norm_params.affine = True | |
setattr(activation_norm_params, 'cond_dims', num_labels) | |
if not hasattr(activation_norm_params, 'weight_norm_type'): | |
setattr(activation_norm_params, 'weight_norm_type', weight_norm_type) | |
global_adaptive_norm_type = getattr(gen_cfg, 'global_adaptive_norm_type', 'sync_batch') | |
use_posenc_in_input_layer = getattr(gen_cfg, 'use_posenc_in_input_layer', True) | |
output_multiplier = getattr(gen_cfg, 'output_multiplier', 1.0) | |
print(activation_norm_params) | |
self.spade_generator = SPADEGenerator(num_labels, | |
out_image_small_side_size, | |
image_channels, | |
num_filters, | |
kernel_size, | |
cond_dims, | |
activation_norm_params, | |
weight_norm_type, | |
global_adaptive_norm_type, | |
skip_activation_norm, | |
use_posenc_in_input_layer, | |
self.use_style_encoder, | |
output_multiplier) | |
if self.use_style: | |
# Build the encoder. | |
style_enc_cfg = getattr(gen_cfg, 'style_enc', None) | |
if style_enc_cfg is None: | |
style_enc_cfg = types.SimpleNamespace() | |
if not hasattr(style_enc_cfg, 'num_filters'): | |
setattr(style_enc_cfg, 'num_filters', 128) | |
if not hasattr(style_enc_cfg, 'kernel_size'): | |
setattr(style_enc_cfg, 'kernel_size', 3) | |
if not hasattr(style_enc_cfg, 'weight_norm_type'): | |
setattr(style_enc_cfg, 'weight_norm_type', weight_norm_type) | |
setattr(style_enc_cfg, 'input_image_channels', image_channels) | |
setattr(style_enc_cfg, 'style_dims', style_dims) | |
self.style_encoder = StyleEncoder(style_enc_cfg) | |
self.z = None | |
print('Done with the SPADE generator initialization.') | |
def forward(self, data, random_style=False): | |
r"""SPADE Generator forward. | |
Args: | |
data (dict): | |
- images (N x C1 x H x W tensor) : Ground truth images | |
- label (N x C2 x H x W tensor) : Semantic representations | |
- z (N x style_dims tensor): Gaussian random noise | |
- random_style (bool): Whether to sample a random style vector. | |
Returns: | |
(dict): | |
- fake_images (N x 3 x H x W tensor): fake images | |
- mu (N x C1 tensor): mean vectors | |
- logvar (N x C1 tensor): log-variance vectors | |
""" | |
if self.use_style_encoder: | |
if random_style: | |
bs = data['label'].size(0) | |
z = torch.randn( | |
bs, self.style_dims, dtype=torch.float32).cuda() | |
if (data['label'].dtype == | |
data['label'].dtype == torch.float16): | |
z = z.half() | |
mu = None | |
logvar = None | |
else: | |
mu, logvar, z = self.style_encoder(data['images']) | |
if self.use_attribute: | |
data['z'] = torch.cat((z, data['attributes'].squeeze(1)), dim=1) | |
else: | |
data['z'] = z | |
output = self.spade_generator(data) | |
if self.use_style_encoder: | |
output['mu'] = mu | |
output['logvar'] = logvar | |
return output | |
def inference(self, | |
data, | |
random_style=False, | |
use_fixed_random_style=False, | |
keep_original_size=False): | |
r"""Compute results images for a batch of input data and save the | |
results in the specified folder. | |
Args: | |
data (dict): | |
- images (N x C1 x H x W tensor) : Ground truth images | |
- label (N x C2 x H x W tensor) : Semantic representations | |
- z (N x style_dims tensor): Gaussian random noise | |
random_style (bool): Whether to sample a random style vector. | |
use_fixed_random_style (bool): Sample random style once and use it | |
for all the remaining inference. | |
keep_original_size (bool): Keep original size of the input. | |
Returns: | |
(dict): | |
- fake_images (N x 3 x H x W tensor): fake images | |
- mu (N x C1 tensor): mean vectors | |
- logvar (N x C1 tensor): log-variance vectors | |
""" | |
self.eval() | |
self.spade_generator.eval() | |
if self.use_style_encoder: | |
if random_style and self.use_style_encoder: | |
if self.z is None or not use_fixed_random_style: | |
bs = data['label'].size(0) | |
z = torch.randn( | |
bs, self.style_dims, dtype=torch.float32).to('cuda') | |
if (data['label'].dtype == | |
data['label'].dtype == | |
torch.float16): | |
z = z.half() | |
self.z = z | |
else: | |
z = self.z | |
else: | |
mu, logvar, z = self.style_encoder(data['images']) | |
data['z'] = z | |
output = self.spade_generator(data) | |
output_images = output['fake_images'] | |
if keep_original_size: | |
height = data['original_h_w'][0][0] | |
width = data['original_h_w'][0][1] | |
output_images = torch.nn.functional.interpolate( | |
output_images, size=[height, width]) | |
for key in data['key'].keys(): | |
if 'segmaps' in key or 'seg_maps' in key: | |
file_names = data['key'][key][0] | |
break | |
for key in data['key'].keys(): | |
if 'edgemaps' in key or 'edge_maps' in key: | |
file_names = data['key'][key][0] | |
break | |
return output_images, file_names | |
class SPADEGenerator(nn.Module): | |
r"""SPADE Image Generator constructor. | |
Args: | |
num_labels (int): Number of different labels. | |
out_image_small_side_size (int): min(width, height) | |
image_channels (int): Num. of channels of the output image. | |
num_filters (int): Base filter numbers. | |
kernel_size (int): Convolution kernel size. | |
style_dims (int): Dimensions of the style code. | |
activation_norm_params (obj): Spatially adaptive normalization param. | |
weight_norm_type (str): Type of weight normalization. | |
``'none'``, ``'spectral'``, or ``'weight'``. | |
global_adaptive_norm_type (str): Type of normalization in SPADE. | |
skip_activation_norm (bool): If ``True``, applies activation norm to the | |
shortcut connection in residual blocks. | |
use_style_encoder (bool): Whether to use global adaptive norm | |
like conditional batch norm or adaptive instance norm. | |
output_multiplier (float): A positive number multiplied to the output | |
""" | |
def __init__(self, | |
num_labels, | |
out_image_small_side_size, | |
image_channels, | |
num_filters, | |
kernel_size, | |
style_dims, | |
activation_norm_params, | |
weight_norm_type, | |
global_adaptive_norm_type, | |
skip_activation_norm, | |
use_posenc_in_input_layer, | |
use_style_encoder, | |
output_multiplier): | |
super(SPADEGenerator, self).__init__() | |
self.output_multiplier = output_multiplier | |
self.use_style_encoder = use_style_encoder | |
self.use_posenc_in_input_layer = use_posenc_in_input_layer | |
self.out_image_small_side_size = out_image_small_side_size | |
self.num_filters = num_filters | |
padding = int(np.ceil((kernel_size - 1.0) / 2)) | |
nonlinearity = 'leakyrelu' | |
activation_norm_type = 'spatially_adaptive' | |
base_res2d_block = \ | |
functools.partial(Res2dBlock, | |
kernel_size=kernel_size, | |
padding=padding, | |
bias=[True, True, False], | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
activation_norm_params=activation_norm_params, | |
skip_activation_norm=skip_activation_norm, | |
nonlinearity=nonlinearity, | |
order='NACNAC') | |
if self.use_style_encoder: | |
self.fc_0 = LinearBlock(style_dims, 2 * style_dims, | |
weight_norm_type=weight_norm_type, | |
nonlinearity='relu', | |
order='CAN') | |
self.fc_1 = LinearBlock(2 * style_dims, 2 * style_dims, | |
weight_norm_type=weight_norm_type, | |
nonlinearity='relu', | |
order='CAN') | |
adaptive_norm_params = types.SimpleNamespace() | |
if not hasattr(adaptive_norm_params, 'cond_dims'): | |
setattr(adaptive_norm_params, 'cond_dims', 2 * style_dims) | |
if not hasattr(adaptive_norm_params, 'activation_norm_type'): | |
setattr(adaptive_norm_params, 'activation_norm_type', global_adaptive_norm_type) | |
if not hasattr(adaptive_norm_params, 'weight_norm_type'): | |
setattr(adaptive_norm_params, 'weight_norm_type', activation_norm_params.weight_norm_type) | |
if not hasattr(adaptive_norm_params, 'separate_projection'): | |
setattr(adaptive_norm_params, 'separate_projection', activation_norm_params.separate_projection) | |
adaptive_norm_params.activation_norm_params = types.SimpleNamespace() | |
setattr(adaptive_norm_params.activation_norm_params, 'affine', | |
activation_norm_params.activation_norm_params.affine) | |
base_cbn2d_block = \ | |
functools.partial(Conv2dBlock, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding, | |
bias=True, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type='adaptive', | |
activation_norm_params=adaptive_norm_params, | |
nonlinearity=nonlinearity, | |
order='NAC') | |
else: | |
base_conv2d_block = \ | |
functools.partial(Conv2dBlock, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding, | |
bias=True, | |
weight_norm_type=weight_norm_type, | |
nonlinearity=nonlinearity, | |
order='NAC') | |
in_num_labels = num_labels | |
in_num_labels += 2 if self.use_posenc_in_input_layer else 0 | |
self.head_0 = Conv2dBlock(in_num_labels, 8 * num_filters, | |
kernel_size=kernel_size, stride=1, | |
padding=padding, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type='none', | |
nonlinearity=nonlinearity) | |
if self.use_style_encoder: | |
self.cbn_head_0 = base_cbn2d_block( | |
8 * num_filters, 16 * num_filters) | |
else: | |
self.conv_head_0 = base_conv2d_block( | |
8 * num_filters, 16 * num_filters) | |
self.head_1 = base_res2d_block(16 * num_filters, 16 * num_filters) | |
self.head_2 = base_res2d_block(16 * num_filters, 16 * num_filters) | |
self.up_0a = base_res2d_block(16 * num_filters, 8 * num_filters) | |
if self.use_style_encoder: | |
self.cbn_up_0a = base_cbn2d_block( | |
8 * num_filters, 8 * num_filters) | |
else: | |
self.conv_up_0a = base_conv2d_block( | |
8 * num_filters, 8 * num_filters) | |
self.up_0b = base_res2d_block(8 * num_filters, 8 * num_filters) | |
self.up_1a = base_res2d_block(8 * num_filters, 4 * num_filters) | |
if self.use_style_encoder: | |
self.cbn_up_1a = base_cbn2d_block( | |
4 * num_filters, 4 * num_filters) | |
else: | |
self.conv_up_1a = base_conv2d_block( | |
4 * num_filters, 4 * num_filters) | |
self.up_1b = base_res2d_block(4 * num_filters, 4 * num_filters) | |
self.up_2a = base_res2d_block(4 * num_filters, 4 * num_filters) | |
if self.use_style_encoder: | |
self.cbn_up_2a = base_cbn2d_block( | |
4 * num_filters, 4 * num_filters) | |
else: | |
self.conv_up_2a = base_conv2d_block( | |
4 * num_filters, 4 * num_filters) | |
self.up_2b = base_res2d_block(4 * num_filters, 2 * num_filters) | |
self.conv_img256 = Conv2dBlock(2 * num_filters, image_channels, | |
5, stride=1, padding=2, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type='none', | |
nonlinearity=nonlinearity, | |
order='ANC') | |
self.base = 16 | |
if self.out_image_small_side_size == 512: | |
self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters) | |
self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters) | |
self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels, | |
5, stride=1, padding=2, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type='none', | |
nonlinearity=nonlinearity, | |
order='ANC') | |
self.base = 32 | |
if self.out_image_small_side_size == 1024: | |
self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters) | |
self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters) | |
self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels, | |
5, stride=1, padding=2, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type='none', | |
nonlinearity=nonlinearity, | |
order='ANC') | |
self.up_4a = base_res2d_block(num_filters, num_filters // 2) | |
self.up_4b = base_res2d_block(num_filters // 2, num_filters // 2) | |
self.conv_img1024 = Conv2dBlock(num_filters // 2, image_channels, | |
5, stride=1, padding=2, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type='none', | |
nonlinearity=nonlinearity, | |
order='ANC') | |
self.nearest_upsample4x = NearestUpsample(scale_factor=4, mode='nearest') | |
self.base = 64 | |
if self.out_image_small_side_size != 256 and self.out_image_small_side_size != 512 \ | |
and self.out_image_small_side_size != 1024: | |
raise ValueError('Generation image size (%d, %d) not supported' % | |
(self.out_image_small_side_size, | |
self.out_image_small_side_size)) | |
self.nearest_upsample2x = NearestUpsample(scale_factor=2, mode='nearest') | |
xv, yv = torch.meshgrid( | |
[torch.arange(-1, 1.1, 2. / 15), torch.arange(-1, 1.1, 2. / 15)]) | |
self.xy = torch.cat((xv.unsqueeze(0), yv.unsqueeze(0)), 0).unsqueeze(0) | |
self.xy = self.xy.cuda() | |
def forward(self, data): | |
r"""SPADE Generator forward. | |
Args: | |
data (dict): | |
- data (N x C1 x H x W tensor) : Ground truth images. | |
- label (N x C2 x H x W tensor) : Semantic representations. | |
- z (N x style_dims tensor): Gaussian random noise. | |
Returns: | |
output (dict): | |
- fake_images (N x 3 x H x W tensor): Fake images. | |
""" | |
seg = data['label'] | |
if self.use_style_encoder: | |
z = data['z'] | |
z = self.fc_0(z) | |
z = self.fc_1(z) | |
# The code piece below makes sure that the input size is always 16x16 | |
sy = math.floor(seg.size()[2] * 1.0 / self.base) | |
sx = math.floor(seg.size()[3] * 1.0 / self.base) | |
in_seg = F.interpolate(seg, size=[sy, sx], mode='nearest') | |
if self.use_posenc_in_input_layer: | |
in_xy = F.interpolate(self.xy, size=[sy, sx], mode='bicubic') | |
in_seg_xy = torch.cat( | |
(in_seg, in_xy.expand(in_seg.size()[0], 2, sy, sx)), 1) | |
else: | |
in_seg_xy = in_seg | |
# 16x16 | |
x = self.head_0(in_seg_xy) | |
if self.use_style_encoder: | |
x = self.cbn_head_0(x, z) | |
else: | |
x = self.conv_head_0(x) | |
x = self.head_1(x, seg) | |
x = self.head_2(x, seg) | |
x = self.nearest_upsample2x(x) | |
# 32x32 | |
x = self.up_0a(x, seg) | |
if self.use_style_encoder: | |
x = self.cbn_up_0a(x, z) | |
else: | |
x = self.conv_up_0a(x) | |
x = self.up_0b(x, seg) | |
x = self.nearest_upsample2x(x) | |
# 64x64 | |
x = self.up_1a(x, seg) | |
if self.use_style_encoder: | |
x = self.cbn_up_1a(x, z) | |
else: | |
x = self.conv_up_1a(x) | |
x = self.up_1b(x, seg) | |
x = self.nearest_upsample2x(x) | |
# 128x128 | |
x = self.up_2a(x, seg) | |
if self.use_style_encoder: | |
x = self.cbn_up_2a(x, z) | |
else: | |
x = self.conv_up_2a(x) | |
x = self.up_2b(x, seg) | |
x = self.nearest_upsample2x(x) | |
# 256x256 | |
if self.out_image_small_side_size == 256: | |
x256 = self.conv_img256(x) | |
x = torch.tanh(self.output_multiplier * x256) | |
# 512x512 | |
elif self.out_image_small_side_size == 512: | |
x256 = self.conv_img256(x) | |
x256 = self.nearest_upsample2x(x256) | |
x = self.up_3a(x, seg) | |
x = self.up_3b(x, seg) | |
x = self.nearest_upsample2x(x) | |
x512 = self.conv_img512(x) | |
x = torch.tanh(self.output_multiplier * (x256 + x512)) | |
# 1024x1024 | |
elif self.out_image_small_side_size == 1024: | |
x256 = self.conv_img256(x) | |
x256 = self.nearest_upsample4x(x256) | |
x = self.up_3a(x, seg) | |
x = self.up_3b(x, seg) | |
x = self.nearest_upsample2x(x) | |
x512 = self.conv_img512(x) | |
x512 = self.nearest_upsample2x(x512) | |
x = self.up_4a(x, seg) | |
x = self.up_4b(x, seg) | |
x = self.nearest_upsample2x(x) | |
x1024 = self.conv_img1024(x) | |
x = torch.tanh(self.output_multiplier * (x256 + x512 + x1024)) | |
output = dict() | |
output['fake_images'] = x | |
return output | |
class StyleEncoder(nn.Module): | |
r"""Style Encode constructor. | |
Args: | |
style_enc_cfg (obj): Style encoder definition file. | |
""" | |
def __init__(self, style_enc_cfg): | |
super(StyleEncoder, self).__init__() | |
input_image_channels = style_enc_cfg.input_image_channels | |
num_filters = style_enc_cfg.num_filters | |
kernel_size = style_enc_cfg.kernel_size | |
padding = int(np.ceil((kernel_size - 1.0) / 2)) | |
style_dims = style_enc_cfg.style_dims | |
weight_norm_type = style_enc_cfg.weight_norm_type | |
activation_norm_type = 'none' | |
nonlinearity = 'leakyrelu' | |
base_conv2d_block = \ | |
functools.partial(Conv2dBlock, | |
kernel_size=kernel_size, | |
stride=2, | |
padding=padding, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
# inplace_nonlinearity=True, | |
nonlinearity=nonlinearity) | |
self.layer1 = base_conv2d_block(input_image_channels, num_filters) | |
self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2) | |
self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4) | |
self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8) | |
self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8) | |
self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8) | |
self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims) | |
self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims) | |
def forward(self, input_x): | |
r"""SPADE Style Encoder forward. | |
Args: | |
input_x (N x 3 x H x W tensor): input images. | |
Returns: | |
(tuple): | |
- mu (N x C tensor): Mean vectors. | |
- logvar (N x C tensor): Log-variance vectors. | |
- z (N x C tensor): Style code vectors. | |
""" | |
if input_x.size(2) != 256 or input_x.size(3) != 256: | |
input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear') | |
x = self.layer1(input_x) | |
x = self.layer2(x) | |
x = self.layer3(x) | |
x = self.layer4(x) | |
x = self.layer5(x) | |
x = self.layer6(x) | |
x = x.view(x.size(0), -1) | |
mu = self.fc_mu(x) | |
logvar = self.fc_var(x) | |
std = torch.exp(0.5 * logvar) | |
eps = torch.randn_like(std) | |
z = eps.mul(std) + mu | |
return mu, logvar, z | |