|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Implements dual path transformer layers proposed in MaX-DeepLab [1]. |
|
|
|
Dual-path transformer introduces a global memory path in addition to a CNN path, |
|
allowing bi-directional communication with any CNN layers. |
|
|
|
[1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers, |
|
CVPR 2021. |
|
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. |
|
""" |
|
|
|
import tensorflow as tf |
|
|
|
from deeplab2.model import utils |
|
from deeplab2.model.layers import activations |
|
from deeplab2.model.layers import convolutions |
|
|
|
|
|
class AttentionOperation(tf.keras.layers.Layer): |
|
"""Computes standard 1D multi-head attention with query, key, and value.""" |
|
|
|
def __init__(self, |
|
name, |
|
activation, |
|
transformer_activation, |
|
bn_layer=tf.keras.layers.BatchNormalization): |
|
"""Initializes an AttentionOperation layer. |
|
|
|
Args: |
|
name: A string, the name of this layer. |
|
activation: A string, type of activation function to apply. |
|
transformer_activation: A string, type of activation function for |
|
self-attention. Support 'sigmoid' and 'softmax'. |
|
bn_layer: An optional tf.keras.layers.Layer that computes the |
|
normalization (default: tf.keras.layers.BatchNormalization). |
|
""" |
|
super(AttentionOperation, self).__init__(name=name) |
|
|
|
|
|
|
|
self._batch_norm_similarity = bn_layer(axis=1, name='batch_norm_similarity') |
|
|
|
|
|
|
|
|
|
self._batch_norm_retrieved_value = bn_layer( |
|
axis=[1, 3], name='batch_norm_retrieved_value') |
|
self._activation_fn = activations.get_activation(activation) |
|
self._transformer_activation_fn = activations.get_activation( |
|
transformer_activation) |
|
|
|
def call(self, inputs, training=False): |
|
"""Performs an AttentionOperation. |
|
|
|
Args: |
|
inputs: A tuple of (query, key, value), where query is [batch, num_head, |
|
query_length, channels] tensor, key is a [batch, num_head, key_length, |
|
channels] tensor, and value is a [batch, key_length, num_head, |
|
value_channels] tensor. |
|
training: A boolean, whether the model is in training mode. |
|
|
|
Returns: |
|
output: A [batch, query_length, num_head * value_channels] tensor, the |
|
retrieved value. |
|
""" |
|
|
|
query, key, value = inputs |
|
|
|
similarity_logits = tf.einsum('bhld,bhmd->bhlm', query, key) |
|
similarity_logits = self._batch_norm_similarity( |
|
similarity_logits, training=training) |
|
|
|
attention_weights = self._transformer_activation_fn(similarity_logits) |
|
|
|
retrieved_value = tf.einsum( |
|
'bhlm,bmhd->bhld', attention_weights, value) |
|
retrieved_value = self._batch_norm_retrieved_value( |
|
retrieved_value, training=training) |
|
retrieved_value = self._activation_fn(retrieved_value) |
|
|
|
return utils.transpose_and_reshape_for_attention_operation( |
|
retrieved_value) |
|
|
|
|
|
class DualPathTransformerLayer(tf.keras.layers.Layer): |
|
"""Applies a dual path transformer layer, as proposed in MaX-DeepLab [1]. |
|
|
|
Dual-path transformer layer takes a pixel space input and a memory space |
|
input, and performs memory2pixel attention, pixel2memory attention, and |
|
memory2memory self-attention. Note that the pixel2pixel self-attention or |
|
convolution in the pixel space is implemented in axial_layers.py and |
|
axial_blocks.py. Thus, the pixel2pixel operation is not included in this |
|
DualPathTransformerLayer implementation. Please use this class together with |
|
a residual block with axial-attention, global-attention, or convolution in |
|
order to construct the full dual path transformer in the paper. |
|
|
|
[1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers, |
|
CVPR 2021. |
|
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. |
|
""" |
|
|
|
def __init__(self, |
|
name='dual_path_transformer_layer', |
|
activation='relu', |
|
filters=128, |
|
num_heads=8, |
|
bottleneck_expansion=2, |
|
key_expansion=1, |
|
value_expansion=2, |
|
feed_forward_network_channels=2048, |
|
use_memory_self_attention=True, |
|
use_pixel2memory_feedback_attention=True, |
|
transformer_activation='softmax', |
|
bn_layer=tf.keras.layers.BatchNormalization, |
|
conv_kernel_weight_decay=0.0): |
|
"""Initializes a DualPathTransformerLayer. |
|
|
|
This function implements a dual path transformer layer between a pixel space |
|
and a memory space, as described in the MaX-DeepLab paper. In this dual path |
|
transformer, the memory2pixel cross attention and the memory self-attention |
|
share a single activation, e.g. softmax. |
|
|
|
Reference: |
|
MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers", |
|
CVPR 2021. https://arxiv.org/abs/2012.00759 |
|
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. |
|
|
|
Args: |
|
name: A string, the name of this dual path transformer layer. |
|
activation: A string, type of activation function to apply. |
|
filters: An integer, the base number of channels for the layer. |
|
num_heads: An integer, the number of heads in multi-head attention. |
|
bottleneck_expansion: A float, the channel expansion ratio for the |
|
bottleneck. |
|
key_expansion: A float, the channel expansion ratio for keys. |
|
value_expansion: A float, the channel expansion ratio for values. |
|
feed_forward_network_channels: An integer, the number of channels for the |
|
feed_forward_network. Zero means no feed_forward_network will be |
|
applied. |
|
use_memory_self_attention: A boolean, whether to apply the memory space |
|
self-attention. |
|
use_pixel2memory_feedback_attention: A boolean, whether to apply the |
|
pixel2memory feedback attention. |
|
transformer_activation: A string, type of activation function for |
|
self-attention. Support 'sigmoid' and 'softmax'. |
|
bn_layer: A tf.keras.layers.Layer that computes the normalization |
|
(default: tf.keras.layers.BatchNormalization). |
|
conv_kernel_weight_decay: A float, the weight decay for convolution |
|
kernels. |
|
|
|
Raises: |
|
ValueError: If filters * key_expansion is not divisible by num_heads. |
|
ValueError: If filters * value_expansion is not divisible by num_heads. |
|
""" |
|
super(DualPathTransformerLayer, self).__init__(name=name) |
|
|
|
bottleneck_channels = int(round(filters * bottleneck_expansion)) |
|
total_key_depth = int(round(filters * key_expansion)) |
|
total_value_depth = int(round(filters * value_expansion)) |
|
|
|
if total_key_depth % num_heads: |
|
raise ValueError('Total_key_depth should be divisible by num_heads.') |
|
|
|
if total_value_depth % num_heads: |
|
raise ValueError('Total_value_depth should be divisible by num_heads.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
initialization_std = bottleneck_channels ** -0.5 |
|
|
|
self._memory_conv1_bn_act = convolutions.Conv1D( |
|
bottleneck_channels, 'memory_conv1_bn_act', |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=bn_layer, |
|
activation=activation, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay) |
|
|
|
self._pixel_conv1_bn_act = convolutions.Conv1D( |
|
bottleneck_channels, 'pixel_conv1_bn_act', |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=bn_layer, |
|
activation=activation, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay) |
|
|
|
|
|
|
|
|
|
|
|
if use_memory_self_attention or use_pixel2memory_feedback_attention: |
|
self._memory_qkv_conv_bn = convolutions.Conv1D( |
|
total_key_depth * 2 + total_value_depth, 'memory_qkv_conv_bn', |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=bn_layer, |
|
activation='none', |
|
conv_kernel_weight_decay=conv_kernel_weight_decay, |
|
kernel_initializer=tf.keras.initializers.TruncatedNormal( |
|
stddev=initialization_std)) |
|
else: |
|
|
|
self._memory_query_conv_bn = convolutions.Conv1D( |
|
total_key_depth, 'memory_query_conv_bn', |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=bn_layer, |
|
activation='none', |
|
conv_kernel_weight_decay=conv_kernel_weight_decay, |
|
kernel_initializer=tf.keras.initializers.TruncatedNormal( |
|
stddev=initialization_std)) |
|
|
|
|
|
|
|
|
|
|
|
if use_pixel2memory_feedback_attention: |
|
self._pixel_qkv_conv_bn = convolutions.Conv1D( |
|
total_key_depth * 2 + total_value_depth, 'pixel_qkv_conv_bn', |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=bn_layer, |
|
activation='none', |
|
conv_kernel_weight_decay=conv_kernel_weight_decay, |
|
kernel_initializer=tf.keras.initializers.TruncatedNormal( |
|
stddev=initialization_std)) |
|
else: |
|
self._pixel_kv_conv_bn = convolutions.Conv1D( |
|
total_key_depth + total_value_depth, 'pixel_kv_conv_bn', |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=bn_layer, |
|
activation='none', |
|
conv_kernel_weight_decay=conv_kernel_weight_decay, |
|
kernel_initializer=tf.keras.initializers.TruncatedNormal( |
|
stddev=initialization_std)) |
|
self._memory_attention = AttentionOperation( |
|
'memory_attention', activation, transformer_activation, |
|
bn_layer=bn_layer) |
|
if use_pixel2memory_feedback_attention: |
|
self._pixel_attention = AttentionOperation( |
|
'pixel_attention', activation, transformer_activation, |
|
bn_layer=bn_layer) |
|
|
|
self._use_memory_self_attention = use_memory_self_attention |
|
self._use_pixel2memory_feedback_attention = ( |
|
use_pixel2memory_feedback_attention) |
|
self._total_key_depth = total_key_depth |
|
self._total_value_depth = total_value_depth |
|
self._num_heads = num_heads |
|
self._bn_layer = bn_layer |
|
self._conv_kernel_weight_decay = conv_kernel_weight_decay |
|
self._activation = activation |
|
self._activation_fn = activations.get_activation(activation) |
|
self._feed_forward_network_channels = feed_forward_network_channels |
|
|
|
def build(self, input_shape_list): |
|
pixel_shape, memory_shape = input_shape_list[:2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._memory_conv3_bn = convolutions.Conv1D( |
|
memory_shape[-1], 'memory_conv3_bn', |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=self._bn_layer, |
|
bn_gamma_initializer='zeros', |
|
activation='none', |
|
conv_kernel_weight_decay=self._conv_kernel_weight_decay) |
|
|
|
if self._feed_forward_network_channels > 0: |
|
self._memory_ffn_conv1_bn_act = convolutions.Conv1D( |
|
self._feed_forward_network_channels, 'memory_ffn_conv1_bn_act', |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=self._bn_layer, |
|
activation=self._activation, |
|
conv_kernel_weight_decay=self._conv_kernel_weight_decay) |
|
|
|
|
|
|
|
self._memory_ffn_conv2_bn = convolutions.Conv1D( |
|
memory_shape[-1], 'memory_ffn_conv2_bn', |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=self._bn_layer, |
|
bn_gamma_initializer='zeros', |
|
activation='none', |
|
conv_kernel_weight_decay=self._conv_kernel_weight_decay) |
|
if self._use_pixel2memory_feedback_attention: |
|
self._pixel_conv3_bn = convolutions.Conv1D( |
|
pixel_shape[-1], 'pixel_conv3_bn', |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=self._bn_layer, |
|
bn_gamma_initializer='zeros', |
|
activation='none', |
|
conv_kernel_weight_decay=self._conv_kernel_weight_decay) |
|
|
|
def call(self, inputs): |
|
"""Performs a forward pass. |
|
|
|
We have to define drop_path_masks outside the layer call and pass it into |
|
the layer call, because recompute_grad (gradient checkpointing) does not |
|
allow any randomness within the function call. In addition, recompute_grad |
|
only supports float tensors as inputs. For this reason, the training flag |
|
should be also passed as a float tensor. For the same reason, we cannot |
|
support passing drop_path_random_mask as None. Instead, we ask the users to |
|
pass only the first two tensors when drop path is not used. |
|
|
|
Args: |
|
inputs: A tuple of 3 or 6 tensors, containing |
|
pixel_space_input should be a [batch, num_pixel, pixel_space_channels] |
|
tensor. |
|
memory_space_input should be a [batch, num_memory, |
|
memory_space_channels] tensor. |
|
float_tensor_training should be a float tensor of 0.0 or 1.0, whether |
|
the model is in training mode. |
|
(optional) pixel_space_drop_path_mask is a drop path mask tensor of |
|
shape [batch, 1, 1] for the pixel space. |
|
(optional) memory_space_attention_drop_path_mask is a drop path mask |
|
tensor of shape [batch, 1, 1] for the memory space. |
|
(optional) memory_space_feed_forward_network_drop_path_mask is a drop |
|
path mask tensor of shape [batch, 1, 1] for the memory space feed |
|
forward network. |
|
|
|
Returns: |
|
pixel_space_output: A [batch, num_pixel, pixel_space_channels] tensor. |
|
activated_pixel_space_output: A [batch, num_pixel, pixel_space_channels] |
|
tensor, activated pixel_space_output. |
|
memory_space_output: A [batch, num_memory, memory_space_channels] |
|
tensor. |
|
|
|
Raises: |
|
ValueError: If the length of inputs is not 3 or 6. |
|
""" |
|
if len(inputs) not in (3, 6): |
|
raise ValueError('The length of inputs should be either 3 or 6.') |
|
|
|
|
|
(pixel_space_input, memory_space_input, float_tensor_training, |
|
pixel_space_drop_path_mask, memory_space_attention_drop_path_mask, |
|
memory_space_feed_forward_network_drop_path_mask) = ( |
|
utils.pad_sequence_with_none(inputs, target_length=6)) |
|
|
|
|
|
|
|
|
|
training = tf.cast(float_tensor_training, tf.bool) |
|
|
|
|
|
pixel_shape = pixel_space_input.get_shape().as_list() |
|
memory_shape = memory_space_input.get_shape().as_list() |
|
|
|
|
|
|
|
memory_space = self._memory_conv1_bn_act(memory_space_input, |
|
training=training) |
|
|
|
|
|
pixel_space = self._pixel_conv1_bn_act( |
|
self._activation_fn(pixel_space_input), training=training) |
|
|
|
if (self._use_memory_self_attention or |
|
self._use_pixel2memory_feedback_attention): |
|
memory_space_qkv = self._memory_qkv_conv_bn(memory_space, |
|
training=training) |
|
|
|
memory_query, memory_key, memory_value = ( |
|
tf.split(memory_space_qkv, [ |
|
self._total_key_depth, self._total_key_depth, |
|
self._total_value_depth], axis=-1)) |
|
memory_key = utils.reshape_and_transpose_for_attention_operation( |
|
memory_key, self._num_heads) |
|
memory_value = tf.reshape(memory_value, [ |
|
-1, memory_shape[1], self._num_heads, |
|
self._total_value_depth // self._num_heads]) |
|
else: |
|
|
|
memory_query = self._memory_query_conv_bn(memory_space, |
|
training=training) |
|
|
|
memory_query = utils.reshape_and_transpose_for_attention_operation( |
|
memory_query, self._num_heads) |
|
|
|
if self._use_pixel2memory_feedback_attention: |
|
pixel_space_qkv = self._pixel_qkv_conv_bn(pixel_space, |
|
training=training) |
|
|
|
pixel_query, pixel_key, pixel_value = tf.split( |
|
pixel_space_qkv, [ |
|
self._total_key_depth, self._total_key_depth, |
|
self._total_value_depth], axis=-1) |
|
pixel_query = utils.reshape_and_transpose_for_attention_operation( |
|
pixel_query, self._num_heads) |
|
else: |
|
pixel_space_kv = self._pixel_kv_conv_bn(pixel_space, training=training) |
|
|
|
pixel_key, pixel_value = tf.split(pixel_space_kv, [ |
|
self._total_key_depth, self._total_value_depth], axis=-1) |
|
|
|
pixel_key = utils.reshape_and_transpose_for_attention_operation( |
|
pixel_key, self._num_heads) |
|
pixel_value = tf.reshape(pixel_value, [ |
|
-1, pixel_shape[1], self._num_heads, |
|
self._total_value_depth // self._num_heads]) |
|
|
|
|
|
if not self._use_memory_self_attention: |
|
|
|
|
|
|
|
memory_attention_key = pixel_key |
|
memory_attention_value = pixel_value |
|
else: |
|
|
|
|
|
|
|
memory_attention_key = tf.concat([pixel_key, memory_key], axis=2) |
|
memory_attention_value = tf.concat([pixel_value, memory_value], axis=1) |
|
|
|
memory_space = self._memory_attention( |
|
(memory_query, memory_attention_key, memory_attention_value), |
|
training=training) |
|
memory_space = self._memory_conv3_bn(memory_space, training=training) |
|
|
|
if memory_space_attention_drop_path_mask is not None: |
|
memory_space = memory_space * memory_space_attention_drop_path_mask |
|
memory_space_output = self._activation_fn( |
|
memory_space_input + memory_space) |
|
|
|
|
|
if self._feed_forward_network_channels > 0: |
|
memory_space = self._memory_ffn_conv1_bn_act(memory_space_output, |
|
training=training) |
|
memory_space = self._memory_ffn_conv2_bn(memory_space, |
|
training=training) |
|
if memory_space_feed_forward_network_drop_path_mask is not None: |
|
memory_space = (memory_space * |
|
memory_space_feed_forward_network_drop_path_mask) |
|
memory_space_output = self._activation_fn( |
|
memory_space_output + memory_space) |
|
|
|
|
|
|
|
if self._use_pixel2memory_feedback_attention: |
|
pixel_space = self._pixel_attention( |
|
(pixel_query, memory_key, memory_value), training=training) |
|
pixel_space = self._pixel_conv3_bn(pixel_space, training=training) |
|
if pixel_space_drop_path_mask is not None: |
|
pixel_space = pixel_space * pixel_space_drop_path_mask |
|
pixel_space_output = pixel_space_input + pixel_space |
|
else: |
|
|
|
|
|
pixel_space_output = pixel_space_input |
|
activated_pixel_space_output = self._activation_fn(pixel_space_output) |
|
|
|
|
|
|
|
|
|
return (pixel_space_output, |
|
activated_pixel_space_output, |
|
memory_space_output) |
|
|