Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Keras-based positional embedding layer.""" | |
# pylint: disable=g-classes-have-attributes | |
import math | |
from typing import Optional | |
import tensorflow as tf, tf_keras | |
from official.modeling import tf_utils | |
Initializer = tf_keras.initializers.Initializer | |
class PositionEmbedding(tf_keras.layers.Layer): | |
"""Creates a positional embedding. | |
Example: | |
```python | |
position_embedding = PositionEmbedding(max_length=100) | |
inputs = tf_keras.Input((100, 32), dtype=tf.float32) | |
outputs = position_embedding(inputs) | |
``` | |
Args: | |
max_length: The maximum size of the dynamic sequence. | |
initializer: The initializer to use for the embedding weights. Defaults to | |
"glorot_uniform". | |
seq_axis: The axis of the input tensor where we add the embeddings. | |
Reference: This layer creates a positional embedding as described in | |
[BERT: Pre-training of Deep Bidirectional Transformers for Language | |
Understanding](https://arxiv.org/abs/1810.04805). | |
""" | |
def __init__(self, | |
max_length, | |
initializer="glorot_uniform", | |
seq_axis=1, | |
**kwargs): | |
super().__init__(**kwargs) | |
if max_length is None: | |
raise ValueError( | |
"`max_length` must be an Integer, not `None`." | |
) | |
self._max_length = max_length | |
self._initializer = tf_keras.initializers.get(initializer) | |
self._seq_axis = seq_axis | |
def get_config(self): | |
config = { | |
"max_length": self._max_length, | |
"initializer": tf_keras.initializers.serialize(self._initializer), | |
"seq_axis": self._seq_axis, | |
} | |
base_config = super(PositionEmbedding, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def build(self, input_shape): | |
dimension_list = input_shape | |
width = dimension_list[-1] | |
weight_sequence_length = self._max_length | |
self._position_embeddings = self.add_weight( | |
"embeddings", | |
shape=[weight_sequence_length, width], | |
initializer=self._initializer) | |
super().build(input_shape) | |
def call(self, inputs): | |
input_shape = tf.shape(inputs) | |
actual_seq_len = input_shape[self._seq_axis] | |
position_embeddings = self._position_embeddings[:actual_seq_len, :] | |
new_shape = [1 for _ in inputs.get_shape().as_list()] | |
new_shape[self._seq_axis] = actual_seq_len | |
new_shape[-1] = position_embeddings.get_shape().as_list()[-1] | |
position_embeddings = tf.reshape(position_embeddings, new_shape) | |
return tf.broadcast_to(position_embeddings, input_shape) | |
class RelativePositionEmbedding(tf_keras.layers.Layer): | |
"""Creates a positional embedding. | |
This layer calculates the position encoding as a mix of sine and cosine | |
functions with geometrically increasing wavelengths. Defined and formulized in | |
"Attention is All You Need", section 3.5. | |
(https://arxiv.org/abs/1706.03762). | |
Args: | |
hidden_size: Size of the hidden layer. | |
min_timescale: Minimum scale that will be applied at each position | |
max_timescale: Maximum scale that will be applied at each position. | |
""" | |
def __init__(self, | |
hidden_size: int, | |
min_timescale: float = 1.0, | |
max_timescale: float = 1.0e4, | |
**kwargs): | |
# We need to have a default dtype of float32, since the inputs (which Keras | |
# usually uses to infer the dtype) will always be int32. | |
# We compute the positional encoding in float32 even if the model uses | |
# float16, as many of the ops used, like log and exp, are numerically | |
# unstable in float16. | |
if "dtype" not in kwargs: | |
kwargs["dtype"] = "float32" | |
super().__init__(**kwargs) | |
self._hidden_size = hidden_size | |
self._min_timescale = min_timescale | |
self._max_timescale = max_timescale | |
def get_config(self): | |
config = { | |
"hidden_size": self._hidden_size, | |
"min_timescale": self._min_timescale, | |
"max_timescale": self._max_timescale, | |
} | |
base_config = super(RelativePositionEmbedding, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, inputs, length=None): | |
"""Implements call() for the layer. | |
Args: | |
inputs: An tensor whose second dimension will be used as `length`. If | |
`None`, the other `length` argument must be specified. | |
length: An optional integer specifying the number of positions. If both | |
`inputs` and `length` are spcified, `length` must be equal to the second | |
dimension of `inputs`. | |
Returns: | |
A tensor in shape of `(length, hidden_size)`. | |
""" | |
if inputs is None and length is None: | |
raise ValueError("If inputs is None, `length` must be set in " | |
"RelativePositionEmbedding().") | |
if inputs is not None: | |
input_shape = tf_utils.get_shape_list(inputs) | |
if length is not None and length != input_shape[1]: | |
raise ValueError( | |
"If inputs is not None, `length` must equal to input_shape[1].") | |
length = input_shape[1] | |
position = tf.cast(tf.range(length), tf.float32) | |
num_timescales = self._hidden_size // 2 | |
min_timescale, max_timescale = self._min_timescale, self._max_timescale | |
log_timescale_increment = ( | |
math.log(float(max_timescale) / float(min_timescale)) / | |
(tf.cast(num_timescales, tf.float32) - 1)) | |
inv_timescales = min_timescale * tf.exp( | |
tf.cast(tf.range(num_timescales), tf.float32) * | |
-log_timescale_increment) | |
scaled_time = tf.expand_dims(position, 1) * tf.expand_dims( | |
inv_timescales, 0) | |
position_embeddings = tf.concat( | |
[tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) | |
return position_embeddings | |
def _relative_position_bucket(relative_position, | |
bidirectional=True, | |
num_buckets=32, | |
max_distance=128): | |
"""Translate relative position to a bucket number for relative attention. | |
The relative position is defined as memory_position - query_position, i.e. | |
the distance in tokens from the attending position to the attended-to | |
position. | |
If `bidirectional=False`, then positive relative positions are invalid. | |
We use smaller buckets for small absolute relative_position and larger | |
buckets for larger absolute relative_positions. | |
All relative positions >=max_distance map to the same bucket. | |
All relative positions <=-max_distance map to the same bucket. | |
This should allow for more graceful generalization to longer sequences | |
than the model has been trained on. | |
Args: | |
relative_position: An int32 Tensor | |
bidirectional: A boolean - whether the attention is bidirectional | |
num_buckets: An integer | |
max_distance: An integer | |
Returns: | |
A Tensor with the same shape as relative_position, containing int32 | |
values in the range [0, num_buckets) | |
""" | |
ret = 0 | |
n = -relative_position | |
if bidirectional: | |
num_buckets //= 2 | |
ret += tf.cast(tf.math.less(n, 0), tf.int32) * num_buckets | |
n = tf.math.abs(n) | |
else: | |
n = tf.math.maximum(n, 0) | |
# now n is in the range [0, inf) | |
max_exact = num_buckets // 2 | |
is_small = tf.math.less(n, max_exact) | |
val_if_large = max_exact + tf.dtypes.cast( | |
tf.math.log(tf.cast(n, tf.float32) / max_exact) / | |
math.log(max_distance / max_exact) * (num_buckets - max_exact), | |
tf.int32, | |
) | |
val_if_large = tf.math.minimum(val_if_large, num_buckets - 1) | |
ret += tf.where(is_small, n, val_if_large) | |
return ret | |
class RelativePositionBias(tf_keras.layers.Layer): | |
"""Relative position embedding via per-head bias in T5 style. | |
Reference implementation in MeshTF: | |
https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L1000 | |
This layer implements the relative position bias used in "Exploring the Limits | |
of Transfer Learning with a Unified Text-to-Text Transformer" | |
(https://arxiv.org/abs/1910.10683) | |
""" | |
def __init__(self, | |
num_heads: int, | |
relative_attention_num_buckets: int = 32, | |
relative_attention_max_distance: int = 128, | |
bidirectional: bool = True, | |
embeddings_initializer: Optional[Initializer] = None, | |
**kwargs): | |
super().__init__(**kwargs) | |
self.num_heads = num_heads | |
self.relative_attention_num_buckets = relative_attention_num_buckets | |
self.bidirectional = bidirectional | |
self.relative_attention_max_distance = relative_attention_max_distance | |
if embeddings_initializer: | |
self._embed_init = embeddings_initializer | |
else: | |
self._embed_init = tf_keras.initializers.TruncatedNormal(stddev=1.0) | |
with tf.name_scope(self.name): | |
self._relative_attention_bias = self.add_weight( | |
"rel_embedding", | |
shape=[self.relative_attention_num_buckets, self.num_heads], | |
initializer=self._embed_init, | |
dtype=self.dtype, | |
trainable=True) | |
def get_config(self): | |
config = { | |
"num_heads": | |
self.num_heads, | |
"relative_attention_num_buckets": | |
self.relative_attention_num_buckets, | |
"relative_attention_max_distance": | |
self.relative_attention_max_distance, | |
"bidirectional": | |
self.bidirectional, | |
"embeddings_initializer": | |
tf_keras.initializers.serialize(self._embed_init), | |
} | |
base_config = super().get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, query: tf.Tensor, key: tf.Tensor): | |
"""Implements the forward pass. | |
Args: | |
query: query input tensor shape [batch, query length, hidden size]. | |
key: key input tensor shape [batch, key length, hidden size]. | |
Returns: | |
A tensor in shape of [batch, heads, query length, key length]. | |
""" | |
batch_size, qlen = tf_utils.get_shape_list(query)[:2] | |
klen = tf_utils.get_shape_list(key)[1] | |
context_position = tf.range(qlen)[:, None] | |
memory_position = tf.range(klen)[None, :] | |
relative_position = memory_position - context_position | |
rp_bucket = _relative_position_bucket( | |
relative_position, | |
bidirectional=self.bidirectional, | |
num_buckets=self.relative_attention_num_buckets, | |
max_distance=self.relative_attention_max_distance) | |
values = tf.nn.embedding_lookup(self._relative_attention_bias, rp_bucket) | |
values = tf.expand_dims( | |
tf.transpose(values, [2, 0, 1]), | |
axis=0) # shape (1, num_heads, qlen, klen) | |
values = tf.tile(values, [batch_size, 1, 1, 1]) | |
return values | |