File size: 22,812 Bytes
506da10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implements dual path transformer layers proposed in MaX-DeepLab [1].

Dual-path transformer introduces a global memory path in addition to a CNN path,
allowing bi-directional communication with any CNN layers.

[1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers,
    CVPR 2021.
      Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.
"""

import tensorflow as tf

from deeplab2.model import utils
from deeplab2.model.layers import activations
from deeplab2.model.layers import convolutions


class AttentionOperation(tf.keras.layers.Layer):
  """Computes standard 1D multi-head attention with query, key, and value."""

  def __init__(self,
               name,
               activation,
               transformer_activation,
               bn_layer=tf.keras.layers.BatchNormalization):
    """Initializes an AttentionOperation layer.

    Args:
      name: A string, the name of this layer.
      activation: A string, type of activation function to apply.
      transformer_activation: A string, type of activation function for
        self-attention. Support 'sigmoid' and 'softmax'.
      bn_layer: An optional tf.keras.layers.Layer that computes the
        normalization (default: tf.keras.layers.BatchNormalization).
    """
    super(AttentionOperation, self).__init__(name=name)
    # batch_norm_similarity has shape [batch, num_heads, num_query, num_key],
    # where num_query and num_key usually equals to height or width or length,
    # i.e., spatial dimensions, so batch norm is applied to axis=1 only.
    self._batch_norm_similarity = bn_layer(axis=1, name='batch_norm_similarity')
    # batch_norm_retrieved_value is done on shape [batch, num_heads, length,
    # value_channels], which will be reshaped to the output shape [batch,
    # length, value_channels * num_heads], so we apply batch norm on the
    # effective channel dimension -- value_channels * num_heads.
    self._batch_norm_retrieved_value = bn_layer(
        axis=[1, 3], name='batch_norm_retrieved_value')
    self._activation_fn = activations.get_activation(activation)
    self._transformer_activation_fn = activations.get_activation(
        transformer_activation)

  def call(self, inputs, training=False):
    """Performs an AttentionOperation.

    Args:
      inputs: A tuple of (query, key, value), where query is [batch, num_head,
        query_length, channels] tensor, key is a [batch, num_head, key_length,
        channels] tensor, and value is a [batch, key_length, num_head,
        value_channels] tensor.
      training: A boolean, whether the model is in training mode.

    Returns:
      output: A [batch, query_length, num_head * value_channels] tensor, the
        retrieved value.
    """
    # Decode query, key, and value from inputs.
    query, key, value = inputs
    # Compute attention similarity.
    similarity_logits = tf.einsum('bhld,bhmd->bhlm', query, key)
    similarity_logits = self._batch_norm_similarity(
        similarity_logits, training=training)
    # Apply a transformer attention activation function, e.g. softmax.
    attention_weights = self._transformer_activation_fn(similarity_logits)
    # Retrieve the value content.
    retrieved_value = tf.einsum(
        'bhlm,bmhd->bhld', attention_weights, value)
    retrieved_value = self._batch_norm_retrieved_value(
        retrieved_value, training=training)
    retrieved_value = self._activation_fn(retrieved_value)
    # Reshape the output.
    return utils.transpose_and_reshape_for_attention_operation(
        retrieved_value)


class DualPathTransformerLayer(tf.keras.layers.Layer):
  """Applies a dual path transformer layer, as proposed in MaX-DeepLab [1].

  Dual-path transformer layer takes a pixel space input and a memory space
  input, and performs memory2pixel attention, pixel2memory attention, and
  memory2memory self-attention. Note that the pixel2pixel self-attention or
  convolution in the pixel space is implemented in axial_layers.py and
  axial_blocks.py. Thus, the pixel2pixel operation is not included in this
  DualPathTransformerLayer implementation. Please use this class together with
  a residual block with axial-attention, global-attention, or convolution in
  order to construct the full dual path transformer in the paper.

  [1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers,
      CVPR 2021.
        Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.
  """

  def __init__(self,
               name='dual_path_transformer_layer',
               activation='relu',
               filters=128,
               num_heads=8,
               bottleneck_expansion=2,
               key_expansion=1,
               value_expansion=2,
               feed_forward_network_channels=2048,
               use_memory_self_attention=True,
               use_pixel2memory_feedback_attention=True,
               transformer_activation='softmax',
               bn_layer=tf.keras.layers.BatchNormalization,
               conv_kernel_weight_decay=0.0):
    """Initializes a DualPathTransformerLayer.

    This function implements a dual path transformer layer between a pixel space
    and a memory space, as described in the MaX-DeepLab paper. In this dual path
    transformer, the memory2pixel cross attention and the memory self-attention
    share a single activation, e.g. softmax.

    Reference:
      MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers",
        CVPR 2021. https://arxiv.org/abs/2012.00759
          Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.

    Args:
      name: A string, the name of this dual path transformer layer.
      activation: A string, type of activation function to apply.
      filters: An integer, the base number of channels for the layer.
      num_heads: An integer, the number of heads in multi-head attention.
      bottleneck_expansion: A float, the channel expansion ratio for the
        bottleneck.
      key_expansion: A float, the channel expansion ratio for keys.
      value_expansion: A float, the channel expansion ratio for values.
      feed_forward_network_channels: An integer, the number of channels for the
        feed_forward_network. Zero means no feed_forward_network will be
        applied.
      use_memory_self_attention: A boolean, whether to apply the memory space
        self-attention.
      use_pixel2memory_feedback_attention: A boolean, whether to apply the
        pixel2memory feedback attention.
      transformer_activation: A string, type of activation function for
        self-attention. Support 'sigmoid' and 'softmax'.
      bn_layer: A tf.keras.layers.Layer that computes the normalization
        (default: tf.keras.layers.BatchNormalization).
      conv_kernel_weight_decay: A float, the weight decay for convolution
        kernels.

    Raises:
      ValueError: If filters * key_expansion is not divisible by num_heads.
      ValueError: If filters * value_expansion is not divisible by num_heads.
    """
    super(DualPathTransformerLayer, self).__init__(name=name)

    bottleneck_channels = int(round(filters * bottleneck_expansion))
    total_key_depth = int(round(filters * key_expansion))
    total_value_depth = int(round(filters * value_expansion))

    if total_key_depth % num_heads:
      raise ValueError('Total_key_depth should be divisible by num_heads.')

    if total_value_depth % num_heads:
      raise ValueError('Total_value_depth should be divisible by num_heads.')

    # Compute query key value with one convolution and a batch norm layer. The
    # initialization std is standard transformer initialization (without batch
    # norm), as used in SASA and ViT. In our case, we use batch norm by default,
    # so it does not require careful tuning. If one wants to remove all batch
    # norms in axial attention, this standard initialization should still be
    # good, but a more careful initialization is encouraged.
    initialization_std = bottleneck_channels ** -0.5

    self._memory_conv1_bn_act = convolutions.Conv1D(
        bottleneck_channels, 'memory_conv1_bn_act',
        use_bias=False,
        use_bn=True,
        bn_layer=bn_layer,
        activation=activation,
        conv_kernel_weight_decay=conv_kernel_weight_decay)

    self._pixel_conv1_bn_act = convolutions.Conv1D(
        bottleneck_channels, 'pixel_conv1_bn_act',
        use_bias=False,
        use_bn=True,
        bn_layer=bn_layer,
        activation=activation,
        conv_kernel_weight_decay=conv_kernel_weight_decay)

    # We always compute the query for memory space, since it gathers information
    # from the pixel space and thus cannot be removed. We compute the key and
    # value for memory space only when they are necessary (i.e. either
    # use_memory_self_attention or use_pixel2memory_feedback_attention).
    if use_memory_self_attention or use_pixel2memory_feedback_attention:
      self._memory_qkv_conv_bn = convolutions.Conv1D(
          total_key_depth * 2 + total_value_depth, 'memory_qkv_conv_bn',
          use_bias=False,
          use_bn=True,
          bn_layer=bn_layer,
          activation='none',
          conv_kernel_weight_decay=conv_kernel_weight_decay,
          kernel_initializer=tf.keras.initializers.TruncatedNormal(
              stddev=initialization_std))
    else:
      # Compute memory query only if memory key and value are not used.
      self._memory_query_conv_bn = convolutions.Conv1D(
          total_key_depth, 'memory_query_conv_bn',
          use_bias=False,
          use_bn=True,
          bn_layer=bn_layer,
          activation='none',
          conv_kernel_weight_decay=conv_kernel_weight_decay,
          kernel_initializer=tf.keras.initializers.TruncatedNormal(
              stddev=initialization_std))

    # For the pixel space, we always compute the key and value, since they
    # provide information for the memory space and thus cannot be removed. We
    # compute the query for pixel space only when it is necessary (i.e.
    # use_pixel2memory_feedback_attention is True).
    if use_pixel2memory_feedback_attention:
      self._pixel_qkv_conv_bn = convolutions.Conv1D(
          total_key_depth * 2 + total_value_depth, 'pixel_qkv_conv_bn',
          use_bias=False,
          use_bn=True,
          bn_layer=bn_layer,
          activation='none',
          conv_kernel_weight_decay=conv_kernel_weight_decay,
          kernel_initializer=tf.keras.initializers.TruncatedNormal(
              stddev=initialization_std))
    else:
      self._pixel_kv_conv_bn = convolutions.Conv1D(
          total_key_depth + total_value_depth, 'pixel_kv_conv_bn',
          use_bias=False,
          use_bn=True,
          bn_layer=bn_layer,
          activation='none',
          conv_kernel_weight_decay=conv_kernel_weight_decay,
          kernel_initializer=tf.keras.initializers.TruncatedNormal(
              stddev=initialization_std))
    self._memory_attention = AttentionOperation(
        'memory_attention', activation, transformer_activation,
        bn_layer=bn_layer)
    if use_pixel2memory_feedback_attention:
      self._pixel_attention = AttentionOperation(
          'pixel_attention', activation, transformer_activation,
          bn_layer=bn_layer)

    self._use_memory_self_attention = use_memory_self_attention
    self._use_pixel2memory_feedback_attention = (
        use_pixel2memory_feedback_attention)
    self._total_key_depth = total_key_depth
    self._total_value_depth = total_value_depth
    self._num_heads = num_heads
    self._bn_layer = bn_layer
    self._conv_kernel_weight_decay = conv_kernel_weight_decay
    self._activation = activation
    self._activation_fn = activations.get_activation(activation)
    self._feed_forward_network_channels = feed_forward_network_channels

  def build(self, input_shape_list):
    pixel_shape, memory_shape = input_shape_list[:2]
    # Here we follow ResNet bottleneck blocks: we apply a batch norm with gamma
    # initialized at zero, followed by drop path and an activation function.
    # Initializing this gamma at zero ensures that at random initialization of
    # the model, the skip connections dominate all residual blocks. In this way,
    # all the skip connections construct an identity mapping that passes the
    # gradients (without any distortion from the randomly initialized blocks) to
    # all residual blocks. This helps training at early epochs.
    # Reference: "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour".
    # https://arxiv.org/abs/1706.02677
    self._memory_conv3_bn = convolutions.Conv1D(
        memory_shape[-1], 'memory_conv3_bn',
        use_bias=False,
        use_bn=True,
        bn_layer=self._bn_layer,
        bn_gamma_initializer='zeros',
        activation='none',
        conv_kernel_weight_decay=self._conv_kernel_weight_decay)

    if self._feed_forward_network_channels > 0:
      self._memory_ffn_conv1_bn_act = convolutions.Conv1D(
          self._feed_forward_network_channels, 'memory_ffn_conv1_bn_act',
          use_bias=False,
          use_bn=True,
          bn_layer=self._bn_layer,
          activation=self._activation,
          conv_kernel_weight_decay=self._conv_kernel_weight_decay)
      # Again, we follow ResNet bottleneck blocks: we apply a batch norm with
      # gamma initialized at zero, followed by drop path and an activation
      # function.
      self._memory_ffn_conv2_bn = convolutions.Conv1D(
          memory_shape[-1], 'memory_ffn_conv2_bn',
          use_bias=False,
          use_bn=True,
          bn_layer=self._bn_layer,
          bn_gamma_initializer='zeros',
          activation='none',
          conv_kernel_weight_decay=self._conv_kernel_weight_decay)
    if self._use_pixel2memory_feedback_attention:
      self._pixel_conv3_bn = convolutions.Conv1D(
          pixel_shape[-1], 'pixel_conv3_bn',
          use_bias=False,
          use_bn=True,
          bn_layer=self._bn_layer,
          bn_gamma_initializer='zeros',
          activation='none',
          conv_kernel_weight_decay=self._conv_kernel_weight_decay)

  def call(self, inputs):
    """Performs a forward pass.

    We have to define drop_path_masks outside the layer call and pass it into
    the layer call, because recompute_grad (gradient checkpointing) does not
    allow any randomness within the function call. In addition, recompute_grad
    only supports float tensors as inputs. For this reason, the training flag
    should be also passed as a float tensor. For the same reason, we cannot
    support passing drop_path_random_mask as None. Instead, we ask the users to
    pass only the first two tensors when drop path is not used.

    Args:
      inputs: A tuple of 3 or 6 tensors, containing
        pixel_space_input should be a [batch, num_pixel, pixel_space_channels]
          tensor.
        memory_space_input should be a [batch, num_memory,
          memory_space_channels] tensor.
        float_tensor_training should be a float tensor of 0.0 or 1.0, whether
          the model is in training mode.
        (optional) pixel_space_drop_path_mask is a drop path mask tensor of
          shape [batch, 1, 1] for the pixel space.
        (optional) memory_space_attention_drop_path_mask is a drop path mask
          tensor of shape [batch, 1, 1] for the memory space.
        (optional) memory_space_feed_forward_network_drop_path_mask is a drop
          path mask tensor of shape [batch, 1, 1] for the memory space feed
          forward network.

    Returns:
      pixel_space_output: A [batch, num_pixel, pixel_space_channels] tensor.
      activated_pixel_space_output: A [batch, num_pixel, pixel_space_channels]
        tensor, activated pixel_space_output.
      memory_space_output: A [batch, num_memory, memory_space_channels]
        tensor.

    Raises:
      ValueError: If the length of inputs is not 3 or 6.
    """
    if len(inputs) not in (3, 6):
      raise ValueError('The length of inputs should be either 3 or 6.')

    # Unpack the inputs.
    (pixel_space_input, memory_space_input, float_tensor_training,
     pixel_space_drop_path_mask, memory_space_attention_drop_path_mask,
     memory_space_feed_forward_network_drop_path_mask) = (
         utils.pad_sequence_with_none(inputs, target_length=6))

    # Recompute_grad takes only float tensors as inputs. It does not allow
    # bools or boolean tensors. For this reason, we cast training to a float
    # tensor outside this call, and now we cast it back to a boolean tensor.
    training = tf.cast(float_tensor_training, tf.bool)

    # Decode the inputs shapes.
    pixel_shape = pixel_space_input.get_shape().as_list()
    memory_shape = memory_space_input.get_shape().as_list()

    # Similar to the ResNet bottleneck design, we do an input down projection
    # in both the pixel space and the memory space.
    memory_space = self._memory_conv1_bn_act(memory_space_input,
                                             training=training)

    # Pixel space input is not activated.
    pixel_space = self._pixel_conv1_bn_act(
        self._activation_fn(pixel_space_input), training=training)

    if (self._use_memory_self_attention or
        self._use_pixel2memory_feedback_attention):
      memory_space_qkv = self._memory_qkv_conv_bn(memory_space,
                                                  training=training)
      # Split, reshape, and transpose the query, key, and value.
      memory_query, memory_key, memory_value = (
          tf.split(memory_space_qkv, [
              self._total_key_depth, self._total_key_depth,
              self._total_value_depth], axis=-1))
      memory_key = utils.reshape_and_transpose_for_attention_operation(
          memory_key, self._num_heads)
      memory_value = tf.reshape(memory_value, [
          -1, memory_shape[1], self._num_heads,
          self._total_value_depth // self._num_heads])
    else:
      # Compute memory query only if memory key and value are not used.
      memory_query = self._memory_query_conv_bn(memory_space,
                                                training=training)
    # Reshape and transpose the query.
    memory_query = utils.reshape_and_transpose_for_attention_operation(
        memory_query, self._num_heads)

    if self._use_pixel2memory_feedback_attention:
      pixel_space_qkv = self._pixel_qkv_conv_bn(pixel_space,
                                                training=training)
      # Split the query, key, and value.
      pixel_query, pixel_key, pixel_value = tf.split(
          pixel_space_qkv, [
              self._total_key_depth, self._total_key_depth,
              self._total_value_depth], axis=-1)
      pixel_query = utils.reshape_and_transpose_for_attention_operation(
          pixel_query, self._num_heads)
    else:
      pixel_space_kv = self._pixel_kv_conv_bn(pixel_space, training=training)
      # Split the key and the value.
      pixel_key, pixel_value = tf.split(pixel_space_kv, [
          self._total_key_depth, self._total_value_depth], axis=-1)
    # Reshape and transpose the key and the value.
    pixel_key = utils.reshape_and_transpose_for_attention_operation(
        pixel_key, self._num_heads)
    pixel_value = tf.reshape(pixel_value, [
        -1, pixel_shape[1], self._num_heads,
        self._total_value_depth // self._num_heads])

    # Compute memory space attention.
    if not self._use_memory_self_attention:
      # If memory self attention is not used, then only memory2pixel cross
      # attention is used for the memory space. In this case, the key and the
      # value are simply pixel_key and pixel_value.
      memory_attention_key = pixel_key
      memory_attention_value = pixel_value
    else:
      # If we also use memory self attention, the key and the value are the
      # concatenation of keys and values in both the pixel space and the
      # memory space.
      memory_attention_key = tf.concat([pixel_key, memory_key], axis=2)
      memory_attention_value = tf.concat([pixel_value, memory_value], axis=1)

    memory_space = self._memory_attention(
        (memory_query, memory_attention_key, memory_attention_value),
        training=training)
    memory_space = self._memory_conv3_bn(memory_space, training=training)

    if memory_space_attention_drop_path_mask is not None:
      memory_space = memory_space * memory_space_attention_drop_path_mask
    memory_space_output = self._activation_fn(
        memory_space_input + memory_space)

    # Apply an optional feed-forward network to the memory space.
    if self._feed_forward_network_channels > 0:
      memory_space = self._memory_ffn_conv1_bn_act(memory_space_output,
                                                   training=training)
      memory_space = self._memory_ffn_conv2_bn(memory_space,
                                               training=training)
      if memory_space_feed_forward_network_drop_path_mask is not None:
        memory_space = (memory_space *
                        memory_space_feed_forward_network_drop_path_mask)
      memory_space_output = self._activation_fn(
          memory_space_output + memory_space)

    # Compute pixel space attention and the output projection only when
    # pixel2memory_feedback_attention is used.
    if self._use_pixel2memory_feedback_attention:
      pixel_space = self._pixel_attention(
          (pixel_query, memory_key, memory_value), training=training)
      pixel_space = self._pixel_conv3_bn(pixel_space, training=training)
      if pixel_space_drop_path_mask is not None:
        pixel_space = pixel_space * pixel_space_drop_path_mask
      pixel_space_output = pixel_space_input + pixel_space
    else:
      # If pixel2memory_feedback_attention is not used, the pixel_space_input
      # is not changed.
      pixel_space_output = pixel_space_input
    activated_pixel_space_output = self._activation_fn(pixel_space_output)

    # Return the pixel space output and memory space output. Note that we
    # return pixel sapce output with and without the activation function,
    # because our decoder might use non-activated features.
    return (pixel_space_output,
            activated_pixel_space_output,
            memory_space_output)