venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import functools
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from imaginaire.layers import Conv2dBlock
class FPSEDiscriminator(nn.Module):
r"""# Feature-Pyramid Semantics Embedding Discriminator. This is a copy
of the discriminator in https://arxiv.org/pdf/1910.06809.pdf
"""
def __init__(self,
num_input_channels,
num_labels,
num_filters,
kernel_size,
weight_norm_type,
activation_norm_type):
super().__init__()
padding = int(np.ceil((kernel_size - 1.0) / 2))
nonlinearity = 'leakyrelu'
stride1_conv2d_block = \
functools.partial(Conv2dBlock,
kernel_size=kernel_size,
stride=1,
padding=padding,
weight_norm_type=weight_norm_type,
activation_norm_type=activation_norm_type,
nonlinearity=nonlinearity,
# inplace_nonlinearity=True,
order='CNA')
down_conv2d_block = \
functools.partial(Conv2dBlock,
kernel_size=kernel_size,
stride=2,
padding=padding,
weight_norm_type=weight_norm_type,
activation_norm_type=activation_norm_type,
nonlinearity=nonlinearity,
# inplace_nonlinearity=True,
order='CNA')
latent_conv2d_block = \
functools.partial(Conv2dBlock,
kernel_size=1,
stride=1,
weight_norm_type=weight_norm_type,
activation_norm_type=activation_norm_type,
nonlinearity=nonlinearity,
# inplace_nonlinearity=True,
order='CNA')
# bottom-up pathway
self.enc1 = down_conv2d_block(num_input_channels, num_filters)
self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters)
self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters)
self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters)
self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters)
# top-down pathway
self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters)
self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters)
self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
# upsampling
self.upsample2x = nn.Upsample(scale_factor=2, mode='bilinear',
align_corners=False)
# final layers
self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
# true/false prediction and semantic alignment prediction
self.output = Conv2dBlock(num_filters * 2, 1, kernel_size=1)
self.seg = Conv2dBlock(num_filters * 2, num_filters * 2, kernel_size=1)
self.embedding = Conv2dBlock(num_labels, num_filters * 2, kernel_size=1)
def forward(self, images, segmaps):
r"""
Args:
images: image tensors.
segmaps: segmentation map tensors.
"""
# bottom-up pathway
feat11 = self.enc1(images)
feat12 = self.enc2(feat11)
feat13 = self.enc3(feat12)
feat14 = self.enc4(feat13)
feat15 = self.enc5(feat14)
# top-down pathway and lateral connections
feat25 = self.lat5(feat15)
feat24 = self.upsample2x(feat25) + self.lat4(feat14)
feat23 = self.upsample2x(feat24) + self.lat3(feat13)
feat22 = self.upsample2x(feat23) + self.lat2(feat12)
# final prediction layers
feat32 = self.final2(feat22)
feat33 = self.final3(feat23)
feat34 = self.final4(feat24)
# Patch-based True/False prediction
pred2 = self.output(feat32)
pred3 = self.output(feat33)
pred4 = self.output(feat34)
seg2 = self.seg(feat32)
seg3 = self.seg(feat33)
seg4 = self.seg(feat34)
# # segmentation map embedding
segembs = self.embedding(segmaps)
segembs = F.avg_pool2d(segembs, kernel_size=2, stride=2)
segembs2 = F.avg_pool2d(segembs, kernel_size=2, stride=2)
segembs3 = F.avg_pool2d(segembs2, kernel_size=2, stride=2)
segembs4 = F.avg_pool2d(segembs3, kernel_size=2, stride=2)
# semantics embedding discriminator score
pred2 += torch.mul(segembs2, seg2).sum(dim=1, keepdim=True)
pred3 += torch.mul(segembs3, seg3).sum(dim=1, keepdim=True)
pred4 += torch.mul(segembs4, seg4).sum(dim=1, keepdim=True)
# concat results from multiple resolutions
# results = [pred2, pred3, pred4]
return pred2, pred3, pred4