Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import warnings | |
from torch import nn | |
from torch.nn import Upsample as NearestUpsample | |
from imaginaire.layers import Conv2dBlock, Res2dBlock | |
class Generator(nn.Module): | |
r"""Improved UNIT generator. | |
Args: | |
gen_cfg (obj): Generator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file. | |
""" | |
def __init__(self, gen_cfg, data_cfg): | |
super().__init__() | |
self.autoencoder_a = AutoEncoder(**vars(gen_cfg)) | |
self.autoencoder_b = AutoEncoder(**vars(gen_cfg)) | |
def forward(self, data, image_recon=True, cycle_recon=True): | |
r"""UNIT forward function""" | |
images_a = data['images_a'] | |
images_b = data['images_b'] | |
net_G_output = dict() | |
# encode input images into latent code | |
content_a = self.autoencoder_a.content_encoder(images_a) | |
content_b = self.autoencoder_b.content_encoder(images_b) | |
# decode (within domain) | |
if image_recon: | |
images_aa = self.autoencoder_a.decoder(content_a) | |
images_bb = self.autoencoder_b.decoder(content_b) | |
net_G_output.update(dict(images_aa=images_aa, images_bb=images_bb)) | |
# decode (cross domain) | |
images_ba = self.autoencoder_a.decoder(content_b) | |
images_ab = self.autoencoder_b.decoder(content_a) | |
# cycle reconstruction | |
if cycle_recon: | |
content_ba = self.autoencoder_a.content_encoder(images_ba) | |
content_ab = self.autoencoder_b.content_encoder(images_ab) | |
images_aba = self.autoencoder_a.decoder(content_ab) | |
images_bab = self.autoencoder_b.decoder(content_ba) | |
net_G_output.update( | |
dict(content_ba=content_ba, content_ab=content_ab, | |
images_aba=images_aba, images_bab=images_bab)) | |
# required outputs | |
net_G_output.update(dict(content_a=content_a, content_b=content_b, | |
images_ba=images_ba, images_ab=images_ab)) | |
return net_G_output | |
def inference(self, data, a2b=True): | |
r"""UNIT inference. | |
Args: | |
data (dict): Training data at the current iteration. | |
- images_a (tensor): Images from domain A. | |
- images_b (tensor): Images from domain B. | |
a2b (bool): If ``True``, translates images from domain A to B, | |
otherwise from B to A. | |
""" | |
if a2b: | |
input_key = 'images_a' | |
content_encode = self.autoencoder_a.content_encoder | |
decode = self.autoencoder_b.decoder | |
else: | |
input_key = 'images_b' | |
content_encode = self.autoencoder_b.content_encoder | |
decode = self.autoencoder_a.decoder | |
content_images = data[input_key] | |
content = content_encode(content_images) | |
output_images = decode(content) | |
filename = '%s/%s' % ( | |
data['key'][input_key]['sequence_name'][0], | |
data['key'][input_key]['filename'][0]) | |
filenames = [filename] | |
return output_images, filenames | |
class AutoEncoder(nn.Module): | |
r"""Improved UNIT autoencoder. | |
Args: | |
num_filters (int): Base filter numbers. | |
max_num_filters (int): Maximum number of filters in the encoder. | |
num_res_blocks (int): Number of residual blocks at the end of the | |
content encoder. | |
num_downsamples_content (int): Number of times we reduce | |
resolution by 2x2 for the content image. | |
num_image_channels (int): Number of input image channels. | |
content_norm_type (str): Type of activation normalization in the | |
content encoder. | |
decoder_norm_type (str): Type of activation normalization in the | |
decoder. | |
weight_norm_type (str): Type of weight normalization. | |
output_nonlinearity (str): Type of nonlinearity before final output, | |
``'tanh'`` or ``'none'``. | |
pre_act (bool): If ``True``, uses pre-activation residual blocks. | |
apply_noise (bool): If ``True``, injects Gaussian noise in the decoder. | |
""" | |
def __init__(self, | |
num_filters=64, | |
max_num_filters=256, | |
num_res_blocks=4, | |
num_downsamples_content=2, | |
num_image_channels=3, | |
content_norm_type='instance', | |
decoder_norm_type='instance', | |
weight_norm_type='', | |
output_nonlinearity='', | |
pre_act=False, | |
apply_noise=False, | |
**kwargs): | |
super().__init__() | |
for key in kwargs: | |
if key != 'type': | |
warnings.warn( | |
"Generator argument '{}' is not used.".format(key)) | |
self.content_encoder = ContentEncoder(num_downsamples_content, | |
num_res_blocks, | |
num_image_channels, | |
num_filters, | |
max_num_filters, | |
'reflect', | |
content_norm_type, | |
weight_norm_type, | |
'relu', | |
pre_act) | |
self.decoder = Decoder(num_downsamples_content, | |
num_res_blocks, | |
self.content_encoder.output_dim, | |
num_image_channels, | |
'reflect', | |
decoder_norm_type, | |
weight_norm_type, | |
'relu', | |
output_nonlinearity, | |
pre_act, | |
apply_noise) | |
def forward(self, images): | |
r"""Reconstruct an image. | |
Args: | |
images (Tensor): Input images. | |
Returns: | |
images_recon (Tensor): Reconstructed images. | |
""" | |
content = self.content_encoder(images) | |
images_recon = self.decoder(content) | |
return images_recon | |
class ContentEncoder(nn.Module): | |
r"""Improved UNIT encoder. The network consists of: | |
- input layers | |
- $(num_downsamples) convolutional blocks | |
- $(num_res_blocks) residual blocks. | |
- output layer. | |
Args: | |
num_downsamples (int): Number of times we reduce | |
resolution by 2x2. | |
num_res_blocks (int): Number of residual blocks at the end of the | |
content encoder. | |
num_image_channels (int): Number of input image channels. | |
num_filters (int): Base filter numbers. | |
max_num_filters (int): Maximum number of filters in the encoder. | |
padding_mode (string): Type of padding. | |
activation_norm_type (str): Type of activation normalization. | |
weight_norm_type (str): Type of weight normalization. | |
nonlinearity (str): Type of nonlinear activation function. | |
pre_act (bool): If ``True``, uses pre-activation residual blocks. | |
""" | |
def __init__(self, | |
num_downsamples, | |
num_res_blocks, | |
num_image_channels, | |
num_filters, | |
max_num_filters, | |
padding_mode, | |
activation_norm_type, | |
weight_norm_type, | |
nonlinearity, | |
pre_act=False): | |
super().__init__() | |
conv_params = dict(padding_mode=padding_mode, | |
activation_norm_type=activation_norm_type, | |
weight_norm_type=weight_norm_type, | |
nonlinearity=nonlinearity) | |
# Whether or not it is safe to use inplace nonlinear activation. | |
if not pre_act or (activation_norm_type != '' and | |
activation_norm_type != 'none'): | |
conv_params['inplace_nonlinearity'] = True | |
# The order of operations in residual blocks. | |
order = 'pre_act' if pre_act else 'CNACNA' | |
model = [] | |
model += [Conv2dBlock(num_image_channels, num_filters, 7, 1, 3, | |
**conv_params)] | |
# Downsampling blocks. | |
for i in range(num_downsamples): | |
num_filters_prev = num_filters | |
num_filters = min(num_filters * 2, max_num_filters) | |
model += [Conv2dBlock(num_filters_prev, num_filters, 4, 2, 1, | |
**conv_params)] | |
# Residual blocks. | |
for _ in range(num_res_blocks): | |
model += [Res2dBlock(num_filters, num_filters, | |
**conv_params, | |
order=order)] | |
self.model = nn.Sequential(*model) | |
self.output_dim = num_filters | |
def forward(self, x): | |
r""" | |
Args: | |
x (tensor): Input image. | |
""" | |
return self.model(x) | |
class Decoder(nn.Module): | |
r"""Improved UNIT decoder. The network consists of: | |
- $(num_res_blocks) residual blocks. | |
- $(num_upsamples) residual blocks or convolutional blocks | |
- output layer. | |
Args: | |
num_upsamples (int): Number of times we increase resolution by 2x2. | |
num_res_blocks (int): Number of residual blocks. | |
num_filters (int): Base filter numbers. | |
num_image_channels (int): Number of input image channels. | |
padding_mode (string): Type of padding. | |
activation_norm_type (str): Type of activation normalization. | |
weight_norm_type (str): Type of weight normalization. | |
nonlinearity (str): Type of nonlinear activation function. | |
output_nonlinearity (str): Type of nonlinearity before final output, | |
``'tanh'`` or ``'none'``. | |
pre_act (bool): If ``True``, uses pre-activation residual blocks. | |
apply_noise (bool): If ``True``, injects Gaussian noise. | |
""" | |
def __init__(self, | |
num_upsamples, | |
num_res_blocks, | |
num_filters, | |
num_image_channels, | |
padding_mode, | |
activation_norm_type, | |
weight_norm_type, | |
nonlinearity, | |
output_nonlinearity, | |
pre_act=False, | |
apply_noise=False): | |
super().__init__() | |
conv_params = dict(padding_mode=padding_mode, | |
nonlinearity=nonlinearity, | |
inplace_nonlinearity=True, | |
apply_noise=apply_noise, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type) | |
# The order of operations in residual blocks. | |
order = 'pre_act' if pre_act else 'CNACNA' | |
# Residual blocks. | |
self.decoder = nn.ModuleList() | |
for _ in range(num_res_blocks): | |
self.decoder += [Res2dBlock(num_filters, num_filters, | |
**conv_params, | |
order=order)] | |
# Convolutional blocks with upsampling. | |
for i in range(num_upsamples): | |
self.decoder += [NearestUpsample(scale_factor=2)] | |
self.decoder += [Conv2dBlock(num_filters, num_filters // 2, | |
5, 1, 2, **conv_params)] | |
num_filters //= 2 | |
self.decoder += [Conv2dBlock(num_filters, num_image_channels, 7, 1, 3, | |
nonlinearity=output_nonlinearity, | |
padding_mode=padding_mode)] | |
def forward(self, x): | |
r""" | |
Args: | |
x (tensor): Content embedding of the content image. | |
""" | |
for block in self.decoder: | |
x = block(x) | |
return x | |