|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Implements a resized feature fuser for stacked decoders in MaX-DeepLab. |
|
|
|
Reference: |
|
MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers, |
|
CVPR 2021. https://arxiv.org/abs/2012.00759 |
|
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. |
|
""" |
|
|
|
import tensorflow as tf |
|
|
|
from deeplab2.model import utils |
|
from deeplab2.model.layers import activations |
|
from deeplab2.model.layers import convolutions |
|
|
|
|
|
class ResizedFuse(tf.keras.layers.Layer): |
|
"""Fuses features by resizing and 1x1 convolutions. |
|
|
|
This function fuses all input features to a desired shape, by projecting the |
|
features to the desired number of channels, bilinear resizing the outputs |
|
(either upsampling or downsampling), and finally adding the outputs. If the |
|
input channel equals the desired output channels, the 1x1 convolutional |
|
projection is skipped. If the projection and bilinear resizing can be fused |
|
into a stride 2 convolution, we use this faster implementation. Other strides |
|
are also supported with the bilinear resizing, but are probably slower than |
|
strided convolutions. |
|
|
|
Reference: |
|
MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers, |
|
CVPR 2021. https://arxiv.org/abs/2012.00759 |
|
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. |
|
""" |
|
|
|
def __init__(self, |
|
name, |
|
height, |
|
width, |
|
num_channels, |
|
activation='relu', |
|
bn_layer=tf.keras.layers.BatchNormalization, |
|
conv_kernel_weight_decay=0.0): |
|
"""Initializes a ResizedFuse layer. |
|
|
|
Args: |
|
name: A string, the name of this layer. |
|
height: An integer, the desired height of the output. |
|
width: An integer, the desired width of the output. |
|
num_channels: An integer, the num of output channels. |
|
activation: A string, type of activation function to apply. |
|
bn_layer: A tf.keras.layers.Layer that computes the normalization |
|
(default: tf.keras.layers.BatchNormalization). |
|
conv_kernel_weight_decay: A float, the weight decay for convolution |
|
kernels. |
|
""" |
|
super(ResizedFuse, self).__init__(name=name) |
|
self._height = height |
|
self._width = width |
|
self._num_channels = num_channels |
|
self._activation_fn = activations.get_activation(activation) |
|
self._bn_layer = bn_layer |
|
self._conv_kernel_weight_decay = conv_kernel_weight_decay |
|
|
|
def build(self, input_shapes): |
|
for index, feature_shape in enumerate(input_shapes): |
|
_, feature_height, feature_width, feature_channels = feature_shape |
|
if feature_channels == self._num_channels: |
|
continue |
|
elif ((feature_height + 1) // 2 == self._height and |
|
(feature_width + 1) // 2 == self._width): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_name = '_strided_conv_bn{}'.format(index + 1) |
|
utils.safe_setattr( |
|
self, current_name, convolutions.Conv2DSame( |
|
self._num_channels, 1, current_name[1:], |
|
strides=2, |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=self._bn_layer, |
|
activation='none', |
|
conv_kernel_weight_decay=self._conv_kernel_weight_decay)) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
current_name = '_resized_conv_bn{}'.format(index + 1) |
|
utils.safe_setattr( |
|
self, current_name, convolutions.Conv2DSame( |
|
self._num_channels, 1, current_name[1:], |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=self._bn_layer, |
|
activation='none', |
|
conv_kernel_weight_decay=self._conv_kernel_weight_decay)) |
|
|
|
def call(self, inputs, training=False): |
|
"""Performs a forward pass. |
|
|
|
Args: |
|
inputs: A list of input [batch, input_height, input_width, input_channels] |
|
tensors to fuse, where each input tensor may have different spatial |
|
resolutions and number of channels. |
|
training: A boolean, whether the model is in training mode. |
|
|
|
Returns: |
|
output: A fused feature [batch, height, width, num_channels] tensor. |
|
""" |
|
|
|
output_features = [] |
|
for index, feature in enumerate(inputs): |
|
_, feature_height, feature_width, feature_channels = ( |
|
feature.get_shape().as_list()) |
|
if feature_channels == self._num_channels: |
|
|
|
|
|
|
|
|
|
|
|
output_features.append(utils.resize_bilinear( |
|
feature, [self._height, self._width], |
|
align_corners=True)) |
|
elif ((feature_height + 1) // 2 == self._height and |
|
(feature_width + 1) // 2 == self._width): |
|
current_name = '_strided_conv_bn{}'.format(index + 1) |
|
feature = self._activation_fn(feature) |
|
feature = getattr(self, current_name)(feature, training=training) |
|
output_features.append(feature) |
|
else: |
|
current_name = '_resized_conv_bn{}'.format(index + 1) |
|
feature = self._activation_fn(feature) |
|
feature = getattr(self, current_name)(feature, training=training) |
|
output_features.append(utils.resize_bilinear( |
|
feature, [self._height, self._width], |
|
align_corners=True)) |
|
|
|
output_features[-1].set_shape( |
|
[None, |
|
self._height, |
|
self._width, |
|
self._num_channels]) |
|
output = tf.add_n(output_features) |
|
return self._activation_fn(output) |
|
|