Spaces:
Paused
Paused
voice_clone_v3
/
transformers
/examples
/research_projects
/performer
/modeling_flax_performer_utils.py
# coding=utf-8 | |
# Copyright 2020 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
IMPORTANT: | |
This code was copied from | |
https://github.com/google-research/google-research/blob/master/performer/fast_self_attention/fast_self_attention.py on | |
6/11/2020. This is very new code, so it might be prone to change soon -> make sure to check the original code and | |
update accordingly | |
Core Fast Attention Module for Flax. Implementation of the approximate fast softmax and generalized attention mechanism | |
leveraging structured random feature maps [RFM] techniques and low rank decomposition of the attention matrix. | |
""" | |
# pylint: disable=invalid-name, missing-function-docstring, line-too-long | |
import abc | |
import functools | |
from collections.abc import Iterable # pylint: disable=g-importing-member | |
import jax | |
import jax.numpy as jnp | |
import numpy as onp | |
from absl import logging | |
from jax import lax, random | |
def nonnegative_softmax_kernel_feature_creator( | |
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True, eps=0.0001 | |
): | |
""" | |
Constructs nonnegative kernel features for fast softmax attention | |
Args: | |
data: input for which features are computes | |
projection_matrix: random matrix used to compute features | |
attention_dims_t: tuple of attention dimensions | |
batch_dims_t: tuple of batch dimensions | |
precision: precision parameter | |
is_query: predicate indicating whether input data corresponds to queries or | |
keys | |
normalize_data: predicate indicating whether data should be normalized, | |
eps: numerical stabilizer | |
Returns: | |
Random features for fast softmax attention. | |
""" | |
del attention_dims_t | |
if normalize_data: | |
# We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where | |
# w_norm = w * data_normalizer for w in {q,k}. | |
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) | |
else: | |
data_normalizer = 1.0 | |
ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0]) | |
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape | |
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix | |
data_dash = lax.dot_general( | |
data_normalizer * data, | |
data_thick_random_matrix, | |
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)), | |
precision=precision, | |
) | |
diag_data = jnp.square(data) | |
diag_data = jnp.sum(diag_data, axis=data.ndim - 1) | |
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer | |
diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1) | |
if is_query: | |
last_dims_t = (len(data_dash.shape) - 1,) | |
data_dash = ratio * ( | |
jnp.exp(data_dash - diag_data - jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps | |
) | |
else: | |
data_dash = ratio * (jnp.exp(data_dash - diag_data - jnp.max(data_dash)) + eps) | |
return data_dash | |
def sincos_softmax_kernel_feature_creator( | |
data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data=True | |
): | |
""" | |
Constructs kernel sin-cos features for fast softmax attention | |
Args: | |
data: input for which features are computes | |
projection_matrix: random matrix used to compute features | |
attention_dims_t: tuple of attention dimensions | |
batch_dims_t: tuple of batch dimensions | |
precision: precision parameter | |
normalize_data: predicate indicating whether data should be normalized | |
Returns: | |
Random features for fast softmax attention. | |
""" | |
if normalize_data: | |
# We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) * | |
# exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}. | |
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) | |
else: | |
data_normalizer = 1.0 | |
ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0]) | |
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape | |
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix | |
data_dash = lax.dot_general( | |
data_normalizer * data, | |
data_thick_random_matrix, | |
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)), | |
precision=precision, | |
) | |
data_dash_cos = ratio * jnp.cos(data_dash) | |
data_dash_sin = ratio * jnp.sin(data_dash) | |
data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1) | |
# Constructing D_data and data^{'} | |
diag_data = jnp.square(data) | |
diag_data = jnp.sum(diag_data, axis=data.ndim - 1) | |
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer | |
diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1) | |
# Additional renormalization for numerical stability | |
data_renormalizer = jnp.max(diag_data, attention_dims_t, keepdims=True) | |
diag_data -= data_renormalizer | |
diag_data = jnp.exp(diag_data) | |
data_prime = data_dash * diag_data | |
return data_prime | |
def generalized_kernel_feature_creator( | |
data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data | |
): | |
""" | |
Constructs kernel features for fast generalized attention | |
Args: | |
data: input for which features are computes | |
projection_matrix: matrix used to compute features | |
batch_dims_t: tuple of batch dimensions | |
precision: precision parameter | |
kernel_fn: kernel function used | |
kernel_epsilon: additive positive term added to every feature for numerical | |
stability | |
normalize_data: predicate indicating whether data should be normalized | |
Returns: | |
Random features for fast generalized attention. | |
""" | |
if normalize_data: | |
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) | |
else: | |
data_normalizer = 1.0 | |
if projection_matrix is None: | |
return kernel_fn(data_normalizer * data) + kernel_epsilon | |
else: | |
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape | |
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix | |
data_dash = lax.dot_general( | |
data_normalizer * data, | |
data_thick_random_matrix, | |
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)), | |
precision=precision, | |
) | |
data_prime = kernel_fn(data_dash) + kernel_epsilon | |
return data_prime | |
def make_fast_softmax_attention( | |
qkv_dim, | |
renormalize_attention=True, | |
numerical_stabilizer=0.000001, | |
nb_features=256, | |
ortho_features=True, | |
ortho_scaling=0.0, | |
redraw_features=True, | |
unidirectional=False, | |
nonnegative_features=True, | |
lax_scan_unroll=1, | |
): | |
"""Construct a fast softmax attention method.""" | |
logging.info( | |
"Fast softmax attention: %s features and orthogonal=%s, renormalize=%s", | |
nb_features, | |
ortho_features, | |
renormalize_attention, | |
) | |
if ortho_features: | |
matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=ortho_scaling) | |
else: | |
matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim) | |
if nonnegative_features: | |
def kernel_feature_creator( | |
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True | |
): | |
return nonnegative_softmax_kernel_feature_creator( | |
data, | |
projection_matrix, | |
attention_dims_t, | |
batch_dims_t, | |
precision, | |
is_query, | |
normalize_data, | |
numerical_stabilizer, | |
) | |
else: | |
def kernel_feature_creator( | |
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True | |
): | |
del is_query | |
return sincos_softmax_kernel_feature_creator( | |
data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data | |
) | |
attention_fn = FastAttentionviaLowRankDecomposition( | |
matrix_creator, | |
kernel_feature_creator, | |
renormalize_attention=renormalize_attention, | |
numerical_stabilizer=numerical_stabilizer, | |
redraw_features=redraw_features, | |
unidirectional=unidirectional, | |
lax_scan_unroll=lax_scan_unroll, | |
).dot_product_attention | |
return attention_fn | |
def make_fast_generalized_attention( | |
qkv_dim, | |
renormalize_attention=True, | |
numerical_stabilizer=0.0, | |
nb_features=256, | |
features_type="deterministic", | |
kernel_fn=jax.nn.relu, | |
kernel_epsilon=0.001, | |
redraw_features=False, | |
unidirectional=False, | |
lax_scan_unroll=1, | |
): | |
"""Construct a fast generalized attention menthod.""" | |
logging.info("Fast generalized attention.: %s features and renormalize=%s", nb_features, renormalize_attention) | |
if features_type == "ortho": | |
matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=False) | |
elif features_type == "iid": | |
matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim) | |
elif features_type == "deterministic": | |
matrix_creator = None | |
else: | |
raise ValueError("Unknown feature value type") | |
def kernel_feature_creator( | |
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=False | |
): | |
del attention_dims_t | |
del is_query | |
return generalized_kernel_feature_creator( | |
data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data | |
) | |
attention_fn = FastAttentionviaLowRankDecomposition( | |
matrix_creator, | |
kernel_feature_creator, | |
renormalize_attention=renormalize_attention, | |
numerical_stabilizer=numerical_stabilizer, | |
redraw_features=redraw_features, | |
unidirectional=unidirectional, | |
lax_scan_unroll=lax_scan_unroll, | |
).dot_product_attention | |
return attention_fn | |
class RandomMatrix(object): | |
r""" | |
Abstract class providing a method for constructing 2D random arrays. Class is responsible for constructing 2D | |
random arrays. | |
""" | |
__metaclass__ = abc.ABCMeta | |
def get_2d_array(self): | |
raise NotImplementedError("Abstract method") | |
class GaussianUnstructuredRandomMatrix(RandomMatrix): | |
def __init__(self, nb_rows, nb_columns, key): | |
self.nb_rows = nb_rows | |
self.nb_columns = nb_columns | |
self.key = key | |
def get_2d_array(self): | |
return random.normal(self.key, (self.nb_rows, self.nb_columns)) | |
class GaussianOrthogonalRandomMatrix(RandomMatrix): | |
r""" | |
Class providing a method to create Gaussian orthogonal matrix. Class is responsible for constructing 2D Gaussian | |
orthogonal arrays. | |
""" | |
def __init__(self, nb_rows, nb_columns, key, scaling=0): | |
self.nb_rows = nb_rows | |
self.nb_columns = nb_columns | |
self.key = key | |
self.scaling = scaling | |
def get_2d_array(self): | |
nb_full_blocks = int(self.nb_rows / self.nb_columns) | |
block_list = [] | |
rng = self.key | |
for _ in range(nb_full_blocks): | |
rng, rng_input = jax.random.split(rng) | |
unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns)) | |
q, _ = jnp.linalg.qr(unstructured_block) | |
q = jnp.transpose(q) | |
block_list.append(q) | |
remaining_rows = self.nb_rows - nb_full_blocks * self.nb_columns | |
if remaining_rows > 0: | |
rng, rng_input = jax.random.split(rng) | |
unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns)) | |
q, _ = jnp.linalg.qr(unstructured_block) | |
q = jnp.transpose(q) | |
block_list.append(q[0:remaining_rows]) | |
final_matrix = jnp.vstack(block_list) | |
if self.scaling == 0: | |
multiplier = jnp.linalg.norm(random.normal(self.key, (self.nb_rows, self.nb_columns)), axis=1) | |
elif self.scaling == 1: | |
multiplier = jnp.sqrt(float(self.nb_columns)) * jnp.ones((self.nb_rows)) | |
else: | |
raise ValueError("Scaling must be one of {0, 1}. Was %s" % self._scaling) | |
return jnp.matmul(jnp.diag(multiplier), final_matrix) | |
class FastAttention(object): | |
r""" | |
Abstract class providing a method for fast attention. Class is responsible for providing a method | |
<dot_product_attention> for fast approximate attention. | |
""" | |
__metaclass__ = abc.ABCMeta | |
def dot_product_attention( | |
self, | |
query, | |
key, | |
value, | |
dtype=jnp.float32, | |
bias=None, | |
axis=None, | |
broadcast_dropout=True, | |
dropout_rng=None, | |
dropout_rate=0.0, | |
deterministic=False, | |
precision=None, | |
): | |
""" | |
Computes dot-product attention given query, key, and value. This is the core function for applying fast | |
approximate dot-product attention. It calculates the attention weights given query and key and combines the | |
values using the attention weights. This function supports multi-dimensional inputs | |
Args: | |
query: queries for calculating attention with shape of [batch_size, dim1, | |
dim2, ..., dimN, num_heads, mem_channels]. | |
key: keys for calculating attention with shape of [batch_size, dim1, dim2, | |
..., dimN, num_heads, mem_channels]. | |
value: values to be used in attention with shape of [batch_size, dim1, | |
dim2,..., dimN, num_heads, value_channels]. | |
dtype: the dtype of the computation (default: float32) | |
bias: bias for the attention weights. This can be used for incorporating | |
autoregressive mask, padding mask, proximity bias. | |
axis: axises over which the attention is applied. | |
broadcast_dropout: bool: use a broadcasted dropout along batch dims. | |
dropout_rng: JAX PRNGKey: to be used for dropout. | |
dropout_rate: dropout rate. | |
deterministic: bool, deterministic or not (to apply dropout). | |
precision: numerical precision of the computation see `jax.lax.Precision` | |
for details | |
Returns: | |
Output of shape [bs, dim1, dim2, ..., dimN,, num_heads, value_channels]. | |
""" | |
raise NotImplementedError("Abstract method") | |
def _numerator(z_slice_shape, precision, unroll=1): | |
def fwd(qs, ks, vs): | |
def body(p, qkv): | |
(q, k, v) = qkv | |
p += jnp.einsum("...m,...d->...md", k, v, precision=precision) | |
X_slice = jnp.einsum("...m,...md->...d", q, p, precision=precision) | |
return p, X_slice | |
init_value = jnp.zeros(z_slice_shape) | |
p, W = lax.scan(body, init_value, (qs, ks, vs), unroll=unroll) | |
return W, (p, qs, ks, vs) | |
def bwd(pqkv, W_ct): | |
def body(carry, qkv_xct): | |
p, p_ct = carry | |
q, k, v, x_ct = qkv_xct | |
q_ct = jnp.einsum("...d,...md->...m", x_ct, p, precision=precision) | |
p_ct += jnp.einsum("...d,...m->...md", x_ct, q, precision=precision) | |
k_ct = jnp.einsum("...md,...d->...m", p_ct, v, precision=precision) | |
v_ct = jnp.einsum("...md,...m->...d", p_ct, k, precision=precision) | |
p -= jnp.einsum("...m,...d->...md", k, v, precision=precision) | |
return (p, p_ct), (q_ct, k_ct, v_ct) | |
p, qs, ks, vs = pqkv | |
_, (qs_ct, ks_ct, vs_ct) = lax.scan( | |
body, (p, jnp.zeros_like(p)), (qs, ks, vs, W_ct), reverse=True, unroll=unroll | |
) | |
return qs_ct, ks_ct, vs_ct | |
def _numerator_impl(qs, ks, vs): | |
W, _ = fwd(qs, ks, vs) | |
return W | |
_numerator_impl.defvjp(fwd, bwd) | |
return _numerator_impl | |
def _denominator(t_slice_shape, precision, unroll=1): | |
def fwd(qs, ks): | |
def body(p, qk): | |
q, k = qk | |
p += k | |
x = jnp.einsum("...m,...m->...", q, p, precision=precision) | |
return p, x | |
p = jnp.zeros(t_slice_shape) | |
p, R = lax.scan(body, p, (qs, ks), unroll=unroll) | |
return R, (qs, ks, p) | |
def bwd(qkp, R_ct): | |
def body(carry, qkx): | |
p, p_ct = carry | |
q, k, x_ct = qkx | |
q_ct = jnp.einsum("...,...m->...m", x_ct, p, precision=precision) | |
p_ct += jnp.einsum("...,...m->...m", x_ct, q, precision=precision) | |
k_ct = p_ct | |
p -= k | |
return (p, p_ct), (q_ct, k_ct) | |
qs, ks, p = qkp | |
_, (qs_ct, ks_ct) = lax.scan(body, (p, jnp.zeros_like(p)), (qs, ks, R_ct), reverse=True, unroll=unroll) | |
return (qs_ct, ks_ct) | |
def _denominator_impl(qs, ks): | |
R, _ = fwd(qs, ks) | |
return R | |
_denominator_impl.defvjp(fwd, bwd) | |
return _denominator_impl | |
class FastAttentionviaLowRankDecomposition(FastAttention): | |
r""" | |
Class providing a method for fast attention via low rank decomposition. Class is responsible for providing a method | |
<dot_product_attention> for fast dot-product attention with the use of low rank decomposition (e.g. with random | |
feature maps). | |
""" | |
def __init__( | |
self, | |
matrix_creator, | |
kernel_feature_creator, | |
renormalize_attention, | |
numerical_stabilizer, | |
redraw_features, | |
unidirectional, | |
lax_scan_unroll=1, | |
): # For optimal GPU performance, set to 16. | |
rng = random.PRNGKey(0) | |
self.matrix_creator = matrix_creator | |
self.projection_matrix = self.draw_weights(rng) | |
self.kernel_feature_creator = kernel_feature_creator | |
self.renormalize_attention = renormalize_attention | |
self.numerical_stabilizer = numerical_stabilizer | |
self.redraw_features = redraw_features | |
self.unidirectional = unidirectional | |
self.lax_scan_unroll = lax_scan_unroll | |
def draw_weights(self, key): | |
if self.matrix_creator is None: | |
return None | |
matrixrng, _ = random.split(key) | |
projection_matrix = self.matrix_creator(key=matrixrng).get_2d_array() | |
return projection_matrix | |
def dot_product_attention( | |
self, | |
query, | |
key, | |
value, | |
dtype=jnp.float32, | |
bias=None, | |
axis=None, | |
broadcast_dropout=True, | |
dropout_rng=None, | |
dropout_rate=0.0, | |
deterministic=False, | |
precision=None, | |
): | |
assert key.shape[:-1] == value.shape[:-1] | |
assert query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1] | |
if axis is None: | |
axis = tuple(range(1, key.ndim - 2)) | |
if not isinstance(axis, Iterable): | |
axis = (axis,) | |
assert key.ndim == query.ndim | |
assert key.ndim == value.ndim | |
for ax in axis: | |
if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): | |
raise ValueError("Attention axis must be between the batch axis and the last-two axes.") | |
n = key.ndim | |
# Constructing projection tensor. | |
if self.redraw_features: | |
# TODO(kchoro): Get rid of the constant below. | |
query_seed = lax.convert_element_type(jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32) | |
rng = random.PRNGKey(query_seed) | |
self.projection_matrix = self.draw_weights(rng) | |
# batch_dims is <bs, <non-attention dims>, num_heads> | |
batch_dims = tuple(onp.delete(range(n), axis + (n - 1,))) | |
# q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels) | |
qk_perm = batch_dims + axis + (n - 1,) | |
k_extra_perm = axis + batch_dims + (n - 1,) | |
key_extra = key.transpose(k_extra_perm) | |
key = key.transpose(qk_perm) | |
query = query.transpose(qk_perm) | |
# v -> (bs, <non-attention dims>, num_heads, <attention dims>, channels) | |
v_perm = batch_dims + axis + (n - 1,) | |
value = value.transpose(v_perm) | |
batch_dims_t = tuple(range(len(batch_dims))) | |
attention_dims_t = tuple(range(len(batch_dims), len(batch_dims) + len(axis))) | |
# Constructing tensors Q^{'} and K^{'}. | |
query_prime = self.kernel_feature_creator( | |
query, self.projection_matrix, attention_dims_t, batch_dims_t, precision, True | |
) | |
key_prime = self.kernel_feature_creator( | |
key, self.projection_matrix, attention_dims_t, batch_dims_t, precision, False | |
) | |
if self.unidirectional: | |
index = attention_dims_t[0] | |
z_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],) + (value.shape[-1],) | |
numerator_fn = _numerator(z_slice_shape, precision, self.lax_scan_unroll) | |
W = numerator_fn( | |
jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0), jnp.moveaxis(value, index, 0) | |
) | |
# Constructing W = (Q^{'}(K^{'})^{T})_{masked}V | |
W = jnp.moveaxis(W, 0, index) | |
if not self.renormalize_attention: | |
# Unidirectional, not-normalized attention. | |
perm_inv = _invert_perm(qk_perm) | |
result = W.transpose(perm_inv) | |
return result | |
else: | |
# Unidirectional, normalized attention. | |
thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)]) | |
index = attention_dims_t[0] | |
t_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],) | |
denominator_fn = _denominator(t_slice_shape, precision, self.lax_scan_unroll) | |
R = denominator_fn(jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0)) | |
R = jnp.moveaxis(R, 0, index) | |
else: | |
contract_query = tuple(range(len(batch_dims) + len(axis), len(batch_dims) + len(axis) + 1)) | |
contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1)) | |
# Constructing Z = (K^{'})^{T}V | |
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v) | |
Z = lax.dot_general( | |
key_prime, | |
value, | |
((attention_dims_t, attention_dims_t), (batch_dims_t, batch_dims_t)), | |
precision=precision, | |
) | |
# Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V | |
# q (bs, <non-attention dims>, num_heads, <attention dims>, channels_m) | |
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v) | |
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v) | |
W = lax.dot_general( | |
query_prime, Z, ((contract_query, contract_z), (batch_dims_t, batch_dims_t)), precision=precision | |
) | |
if not self.renormalize_attention: | |
# Bidirectional, not-normalized attention. | |
perm_inv = _invert_perm(qk_perm) | |
result = W.transpose(perm_inv) | |
return result | |
else: | |
# Bidirectional, normalized attention. | |
thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)]) | |
contract_key = tuple(range(len(batch_dims), len(batch_dims) + len(axis))) | |
contract_thick_all_ones = tuple(range(thick_all_ones.ndim - len(axis), thick_all_ones.ndim)) | |
# Construct T = (K^{'})^{T} 1_L | |
# k (bs, <non-attention dims>, num_heads, <attention dims>, channels) | |
T = lax.dot_general( | |
key_prime, | |
thick_all_ones, | |
((contract_key, contract_thick_all_ones), (batch_dims_t, batch_dims_t)), | |
precision=precision, | |
) | |
# Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L | |
# q_p (bs, <non-attention dims>, num_heads, <attention dims>, channs_m) | |
# T (bs, <non-attention dims>, num_heads, channels_m) | |
R = lax.dot_general( | |
query_prime, | |
T, | |
(((query_prime.ndim - 1,), (T.ndim - 1,)), (batch_dims_t, range(0, len(T.shape) - 1))), | |
precision=precision, | |
) | |
R = R + 2 * self.numerical_stabilizer * (jnp.abs(R) <= self.numerical_stabilizer) | |
R = jnp.reciprocal(R) | |
R = jnp.expand_dims(R, len(R.shape)) | |
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v) | |
# R (bs, <non-attention dims>, num_heads, <attention dims>, extra_channel) | |
result = W * R | |
# back to (bs, dim1, dim2, ..., dimN, num_heads, channels) | |
perm_inv = _invert_perm(qk_perm) | |
result = result.transpose(perm_inv) | |
return result | |
def _invert_perm(perm): | |
perm_inv = [0] * len(perm) | |
for i, j in enumerate(perm): | |
perm_inv[j] = i | |
return tuple(perm_inv) | |