Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import functools | |
from imaginaire.layers import Conv2dBlock | |
from imaginaire.utils.data import get_paired_input_label_channel_number, get_paired_input_image_channel_number | |
from imaginaire.utils.distributed import master_only_print as print | |
class Discriminator(nn.Module): | |
r"""Multi-resolution patch discriminator. Based on FPSE discriminator but with N+1 labels. | |
Args: | |
dis_cfg (obj): Discriminator definition part of the yaml config file. | |
data_cfg (obj): Data definition part of the yaml config file. | |
""" | |
def __init__(self, dis_cfg, data_cfg): | |
super(Discriminator, self).__init__() | |
# We assume the first datum is the ground truth image. | |
image_channels = get_paired_input_image_channel_number(data_cfg) | |
# Calculate number of channels in the input label. | |
num_labels = get_paired_input_label_channel_number(data_cfg) | |
self.use_label = getattr(dis_cfg, 'use_label', True) | |
# Override number of input channels | |
if hasattr(dis_cfg, 'image_channels'): | |
image_channels = dis_cfg.image_channels | |
if hasattr(dis_cfg, 'num_labels'): | |
num_labels = dis_cfg.num_labels | |
else: | |
# We assume the first datum is the ground truth image. | |
image_channels = get_paired_input_image_channel_number(data_cfg) | |
# Calculate number of channels in the input label. | |
num_labels = get_paired_input_label_channel_number(data_cfg) | |
if not self.use_label: | |
num_labels = 2 # ignore + true | |
# Build the discriminator. | |
num_filters = getattr(dis_cfg, 'num_filters', 128) | |
weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral') | |
fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3) | |
fpse_activation_norm_type = getattr(dis_cfg, | |
'fpse_activation_norm_type', | |
'none') | |
do_multiscale = getattr(dis_cfg, 'do_multiscale', False) | |
smooth_resample = getattr(dis_cfg, 'smooth_resample', False) | |
no_label_except_largest_scale = getattr(dis_cfg, 'no_label_except_largest_scale', False) | |
self.fpse_discriminator = FPSEDiscriminator( | |
image_channels, | |
num_labels, | |
num_filters, | |
fpse_kernel_size, | |
weight_norm_type, | |
fpse_activation_norm_type, | |
do_multiscale, | |
smooth_resample, | |
no_label_except_largest_scale) | |
def _single_forward(self, input_label, input_image, weights): | |
output_list, features_list = self.fpse_discriminator(input_image, input_label, weights) | |
return output_list, [features_list] | |
def forward(self, data, net_G_output, weights=None, incl_real=False, incl_pseudo_real=False): | |
r"""GANcraft discriminator forward. | |
Args: | |
data (dict): | |
- data (N x C1 x H x W tensor) : Ground truth images. | |
- label (N x C2 x H x W tensor) : Semantic representations. | |
- z (N x style_dims tensor): Gaussian random noise. | |
net_G_output (dict): | |
- fake_images (N x C1 x H x W tensor) : Fake images. | |
Returns: | |
output_x (dict): | |
- real_outputs (list): list of output tensors produced by | |
individual patch discriminators for real images. | |
- real_features (list): list of lists of features produced by | |
individual patch discriminators for real images. | |
- fake_outputs (list): list of output tensors produced by | |
individual patch discriminators for fake images. | |
- fake_features (list): list of lists of features produced by | |
individual patch discriminators for fake images. | |
""" | |
output_x = dict() | |
# Fake. | |
fake_images = net_G_output['fake_images'] | |
if self.use_label: | |
fake_labels = data['fake_masks'] | |
else: | |
fake_labels = torch.zeros([fake_images.size(0), 2, fake_images.size( | |
2), fake_images.size(3)], device=fake_images.device, dtype=fake_images.dtype) | |
fake_labels[:, 1, :, :] = 1 | |
output_x['fake_outputs'], output_x['fake_features'] = \ | |
self._single_forward(fake_labels, fake_images, None) | |
# Real. | |
if incl_real: | |
real_images = data['images'] | |
if self.use_label: | |
real_labels = data['real_masks'] | |
else: | |
real_labels = torch.zeros([real_images.size(0), 2, real_images.size( | |
2), real_images.size(3)], device=real_images.device, dtype=real_images.dtype) | |
real_labels[:, 1, :, :] = 1 | |
output_x['real_outputs'], output_x['real_features'] = \ | |
self._single_forward(real_labels, real_images, None) | |
# pseudo-Real. | |
if incl_pseudo_real: | |
preal_images = data['pseudo_real_img'] | |
preal_labels = data['fake_masks'] | |
if not self.use_label: | |
preal_labels = torch.zeros([preal_images.size(0), 2, preal_images.size( | |
2), preal_images.size(3)], device=preal_images.device, dtype=preal_images.dtype) | |
preal_labels[:, 1, :, :] = 1 | |
output_x['pseudo_real_outputs'], output_x['pseudo_real_features'] = \ | |
self._single_forward(preal_labels, preal_images, None) | |
return output_x | |
class FPSEDiscriminator(nn.Module): | |
def __init__(self, | |
num_input_channels, | |
num_labels, | |
num_filters, | |
kernel_size, | |
weight_norm_type, | |
activation_norm_type, | |
do_multiscale, | |
smooth_resample, | |
no_label_except_largest_scale): | |
super().__init__() | |
self.do_multiscale = do_multiscale | |
self.no_label_except_largest_scale = no_label_except_largest_scale | |
padding = int(np.ceil((kernel_size - 1.0) / 2)) | |
nonlinearity = 'leakyrelu' | |
stride1_conv2d_block = \ | |
functools.partial(Conv2dBlock, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
nonlinearity=nonlinearity, | |
# inplace_nonlinearity=True, | |
order='CNA') | |
down_conv2d_block = \ | |
functools.partial(Conv2dBlock, | |
kernel_size=kernel_size, | |
stride=2, | |
padding=padding, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
nonlinearity=nonlinearity, | |
# inplace_nonlinearity=True, | |
order='CNA') | |
latent_conv2d_block = \ | |
functools.partial(Conv2dBlock, | |
kernel_size=1, | |
stride=1, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
nonlinearity=nonlinearity, | |
# inplace_nonlinearity=True, | |
order='CNA') | |
# bottom-up pathway | |
self.enc1 = down_conv2d_block(num_input_channels, num_filters) # 3 | |
self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters) # 7 | |
self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters) # 15 | |
self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters) # 31 | |
self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters) # 63 | |
# top-down pathway | |
# self.lat1 = latent_conv2d_block(num_filters, 2 * num_filters) # Zekun | |
self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters) | |
self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters) | |
self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters) | |
self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters) | |
# upsampling | |
self.upsample2x = nn.Upsample(scale_factor=2, mode='bilinear', | |
align_corners=False) | |
# final layers | |
self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) | |
self.output = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1) | |
if self.do_multiscale: | |
self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) | |
self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) | |
if self.no_label_except_largest_scale: | |
self.output3 = Conv2dBlock(num_filters * 2, 2, kernel_size=1) | |
self.output4 = Conv2dBlock(num_filters * 2, 2, kernel_size=1) | |
else: | |
self.output3 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1) | |
self.output4 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1) | |
self.interpolator = functools.partial(F.interpolate, mode='nearest') | |
if smooth_resample: | |
self.interpolator = self.smooth_interp | |
def smooth_interp(x, size): | |
r"""Smooth interpolation of segmentation maps. | |
Args: | |
x (4D tensor): Segmentation maps. | |
size(2D list): Target size (H, W). | |
""" | |
x = F.interpolate(x, size=size, mode='area') | |
onehot_idx = torch.argmax(x, dim=-3, keepdims=True) | |
x.fill_(0.0) | |
x.scatter_(1, onehot_idx, 1.0) | |
return x | |
# Weights: [N C] | |
def forward(self, images, segmaps, weights=None): | |
# Assume images 256x256 | |
# bottom-up pathway | |
feat11 = self.enc1(images) # 128 | |
feat12 = self.enc2(feat11) # 64 | |
feat13 = self.enc3(feat12) # 32 | |
feat14 = self.enc4(feat13) # 16 | |
feat15 = self.enc5(feat14) # 8 | |
# top-down pathway and lateral connections | |
feat25 = self.lat5(feat15) # 8 | |
feat24 = self.upsample2x(feat25) + self.lat4(feat14) # 16 | |
feat23 = self.upsample2x(feat24) + self.lat3(feat13) # 32 | |
feat22 = self.upsample2x(feat23) + self.lat2(feat12) # 64 | |
# final prediction layers | |
feat32 = self.final2(feat22) | |
results = [] | |
label_map = self.interpolator(segmaps, size=feat32.size()[2:]) | |
pred2 = self.output(feat32) # N, num_labels+1, H//4, W//4 | |
features = [feat11, feat12, feat13, feat14, feat15, feat25, feat24, feat23, feat22] | |
if weights is not None: | |
label_map = label_map * weights[..., None, None] | |
results.append({'pred': pred2, 'label': label_map}) | |
if self.do_multiscale: | |
feat33 = self.final3(feat23) | |
pred3 = self.output3(feat33) | |
feat34 = self.final4(feat24) | |
pred4 = self.output4(feat34) | |
if self.no_label_except_largest_scale: | |
label_map3 = torch.ones([pred3.size(0), 1, pred3.size(2), pred3.size(3)], device=pred3.device) | |
label_map4 = torch.ones([pred4.size(0), 1, pred4.size(2), pred4.size(3)], device=pred4.device) | |
else: | |
label_map3 = self.interpolator(segmaps, size=pred3.size()[2:]) | |
label_map4 = self.interpolator(segmaps, size=pred4.size()[2:]) | |
if weights is not None: | |
label_map3 = label_map3 * weights[..., None, None] | |
label_map4 = label_map4 * weights[..., None, None] | |
results.append({'pred': pred3, 'label': label_map3}) | |
results.append({'pred': pred4, 'label': label_map4}) | |
return results, features | |