deeplab2 / model /layers /axial_blocks.py
akhaliq3
spaces demo
506da10
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements Axial-Blocks proposed in Axial-DeepLab [1].
Axial-Blocks are based on residual bottleneck blocks, but with the 3x3
convolution replaced with two axial-attention layers, one on the height-axis,
followed by the other on the width-axis.
[1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation,
ECCV 2020 Spotlight.
Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille,
Liang-Chieh Chen.
"""
import tensorflow as tf
from deeplab2.model import utils
from deeplab2.model.layers import activations
from deeplab2.model.layers import axial_layers
from deeplab2.model.layers import convolutions
from deeplab2.model.layers import squeeze_and_excite
class AxialBlock(tf.keras.layers.Layer):
"""An AxialBlock as a building block for an Axial-ResNet model.
We implement the Axial-Block proposed in [1] in a general way that also
includes convolutional residual blocks, such as the basic block and the
bottleneck block (w/ and w/o Switchable Atrous Convolution).
A basic block consists of two 3x3 convolutions and a residual connection. It
is the main building block for wide-resnet variants.
A bottleneck block consists of consecutive 1x1, 3x3, 1x1 convolutions and a
residual connection. It is the main building block for standard resnet
variants.
An axial block consists of a 1x1 input convolution, a self-attention layer
(either axial-attention or global attention), a 1x1 output convolution, and a
residual connection. It is the main building block for axial-resnet variants.
Note: We apply the striding in the first spatial operation (i.e. 3x3
convolution or self-attention layer).
"""
def __init__(self,
filters_list,
kernel_size=3,
strides=1,
atrous_rate=1,
use_squeeze_and_excite=False,
use_sac=False,
bn_layer=tf.keras.layers.BatchNormalization,
activation='relu',
name=None,
conv_kernel_weight_decay=0.0,
basic_block_second_conv_atrous_rate=None,
attention_type=None,
axial_layer_config=None):
"""Initializes an AxialBlock.
Args:
filters_list: A list of filter numbers in the residual block. We currently
support filters_list with two or three elements. Two elements specify
the filters for two consecutive 3x3 convolutions, while three elements
specify the filters for three convolutions (1x1, 3x3, and 1x1).
kernel_size: The size of the convolution kernels (default: 3).
strides: The strides of the block (default: 1).
atrous_rate: The atrous rate of the 3x3 convolutions (default: 1). If this
residual block is a basic block, it is recommendeded to specify correct
basic_block_second_conv_atrous_rate for the second 3x3 convolution.
Otherwise, the second conv will also use atrous rate, which might cause
atrous inconsistency with different output strides, as tested in
axial_block_groups_test.test_atrous_consistency_basic_block.
use_squeeze_and_excite: A boolean flag indicating whether
squeeze-and-excite (SE) is used.
use_sac: A boolean, using the Switchable Atrous Convolution (SAC) or not.
bn_layer: A tf.keras.layers.Layer that computes the normalization
(default: tf.keras.layers.BatchNormalization).
activation: A string specifying the activation function to apply.
name: An string specifying the name of the layer (default: None).
conv_kernel_weight_decay: A float, the weight decay for convolution
kernels.
basic_block_second_conv_atrous_rate: An integer, the atrous rate for the
second convolution of basic block. This is necessary to ensure atrous
consistency with different output_strides. Defaults to atrous_rate.
attention_type: A string, type of attention to apply. Support 'axial' and
'global'.
axial_layer_config: A dict, an argument dictionary for the axial layer.
Raises:
ValueError: If filters_list does not have two or three elements.
ValueError: If attention_type is not supported.
ValueError: If double_global_attention is True in axial_layer_config.
"""
super(AxialBlock, self).__init__(name=name)
self._filters_list = filters_list
self._strides = strides
self._use_squeeze_and_excite = use_squeeze_and_excite
self._bn_layer = bn_layer
self._activate_fn = activations.get_activation(activation)
self._attention_type = attention_type
if axial_layer_config is None:
axial_layer_config = {}
if basic_block_second_conv_atrous_rate is None:
basic_block_second_conv_atrous_rate = atrous_rate
if len(filters_list) == 3:
# Three consecutive convolutions: 1x1, 3x3, and 1x1.
self._conv1_bn_act = convolutions.Conv2DSame(
filters_list[0], 1, 'conv1_bn_act',
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation=activation,
conv_kernel_weight_decay=conv_kernel_weight_decay)
if attention_type is None or attention_type.lower() == 'none':
self._conv2_bn_act = convolutions.Conv2DSame(
filters_list[1], kernel_size, 'conv2_bn_act',
strides=strides,
atrous_rate=atrous_rate,
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation=activation,
use_switchable_atrous_conv=use_sac,
# We default to use global context in SAC if use_sac is True. This
# setting is experimentally found effective.
use_global_context_in_sac=use_sac,
conv_kernel_weight_decay=conv_kernel_weight_decay)
elif attention_type == 'axial':
if 'double_global_attention' in axial_layer_config:
if axial_layer_config['double_global_attention']:
raise ValueError('Double_global_attention takes no effect in '
'AxialAttention2D.')
del axial_layer_config['double_global_attention']
self._attention = axial_layers.AxialAttention2D(
strides=strides,
filters=filters_list[1],
name='attention',
bn_layer=bn_layer,
conv_kernel_weight_decay=conv_kernel_weight_decay,
**axial_layer_config)
elif attention_type == 'global':
self._attention = axial_layers.GlobalAttention2D(
strides=strides,
filters=filters_list[1],
name='attention',
bn_layer=bn_layer,
conv_kernel_weight_decay=conv_kernel_weight_decay,
**axial_layer_config)
else:
raise ValueError(attention_type + ' is not supported.')
# Here we apply a batch norm with gamma initialized at zero. This ensures
# that at random initialization of the model, the skip connections
# dominate all residual blocks. In this way, all the skip connections
# construct an identity mapping that passes the gradients (without any
# distortion from the randomly initialized blocks) to all residual blocks.
# This trick helps training at early epochs.
# Reference: "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour".
# https://arxiv.org/abs/1706.02677
self._conv3_bn = convolutions.Conv2DSame(
filters_list[2], 1, 'conv3_bn',
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
bn_gamma_initializer='zeros',
activation='none',
conv_kernel_weight_decay=conv_kernel_weight_decay)
elif len(filters_list) == 2:
# Two consecutive convolutions: 3x3 and 3x3.
self._conv1_bn_act = convolutions.Conv2DSame(
filters_list[0], kernel_size, 'conv1_bn_act',
strides=strides,
atrous_rate=atrous_rate,
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation=activation,
use_switchable_atrous_conv=use_sac,
use_global_context_in_sac=use_sac,
conv_kernel_weight_decay=conv_kernel_weight_decay)
# Here we apply a batch norm with gamma initialized at zero. This ensures
# that at random initialization of the model, the skip connections
# dominate all residual blocks. In this way, all the skip connections
# construct an identity mapping that passes the gradients (without any
# distortion from the randomly initialized blocks) to all residual blocks.
# This trick helps training at early epochs.
# Reference: "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour".
# https://arxiv.org/abs/1706.02677
self._conv2_bn = convolutions.Conv2DSame(
filters_list[1], kernel_size, 'conv2_bn',
strides=1,
atrous_rate=basic_block_second_conv_atrous_rate,
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
bn_gamma_initializer='zeros',
activation='none',
use_switchable_atrous_conv=use_sac,
use_global_context_in_sac=use_sac,
conv_kernel_weight_decay=conv_kernel_weight_decay)
else:
raise ValueError('Expect filters_list to have length 2 or 3; got %d' %
len(filters_list))
if self._use_squeeze_and_excite:
self._squeeze_and_excite = squeeze_and_excite.SimplifiedSqueezeAndExcite(
filters_list[-1])
self._conv_kernel_weight_decay = conv_kernel_weight_decay
def build(self, input_shape_list):
input_tensor_shape = input_shape_list[0]
self._shortcut = None
if input_tensor_shape[3] != self._filters_list[-1]:
self._shortcut = convolutions.Conv2DSame(
self._filters_list[-1], 1, 'shortcut',
strides=self._strides,
use_bias=False,
use_bn=True,
bn_layer=self._bn_layer,
activation='none',
conv_kernel_weight_decay=self._conv_kernel_weight_decay)
def call(self, inputs):
"""Performs a forward pass.
We have to define drop_path_random_mask outside the layer call and pass it
into the layer, because recompute_grad (gradient checkpointing) does not
allow any randomness within the function call. In addition, recompute_grad
only supports float tensors as inputs. For this reason, the training flag
should be also passed as a float tensor. For the same reason, we cannot
support passing drop_path_random_mask as None. Instead, we ask the users to
pass only the first two tensors when drop path is not used.
Args:
inputs: A tuple of 2 or 3 tensors, containing
input_tensor should be an input tensor of type tf.Tensor with shape
[batch, height, width, channels].
float_tensor_training should be a float tensor of 0.0 or 1.0, whether
the model is in training mode.
(optional) drop_path_random_mask is a drop path random mask of type
tf.Tensor with shape [batch, 1, 1, 1].
Returns:
outputs: two tensors. The first tensor does not use the last activation
function. The second tensor uses the activation. We return non-activated
output to support MaX-DeepLab which uses non-activated feature for the
stacked decoders.
Raises:
ValueError: If the length of inputs is not 2 or 3.
"""
if len(inputs) not in (2, 3):
raise ValueError('The length of inputs should be either 2 or 3.')
# Unpack the inputs.
input_tensor, float_tensor_training, drop_path_random_mask = (
utils.pad_sequence_with_none(inputs, target_length=3))
# Recompute_grad takes only float tensors as inputs. It does not allow
# bools or boolean tensors. For this reason, we cast training to a float
# tensor outside this call, and now we cast it back to a boolean tensor.
training = tf.cast(float_tensor_training, tf.bool)
shortcut = input_tensor
if self._shortcut is not None:
shortcut = self._shortcut(shortcut, training=training)
elif self._strides != 1:
shortcut = shortcut[:, ::self._strides, ::self._strides, :]
if len(self._filters_list) == 3:
x = self._conv1_bn_act(input_tensor, training=training)
if (self._attention_type is None or
self._attention_type.lower() == 'none'):
x = self._conv2_bn_act(x, training=training)
else:
x = self._attention(x, training=training)
x = self._activate_fn(x)
x = self._conv3_bn(x, training=training)
if len(self._filters_list) == 2:
x = self._conv1_bn_act(input_tensor, training=training)
x = self._conv2_bn(x, training=training)
if self._use_squeeze_and_excite:
x = self._squeeze_and_excite(x)
if drop_path_random_mask is not None:
x = x * drop_path_random_mask
x = x + shortcut
return x, self._activate_fn(x)