# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math import torch from .multihead_attention import MultiheadAttention class SparseMultiheadAttention(MultiheadAttention): """Sparse Multi-Headed Attention. "Generating Long Sequences with Sparse Transformers". Implements fixed factorized self attention, where l=stride and c=expressivity. A(1) includes all words in the stride window and A(2) takes a summary of c words from the end of each stride window. If is_bidirectional=False, we do not include any words past the current word, as in the paper. """ def __init__( self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, stride=32, expressivity=8, is_bidirectional=True, ): super().__init__( embed_dim, num_heads, kdim, vdim, dropout, bias, add_bias_kv, add_zero_attn, self_attention, encoder_decoder_attention, ) self.is_bidirectional = is_bidirectional self.stride = stride self.expressivity = expressivity assert self.stride > 0 and self.stride >= self.expressivity # Used for Ai(2) calculations - beginning of [l-c, l] range def compute_checkpoint(self, word_index): if word_index % self.stride == 0 and word_index != 0: checkpoint_index = word_index - self.expressivity else: checkpoint_index = ( math.floor(word_index / self.stride) * self.stride + self.stride - self.expressivity ) return checkpoint_index # Computes Ai(2) def compute_subset_summaries(self, absolute_max): checkpoint_index = self.compute_checkpoint(0) subset_two = set() while checkpoint_index <= absolute_max - 1: summary = set( range( checkpoint_index, min(checkpoint_index + self.expressivity + 1, absolute_max), ) ) subset_two = subset_two.union(summary) checkpoint_index = self.compute_checkpoint(checkpoint_index + self.stride) return subset_two # Sparse Transformer Fixed Attention Pattern: https://arxiv.org/pdf/1904.10509.pdf def compute_fixed_attention_subset(self, word_index, tgt_len): # +1s account for range function; [min, max) -> [min, max] if not self.is_bidirectional: absolute_max = word_index + 1 else: absolute_max = tgt_len # Subset 1 - whole window rounded_index = ( math.floor((word_index + self.stride) / self.stride) * self.stride ) if word_index % self.stride == 0 and word_index != 0: subset_one = set( range(word_index - self.stride, min(absolute_max, word_index + 1)) ) else: subset_one = set( range( max(0, rounded_index - self.stride), min(absolute_max, rounded_index + 1), ) ) # Subset 2 - summary per window # If bidirectional, subset 2 is the same for every index subset_two = set() if not self.is_bidirectional: subset_two = self.compute_subset_summaries(absolute_max) return subset_one.union(subset_two) # Compute sparse mask - if bidirectional, can pre-compute and store def buffered_sparse_mask(self, tensor, tgt_len, src_len): assert tgt_len > self.stride sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float("-inf")) # If bidirectional, subset 2 is the same for every index subset_summaries = set() if self.is_bidirectional: subset_summaries = self.compute_subset_summaries(tgt_len) for i in range(tgt_len): fixed_attention_subset = self.compute_fixed_attention_subset(i, tgt_len) fixed_attention_subset = fixed_attention_subset.union(subset_summaries) included_word_indices = torch.LongTensor(list(fixed_attention_subset)) sparse_mask[i].index_fill_(0, included_word_indices, 0) return sparse_mask.type_as(tensor) def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): sparse_mask = self.buffered_sparse_mask(attn_weights, tgt_len, src_len) sparse_mask = sparse_mask.unsqueeze(0).expand( bsz * self.num_heads, tgt_len, src_len ) attn_weights += sparse_mask