File size: 11,178 Bytes
506da10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implements relative [1, 2, 3] and global [3, 4] positional encodings.

Our Axial-Deeplab [1] proposes position-sensitive self-attention which uses
relative positional encodings for query, key, and value.

[1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation,
    ECCV 2020 Spotlight.
      Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille,
      Liang-Chieh Chen.
[2] Self-Attention with Relative Position Representations, NAACL 2018.
      Peter Shaw, Jakob Uszkoreit, Ashish Vaswani.
[3] Tensor2Tensor for Neural Machine Translation, arXiv 2018,
    http://arxiv.org/abs/1803.07416.
      Ashish Vaswani, Samy Bengio, Eugene Brevdo, Francois Chollet,
      Aidan N. Gomez, Stephan Gouws, Llion Jones, Łukasz Kaiser,
      Nal Kalchbrenner, Niki Parmar, Ryan Sepassi, Noam Shazeer,
      Jakob Uszkoreit.
[4] Attention Is All You Need, NeurIPS 2017.
      Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
      Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin.
[5] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale,
    ICLR 2021.
      Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn,
      Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer,
      Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
"""

import tensorflow as tf

# MAX_SPAN defines the maximum shape of positional encoding. It is set as a
# large constant so that we can easily load and use models with global or
# different local spans, but it should not be too large so that it takes a
# reasonable amount of memory. The value 255 is larger than almost all span
# choices (e.g. 65 for local attention, 129, 193, etc.) so 255 is large enough.
# 257 will be a good choice for gpu, but 255 is more efficient on TPU which pads
# tensors to 128x.
MAX_SPAN = 255


def _compute_relative_distance_matrix(query_length, key_length):
  """Computes a relative distance matrix between queries and keys.

  We assume that the queries and the keys are centered, i.e.,
  key_length = memory_flange + query_length + memory_flange.

  The function is based on the _generate_relative_positions_matrix function in
  common_attention.py of tensor2tensor codebase:
  https://github.com/tensorflow/tensor2tensor/blob/5623deb79cfcd28f8f8c5463b58b5bd76a81fd0d/tensor2tensor/layers/common_attention.py#L1670

  Args:
    query_length: An integer, the length of queries.
    key_length: An integer, the length of keys.

  Returns:
    distance_matrix: A [query_length, key_length] tensor.

  Raises:
    ValueError: If (key_length - query_length) is odd, i.e., the assumption does
      not hold.
  """
  if (key_length - query_length) % 2:
    raise ValueError('Key_length should be query_length + 2 * memory_flange.')
  key_index = tf.range(key_length)
  query_index = tf.range(query_length) + (key_length - query_length) // 2
  distance_matrix = key_index[None, :] - query_index[:, None]
  # Shift the distance_matrix so that it is >= 0. Each entry of the
  # distance_matrix distance will index a relative positional embedding.
  distance_matrix = distance_matrix + MAX_SPAN - 1
  if query_length + (key_length - query_length) // 2 > MAX_SPAN:
    tf.logging.warn('Axial attention span is larger than MAX_SPAN. In this '
                    'case, we use a single shared embedding for all positions '
                    'beyond this relative distance. Please make sure, this '
                    'behavior is intended.')
    distance_matrix = tf.clip_by_value(distance_matrix, 0, MAX_SPAN * 2 - 2)
  return distance_matrix


class RelativePositionalEncoding(tf.keras.layers.Layer):
  """Generates relative positional encoding.

  The function is based on the _generate_relative_positions_embeddings function
  in common_attention.py of tensor2tensor codebase:
  https://github.com/tensorflow/tensor2tensor/blob/5623deb79cfcd28f8f8c5463b58b5bd76a81fd0d/tensor2tensor/layers/common_attention.py#L1691
  """

  def __init__(self, query_length, key_length, depth, num_heads, name,
               initialization_std=1.0, conv_kernel_weight_decay=0.0):
    """Initializes a relative position encoding layer.

    Args:
      query_length: An integer, the length of queries.
      key_length: An integer, the length of keys.
      depth: An integer, the number of embedding channels per head.
      num_heads: An integer, the number of heads in multi-head attention.
      name: A string, the name of the embedding.
      initialization_std: A float, the initialization std for the embedding.
      conv_kernel_weight_decay: A float, the weight decay for convolution
        kernels.

    Returns:
      output: A [num_heads, query, key, depth] tensor, the relative positional
        encodings for each head and each query-key-pair.
    """
    super(RelativePositionalEncoding, self).__init__(name=name)
    self._initializer = tf.keras.initializers.TruncatedNormal(
        stddev=initialization_std)
    self._regularizer = tf.keras.regularizers.l2(conv_kernel_weight_decay)

    self._relative_distance_matrix = _compute_relative_distance_matrix(
        query_length, key_length)
    self._num_heads = num_heads
    self._embedding_shape = (MAX_SPAN * 2 - 1, depth)

  def build(self, input_shape):
    """Builds the embedding weight."""
    del input_shape
    self._embeddings = self.add_weight(
        shape=self._embedding_shape,
        initializer=self._initializer, trainable=True,
        name='embeddings',
        regularizer=self._regularizer)

  def call(self, inputs):
    """A forward pass that gathers the relative positional encoding."""
    del inputs
    # Gather the embeddings according to the relative distances.
    embeddings = tf.gather(self._embeddings, self._relative_distance_matrix)
    return tf.tile(tf.expand_dims(embeddings, axis=0),
                   [self._num_heads, 1, 1, 1])


class AddAbsolutePositionalEncoding(tf.keras.layers.Layer):
  """Adds a learnable absolute positional encoding to the input feature.

  Supports both 1D and 2D versions of the positional encoding: (1) 1D positional
  encoding represents each row index with an embedding, and represents each
  column index with another embedding. This results in a total of (height +
  width) learnable embedding vectors. (2) 2D positional encoding adds
  independent embeddings to each input grid position. This choice uses a total
  of (height * width) learnable embedding vectors.
  """

  def __init__(self, name, positional_encoding_type=None,
               bn_layer=tf.keras.layers.BatchNormalization,
               conv_kernel_weight_decay=0.0):
    """Initializes an AddAbsolutePositionEmbedding layer.

    Args:
      name: A string specifying the name of the layer.
      positional_encoding_type: A string, type of the positional encoding.
        Support '2D', '1D', 'none', and None. The feature is returned as is if
        positional_encoding_type is 'none' or None.
      bn_layer: An optional tf.keras.layers.Layer that computes the
        normalization (default: tf.keras.layers.BatchNormalization).
      conv_kernel_weight_decay: A float, the weight decay for convolution
        kernels.

    Raises:
      ValueError: If positional_encoding_type is not one of '1D', '2D', 'none',
        and None.
    """
    super(AddAbsolutePositionalEncoding, self).__init__(name=name)
    if not any([positional_encoding_type is None,
                positional_encoding_type.lower() == 'none',
                positional_encoding_type.lower() == '2d',
                positional_encoding_type.lower() == '1d']):
      raise ValueError(positional_encoding_type + ' is not supported.')
    self._positional_encoding_type = positional_encoding_type
    # This initialization std is tuned for global attention, but it does not
    # seem to be a sensitive hyper-parameter, since we use batch norm on the
    # positional encodings.
    self._initializer = tf.keras.initializers.TruncatedNormal(stddev=0.2)
    self._kernel_regularizer = tf.keras.regularizers.l2(
        conv_kernel_weight_decay)
    self._bn_layer = bn_layer

  def build(self, input_shape):
    """Builds the layer weights whose shape depends on the 4D input shape."""
    _, height, width, channel = input_shape
    if self._positional_encoding_type.lower() == '2d':
      self._embeddings = self.add_weight(
          shape=(1, height, width, channel),
          initializer=self._initializer, trainable=True,
          name='embeddings',
          regularizer=self._kernel_regularizer)
      self._batch_norm = self._bn_layer(axis=-1, name='batch_norm')
    elif self._positional_encoding_type.lower() == '1d':
      # Generate separable positional encodings for the height axis and the
      # width axis.
      self._height_axis_embeddings = self.add_weight(
          shape=(1, height, 1, channel),
          initializer=self._initializer, trainable=True,
          name='height_axis_embeddings',
          regularizer=self._kernel_regularizer)
      self._height_axis_batch_norm = self._bn_layer(
          axis=-1, name='height_axis_batch_norm')
      self._width_axis_embeddings = self.add_weight(
          shape=(1, height, 1, channel),
          initializer=self._initializer, trainable=True,
          name='width_axis_embeddings',
          regularizer=self._kernel_regularizer)
      self._width_axis_batch_norm = self._bn_layer(
          axis=-1, name='width_axis_batch_norm')

  def call(self, features, training=False):
    """Performs a forward pass.

    Args:
      features: An input [batch, height, width, channels] tensor.
      training: A boolean, whether the model is in training mode.

    Returns:
      output: The sum of the input feature and learnable positional encodings.
    """
    if (self._positional_encoding_type is None or
        self._positional_encoding_type.lower() == 'none'):
      return features
    elif self._positional_encoding_type.lower() == '2d':
      positional_encoding = self._batch_norm(self._embeddings,
                                             training=training)
    elif self._positional_encoding_type.lower() == '1d':
      height_axis_positional_encoding = self._height_axis_batch_norm(
          self._height_axis_embeddings, training=training)
      width_axis_positional_encoding = self._width_axis_batch_norm(
          self._width_axis_embeddings, training=training)
      positional_encoding = (height_axis_positional_encoding +
                             width_axis_positional_encoding)
    return features + positional_encoding