Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import functools | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from imaginaire.layers import Conv2dBlock | |
class FPSEDiscriminator(nn.Module): | |
r"""# Feature-Pyramid Semantics Embedding Discriminator. This is a copy | |
of the discriminator in https://arxiv.org/pdf/1910.06809.pdf | |
""" | |
def __init__(self, | |
num_input_channels, | |
num_labels, | |
num_filters, | |
kernel_size, | |
weight_norm_type, | |
activation_norm_type): | |
super().__init__() | |
padding = int(np.ceil((kernel_size - 1.0) / 2)) | |
nonlinearity = 'leakyrelu' | |
stride1_conv2d_block = \ | |
functools.partial(Conv2dBlock, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
nonlinearity=nonlinearity, | |
# inplace_nonlinearity=True, | |
order='CNA') | |
down_conv2d_block = \ | |
functools.partial(Conv2dBlock, | |
kernel_size=kernel_size, | |
stride=2, | |
padding=padding, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
nonlinearity=nonlinearity, | |
# inplace_nonlinearity=True, | |
order='CNA') | |
latent_conv2d_block = \ | |
functools.partial(Conv2dBlock, | |
kernel_size=1, | |
stride=1, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
nonlinearity=nonlinearity, | |
# inplace_nonlinearity=True, | |
order='CNA') | |
# bottom-up pathway | |
self.enc1 = down_conv2d_block(num_input_channels, num_filters) | |
self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters) | |
self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters) | |
self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters) | |
self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters) | |
# top-down pathway | |
self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters) | |
self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters) | |
self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters) | |
self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters) | |
# upsampling | |
self.upsample2x = nn.Upsample(scale_factor=2, mode='bilinear', | |
align_corners=False) | |
# final layers | |
self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) | |
self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) | |
self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) | |
# true/false prediction and semantic alignment prediction | |
self.output = Conv2dBlock(num_filters * 2, 1, kernel_size=1) | |
self.seg = Conv2dBlock(num_filters * 2, num_filters * 2, kernel_size=1) | |
self.embedding = Conv2dBlock(num_labels, num_filters * 2, kernel_size=1) | |
def forward(self, images, segmaps): | |
r""" | |
Args: | |
images: image tensors. | |
segmaps: segmentation map tensors. | |
""" | |
# bottom-up pathway | |
feat11 = self.enc1(images) | |
feat12 = self.enc2(feat11) | |
feat13 = self.enc3(feat12) | |
feat14 = self.enc4(feat13) | |
feat15 = self.enc5(feat14) | |
# top-down pathway and lateral connections | |
feat25 = self.lat5(feat15) | |
feat24 = self.upsample2x(feat25) + self.lat4(feat14) | |
feat23 = self.upsample2x(feat24) + self.lat3(feat13) | |
feat22 = self.upsample2x(feat23) + self.lat2(feat12) | |
# final prediction layers | |
feat32 = self.final2(feat22) | |
feat33 = self.final3(feat23) | |
feat34 = self.final4(feat24) | |
# Patch-based True/False prediction | |
pred2 = self.output(feat32) | |
pred3 = self.output(feat33) | |
pred4 = self.output(feat34) | |
seg2 = self.seg(feat32) | |
seg3 = self.seg(feat33) | |
seg4 = self.seg(feat34) | |
# # segmentation map embedding | |
segembs = self.embedding(segmaps) | |
segembs = F.avg_pool2d(segembs, kernel_size=2, stride=2) | |
segembs2 = F.avg_pool2d(segembs, kernel_size=2, stride=2) | |
segembs3 = F.avg_pool2d(segembs2, kernel_size=2, stride=2) | |
segembs4 = F.avg_pool2d(segembs3, kernel_size=2, stride=2) | |
# semantics embedding discriminator score | |
pred2 += torch.mul(segembs2, seg2).sum(dim=1, keepdim=True) | |
pred3 += torch.mul(segembs3, seg3).sum(dim=1, keepdim=True) | |
pred4 += torch.mul(segembs4, seg4).sum(dim=1, keepdim=True) | |
# concat results from multiple resolutions | |
# results = [pred2, pred3, pred4] | |
return pred2, pred3, pred4 | |