# coding=utf-8 # Copyright 2021 The Deeplab2 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Implements dual path transformer layers proposed in MaX-DeepLab [1]. Dual-path transformer introduces a global memory path in addition to a CNN path, allowing bi-directional communication with any CNN layers. [1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers, CVPR 2021. Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. """ import tensorflow as tf from deeplab2.model import utils from deeplab2.model.layers import activations from deeplab2.model.layers import convolutions class AttentionOperation(tf.keras.layers.Layer): """Computes standard 1D multi-head attention with query, key, and value.""" def __init__(self, name, activation, transformer_activation, bn_layer=tf.keras.layers.BatchNormalization): """Initializes an AttentionOperation layer. Args: name: A string, the name of this layer. activation: A string, type of activation function to apply. transformer_activation: A string, type of activation function for self-attention. Support 'sigmoid' and 'softmax'. bn_layer: An optional tf.keras.layers.Layer that computes the normalization (default: tf.keras.layers.BatchNormalization). """ super(AttentionOperation, self).__init__(name=name) # batch_norm_similarity has shape [batch, num_heads, num_query, num_key], # where num_query and num_key usually equals to height or width or length, # i.e., spatial dimensions, so batch norm is applied to axis=1 only. self._batch_norm_similarity = bn_layer(axis=1, name='batch_norm_similarity') # batch_norm_retrieved_value is done on shape [batch, num_heads, length, # value_channels], which will be reshaped to the output shape [batch, # length, value_channels * num_heads], so we apply batch norm on the # effective channel dimension -- value_channels * num_heads. self._batch_norm_retrieved_value = bn_layer( axis=[1, 3], name='batch_norm_retrieved_value') self._activation_fn = activations.get_activation(activation) self._transformer_activation_fn = activations.get_activation( transformer_activation) def call(self, inputs, training=False): """Performs an AttentionOperation. Args: inputs: A tuple of (query, key, value), where query is [batch, num_head, query_length, channels] tensor, key is a [batch, num_head, key_length, channels] tensor, and value is a [batch, key_length, num_head, value_channels] tensor. training: A boolean, whether the model is in training mode. Returns: output: A [batch, query_length, num_head * value_channels] tensor, the retrieved value. """ # Decode query, key, and value from inputs. query, key, value = inputs # Compute attention similarity. similarity_logits = tf.einsum('bhld,bhmd->bhlm', query, key) similarity_logits = self._batch_norm_similarity( similarity_logits, training=training) # Apply a transformer attention activation function, e.g. softmax. attention_weights = self._transformer_activation_fn(similarity_logits) # Retrieve the value content. retrieved_value = tf.einsum( 'bhlm,bmhd->bhld', attention_weights, value) retrieved_value = self._batch_norm_retrieved_value( retrieved_value, training=training) retrieved_value = self._activation_fn(retrieved_value) # Reshape the output. return utils.transpose_and_reshape_for_attention_operation( retrieved_value) class DualPathTransformerLayer(tf.keras.layers.Layer): """Applies a dual path transformer layer, as proposed in MaX-DeepLab [1]. Dual-path transformer layer takes a pixel space input and a memory space input, and performs memory2pixel attention, pixel2memory attention, and memory2memory self-attention. Note that the pixel2pixel self-attention or convolution in the pixel space is implemented in axial_layers.py and axial_blocks.py. Thus, the pixel2pixel operation is not included in this DualPathTransformerLayer implementation. Please use this class together with a residual block with axial-attention, global-attention, or convolution in order to construct the full dual path transformer in the paper. [1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers, CVPR 2021. Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. """ def __init__(self, name='dual_path_transformer_layer', activation='relu', filters=128, num_heads=8, bottleneck_expansion=2, key_expansion=1, value_expansion=2, feed_forward_network_channels=2048, use_memory_self_attention=True, use_pixel2memory_feedback_attention=True, transformer_activation='softmax', bn_layer=tf.keras.layers.BatchNormalization, conv_kernel_weight_decay=0.0): """Initializes a DualPathTransformerLayer. This function implements a dual path transformer layer between a pixel space and a memory space, as described in the MaX-DeepLab paper. In this dual path transformer, the memory2pixel cross attention and the memory self-attention share a single activation, e.g. softmax. Reference: MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers", CVPR 2021. https://arxiv.org/abs/2012.00759 Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. Args: name: A string, the name of this dual path transformer layer. activation: A string, type of activation function to apply. filters: An integer, the base number of channels for the layer. num_heads: An integer, the number of heads in multi-head attention. bottleneck_expansion: A float, the channel expansion ratio for the bottleneck. key_expansion: A float, the channel expansion ratio for keys. value_expansion: A float, the channel expansion ratio for values. feed_forward_network_channels: An integer, the number of channels for the feed_forward_network. Zero means no feed_forward_network will be applied. use_memory_self_attention: A boolean, whether to apply the memory space self-attention. use_pixel2memory_feedback_attention: A boolean, whether to apply the pixel2memory feedback attention. transformer_activation: A string, type of activation function for self-attention. Support 'sigmoid' and 'softmax'. bn_layer: A tf.keras.layers.Layer that computes the normalization (default: tf.keras.layers.BatchNormalization). conv_kernel_weight_decay: A float, the weight decay for convolution kernels. Raises: ValueError: If filters * key_expansion is not divisible by num_heads. ValueError: If filters * value_expansion is not divisible by num_heads. """ super(DualPathTransformerLayer, self).__init__(name=name) bottleneck_channels = int(round(filters * bottleneck_expansion)) total_key_depth = int(round(filters * key_expansion)) total_value_depth = int(round(filters * value_expansion)) if total_key_depth % num_heads: raise ValueError('Total_key_depth should be divisible by num_heads.') if total_value_depth % num_heads: raise ValueError('Total_value_depth should be divisible by num_heads.') # Compute query key value with one convolution and a batch norm layer. The # initialization std is standard transformer initialization (without batch # norm), as used in SASA and ViT. In our case, we use batch norm by default, # so it does not require careful tuning. If one wants to remove all batch # norms in axial attention, this standard initialization should still be # good, but a more careful initialization is encouraged. initialization_std = bottleneck_channels ** -0.5 self._memory_conv1_bn_act = convolutions.Conv1D( bottleneck_channels, 'memory_conv1_bn_act', use_bias=False, use_bn=True, bn_layer=bn_layer, activation=activation, conv_kernel_weight_decay=conv_kernel_weight_decay) self._pixel_conv1_bn_act = convolutions.Conv1D( bottleneck_channels, 'pixel_conv1_bn_act', use_bias=False, use_bn=True, bn_layer=bn_layer, activation=activation, conv_kernel_weight_decay=conv_kernel_weight_decay) # We always compute the query for memory space, since it gathers information # from the pixel space and thus cannot be removed. We compute the key and # value for memory space only when they are necessary (i.e. either # use_memory_self_attention or use_pixel2memory_feedback_attention). if use_memory_self_attention or use_pixel2memory_feedback_attention: self._memory_qkv_conv_bn = convolutions.Conv1D( total_key_depth * 2 + total_value_depth, 'memory_qkv_conv_bn', use_bias=False, use_bn=True, bn_layer=bn_layer, activation='none', conv_kernel_weight_decay=conv_kernel_weight_decay, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=initialization_std)) else: # Compute memory query only if memory key and value are not used. self._memory_query_conv_bn = convolutions.Conv1D( total_key_depth, 'memory_query_conv_bn', use_bias=False, use_bn=True, bn_layer=bn_layer, activation='none', conv_kernel_weight_decay=conv_kernel_weight_decay, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=initialization_std)) # For the pixel space, we always compute the key and value, since they # provide information for the memory space and thus cannot be removed. We # compute the query for pixel space only when it is necessary (i.e. # use_pixel2memory_feedback_attention is True). if use_pixel2memory_feedback_attention: self._pixel_qkv_conv_bn = convolutions.Conv1D( total_key_depth * 2 + total_value_depth, 'pixel_qkv_conv_bn', use_bias=False, use_bn=True, bn_layer=bn_layer, activation='none', conv_kernel_weight_decay=conv_kernel_weight_decay, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=initialization_std)) else: self._pixel_kv_conv_bn = convolutions.Conv1D( total_key_depth + total_value_depth, 'pixel_kv_conv_bn', use_bias=False, use_bn=True, bn_layer=bn_layer, activation='none', conv_kernel_weight_decay=conv_kernel_weight_decay, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=initialization_std)) self._memory_attention = AttentionOperation( 'memory_attention', activation, transformer_activation, bn_layer=bn_layer) if use_pixel2memory_feedback_attention: self._pixel_attention = AttentionOperation( 'pixel_attention', activation, transformer_activation, bn_layer=bn_layer) self._use_memory_self_attention = use_memory_self_attention self._use_pixel2memory_feedback_attention = ( use_pixel2memory_feedback_attention) self._total_key_depth = total_key_depth self._total_value_depth = total_value_depth self._num_heads = num_heads self._bn_layer = bn_layer self._conv_kernel_weight_decay = conv_kernel_weight_decay self._activation = activation self._activation_fn = activations.get_activation(activation) self._feed_forward_network_channels = feed_forward_network_channels def build(self, input_shape_list): pixel_shape, memory_shape = input_shape_list[:2] # Here we follow ResNet bottleneck blocks: we apply a batch norm with gamma # initialized at zero, followed by drop path and an activation function. # Initializing this gamma at zero ensures that at random initialization of # the model, the skip connections dominate all residual blocks. In this way, # all the skip connections construct an identity mapping that passes the # gradients (without any distortion from the randomly initialized blocks) to # all residual blocks. This helps training at early epochs. # Reference: "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour". # https://arxiv.org/abs/1706.02677 self._memory_conv3_bn = convolutions.Conv1D( memory_shape[-1], 'memory_conv3_bn', use_bias=False, use_bn=True, bn_layer=self._bn_layer, bn_gamma_initializer='zeros', activation='none', conv_kernel_weight_decay=self._conv_kernel_weight_decay) if self._feed_forward_network_channels > 0: self._memory_ffn_conv1_bn_act = convolutions.Conv1D( self._feed_forward_network_channels, 'memory_ffn_conv1_bn_act', use_bias=False, use_bn=True, bn_layer=self._bn_layer, activation=self._activation, conv_kernel_weight_decay=self._conv_kernel_weight_decay) # Again, we follow ResNet bottleneck blocks: we apply a batch norm with # gamma initialized at zero, followed by drop path and an activation # function. self._memory_ffn_conv2_bn = convolutions.Conv1D( memory_shape[-1], 'memory_ffn_conv2_bn', use_bias=False, use_bn=True, bn_layer=self._bn_layer, bn_gamma_initializer='zeros', activation='none', conv_kernel_weight_decay=self._conv_kernel_weight_decay) if self._use_pixel2memory_feedback_attention: self._pixel_conv3_bn = convolutions.Conv1D( pixel_shape[-1], 'pixel_conv3_bn', use_bias=False, use_bn=True, bn_layer=self._bn_layer, bn_gamma_initializer='zeros', activation='none', conv_kernel_weight_decay=self._conv_kernel_weight_decay) def call(self, inputs): """Performs a forward pass. We have to define drop_path_masks outside the layer call and pass it into the layer call, because recompute_grad (gradient checkpointing) does not allow any randomness within the function call. In addition, recompute_grad only supports float tensors as inputs. For this reason, the training flag should be also passed as a float tensor. For the same reason, we cannot support passing drop_path_random_mask as None. Instead, we ask the users to pass only the first two tensors when drop path is not used. Args: inputs: A tuple of 3 or 6 tensors, containing pixel_space_input should be a [batch, num_pixel, pixel_space_channels] tensor. memory_space_input should be a [batch, num_memory, memory_space_channels] tensor. float_tensor_training should be a float tensor of 0.0 or 1.0, whether the model is in training mode. (optional) pixel_space_drop_path_mask is a drop path mask tensor of shape [batch, 1, 1] for the pixel space. (optional) memory_space_attention_drop_path_mask is a drop path mask tensor of shape [batch, 1, 1] for the memory space. (optional) memory_space_feed_forward_network_drop_path_mask is a drop path mask tensor of shape [batch, 1, 1] for the memory space feed forward network. Returns: pixel_space_output: A [batch, num_pixel, pixel_space_channels] tensor. activated_pixel_space_output: A [batch, num_pixel, pixel_space_channels] tensor, activated pixel_space_output. memory_space_output: A [batch, num_memory, memory_space_channels] tensor. Raises: ValueError: If the length of inputs is not 3 or 6. """ if len(inputs) not in (3, 6): raise ValueError('The length of inputs should be either 3 or 6.') # Unpack the inputs. (pixel_space_input, memory_space_input, float_tensor_training, pixel_space_drop_path_mask, memory_space_attention_drop_path_mask, memory_space_feed_forward_network_drop_path_mask) = ( utils.pad_sequence_with_none(inputs, target_length=6)) # Recompute_grad takes only float tensors as inputs. It does not allow # bools or boolean tensors. For this reason, we cast training to a float # tensor outside this call, and now we cast it back to a boolean tensor. training = tf.cast(float_tensor_training, tf.bool) # Decode the inputs shapes. pixel_shape = pixel_space_input.get_shape().as_list() memory_shape = memory_space_input.get_shape().as_list() # Similar to the ResNet bottleneck design, we do an input down projection # in both the pixel space and the memory space. memory_space = self._memory_conv1_bn_act(memory_space_input, training=training) # Pixel space input is not activated. pixel_space = self._pixel_conv1_bn_act( self._activation_fn(pixel_space_input), training=training) if (self._use_memory_self_attention or self._use_pixel2memory_feedback_attention): memory_space_qkv = self._memory_qkv_conv_bn(memory_space, training=training) # Split, reshape, and transpose the query, key, and value. memory_query, memory_key, memory_value = ( tf.split(memory_space_qkv, [ self._total_key_depth, self._total_key_depth, self._total_value_depth], axis=-1)) memory_key = utils.reshape_and_transpose_for_attention_operation( memory_key, self._num_heads) memory_value = tf.reshape(memory_value, [ -1, memory_shape[1], self._num_heads, self._total_value_depth // self._num_heads]) else: # Compute memory query only if memory key and value are not used. memory_query = self._memory_query_conv_bn(memory_space, training=training) # Reshape and transpose the query. memory_query = utils.reshape_and_transpose_for_attention_operation( memory_query, self._num_heads) if self._use_pixel2memory_feedback_attention: pixel_space_qkv = self._pixel_qkv_conv_bn(pixel_space, training=training) # Split the query, key, and value. pixel_query, pixel_key, pixel_value = tf.split( pixel_space_qkv, [ self._total_key_depth, self._total_key_depth, self._total_value_depth], axis=-1) pixel_query = utils.reshape_and_transpose_for_attention_operation( pixel_query, self._num_heads) else: pixel_space_kv = self._pixel_kv_conv_bn(pixel_space, training=training) # Split the key and the value. pixel_key, pixel_value = tf.split(pixel_space_kv, [ self._total_key_depth, self._total_value_depth], axis=-1) # Reshape and transpose the key and the value. pixel_key = utils.reshape_and_transpose_for_attention_operation( pixel_key, self._num_heads) pixel_value = tf.reshape(pixel_value, [ -1, pixel_shape[1], self._num_heads, self._total_value_depth // self._num_heads]) # Compute memory space attention. if not self._use_memory_self_attention: # If memory self attention is not used, then only memory2pixel cross # attention is used for the memory space. In this case, the key and the # value are simply pixel_key and pixel_value. memory_attention_key = pixel_key memory_attention_value = pixel_value else: # If we also use memory self attention, the key and the value are the # concatenation of keys and values in both the pixel space and the # memory space. memory_attention_key = tf.concat([pixel_key, memory_key], axis=2) memory_attention_value = tf.concat([pixel_value, memory_value], axis=1) memory_space = self._memory_attention( (memory_query, memory_attention_key, memory_attention_value), training=training) memory_space = self._memory_conv3_bn(memory_space, training=training) if memory_space_attention_drop_path_mask is not None: memory_space = memory_space * memory_space_attention_drop_path_mask memory_space_output = self._activation_fn( memory_space_input + memory_space) # Apply an optional feed-forward network to the memory space. if self._feed_forward_network_channels > 0: memory_space = self._memory_ffn_conv1_bn_act(memory_space_output, training=training) memory_space = self._memory_ffn_conv2_bn(memory_space, training=training) if memory_space_feed_forward_network_drop_path_mask is not None: memory_space = (memory_space * memory_space_feed_forward_network_drop_path_mask) memory_space_output = self._activation_fn( memory_space_output + memory_space) # Compute pixel space attention and the output projection only when # pixel2memory_feedback_attention is used. if self._use_pixel2memory_feedback_attention: pixel_space = self._pixel_attention( (pixel_query, memory_key, memory_value), training=training) pixel_space = self._pixel_conv3_bn(pixel_space, training=training) if pixel_space_drop_path_mask is not None: pixel_space = pixel_space * pixel_space_drop_path_mask pixel_space_output = pixel_space_input + pixel_space else: # If pixel2memory_feedback_attention is not used, the pixel_space_input # is not changed. pixel_space_output = pixel_space_input activated_pixel_space_output = self._activation_fn(pixel_space_output) # Return the pixel space output and memory space output. Note that we # return pixel sapce output with and without the activation function, # because our decoder might use non-activated features. return (pixel_space_output, activated_pixel_space_output, memory_space_output)