# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Implementation of multiheaded attention and self-attention layers.""" import math import tensorflow as tf, tf_keras from official.modeling import tf_utils class Attention(tf_keras.layers.Layer): """Multi-headed attention layer.""" def __init__(self, hidden_size, num_heads, attention_dropout): """Initialize Attention. Args: hidden_size: int, output dim of hidden layer. num_heads: int, number of heads to repeat the same attention structure. attention_dropout: float, dropout rate inside attention for training. """ if hidden_size % num_heads: raise ValueError( "Hidden size ({}) must be divisible by the number of heads ({})." .format(hidden_size, num_heads)) super(Attention, self).__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.attention_dropout = attention_dropout def build(self, input_shape): """Builds the layer.""" # Layers for linearly projecting the queries, keys, and values. size_per_head = self.hidden_size // self.num_heads def _glorot_initializer(fan_in, fan_out): limit = math.sqrt(6.0 / (fan_in + fan_out)) return tf_keras.initializers.RandomUniform(minval=-limit, maxval=limit) attention_initializer = _glorot_initializer(input_shape.as_list()[-1], self.hidden_size) self.query_dense_layer = tf_keras.layers.EinsumDense( "BTE,ENH->BTNH", output_shape=(None, self.num_heads, size_per_head), kernel_initializer=tf_utils.clone_initializer(attention_initializer), bias_axes=None, name="query") self.key_dense_layer = tf_keras.layers.EinsumDense( "BTE,ENH->BTNH", output_shape=(None, self.num_heads, size_per_head), kernel_initializer=tf_utils.clone_initializer(attention_initializer), bias_axes=None, name="key") self.value_dense_layer = tf_keras.layers.EinsumDense( "BTE,ENH->BTNH", output_shape=(None, self.num_heads, size_per_head), kernel_initializer=tf_utils.clone_initializer(attention_initializer), bias_axes=None, name="value") output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size) self.output_dense_layer = tf_keras.layers.EinsumDense( "BTNH,NHE->BTE", output_shape=(None, self.hidden_size), kernel_initializer=output_initializer, bias_axes=None, name="output_transform") super(Attention, self).build(input_shape) def get_config(self): return { "hidden_size": self.hidden_size, "num_heads": self.num_heads, "attention_dropout": self.attention_dropout, } def call(self, query_input, source_input, bias, training, cache=None, decode_loop_step=None): """Apply attention mechanism to query_input and source_input. Args: query_input: A tensor with shape [batch_size, length_query, hidden_size]. source_input: A tensor with shape [batch_size, length_source, hidden_size]. bias: A tensor with shape [batch_size, 1, length_query, length_source], the attention bias that will be added to the result of the dot product. training: A bool, whether in training mode or not. cache: (Used during prediction) A dictionary with tensors containing results of previous attentions. The dictionary must have the items: {"k": tensor with shape [batch_size, i, heads, dim_per_head], "v": tensor with shape [batch_size, i, heads, dim_per_head]} where i is the current decoded length for non-padded decode, or max sequence length for padded decode. decode_loop_step: An integer, step number of the decoding loop. Used only for autoregressive inference on TPU. Returns: Attention layer output with shape [batch_size, length_query, hidden_size] """ # Linearly project the query, key and value using different learned # projections. Splitting heads is automatically done during the linear # projections --> [batch_size, length, num_heads, dim_per_head]. query = self.query_dense_layer(query_input) key = self.key_dense_layer(source_input) value = self.value_dense_layer(source_input) if cache is not None: # Combine cached keys and values with new keys and values. if decode_loop_step is not None: cache_k_shape = cache["k"].shape.as_list() indices = tf.reshape( tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype), [1, cache_k_shape[1], 1, 1]) key = cache["k"] + key * indices cache_v_shape = cache["v"].shape.as_list() indices = tf.reshape( tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype), [1, cache_v_shape[1], 1, 1]) value = cache["v"] + value * indices else: key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1) value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1) # Update cache cache["k"] = key cache["v"] = value # Scale query to prevent the dot product between query and key from growing # too large. depth = (self.hidden_size // self.num_heads) query *= depth**-0.5 # Calculate dot product attention logits = tf.einsum("BTNH,BFNH->BNFT", key, query) logits += bias # Note that softmax internally performs math operations using float32 # for numeric stability. When training with float16, we keep the input # and output in float16 for better performance. weights = tf.nn.softmax(logits, name="attention_weights") if training: weights = tf.nn.dropout(weights, rate=self.attention_dropout) attention_output = tf.einsum("BNFT,BTNH->BFNH", weights, value) # Run the outputs through another linear projection layer. Recombining heads # is automatically done --> [batch_size, length, hidden_size] attention_output = self.output_dense_layer(attention_output) return attention_output class SelfAttention(Attention): """Multiheaded self-attention layer.""" def call(self, query_input, bias, training, cache=None, decode_loop_step=None): return super(SelfAttention, self).call(query_input, query_input, bias, training, cache, decode_loop_step)