# coding=utf-8 # Copyright 2021 The Deeplab2 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Implements convolutional and attentional residual block groups.""" import math import tensorflow as tf from deeplab2.model import utils from deeplab2.model.layers import activations from deeplab2.model.layers import axial_blocks from deeplab2.model.layers import drop_path from deeplab2.model.layers import dual_path_transformer from deeplab2.model.layers import positional_encodings from deeplab2.model.layers import recompute_grad as recompute_grad_lib # We will apply 10x larger learning rates on transformer layers. This global # variable name will be accessed when we build the optimizers. This keyword is # reserved and should not be a part of the variable names in a classification # pretrained backbone. TRANSFORMER = 'transformer' def _get_current_names(index): current_name = '_block{}'.format(index + 1) transformer_current_name = '_block{}_{}'.format(index + 1, TRANSFORMER) return current_name, transformer_current_name class BlockGroup(tf.keras.layers.Layer): """Applies a group of residual blocks with dual path transformer layers [1]. An optional dual-path transformer layer is inserted after each residual block. The transformer layer performs memory2pixel attention, pixel2memory attention, and memory2memory self-attention, while the standard residual block applies the pixel2pixel axial-attention, global-attention, or spatial convolution. Reference: [1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers, CVPR 2021. https://arxiv.org/abs/2012.00759 Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. """ def __init__(self, filters, num_blocks, name, original_resnet_stride, original_resnet_input_stride, output_stride=16, backbone_type='resnet_beta', positional_encoding_type=None, use_global_beyond_stride=0, use_axial_beyond_stride=16, use_transformer_beyond_stride=32, use_sac_beyond_stride=0, use_squeeze_and_excite=False, conv_use_recompute_grad=False, axial_use_recompute_grad=True, recompute_within_stride=0, transformer_use_recompute_grad=False, transformer_expansion=1, drop_path_keep_prob=0.8, drop_path_beyond_stride=16, drop_path_schedule='constant', activation='relu', attention_bottleneck_expansion=2, axial_layer_config=None, dual_path_transformer_layer_config=None, bn_layer=tf.keras.layers.BatchNormalization, conv_kernel_weight_decay=0.0): """Initializes a BlockGroup layer. Args: filters: An integer, the base number of channels for this block group. num_blocks: An integer, the number of blocks for this block group. name: A string, the name of the block group. original_resnet_stride: An integer, the original resnet stride for this block, usually 1 or 2. The stride will be applied if original_resnet_input_stride is smaller than the desired output_stride. Otherwise, the stride will not be applied, and atrous convolution will be used after the first block. original_resnet_input_stride: An integer, the total input stride in the original resnet. For example, the total input stride for the last stage of the original resnet is 16, and the total output stride is 32. This stride differs from the true stride of the feature in that we might use atrous convolution to change both the input and output stride to, e.g. 8, but its original resnet input stride remains the same. In this case, we also use the original resnet input stride to compute the atrous rate. output_stride: An integer, the desired output_stride for the ResNet. backbone_type: A string, the type of the backbone. Supports 'resnet', 'resnet_beta', and 'wider_resnet'. The 'resnet' refers to the original resnet with a 7x7 convolutional stem. The 'resnet_beta' means a resnet but with an inception stem. The 'wider_resnet' is a wider variant of resnet with extensively used 3x3 convolutions. positional_encoding_type: A string, type of the positional encoding. Support '2D', '1D', and None. use_global_beyond_stride: An integer, the stride beyond which we use global attention. Set to 0 if no global attention is desired. Defaults to 0, i.e. we do not use global attention. use_axial_beyond_stride: An integer, the stride beyond which we use axial attention. Note that use_global_beyond_stride has a higher priority, i.e. we use global attention if the stride is also beyond use_global_beyond_stride. Set to 0 if no axial attention is desired. Defaults to 16 as in MaX-DeepLab. use_transformer_beyond_stride: An integer, the stride beyond which we use a transformer layer. Set to 0 if no transformer is desired. Defaults to 32 as in MaX-DeepLab-S. use_sac_beyond_stride: An integer. Use the Switchable Atrous Convolution (SAC) beyond the specified stride. For example, if `use_sac_beyond_stride` = 16, SAC will be applied to the network stage whose output stride >= 16 (i.e., 16 and 32). Set to 0 or -1 to disable it. Defaults to 0 as SAC is not used in MaX-DeepLab. use_squeeze_and_excite: A boolean, whether squeeze-and-excite (SE) is used. Defaults to False as SE is not used in MaX-DeepLab. conv_use_recompute_grad: A boolean, whether to use the gradient checkpointing trick for convolutional blocks. This trick reduces accelerator memory usage, but takes longer to compute gradients. Defaults to False since convolutional layers are memory efficient. axial_use_recompute_grad: A boolean, whether to use the gradient checkpointing trick for axial blocks. This trick reduces accelerator memory usage, but takes longer to compute gradients. Defaults to True since it saves memory for axial blocks. recompute_within_stride: An integer, the stride within which we use the gradient checkpointing trick. This trick reduces accelerator memory usage, but takes longer to compute gradients. Defaults to 0 (do not recompute any layer). transformer_use_recompute_grad: A boolean, whether to use the gradient checkpointing trick for dual-path transformer blocks. This trick reduces accelerator memory usage, but takes longer to compute gradients. Defaults to False. transformer_expansion: An integer, the expansion ratio for the transformer bottleneck. drop_path_keep_prob: A float, the keep probability for dropping path. Defaults to 0.8 as in MaX-DeepLab-S. drop_path_beyond_stride: An integer, the stride beyond which we apply drop path augmentation. Defaults to 16 as in MaX-DeepLab-S. drop_path_schedule: A string, the drop path schedule. Currently, we support 'constant': use the same drop path keep probability for all stages, and 'linear': linearly decrease the drop path keep probability from 1.0 at 0-th stage (or STEM) to `drop_path_keep_prob` at last stage. activation: A string, type of activation function to apply. Support 'relu', 'swish' (or 'silu'), 'gelu', 'approximated_gelu', and 'elu'. attention_bottleneck_expansion: An integer, the expansion ratio for axial attention blocks. axial_layer_config: A dict, an argument dictionary for the axial layer. dual_path_transformer_layer_config: A dict, an argument dictionary for the transformer. bn_layer: An optional tf.keras.layers.Layer that computes the normalization (default: tf.keras.layers.BatchNormalization). conv_kernel_weight_decay: A float, the weight decay for convolution kernels. Raises: ValueError: If backbone_type is not one of 'resnet', 'resnet_beta', or 'wider_resnet'. ValueError: original_resnet_input_stride is not power of 2. ValueError: output_stride is not power of 2. """ if original_resnet_input_stride & (original_resnet_input_stride - 1): raise ValueError('original_resnet_input_stride is not power of 2.') if output_stride & (output_stride - 1): raise ValueError('output_stride is not power of 2.') super(BlockGroup, self).__init__(name=name) self._add_absolute_positional_encoding = None self._activation_fn = activations.get_activation(activation) self._num_blocks = num_blocks self._drop_path_keep_prob = [] self._recompute_grad = [] self._transformer_use_recompute_grad = transformer_use_recompute_grad if dual_path_transformer_layer_config is None: dual_path_transformer_layer_config = {} original_resnet_current_stride = original_resnet_input_stride use_sac = (original_resnet_input_stride * original_resnet_stride >= use_sac_beyond_stride > 0) recompute_grad = (original_resnet_input_stride * original_resnet_stride <= recompute_within_stride) for index in range(num_blocks): current_name, transformer_current_name = _get_current_names(index) # Compute the current strides. If there is a stride for this block group, # we do it in the first residual block. if index == 0 and original_resnet_input_stride < output_stride: current_strides = original_resnet_stride else: current_strides = 1 # Compute the current atrous rate. if original_resnet_current_stride > output_stride: atrous_rate = original_resnet_current_stride // output_stride else: atrous_rate = 1 # Compute the atrous rate for the second conv in the first basic block. if (index == 0 and original_resnet_input_stride * original_resnet_stride > output_stride): basic_block_second_conv_atrous_rate = ( original_resnet_input_stride * original_resnet_stride // output_stride) else: basic_block_second_conv_atrous_rate = atrous_rate # Compute the current drop_path_keep_prob. current_stage = math.log2(original_resnet_current_stride) - 1 if original_resnet_current_stride >= drop_path_beyond_stride: current_drop_path_keep_prob = drop_path.get_drop_path_keep_prob( drop_path_keep_prob, drop_path_schedule, current_stage=int(round(current_stage)), num_stages=4) else: current_drop_path_keep_prob = 1.0 # Compute which block_fn to use for this residual block. if original_resnet_current_stride >= use_global_beyond_stride > 0: attention_type = 'global' recompute_grad = axial_use_recompute_grad or recompute_grad filters_list = [filters * attention_bottleneck_expansion, filters, filters * 4] elif original_resnet_current_stride >= use_axial_beyond_stride > 0: attention_type = 'axial' recompute_grad = axial_use_recompute_grad or recompute_grad filters_list = [filters * attention_bottleneck_expansion, filters, filters * 4] elif backbone_type == 'resnet' or backbone_type == 'resnet_beta': attention_type = None recompute_grad = conv_use_recompute_grad or recompute_grad filters_list = [filters, filters, filters * 4] elif backbone_type == 'wider_resnet': if original_resnet_input_stride * original_resnet_stride < 32: # Wider-ResNet uses conv basic blocks except the last stage. attention_type = None recompute_grad = conv_use_recompute_grad or recompute_grad filters_list = [filters * 4, filters * 4] else: # Wider-ResNet uses an expanded bottleneck block in the last stage. attention_type = None recompute_grad = conv_use_recompute_grad or recompute_grad filters_list = [filters, filters * 2, filters * 4] else: raise ValueError(backbone_type + ' is not supported.') self._drop_path_keep_prob.append(current_drop_path_keep_prob) # Apply the residual block. # The inputs to block_fn should be activated features. block_fn = axial_blocks.AxialBlock( filters_list, kernel_size=3, strides=current_strides, atrous_rate=atrous_rate, use_squeeze_and_excite=use_squeeze_and_excite, use_sac=use_sac, bn_layer=bn_layer, activation=activation, name=current_name[1:], conv_kernel_weight_decay=conv_kernel_weight_decay, basic_block_second_conv_atrous_rate=( basic_block_second_conv_atrous_rate), attention_type=attention_type, axial_layer_config=axial_layer_config) self._recompute_grad.append(recompute_grad) utils.safe_setattr(self, current_name, block_fn) # Modify the original_resnet_stride according to the strides. if index == 0 and original_resnet_stride > 1: original_resnet_current_stride *= original_resnet_stride # Add absolute positional encoding if we will apply global attention # beyond this stride. if original_resnet_current_stride == use_global_beyond_stride > 0: self._add_absolute_positional_encoding = ( positional_encodings.AddAbsolutePositionalEncoding( 'add_absolute_positional_encoding', positional_encoding_type, bn_layer, conv_kernel_weight_decay)) if original_resnet_current_stride >= use_transformer_beyond_stride > 0: # Apply a dual-path transformer. transformer_block_fn = dual_path_transformer.DualPathTransformerLayer( name=transformer_current_name[1:], filters=int(128 * transformer_expansion), activation=activation, bn_layer=bn_layer, conv_kernel_weight_decay=conv_kernel_weight_decay, **dual_path_transformer_layer_config) utils.safe_setattr(self, transformer_current_name, transformer_block_fn) else: utils.safe_setattr(self, transformer_current_name, None) # Avoid using recompute_grad for the first call that builds the sub-layers. # Otherwise, recompute_grad will not track newly built model parameters. self._first_building_call = True def call(self, inputs, training=False): """Performs a forward pass. Args: inputs: two tensors. The first tensor is a pixel_space_input with shape [batch, height, width, pixel_channels]. The second tensor is memory_space_input with shape [batch, length, memory_channels]. This input will be used only if a transformer is used. Otherwise, the input is returned unmodified. training: A boolean flag indicating whether training behavior should be used (default: False). Returns: output: An output [batch, height, width, filters * 4] tensor. activated_output: An activated output [batch, height, width, filters * 4] tensor. memory_space_output: A memory space output [batch, length, memory_channels] tensor. """ # The pixel space inputs are activated features. activated_features, memory_space_output = inputs # Recompute_grad takes only float tensors as inputs. It does not allow # bools or boolean tensors. For this reason, we cast training to a float # tensor and cast it back after we go through the recompute_grad wrap. float_tensor_training = tf.cast(training, tf.float32) for index in range(self._num_blocks): current_name, transformer_current_name = _get_current_names(index) block_fn_no_recompute = getattr( self, current_name) transformer_block_fn_no_recompute = getattr( self, transformer_current_name) current_drop_path_keep_prob = self._drop_path_keep_prob[index] # Wrap the layer if we want to recompute it in the backward pass. if (self._recompute_grad[index] and training): # The seed is not actually used since we do not have any random # operation in the recomputed function. The purpose of the provided seed # is to prevent recompute_grad from generating a new seed variable which # is not compatible with model exporting. block_fn = recompute_grad_lib.recompute_grad( block_fn_no_recompute, seed=tf.constant(0, tf.int32)) else: block_fn = block_fn_no_recompute # The inputs to block_fn should be activated features. block_fn_inputs = [activated_features, float_tensor_training] # We have to define drop_path_masks outside the layer call and pass it # into the layer, because tf.recompute_grad (gradient checkpointing) does # not allow any randomness within the function call. In addition, # recompute_grad functions can only take Tensors as inputs, so we do not # pass the drop_path_random_mask (when it is None) into block_fn. if current_drop_path_keep_prob < 1.0 and training: drop_path_random_mask = drop_path.generate_drop_path_random_mask( activated_features, current_drop_path_keep_prob) block_fn_inputs.append(drop_path_random_mask) # Build the sub-layers when the block_fn is called for the first time. # Otherwise, recompute_grad will not track newly built model parameters. if self._first_building_call: _ = block_fn_no_recompute(tuple(block_fn_inputs)) # Apply the residual block. features, activated_features = block_fn(tuple(block_fn_inputs)) if index == 0 and self._add_absolute_positional_encoding is not None: features = self._add_absolute_positional_encoding(features, training=training) activated_features = self._activation_fn(features) if transformer_block_fn_no_recompute is not None: # Reshape pixel space features from 4D to 3D. _, height, width, channels = features.get_shape().as_list() features = tf.reshape( features, [-1, height * width, channels]) # Wrap the layer if we want to recompute it in the backward pass. if (self._transformer_use_recompute_grad and training): # The seed is not actually used since we do not have any random # operation in the recomputed function. The purpose of the provided # seed is to prevent recompute_grad from generating a new seed # variable which is not compatible with model exporting. transformer_block_fn = recompute_grad_lib.recompute_grad( transformer_block_fn_no_recompute, seed=tf.constant(0, tf.int32)) else: transformer_block_fn = transformer_block_fn_no_recompute transformer_block_fn_input_list = [ features, memory_space_output, float_tensor_training] # We have to define drop_path_masks outside the layer call and pass it # into the layer, because recompute_grad (gradient checkpointing) does # not allow any randomness within the function call. In addition, # recompute_grad functions can only take Tensors as inputs, so we do not # pass the drop_path_masks (when they are None) into # transformer_block_fn. if current_drop_path_keep_prob < 1.0 and training: # Drop path random mask for pixel space attention. pixel_space_drop_path_mask = drop_path.generate_drop_path_random_mask( memory_space_output, current_drop_path_keep_prob) # Drop path random mask for memory space attention. memory_space_attention_drop_path_mask = ( drop_path.generate_drop_path_random_mask( memory_space_output, current_drop_path_keep_prob)) # Drop path random mask for memory space feed-forward network. memory_space_feed_forward_network_drop_path_mask = ( drop_path.generate_drop_path_random_mask( memory_space_output, current_drop_path_keep_prob)) transformer_block_fn_input_list += [ pixel_space_drop_path_mask, memory_space_attention_drop_path_mask, memory_space_feed_forward_network_drop_path_mask] # Build the sub-layers when the transformer_block_fn is called for the # first time. Otherwise, recompute_grad will not track newly built model # parameters. if self._first_building_call: _ = transformer_block_fn_no_recompute( tuple(transformer_block_fn_input_list)) # Apply a dual-path transformer. features, activated_features, memory_space_output = ( transformer_block_fn(tuple(transformer_block_fn_input_list))) # Reshape pixel space features back to 4D. features = tf.reshape(features, [-1, height, width, channels]) activated_features = tf.reshape(activated_features, [-1, height, width, channels]) # Now the first call has finished and the sub-layers have been built. self._first_building_call = False # We also return the non-activated output so that the function is compatible # with a decoder that takes a non-activated tensor as input. return features, activated_features, memory_space_output