deanna-emery's picture
updates
93528c6
raw
history blame
11.3 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based positional embedding layer."""
# pylint: disable=g-classes-have-attributes
import math
from typing import Optional
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
Initializer = tf_keras.initializers.Initializer
@tf_keras.utils.register_keras_serializable(package="Text")
class PositionEmbedding(tf_keras.layers.Layer):
"""Creates a positional embedding.
Example:
```python
position_embedding = PositionEmbedding(max_length=100)
inputs = tf_keras.Input((100, 32), dtype=tf.float32)
outputs = position_embedding(inputs)
```
Args:
max_length: The maximum size of the dynamic sequence.
initializer: The initializer to use for the embedding weights. Defaults to
"glorot_uniform".
seq_axis: The axis of the input tensor where we add the embeddings.
Reference: This layer creates a positional embedding as described in
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805).
"""
def __init__(self,
max_length,
initializer="glorot_uniform",
seq_axis=1,
**kwargs):
super().__init__(**kwargs)
if max_length is None:
raise ValueError(
"`max_length` must be an Integer, not `None`."
)
self._max_length = max_length
self._initializer = tf_keras.initializers.get(initializer)
self._seq_axis = seq_axis
def get_config(self):
config = {
"max_length": self._max_length,
"initializer": tf_keras.initializers.serialize(self._initializer),
"seq_axis": self._seq_axis,
}
base_config = super(PositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
dimension_list = input_shape
width = dimension_list[-1]
weight_sequence_length = self._max_length
self._position_embeddings = self.add_weight(
"embeddings",
shape=[weight_sequence_length, width],
initializer=self._initializer)
super().build(input_shape)
def call(self, inputs):
input_shape = tf.shape(inputs)
actual_seq_len = input_shape[self._seq_axis]
position_embeddings = self._position_embeddings[:actual_seq_len, :]
new_shape = [1 for _ in inputs.get_shape().as_list()]
new_shape[self._seq_axis] = actual_seq_len
new_shape[-1] = position_embeddings.get_shape().as_list()[-1]
position_embeddings = tf.reshape(position_embeddings, new_shape)
return tf.broadcast_to(position_embeddings, input_shape)
@tf_keras.utils.register_keras_serializable(package="Text")
class RelativePositionEmbedding(tf_keras.layers.Layer):
"""Creates a positional embedding.
This layer calculates the position encoding as a mix of sine and cosine
functions with geometrically increasing wavelengths. Defined and formulized in
"Attention is All You Need", section 3.5.
(https://arxiv.org/abs/1706.03762).
Args:
hidden_size: Size of the hidden layer.
min_timescale: Minimum scale that will be applied at each position
max_timescale: Maximum scale that will be applied at each position.
"""
def __init__(self,
hidden_size: int,
min_timescale: float = 1.0,
max_timescale: float = 1.0e4,
**kwargs):
# We need to have a default dtype of float32, since the inputs (which Keras
# usually uses to infer the dtype) will always be int32.
# We compute the positional encoding in float32 even if the model uses
# float16, as many of the ops used, like log and exp, are numerically
# unstable in float16.
if "dtype" not in kwargs:
kwargs["dtype"] = "float32"
super().__init__(**kwargs)
self._hidden_size = hidden_size
self._min_timescale = min_timescale
self._max_timescale = max_timescale
def get_config(self):
config = {
"hidden_size": self._hidden_size,
"min_timescale": self._min_timescale,
"max_timescale": self._max_timescale,
}
base_config = super(RelativePositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, length=None):
"""Implements call() for the layer.
Args:
inputs: An tensor whose second dimension will be used as `length`. If
`None`, the other `length` argument must be specified.
length: An optional integer specifying the number of positions. If both
`inputs` and `length` are spcified, `length` must be equal to the second
dimension of `inputs`.
Returns:
A tensor in shape of `(length, hidden_size)`.
"""
if inputs is None and length is None:
raise ValueError("If inputs is None, `length` must be set in "
"RelativePositionEmbedding().")
if inputs is not None:
input_shape = tf_utils.get_shape_list(inputs)
if length is not None and length != input_shape[1]:
raise ValueError(
"If inputs is not None, `length` must equal to input_shape[1].")
length = input_shape[1]
position = tf.cast(tf.range(length), tf.float32)
num_timescales = self._hidden_size // 2
min_timescale, max_timescale = self._min_timescale, self._max_timescale
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(tf.cast(num_timescales, tf.float32) - 1))
inv_timescales = min_timescale * tf.exp(
tf.cast(tf.range(num_timescales), tf.float32) *
-log_timescale_increment)
scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(
inv_timescales, 0)
position_embeddings = tf.concat(
[tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
return position_embeddings
def _relative_position_bucket(relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128):
"""Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position.
If `bidirectional=False`, then positive relative positions are invalid.
We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions.
All relative positions >=max_distance map to the same bucket.
All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences
than the model has been trained on.
Args:
relative_position: An int32 Tensor
bidirectional: A boolean - whether the attention is bidirectional
num_buckets: An integer
max_distance: An integer
Returns:
A Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += tf.cast(tf.math.less(n, 0), tf.int32) * num_buckets
n = tf.math.abs(n)
else:
n = tf.math.maximum(n, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = tf.math.less(n, max_exact)
val_if_large = max_exact + tf.dtypes.cast(
tf.math.log(tf.cast(n, tf.float32) / max_exact) /
math.log(max_distance / max_exact) * (num_buckets - max_exact),
tf.int32,
)
val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
ret += tf.where(is_small, n, val_if_large)
return ret
@tf_keras.utils.register_keras_serializable(package="Text")
class RelativePositionBias(tf_keras.layers.Layer):
"""Relative position embedding via per-head bias in T5 style.
Reference implementation in MeshTF:
https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L1000
This layer implements the relative position bias used in "Exploring the Limits
of Transfer Learning with a Unified Text-to-Text Transformer"
(https://arxiv.org/abs/1910.10683)
"""
def __init__(self,
num_heads: int,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
bidirectional: bool = True,
embeddings_initializer: Optional[Initializer] = None,
**kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.bidirectional = bidirectional
self.relative_attention_max_distance = relative_attention_max_distance
if embeddings_initializer:
self._embed_init = embeddings_initializer
else:
self._embed_init = tf_keras.initializers.TruncatedNormal(stddev=1.0)
with tf.name_scope(self.name):
self._relative_attention_bias = self.add_weight(
"rel_embedding",
shape=[self.relative_attention_num_buckets, self.num_heads],
initializer=self._embed_init,
dtype=self.dtype,
trainable=True)
def get_config(self):
config = {
"num_heads":
self.num_heads,
"relative_attention_num_buckets":
self.relative_attention_num_buckets,
"relative_attention_max_distance":
self.relative_attention_max_distance,
"bidirectional":
self.bidirectional,
"embeddings_initializer":
tf_keras.initializers.serialize(self._embed_init),
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, query: tf.Tensor, key: tf.Tensor):
"""Implements the forward pass.
Args:
query: query input tensor shape [batch, query length, hidden size].
key: key input tensor shape [batch, key length, hidden size].
Returns:
A tensor in shape of [batch, heads, query length, key length].
"""
batch_size, qlen = tf_utils.get_shape_list(query)[:2]
klen = tf_utils.get_shape_list(key)[1]
context_position = tf.range(qlen)[:, None]
memory_position = tf.range(klen)[None, :]
relative_position = memory_position - context_position
rp_bucket = _relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance)
values = tf.nn.embedding_lookup(self._relative_attention_bias, rp_bucket)
values = tf.expand_dims(
tf.transpose(values, [2, 0, 1]),
axis=0) # shape (1, num_heads, qlen, klen)
values = tf.tile(values, [batch_size, 1, 1, 1])
return values