# coding=utf-8 # Copyright 2021 The Deeplab2 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Implements Axial-Blocks proposed in Axial-DeepLab [1]. Axial-Blocks are based on residual bottleneck blocks, but with the 3x3 convolution replaced with two axial-attention layers, one on the height-axis, followed by the other on the width-axis. [1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation, ECCV 2020 Spotlight. Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. """ import tensorflow as tf from deeplab2.model import utils from deeplab2.model.layers import activations from deeplab2.model.layers import axial_layers from deeplab2.model.layers import convolutions from deeplab2.model.layers import squeeze_and_excite class AxialBlock(tf.keras.layers.Layer): """An AxialBlock as a building block for an Axial-ResNet model. We implement the Axial-Block proposed in [1] in a general way that also includes convolutional residual blocks, such as the basic block and the bottleneck block (w/ and w/o Switchable Atrous Convolution). A basic block consists of two 3x3 convolutions and a residual connection. It is the main building block for wide-resnet variants. A bottleneck block consists of consecutive 1x1, 3x3, 1x1 convolutions and a residual connection. It is the main building block for standard resnet variants. An axial block consists of a 1x1 input convolution, a self-attention layer (either axial-attention or global attention), a 1x1 output convolution, and a residual connection. It is the main building block for axial-resnet variants. Note: We apply the striding in the first spatial operation (i.e. 3x3 convolution or self-attention layer). """ def __init__(self, filters_list, kernel_size=3, strides=1, atrous_rate=1, use_squeeze_and_excite=False, use_sac=False, bn_layer=tf.keras.layers.BatchNormalization, activation='relu', name=None, conv_kernel_weight_decay=0.0, basic_block_second_conv_atrous_rate=None, attention_type=None, axial_layer_config=None): """Initializes an AxialBlock. Args: filters_list: A list of filter numbers in the residual block. We currently support filters_list with two or three elements. Two elements specify the filters for two consecutive 3x3 convolutions, while three elements specify the filters for three convolutions (1x1, 3x3, and 1x1). kernel_size: The size of the convolution kernels (default: 3). strides: The strides of the block (default: 1). atrous_rate: The atrous rate of the 3x3 convolutions (default: 1). If this residual block is a basic block, it is recommendeded to specify correct basic_block_second_conv_atrous_rate for the second 3x3 convolution. Otherwise, the second conv will also use atrous rate, which might cause atrous inconsistency with different output strides, as tested in axial_block_groups_test.test_atrous_consistency_basic_block. use_squeeze_and_excite: A boolean flag indicating whether squeeze-and-excite (SE) is used. use_sac: A boolean, using the Switchable Atrous Convolution (SAC) or not. bn_layer: A tf.keras.layers.Layer that computes the normalization (default: tf.keras.layers.BatchNormalization). activation: A string specifying the activation function to apply. name: An string specifying the name of the layer (default: None). conv_kernel_weight_decay: A float, the weight decay for convolution kernels. basic_block_second_conv_atrous_rate: An integer, the atrous rate for the second convolution of basic block. This is necessary to ensure atrous consistency with different output_strides. Defaults to atrous_rate. attention_type: A string, type of attention to apply. Support 'axial' and 'global'. axial_layer_config: A dict, an argument dictionary for the axial layer. Raises: ValueError: If filters_list does not have two or three elements. ValueError: If attention_type is not supported. ValueError: If double_global_attention is True in axial_layer_config. """ super(AxialBlock, self).__init__(name=name) self._filters_list = filters_list self._strides = strides self._use_squeeze_and_excite = use_squeeze_and_excite self._bn_layer = bn_layer self._activate_fn = activations.get_activation(activation) self._attention_type = attention_type if axial_layer_config is None: axial_layer_config = {} if basic_block_second_conv_atrous_rate is None: basic_block_second_conv_atrous_rate = atrous_rate if len(filters_list) == 3: # Three consecutive convolutions: 1x1, 3x3, and 1x1. self._conv1_bn_act = convolutions.Conv2DSame( filters_list[0], 1, 'conv1_bn_act', use_bias=False, use_bn=True, bn_layer=bn_layer, activation=activation, conv_kernel_weight_decay=conv_kernel_weight_decay) if attention_type is None or attention_type.lower() == 'none': self._conv2_bn_act = convolutions.Conv2DSame( filters_list[1], kernel_size, 'conv2_bn_act', strides=strides, atrous_rate=atrous_rate, use_bias=False, use_bn=True, bn_layer=bn_layer, activation=activation, use_switchable_atrous_conv=use_sac, # We default to use global context in SAC if use_sac is True. This # setting is experimentally found effective. use_global_context_in_sac=use_sac, conv_kernel_weight_decay=conv_kernel_weight_decay) elif attention_type == 'axial': if 'double_global_attention' in axial_layer_config: if axial_layer_config['double_global_attention']: raise ValueError('Double_global_attention takes no effect in ' 'AxialAttention2D.') del axial_layer_config['double_global_attention'] self._attention = axial_layers.AxialAttention2D( strides=strides, filters=filters_list[1], name='attention', bn_layer=bn_layer, conv_kernel_weight_decay=conv_kernel_weight_decay, **axial_layer_config) elif attention_type == 'global': self._attention = axial_layers.GlobalAttention2D( strides=strides, filters=filters_list[1], name='attention', bn_layer=bn_layer, conv_kernel_weight_decay=conv_kernel_weight_decay, **axial_layer_config) else: raise ValueError(attention_type + ' is not supported.') # Here we apply a batch norm with gamma initialized at zero. This ensures # that at random initialization of the model, the skip connections # dominate all residual blocks. In this way, all the skip connections # construct an identity mapping that passes the gradients (without any # distortion from the randomly initialized blocks) to all residual blocks. # This trick helps training at early epochs. # Reference: "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour". # https://arxiv.org/abs/1706.02677 self._conv3_bn = convolutions.Conv2DSame( filters_list[2], 1, 'conv3_bn', use_bias=False, use_bn=True, bn_layer=bn_layer, bn_gamma_initializer='zeros', activation='none', conv_kernel_weight_decay=conv_kernel_weight_decay) elif len(filters_list) == 2: # Two consecutive convolutions: 3x3 and 3x3. self._conv1_bn_act = convolutions.Conv2DSame( filters_list[0], kernel_size, 'conv1_bn_act', strides=strides, atrous_rate=atrous_rate, use_bias=False, use_bn=True, bn_layer=bn_layer, activation=activation, use_switchable_atrous_conv=use_sac, use_global_context_in_sac=use_sac, conv_kernel_weight_decay=conv_kernel_weight_decay) # Here we apply a batch norm with gamma initialized at zero. This ensures # that at random initialization of the model, the skip connections # dominate all residual blocks. In this way, all the skip connections # construct an identity mapping that passes the gradients (without any # distortion from the randomly initialized blocks) to all residual blocks. # This trick helps training at early epochs. # Reference: "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour". # https://arxiv.org/abs/1706.02677 self._conv2_bn = convolutions.Conv2DSame( filters_list[1], kernel_size, 'conv2_bn', strides=1, atrous_rate=basic_block_second_conv_atrous_rate, use_bias=False, use_bn=True, bn_layer=bn_layer, bn_gamma_initializer='zeros', activation='none', use_switchable_atrous_conv=use_sac, use_global_context_in_sac=use_sac, conv_kernel_weight_decay=conv_kernel_weight_decay) else: raise ValueError('Expect filters_list to have length 2 or 3; got %d' % len(filters_list)) if self._use_squeeze_and_excite: self._squeeze_and_excite = squeeze_and_excite.SimplifiedSqueezeAndExcite( filters_list[-1]) self._conv_kernel_weight_decay = conv_kernel_weight_decay def build(self, input_shape_list): input_tensor_shape = input_shape_list[0] self._shortcut = None if input_tensor_shape[3] != self._filters_list[-1]: self._shortcut = convolutions.Conv2DSame( self._filters_list[-1], 1, 'shortcut', strides=self._strides, use_bias=False, use_bn=True, bn_layer=self._bn_layer, activation='none', conv_kernel_weight_decay=self._conv_kernel_weight_decay) def call(self, inputs): """Performs a forward pass. We have to define drop_path_random_mask outside the layer call and pass it into the layer, because recompute_grad (gradient checkpointing) does not allow any randomness within the function call. In addition, recompute_grad only supports float tensors as inputs. For this reason, the training flag should be also passed as a float tensor. For the same reason, we cannot support passing drop_path_random_mask as None. Instead, we ask the users to pass only the first two tensors when drop path is not used. Args: inputs: A tuple of 2 or 3 tensors, containing input_tensor should be an input tensor of type tf.Tensor with shape [batch, height, width, channels]. float_tensor_training should be a float tensor of 0.0 or 1.0, whether the model is in training mode. (optional) drop_path_random_mask is a drop path random mask of type tf.Tensor with shape [batch, 1, 1, 1]. Returns: outputs: two tensors. The first tensor does not use the last activation function. The second tensor uses the activation. We return non-activated output to support MaX-DeepLab which uses non-activated feature for the stacked decoders. Raises: ValueError: If the length of inputs is not 2 or 3. """ if len(inputs) not in (2, 3): raise ValueError('The length of inputs should be either 2 or 3.') # Unpack the inputs. input_tensor, float_tensor_training, drop_path_random_mask = ( utils.pad_sequence_with_none(inputs, target_length=3)) # Recompute_grad takes only float tensors as inputs. It does not allow # bools or boolean tensors. For this reason, we cast training to a float # tensor outside this call, and now we cast it back to a boolean tensor. training = tf.cast(float_tensor_training, tf.bool) shortcut = input_tensor if self._shortcut is not None: shortcut = self._shortcut(shortcut, training=training) elif self._strides != 1: shortcut = shortcut[:, ::self._strides, ::self._strides, :] if len(self._filters_list) == 3: x = self._conv1_bn_act(input_tensor, training=training) if (self._attention_type is None or self._attention_type.lower() == 'none'): x = self._conv2_bn_act(x, training=training) else: x = self._attention(x, training=training) x = self._activate_fn(x) x = self._conv3_bn(x, training=training) if len(self._filters_list) == 2: x = self._conv1_bn_act(input_tensor, training=training) x = self._conv2_bn(x, training=training) if self._use_squeeze_and_excite: x = self._squeeze_and_excite(x) if drop_path_random_mask is not None: x = x * drop_path_random_mask x = x + shortcut return x, self._activate_fn(x)