deeplab2 / model /layers /resized_fuse.py
akhaliq3
spaces demo
506da10
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements a resized feature fuser for stacked decoders in MaX-DeepLab.
Reference:
MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers,
CVPR 2021. https://arxiv.org/abs/2012.00759
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.
"""
import tensorflow as tf
from deeplab2.model import utils
from deeplab2.model.layers import activations
from deeplab2.model.layers import convolutions
class ResizedFuse(tf.keras.layers.Layer):
"""Fuses features by resizing and 1x1 convolutions.
This function fuses all input features to a desired shape, by projecting the
features to the desired number of channels, bilinear resizing the outputs
(either upsampling or downsampling), and finally adding the outputs. If the
input channel equals the desired output channels, the 1x1 convolutional
projection is skipped. If the projection and bilinear resizing can be fused
into a stride 2 convolution, we use this faster implementation. Other strides
are also supported with the bilinear resizing, but are probably slower than
strided convolutions.
Reference:
MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers,
CVPR 2021. https://arxiv.org/abs/2012.00759
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.
"""
def __init__(self,
name,
height,
width,
num_channels,
activation='relu',
bn_layer=tf.keras.layers.BatchNormalization,
conv_kernel_weight_decay=0.0):
"""Initializes a ResizedFuse layer.
Args:
name: A string, the name of this layer.
height: An integer, the desired height of the output.
width: An integer, the desired width of the output.
num_channels: An integer, the num of output channels.
activation: A string, type of activation function to apply.
bn_layer: A tf.keras.layers.Layer that computes the normalization
(default: tf.keras.layers.BatchNormalization).
conv_kernel_weight_decay: A float, the weight decay for convolution
kernels.
"""
super(ResizedFuse, self).__init__(name=name)
self._height = height
self._width = width
self._num_channels = num_channels
self._activation_fn = activations.get_activation(activation)
self._bn_layer = bn_layer
self._conv_kernel_weight_decay = conv_kernel_weight_decay
def build(self, input_shapes):
for index, feature_shape in enumerate(input_shapes):
_, feature_height, feature_width, feature_channels = feature_shape
if feature_channels == self._num_channels:
continue
elif ((feature_height + 1) // 2 == self._height and
(feature_width + 1) // 2 == self._width):
# Use stride 2 convolution to accelerate the operation if it generates
# the desired spatial shape. Otherwise, the more general 1x1 convolution
# and bilinear resizing are applied.
# In a stacked decoder, we follow relu-conv-bn because we do the feature
# summation before relu and after bn (following ResNet bottleneck
# design). This ordering makes it easier to implement. Besides, it
# avoids using many 1x1 convolutions when the input has a correct shape.
current_name = '_strided_conv_bn{}'.format(index + 1)
utils.safe_setattr(
self, current_name, convolutions.Conv2DSame(
self._num_channels, 1, current_name[1:],
strides=2,
use_bias=False,
use_bn=True,
bn_layer=self._bn_layer,
activation='none',
conv_kernel_weight_decay=self._conv_kernel_weight_decay))
else:
# If the input channel does not match that of the output, and the
# operation cannot be accelerated by stride 2 convolution, then we
# perform a flexible operation as follows. We first project the feature
# to the desired number of channels, and then bilinearly resize the
# output to the desired spatial resolution.
current_name = '_resized_conv_bn{}'.format(index + 1)
utils.safe_setattr(
self, current_name, convolutions.Conv2DSame(
self._num_channels, 1, current_name[1:],
use_bias=False,
use_bn=True,
bn_layer=self._bn_layer,
activation='none',
conv_kernel_weight_decay=self._conv_kernel_weight_decay))
def call(self, inputs, training=False):
"""Performs a forward pass.
Args:
inputs: A list of input [batch, input_height, input_width, input_channels]
tensors to fuse, where each input tensor may have different spatial
resolutions and number of channels.
training: A boolean, whether the model is in training mode.
Returns:
output: A fused feature [batch, height, width, num_channels] tensor.
"""
output_features = []
for index, feature in enumerate(inputs):
_, feature_height, feature_width, feature_channels = (
feature.get_shape().as_list())
if feature_channels == self._num_channels:
# Resize the input feature if the number of channels equals the output.
# We do not use a 1x1 convolution for this case because the previous
# operation and the next operation are usually also 1x1 convolutions.
# Besides, in stacked decoder, a feature can be reused many time, so it
# saves parameter to avoid those many 1x1 convolutions.
output_features.append(utils.resize_bilinear(
feature, [self._height, self._width],
align_corners=True))
elif ((feature_height + 1) // 2 == self._height and
(feature_width + 1) // 2 == self._width):
current_name = '_strided_conv_bn{}'.format(index + 1)
feature = self._activation_fn(feature)
feature = getattr(self, current_name)(feature, training=training)
output_features.append(feature)
else:
current_name = '_resized_conv_bn{}'.format(index + 1)
feature = self._activation_fn(feature)
feature = getattr(self, current_name)(feature, training=training)
output_features.append(utils.resize_bilinear(
feature, [self._height, self._width],
align_corners=True))
# Set the spatial shape of each output feature if possible.
output_features[-1].set_shape(
[None,
self._height,
self._width,
self._num_channels])
output = tf.add_n(output_features)
return self._activation_fn(output)