# coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ IMPORTANT: This code was copied from https://github.com/google-research/google-research/blob/master/performer/fast_self_attention/fast_self_attention.py on 6/11/2020. This is very new code, so it might be prone to change soon -> make sure to check the original code and update accordingly Core Fast Attention Module for Flax. Implementation of the approximate fast softmax and generalized attention mechanism leveraging structured random feature maps [RFM] techniques and low rank decomposition of the attention matrix. """ # pylint: disable=invalid-name, missing-function-docstring, line-too-long import abc import functools from collections.abc import Iterable # pylint: disable=g-importing-member import jax import jax.numpy as jnp import numpy as onp from absl import logging from jax import lax, random def nonnegative_softmax_kernel_feature_creator( data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True, eps=0.0001 ): """ Constructs nonnegative kernel features for fast softmax attention Args: data: input for which features are computes projection_matrix: random matrix used to compute features attention_dims_t: tuple of attention dimensions batch_dims_t: tuple of batch dimensions precision: precision parameter is_query: predicate indicating whether input data corresponds to queries or keys normalize_data: predicate indicating whether data should be normalized, eps: numerical stabilizer Returns: Random features for fast softmax attention. """ del attention_dims_t if normalize_data: # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where # w_norm = w * data_normalizer for w in {q,k}. data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) else: data_normalizer = 1.0 ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0]) data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix data_dash = lax.dot_general( data_normalizer * data, data_thick_random_matrix, (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)), precision=precision, ) diag_data = jnp.square(data) diag_data = jnp.sum(diag_data, axis=data.ndim - 1) diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1) if is_query: last_dims_t = (len(data_dash.shape) - 1,) data_dash = ratio * ( jnp.exp(data_dash - diag_data - jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps ) else: data_dash = ratio * (jnp.exp(data_dash - diag_data - jnp.max(data_dash)) + eps) return data_dash def sincos_softmax_kernel_feature_creator( data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data=True ): """ Constructs kernel sin-cos features for fast softmax attention Args: data: input for which features are computes projection_matrix: random matrix used to compute features attention_dims_t: tuple of attention dimensions batch_dims_t: tuple of batch dimensions precision: precision parameter normalize_data: predicate indicating whether data should be normalized Returns: Random features for fast softmax attention. """ if normalize_data: # We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) * # exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}. data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) else: data_normalizer = 1.0 ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0]) data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix data_dash = lax.dot_general( data_normalizer * data, data_thick_random_matrix, (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)), precision=precision, ) data_dash_cos = ratio * jnp.cos(data_dash) data_dash_sin = ratio * jnp.sin(data_dash) data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1) # Constructing D_data and data^{'} diag_data = jnp.square(data) diag_data = jnp.sum(diag_data, axis=data.ndim - 1) diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1) # Additional renormalization for numerical stability data_renormalizer = jnp.max(diag_data, attention_dims_t, keepdims=True) diag_data -= data_renormalizer diag_data = jnp.exp(diag_data) data_prime = data_dash * diag_data return data_prime def generalized_kernel_feature_creator( data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data ): """ Constructs kernel features for fast generalized attention Args: data: input for which features are computes projection_matrix: matrix used to compute features batch_dims_t: tuple of batch dimensions precision: precision parameter kernel_fn: kernel function used kernel_epsilon: additive positive term added to every feature for numerical stability normalize_data: predicate indicating whether data should be normalized Returns: Random features for fast generalized attention. """ if normalize_data: data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) else: data_normalizer = 1.0 if projection_matrix is None: return kernel_fn(data_normalizer * data) + kernel_epsilon else: data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix data_dash = lax.dot_general( data_normalizer * data, data_thick_random_matrix, (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)), precision=precision, ) data_prime = kernel_fn(data_dash) + kernel_epsilon return data_prime def make_fast_softmax_attention( qkv_dim, renormalize_attention=True, numerical_stabilizer=0.000001, nb_features=256, ortho_features=True, ortho_scaling=0.0, redraw_features=True, unidirectional=False, nonnegative_features=True, lax_scan_unroll=1, ): """Construct a fast softmax attention method.""" logging.info( "Fast softmax attention: %s features and orthogonal=%s, renormalize=%s", nb_features, ortho_features, renormalize_attention, ) if ortho_features: matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=ortho_scaling) else: matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim) if nonnegative_features: def kernel_feature_creator( data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True ): return nonnegative_softmax_kernel_feature_creator( data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data, numerical_stabilizer, ) else: def kernel_feature_creator( data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True ): del is_query return sincos_softmax_kernel_feature_creator( data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data ) attention_fn = FastAttentionviaLowRankDecomposition( matrix_creator, kernel_feature_creator, renormalize_attention=renormalize_attention, numerical_stabilizer=numerical_stabilizer, redraw_features=redraw_features, unidirectional=unidirectional, lax_scan_unroll=lax_scan_unroll, ).dot_product_attention return attention_fn def make_fast_generalized_attention( qkv_dim, renormalize_attention=True, numerical_stabilizer=0.0, nb_features=256, features_type="deterministic", kernel_fn=jax.nn.relu, kernel_epsilon=0.001, redraw_features=False, unidirectional=False, lax_scan_unroll=1, ): """Construct a fast generalized attention menthod.""" logging.info("Fast generalized attention.: %s features and renormalize=%s", nb_features, renormalize_attention) if features_type == "ortho": matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=False) elif features_type == "iid": matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim) elif features_type == "deterministic": matrix_creator = None else: raise ValueError("Unknown feature value type") def kernel_feature_creator( data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=False ): del attention_dims_t del is_query return generalized_kernel_feature_creator( data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data ) attention_fn = FastAttentionviaLowRankDecomposition( matrix_creator, kernel_feature_creator, renormalize_attention=renormalize_attention, numerical_stabilizer=numerical_stabilizer, redraw_features=redraw_features, unidirectional=unidirectional, lax_scan_unroll=lax_scan_unroll, ).dot_product_attention return attention_fn class RandomMatrix(object): r""" Abstract class providing a method for constructing 2D random arrays. Class is responsible for constructing 2D random arrays. """ __metaclass__ = abc.ABCMeta @abc.abstractmethod def get_2d_array(self): raise NotImplementedError("Abstract method") class GaussianUnstructuredRandomMatrix(RandomMatrix): def __init__(self, nb_rows, nb_columns, key): self.nb_rows = nb_rows self.nb_columns = nb_columns self.key = key def get_2d_array(self): return random.normal(self.key, (self.nb_rows, self.nb_columns)) class GaussianOrthogonalRandomMatrix(RandomMatrix): r""" Class providing a method to create Gaussian orthogonal matrix. Class is responsible for constructing 2D Gaussian orthogonal arrays. """ def __init__(self, nb_rows, nb_columns, key, scaling=0): self.nb_rows = nb_rows self.nb_columns = nb_columns self.key = key self.scaling = scaling def get_2d_array(self): nb_full_blocks = int(self.nb_rows / self.nb_columns) block_list = [] rng = self.key for _ in range(nb_full_blocks): rng, rng_input = jax.random.split(rng) unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns)) q, _ = jnp.linalg.qr(unstructured_block) q = jnp.transpose(q) block_list.append(q) remaining_rows = self.nb_rows - nb_full_blocks * self.nb_columns if remaining_rows > 0: rng, rng_input = jax.random.split(rng) unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns)) q, _ = jnp.linalg.qr(unstructured_block) q = jnp.transpose(q) block_list.append(q[0:remaining_rows]) final_matrix = jnp.vstack(block_list) if self.scaling == 0: multiplier = jnp.linalg.norm(random.normal(self.key, (self.nb_rows, self.nb_columns)), axis=1) elif self.scaling == 1: multiplier = jnp.sqrt(float(self.nb_columns)) * jnp.ones((self.nb_rows)) else: raise ValueError("Scaling must be one of {0, 1}. Was %s" % self._scaling) return jnp.matmul(jnp.diag(multiplier), final_matrix) class FastAttention(object): r""" Abstract class providing a method for fast attention. Class is responsible for providing a method for fast approximate attention. """ __metaclass__ = abc.ABCMeta @abc.abstractmethod def dot_product_attention( self, query, key, value, dtype=jnp.float32, bias=None, axis=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, precision=None, ): """ Computes dot-product attention given query, key, and value. This is the core function for applying fast approximate dot-product attention. It calculates the attention weights given query and key and combines the values using the attention weights. This function supports multi-dimensional inputs Args: query: queries for calculating attention with shape of [batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]. key: keys for calculating attention with shape of [batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]. value: values to be used in attention with shape of [batch_size, dim1, dim2,..., dimN, num_heads, value_channels]. dtype: the dtype of the computation (default: float32) bias: bias for the attention weights. This can be used for incorporating autoregressive mask, padding mask, proximity bias. axis: axises over which the attention is applied. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout. dropout_rate: dropout rate. deterministic: bool, deterministic or not (to apply dropout). precision: numerical precision of the computation see `jax.lax.Precision` for details Returns: Output of shape [bs, dim1, dim2, ..., dimN,, num_heads, value_channels]. """ raise NotImplementedError("Abstract method") def _numerator(z_slice_shape, precision, unroll=1): def fwd(qs, ks, vs): def body(p, qkv): (q, k, v) = qkv p += jnp.einsum("...m,...d->...md", k, v, precision=precision) X_slice = jnp.einsum("...m,...md->...d", q, p, precision=precision) return p, X_slice init_value = jnp.zeros(z_slice_shape) p, W = lax.scan(body, init_value, (qs, ks, vs), unroll=unroll) return W, (p, qs, ks, vs) def bwd(pqkv, W_ct): def body(carry, qkv_xct): p, p_ct = carry q, k, v, x_ct = qkv_xct q_ct = jnp.einsum("...d,...md->...m", x_ct, p, precision=precision) p_ct += jnp.einsum("...d,...m->...md", x_ct, q, precision=precision) k_ct = jnp.einsum("...md,...d->...m", p_ct, v, precision=precision) v_ct = jnp.einsum("...md,...m->...d", p_ct, k, precision=precision) p -= jnp.einsum("...m,...d->...md", k, v, precision=precision) return (p, p_ct), (q_ct, k_ct, v_ct) p, qs, ks, vs = pqkv _, (qs_ct, ks_ct, vs_ct) = lax.scan( body, (p, jnp.zeros_like(p)), (qs, ks, vs, W_ct), reverse=True, unroll=unroll ) return qs_ct, ks_ct, vs_ct @jax.custom_vjp def _numerator_impl(qs, ks, vs): W, _ = fwd(qs, ks, vs) return W _numerator_impl.defvjp(fwd, bwd) return _numerator_impl def _denominator(t_slice_shape, precision, unroll=1): def fwd(qs, ks): def body(p, qk): q, k = qk p += k x = jnp.einsum("...m,...m->...", q, p, precision=precision) return p, x p = jnp.zeros(t_slice_shape) p, R = lax.scan(body, p, (qs, ks), unroll=unroll) return R, (qs, ks, p) def bwd(qkp, R_ct): def body(carry, qkx): p, p_ct = carry q, k, x_ct = qkx q_ct = jnp.einsum("...,...m->...m", x_ct, p, precision=precision) p_ct += jnp.einsum("...,...m->...m", x_ct, q, precision=precision) k_ct = p_ct p -= k return (p, p_ct), (q_ct, k_ct) qs, ks, p = qkp _, (qs_ct, ks_ct) = lax.scan(body, (p, jnp.zeros_like(p)), (qs, ks, R_ct), reverse=True, unroll=unroll) return (qs_ct, ks_ct) @jax.custom_vjp def _denominator_impl(qs, ks): R, _ = fwd(qs, ks) return R _denominator_impl.defvjp(fwd, bwd) return _denominator_impl class FastAttentionviaLowRankDecomposition(FastAttention): r""" Class providing a method for fast attention via low rank decomposition. Class is responsible for providing a method for fast dot-product attention with the use of low rank decomposition (e.g. with random feature maps). """ def __init__( self, matrix_creator, kernel_feature_creator, renormalize_attention, numerical_stabilizer, redraw_features, unidirectional, lax_scan_unroll=1, ): # For optimal GPU performance, set to 16. rng = random.PRNGKey(0) self.matrix_creator = matrix_creator self.projection_matrix = self.draw_weights(rng) self.kernel_feature_creator = kernel_feature_creator self.renormalize_attention = renormalize_attention self.numerical_stabilizer = numerical_stabilizer self.redraw_features = redraw_features self.unidirectional = unidirectional self.lax_scan_unroll = lax_scan_unroll def draw_weights(self, key): if self.matrix_creator is None: return None matrixrng, _ = random.split(key) projection_matrix = self.matrix_creator(key=matrixrng).get_2d_array() return projection_matrix def dot_product_attention( self, query, key, value, dtype=jnp.float32, bias=None, axis=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, precision=None, ): assert key.shape[:-1] == value.shape[:-1] assert query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1] if axis is None: axis = tuple(range(1, key.ndim - 2)) if not isinstance(axis, Iterable): axis = (axis,) assert key.ndim == query.ndim assert key.ndim == value.ndim for ax in axis: if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): raise ValueError("Attention axis must be between the batch axis and the last-two axes.") n = key.ndim # Constructing projection tensor. if self.redraw_features: # TODO(kchoro): Get rid of the constant below. query_seed = lax.convert_element_type(jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32) rng = random.PRNGKey(query_seed) self.projection_matrix = self.draw_weights(rng) # batch_dims is , num_heads> batch_dims = tuple(onp.delete(range(n), axis + (n - 1,))) # q & k -> (bs, , num_heads, , channels) qk_perm = batch_dims + axis + (n - 1,) k_extra_perm = axis + batch_dims + (n - 1,) key_extra = key.transpose(k_extra_perm) key = key.transpose(qk_perm) query = query.transpose(qk_perm) # v -> (bs, , num_heads, , channels) v_perm = batch_dims + axis + (n - 1,) value = value.transpose(v_perm) batch_dims_t = tuple(range(len(batch_dims))) attention_dims_t = tuple(range(len(batch_dims), len(batch_dims) + len(axis))) # Constructing tensors Q^{'} and K^{'}. query_prime = self.kernel_feature_creator( query, self.projection_matrix, attention_dims_t, batch_dims_t, precision, True ) key_prime = self.kernel_feature_creator( key, self.projection_matrix, attention_dims_t, batch_dims_t, precision, False ) if self.unidirectional: index = attention_dims_t[0] z_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],) + (value.shape[-1],) numerator_fn = _numerator(z_slice_shape, precision, self.lax_scan_unroll) W = numerator_fn( jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0), jnp.moveaxis(value, index, 0) ) # Constructing W = (Q^{'}(K^{'})^{T})_{masked}V W = jnp.moveaxis(W, 0, index) if not self.renormalize_attention: # Unidirectional, not-normalized attention. perm_inv = _invert_perm(qk_perm) result = W.transpose(perm_inv) return result else: # Unidirectional, normalized attention. thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)]) index = attention_dims_t[0] t_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],) denominator_fn = _denominator(t_slice_shape, precision, self.lax_scan_unroll) R = denominator_fn(jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0)) R = jnp.moveaxis(R, 0, index) else: contract_query = tuple(range(len(batch_dims) + len(axis), len(batch_dims) + len(axis) + 1)) contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1)) # Constructing Z = (K^{'})^{T}V # Z (bs, , num_heads, channels_m, channels_v) Z = lax.dot_general( key_prime, value, ((attention_dims_t, attention_dims_t), (batch_dims_t, batch_dims_t)), precision=precision, ) # Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V # q (bs, , num_heads, , channels_m) # Z (bs, , num_heads, channels_m, channels_v) # W (bs, , num_heads, , channels_v) W = lax.dot_general( query_prime, Z, ((contract_query, contract_z), (batch_dims_t, batch_dims_t)), precision=precision ) if not self.renormalize_attention: # Bidirectional, not-normalized attention. perm_inv = _invert_perm(qk_perm) result = W.transpose(perm_inv) return result else: # Bidirectional, normalized attention. thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)]) contract_key = tuple(range(len(batch_dims), len(batch_dims) + len(axis))) contract_thick_all_ones = tuple(range(thick_all_ones.ndim - len(axis), thick_all_ones.ndim)) # Construct T = (K^{'})^{T} 1_L # k (bs, , num_heads, , channels) T = lax.dot_general( key_prime, thick_all_ones, ((contract_key, contract_thick_all_ones), (batch_dims_t, batch_dims_t)), precision=precision, ) # Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L # q_p (bs, , num_heads, , channs_m) # T (bs, , num_heads, channels_m) R = lax.dot_general( query_prime, T, (((query_prime.ndim - 1,), (T.ndim - 1,)), (batch_dims_t, range(0, len(T.shape) - 1))), precision=precision, ) R = R + 2 * self.numerical_stabilizer * (jnp.abs(R) <= self.numerical_stabilizer) R = jnp.reciprocal(R) R = jnp.expand_dims(R, len(R.shape)) # W (bs, , num_heads, , channels_v) # R (bs, , num_heads, , extra_channel) result = W * R # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) perm_inv = _invert_perm(qk_perm) result = result.transpose(perm_inv) return result def _invert_perm(perm): perm_inv = [0] * len(perm) for i, j in enumerate(perm): perm_inv[j] = i return tuple(perm_inv)