deeplab2 / model /layers /dual_path_transformer.py
akhaliq3
spaces demo
506da10
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements dual path transformer layers proposed in MaX-DeepLab [1].
Dual-path transformer introduces a global memory path in addition to a CNN path,
allowing bi-directional communication with any CNN layers.
[1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers,
CVPR 2021.
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.
"""
import tensorflow as tf
from deeplab2.model import utils
from deeplab2.model.layers import activations
from deeplab2.model.layers import convolutions
class AttentionOperation(tf.keras.layers.Layer):
"""Computes standard 1D multi-head attention with query, key, and value."""
def __init__(self,
name,
activation,
transformer_activation,
bn_layer=tf.keras.layers.BatchNormalization):
"""Initializes an AttentionOperation layer.
Args:
name: A string, the name of this layer.
activation: A string, type of activation function to apply.
transformer_activation: A string, type of activation function for
self-attention. Support 'sigmoid' and 'softmax'.
bn_layer: An optional tf.keras.layers.Layer that computes the
normalization (default: tf.keras.layers.BatchNormalization).
"""
super(AttentionOperation, self).__init__(name=name)
# batch_norm_similarity has shape [batch, num_heads, num_query, num_key],
# where num_query and num_key usually equals to height or width or length,
# i.e., spatial dimensions, so batch norm is applied to axis=1 only.
self._batch_norm_similarity = bn_layer(axis=1, name='batch_norm_similarity')
# batch_norm_retrieved_value is done on shape [batch, num_heads, length,
# value_channels], which will be reshaped to the output shape [batch,
# length, value_channels * num_heads], so we apply batch norm on the
# effective channel dimension -- value_channels * num_heads.
self._batch_norm_retrieved_value = bn_layer(
axis=[1, 3], name='batch_norm_retrieved_value')
self._activation_fn = activations.get_activation(activation)
self._transformer_activation_fn = activations.get_activation(
transformer_activation)
def call(self, inputs, training=False):
"""Performs an AttentionOperation.
Args:
inputs: A tuple of (query, key, value), where query is [batch, num_head,
query_length, channels] tensor, key is a [batch, num_head, key_length,
channels] tensor, and value is a [batch, key_length, num_head,
value_channels] tensor.
training: A boolean, whether the model is in training mode.
Returns:
output: A [batch, query_length, num_head * value_channels] tensor, the
retrieved value.
"""
# Decode query, key, and value from inputs.
query, key, value = inputs
# Compute attention similarity.
similarity_logits = tf.einsum('bhld,bhmd->bhlm', query, key)
similarity_logits = self._batch_norm_similarity(
similarity_logits, training=training)
# Apply a transformer attention activation function, e.g. softmax.
attention_weights = self._transformer_activation_fn(similarity_logits)
# Retrieve the value content.
retrieved_value = tf.einsum(
'bhlm,bmhd->bhld', attention_weights, value)
retrieved_value = self._batch_norm_retrieved_value(
retrieved_value, training=training)
retrieved_value = self._activation_fn(retrieved_value)
# Reshape the output.
return utils.transpose_and_reshape_for_attention_operation(
retrieved_value)
class DualPathTransformerLayer(tf.keras.layers.Layer):
"""Applies a dual path transformer layer, as proposed in MaX-DeepLab [1].
Dual-path transformer layer takes a pixel space input and a memory space
input, and performs memory2pixel attention, pixel2memory attention, and
memory2memory self-attention. Note that the pixel2pixel self-attention or
convolution in the pixel space is implemented in axial_layers.py and
axial_blocks.py. Thus, the pixel2pixel operation is not included in this
DualPathTransformerLayer implementation. Please use this class together with
a residual block with axial-attention, global-attention, or convolution in
order to construct the full dual path transformer in the paper.
[1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers,
CVPR 2021.
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.
"""
def __init__(self,
name='dual_path_transformer_layer',
activation='relu',
filters=128,
num_heads=8,
bottleneck_expansion=2,
key_expansion=1,
value_expansion=2,
feed_forward_network_channels=2048,
use_memory_self_attention=True,
use_pixel2memory_feedback_attention=True,
transformer_activation='softmax',
bn_layer=tf.keras.layers.BatchNormalization,
conv_kernel_weight_decay=0.0):
"""Initializes a DualPathTransformerLayer.
This function implements a dual path transformer layer between a pixel space
and a memory space, as described in the MaX-DeepLab paper. In this dual path
transformer, the memory2pixel cross attention and the memory self-attention
share a single activation, e.g. softmax.
Reference:
MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers",
CVPR 2021. https://arxiv.org/abs/2012.00759
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.
Args:
name: A string, the name of this dual path transformer layer.
activation: A string, type of activation function to apply.
filters: An integer, the base number of channels for the layer.
num_heads: An integer, the number of heads in multi-head attention.
bottleneck_expansion: A float, the channel expansion ratio for the
bottleneck.
key_expansion: A float, the channel expansion ratio for keys.
value_expansion: A float, the channel expansion ratio for values.
feed_forward_network_channels: An integer, the number of channels for the
feed_forward_network. Zero means no feed_forward_network will be
applied.
use_memory_self_attention: A boolean, whether to apply the memory space
self-attention.
use_pixel2memory_feedback_attention: A boolean, whether to apply the
pixel2memory feedback attention.
transformer_activation: A string, type of activation function for
self-attention. Support 'sigmoid' and 'softmax'.
bn_layer: A tf.keras.layers.Layer that computes the normalization
(default: tf.keras.layers.BatchNormalization).
conv_kernel_weight_decay: A float, the weight decay for convolution
kernels.
Raises:
ValueError: If filters * key_expansion is not divisible by num_heads.
ValueError: If filters * value_expansion is not divisible by num_heads.
"""
super(DualPathTransformerLayer, self).__init__(name=name)
bottleneck_channels = int(round(filters * bottleneck_expansion))
total_key_depth = int(round(filters * key_expansion))
total_value_depth = int(round(filters * value_expansion))
if total_key_depth % num_heads:
raise ValueError('Total_key_depth should be divisible by num_heads.')
if total_value_depth % num_heads:
raise ValueError('Total_value_depth should be divisible by num_heads.')
# Compute query key value with one convolution and a batch norm layer. The
# initialization std is standard transformer initialization (without batch
# norm), as used in SASA and ViT. In our case, we use batch norm by default,
# so it does not require careful tuning. If one wants to remove all batch
# norms in axial attention, this standard initialization should still be
# good, but a more careful initialization is encouraged.
initialization_std = bottleneck_channels ** -0.5
self._memory_conv1_bn_act = convolutions.Conv1D(
bottleneck_channels, 'memory_conv1_bn_act',
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation=activation,
conv_kernel_weight_decay=conv_kernel_weight_decay)
self._pixel_conv1_bn_act = convolutions.Conv1D(
bottleneck_channels, 'pixel_conv1_bn_act',
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation=activation,
conv_kernel_weight_decay=conv_kernel_weight_decay)
# We always compute the query for memory space, since it gathers information
# from the pixel space and thus cannot be removed. We compute the key and
# value for memory space only when they are necessary (i.e. either
# use_memory_self_attention or use_pixel2memory_feedback_attention).
if use_memory_self_attention or use_pixel2memory_feedback_attention:
self._memory_qkv_conv_bn = convolutions.Conv1D(
total_key_depth * 2 + total_value_depth, 'memory_qkv_conv_bn',
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation='none',
conv_kernel_weight_decay=conv_kernel_weight_decay,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=initialization_std))
else:
# Compute memory query only if memory key and value are not used.
self._memory_query_conv_bn = convolutions.Conv1D(
total_key_depth, 'memory_query_conv_bn',
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation='none',
conv_kernel_weight_decay=conv_kernel_weight_decay,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=initialization_std))
# For the pixel space, we always compute the key and value, since they
# provide information for the memory space and thus cannot be removed. We
# compute the query for pixel space only when it is necessary (i.e.
# use_pixel2memory_feedback_attention is True).
if use_pixel2memory_feedback_attention:
self._pixel_qkv_conv_bn = convolutions.Conv1D(
total_key_depth * 2 + total_value_depth, 'pixel_qkv_conv_bn',
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation='none',
conv_kernel_weight_decay=conv_kernel_weight_decay,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=initialization_std))
else:
self._pixel_kv_conv_bn = convolutions.Conv1D(
total_key_depth + total_value_depth, 'pixel_kv_conv_bn',
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation='none',
conv_kernel_weight_decay=conv_kernel_weight_decay,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=initialization_std))
self._memory_attention = AttentionOperation(
'memory_attention', activation, transformer_activation,
bn_layer=bn_layer)
if use_pixel2memory_feedback_attention:
self._pixel_attention = AttentionOperation(
'pixel_attention', activation, transformer_activation,
bn_layer=bn_layer)
self._use_memory_self_attention = use_memory_self_attention
self._use_pixel2memory_feedback_attention = (
use_pixel2memory_feedback_attention)
self._total_key_depth = total_key_depth
self._total_value_depth = total_value_depth
self._num_heads = num_heads
self._bn_layer = bn_layer
self._conv_kernel_weight_decay = conv_kernel_weight_decay
self._activation = activation
self._activation_fn = activations.get_activation(activation)
self._feed_forward_network_channels = feed_forward_network_channels
def build(self, input_shape_list):
pixel_shape, memory_shape = input_shape_list[:2]
# Here we follow ResNet bottleneck blocks: we apply a batch norm with gamma
# initialized at zero, followed by drop path and an activation function.
# Initializing this gamma at zero ensures that at random initialization of
# the model, the skip connections dominate all residual blocks. In this way,
# all the skip connections construct an identity mapping that passes the
# gradients (without any distortion from the randomly initialized blocks) to
# all residual blocks. This helps training at early epochs.
# Reference: "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour".
# https://arxiv.org/abs/1706.02677
self._memory_conv3_bn = convolutions.Conv1D(
memory_shape[-1], 'memory_conv3_bn',
use_bias=False,
use_bn=True,
bn_layer=self._bn_layer,
bn_gamma_initializer='zeros',
activation='none',
conv_kernel_weight_decay=self._conv_kernel_weight_decay)
if self._feed_forward_network_channels > 0:
self._memory_ffn_conv1_bn_act = convolutions.Conv1D(
self._feed_forward_network_channels, 'memory_ffn_conv1_bn_act',
use_bias=False,
use_bn=True,
bn_layer=self._bn_layer,
activation=self._activation,
conv_kernel_weight_decay=self._conv_kernel_weight_decay)
# Again, we follow ResNet bottleneck blocks: we apply a batch norm with
# gamma initialized at zero, followed by drop path and an activation
# function.
self._memory_ffn_conv2_bn = convolutions.Conv1D(
memory_shape[-1], 'memory_ffn_conv2_bn',
use_bias=False,
use_bn=True,
bn_layer=self._bn_layer,
bn_gamma_initializer='zeros',
activation='none',
conv_kernel_weight_decay=self._conv_kernel_weight_decay)
if self._use_pixel2memory_feedback_attention:
self._pixel_conv3_bn = convolutions.Conv1D(
pixel_shape[-1], 'pixel_conv3_bn',
use_bias=False,
use_bn=True,
bn_layer=self._bn_layer,
bn_gamma_initializer='zeros',
activation='none',
conv_kernel_weight_decay=self._conv_kernel_weight_decay)
def call(self, inputs):
"""Performs a forward pass.
We have to define drop_path_masks outside the layer call and pass it into
the layer call, because recompute_grad (gradient checkpointing) does not
allow any randomness within the function call. In addition, recompute_grad
only supports float tensors as inputs. For this reason, the training flag
should be also passed as a float tensor. For the same reason, we cannot
support passing drop_path_random_mask as None. Instead, we ask the users to
pass only the first two tensors when drop path is not used.
Args:
inputs: A tuple of 3 or 6 tensors, containing
pixel_space_input should be a [batch, num_pixel, pixel_space_channels]
tensor.
memory_space_input should be a [batch, num_memory,
memory_space_channels] tensor.
float_tensor_training should be a float tensor of 0.0 or 1.0, whether
the model is in training mode.
(optional) pixel_space_drop_path_mask is a drop path mask tensor of
shape [batch, 1, 1] for the pixel space.
(optional) memory_space_attention_drop_path_mask is a drop path mask
tensor of shape [batch, 1, 1] for the memory space.
(optional) memory_space_feed_forward_network_drop_path_mask is a drop
path mask tensor of shape [batch, 1, 1] for the memory space feed
forward network.
Returns:
pixel_space_output: A [batch, num_pixel, pixel_space_channels] tensor.
activated_pixel_space_output: A [batch, num_pixel, pixel_space_channels]
tensor, activated pixel_space_output.
memory_space_output: A [batch, num_memory, memory_space_channels]
tensor.
Raises:
ValueError: If the length of inputs is not 3 or 6.
"""
if len(inputs) not in (3, 6):
raise ValueError('The length of inputs should be either 3 or 6.')
# Unpack the inputs.
(pixel_space_input, memory_space_input, float_tensor_training,
pixel_space_drop_path_mask, memory_space_attention_drop_path_mask,
memory_space_feed_forward_network_drop_path_mask) = (
utils.pad_sequence_with_none(inputs, target_length=6))
# Recompute_grad takes only float tensors as inputs. It does not allow
# bools or boolean tensors. For this reason, we cast training to a float
# tensor outside this call, and now we cast it back to a boolean tensor.
training = tf.cast(float_tensor_training, tf.bool)
# Decode the inputs shapes.
pixel_shape = pixel_space_input.get_shape().as_list()
memory_shape = memory_space_input.get_shape().as_list()
# Similar to the ResNet bottleneck design, we do an input down projection
# in both the pixel space and the memory space.
memory_space = self._memory_conv1_bn_act(memory_space_input,
training=training)
# Pixel space input is not activated.
pixel_space = self._pixel_conv1_bn_act(
self._activation_fn(pixel_space_input), training=training)
if (self._use_memory_self_attention or
self._use_pixel2memory_feedback_attention):
memory_space_qkv = self._memory_qkv_conv_bn(memory_space,
training=training)
# Split, reshape, and transpose the query, key, and value.
memory_query, memory_key, memory_value = (
tf.split(memory_space_qkv, [
self._total_key_depth, self._total_key_depth,
self._total_value_depth], axis=-1))
memory_key = utils.reshape_and_transpose_for_attention_operation(
memory_key, self._num_heads)
memory_value = tf.reshape(memory_value, [
-1, memory_shape[1], self._num_heads,
self._total_value_depth // self._num_heads])
else:
# Compute memory query only if memory key and value are not used.
memory_query = self._memory_query_conv_bn(memory_space,
training=training)
# Reshape and transpose the query.
memory_query = utils.reshape_and_transpose_for_attention_operation(
memory_query, self._num_heads)
if self._use_pixel2memory_feedback_attention:
pixel_space_qkv = self._pixel_qkv_conv_bn(pixel_space,
training=training)
# Split the query, key, and value.
pixel_query, pixel_key, pixel_value = tf.split(
pixel_space_qkv, [
self._total_key_depth, self._total_key_depth,
self._total_value_depth], axis=-1)
pixel_query = utils.reshape_and_transpose_for_attention_operation(
pixel_query, self._num_heads)
else:
pixel_space_kv = self._pixel_kv_conv_bn(pixel_space, training=training)
# Split the key and the value.
pixel_key, pixel_value = tf.split(pixel_space_kv, [
self._total_key_depth, self._total_value_depth], axis=-1)
# Reshape and transpose the key and the value.
pixel_key = utils.reshape_and_transpose_for_attention_operation(
pixel_key, self._num_heads)
pixel_value = tf.reshape(pixel_value, [
-1, pixel_shape[1], self._num_heads,
self._total_value_depth // self._num_heads])
# Compute memory space attention.
if not self._use_memory_self_attention:
# If memory self attention is not used, then only memory2pixel cross
# attention is used for the memory space. In this case, the key and the
# value are simply pixel_key and pixel_value.
memory_attention_key = pixel_key
memory_attention_value = pixel_value
else:
# If we also use memory self attention, the key and the value are the
# concatenation of keys and values in both the pixel space and the
# memory space.
memory_attention_key = tf.concat([pixel_key, memory_key], axis=2)
memory_attention_value = tf.concat([pixel_value, memory_value], axis=1)
memory_space = self._memory_attention(
(memory_query, memory_attention_key, memory_attention_value),
training=training)
memory_space = self._memory_conv3_bn(memory_space, training=training)
if memory_space_attention_drop_path_mask is not None:
memory_space = memory_space * memory_space_attention_drop_path_mask
memory_space_output = self._activation_fn(
memory_space_input + memory_space)
# Apply an optional feed-forward network to the memory space.
if self._feed_forward_network_channels > 0:
memory_space = self._memory_ffn_conv1_bn_act(memory_space_output,
training=training)
memory_space = self._memory_ffn_conv2_bn(memory_space,
training=training)
if memory_space_feed_forward_network_drop_path_mask is not None:
memory_space = (memory_space *
memory_space_feed_forward_network_drop_path_mask)
memory_space_output = self._activation_fn(
memory_space_output + memory_space)
# Compute pixel space attention and the output projection only when
# pixel2memory_feedback_attention is used.
if self._use_pixel2memory_feedback_attention:
pixel_space = self._pixel_attention(
(pixel_query, memory_key, memory_value), training=training)
pixel_space = self._pixel_conv3_bn(pixel_space, training=training)
if pixel_space_drop_path_mask is not None:
pixel_space = pixel_space * pixel_space_drop_path_mask
pixel_space_output = pixel_space_input + pixel_space
else:
# If pixel2memory_feedback_attention is not used, the pixel_space_input
# is not changed.
pixel_space_output = pixel_space_input
activated_pixel_space_output = self._activation_fn(pixel_space_output)
# Return the pixel space output and memory space output. Note that we
# return pixel sapce output with and without the activation function,
# because our decoder might use non-activated features.
return (pixel_space_output,
activated_pixel_space_output,
memory_space_output)