Spaces:
Runtime error
Runtime error
| # Copyright 2022 Google. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Class for T5 relative position biases. | |
| Adapted from flaxformer.components.relative_position_biases.py | |
| """ | |
| from typing import Any, Callable, Optional | |
| from flax import linen as nn | |
| import gin | |
| from jax import lax | |
| import jax.numpy as jnp | |
| from transformer import position | |
| import numpy as np | |
| Array = Any | |
| class T5RelativePositionBiases(nn.Module): | |
| """Adds T5-style relative positional embeddings to the attention logits. | |
| Attributes: | |
| num_buckets: Number of buckets to bucket distances between key and query | |
| positions into. | |
| max_distance: Maximum distance before everything is lumped into the last | |
| distance bucket. | |
| num_heads: Number of heads in the attention layer. Each head will get a | |
| different relative position weighting. | |
| dtype: Type of arrays through this module. | |
| embedding_init: initializer for relative embedding table. | |
| """ | |
| num_buckets: int | |
| max_distance: int | |
| num_heads: int | |
| dtype: Any | |
| embedding_init: Callable[..., Array] = nn.linear.default_embed_init | |
| def _relative_position_bucket(relative_position, | |
| bidirectional=True, | |
| num_buckets=32, | |
| max_distance=128): | |
| """Translate relative position to a bucket number for relative attention. | |
| The relative position is defined as memory_position - query_position, i.e. | |
| the distance in tokens from the attending position to the attended-to | |
| position. If bidirectional=False, then positive relative positions are | |
| invalid. | |
| We use smaller buckets for small absolute relative_position and larger | |
| buckets for larger absolute relative_positions. All relative | |
| positions >=max_distance map to the same bucket. All relative | |
| positions <=-max_distance map to the same bucket. This should allow for | |
| more graceful generalization to longer sequences than the model has been | |
| trained on. | |
| Args: | |
| relative_position: an int32 array | |
| bidirectional: a boolean - whether the attention is bidirectional | |
| num_buckets: an integer | |
| max_distance: an integer | |
| Returns: | |
| a Tensor with the same shape as relative_position, containing int32 | |
| values in the range [0, num_buckets) | |
| """ | |
| ret = 0 | |
| n = -relative_position | |
| if bidirectional: | |
| num_buckets //= 2 | |
| ret += (n < 0).astype(np.int32) * num_buckets | |
| n = np.abs(n) | |
| else: | |
| n = np.maximum(n, 0) | |
| # now n is in the range [0, inf) | |
| max_exact = num_buckets // 2 | |
| is_small = (n < max_exact) | |
| val_if_large = max_exact + ( | |
| np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) / | |
| np.log(max_distance / max_exact) * | |
| (num_buckets - max_exact)).astype(np.int32) | |
| val_if_large = np.minimum(val_if_large, num_buckets - 1) | |
| ret += np.where(is_small, n, val_if_large) | |
| return ret | |
| def __call__(self, num_queries, num_keys, offset: Optional[int]=None, | |
| bidirectional=True): | |
| """Produce relative position embedding attention biases. | |
| Args: | |
| num_queries: Number of queries. | |
| num_keys: Number of keys. | |
| offset: Offset of the first query with respect to the first key. | |
| (See position.relative_positions() for more info.) | |
| bidirectional: whether to allow positive memory-query relative position | |
| embeddings. | |
| Returns: | |
| output: `(1, num_heads, num_queries, num_keys)` attention bias | |
| """ | |
| # Find the distance between each query and each key. | |
| # This is where this implementation differs from the T5 implementation; | |
| # this version lines the /last/ N queries up with the /last/ N keys, | |
| # (which is appropriate for XL/sliding window) while the T5 version lines | |
| # up the /first/ N queries with the first N keys, in cases where the | |
| # number of keys and queries differ. | |
| relative_position = position.relative_positions_np( | |
| num_queries=num_queries, num_keys=num_keys, offset=offset) | |
| rp_bucket = self._relative_position_bucket( | |
| relative_position, | |
| bidirectional=bidirectional, | |
| num_buckets=self.num_buckets, | |
| max_distance=self.max_distance) | |
| relative_attention_bias = self.param('rel_embedding', self.embedding_init, | |
| (self.num_heads, self.num_buckets), | |
| jnp.float32) | |
| relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) | |
| # Instead of using a slow gather, we create a leading-dimension one-hot | |
| # array from rp_bucket and use it to perform the gather-equivalent via a | |
| # contraction, i.e.: | |
| # (num_head, num_buckets) x (num_buckets one-hot, num_queries, num_keys). | |
| # This is equivalent to relative_attention_bias[:, rp_bucket] | |
| bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) | |
| rp_bucket_one_hot = jnp.array( | |
| rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype) | |
| # --> shape (num_queries, num_keys, num_heads) | |
| values = lax.dot_general( | |
| relative_attention_bias, | |
| rp_bucket_one_hot, | |
| ( | |
| ((1,), (0,)), # rhs, lhs contracting dims | |
| ((), ()))) # no batched dims | |
| # Add a singleton batch dimension. | |
| # --> shape (1, num_heads, num_queries, num_keys) | |
| out = values[jnp.newaxis, ...] | |
| return out | |