OFA-Image_Caption / fairseq /fairseq /modules /sparse_multihead_attention.py
JustinLin610
update
8437114
raw history blame
No virus
4.93 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from .multihead_attention import MultiheadAttention
class SparseMultiheadAttention(MultiheadAttention):
"""Sparse Multi-Headed Attention.
"Generating Long Sequences with Sparse Transformers". Implements
fixed factorized self attention, where l=stride and c=expressivity.
A(1) includes all words in the stride window and A(2) takes a summary of c
words from the end of each stride window.
If is_bidirectional=False, we do not include any words past the current word,
as in the paper.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
stride=32,
expressivity=8,
is_bidirectional=True,
):
super().__init__(
embed_dim,
num_heads,
kdim,
vdim,
dropout,
bias,
add_bias_kv,
add_zero_attn,
self_attention,
encoder_decoder_attention,
)
self.is_bidirectional = is_bidirectional
self.stride = stride
self.expressivity = expressivity
assert self.stride > 0 and self.stride >= self.expressivity
# Used for Ai(2) calculations - beginning of [l-c, l] range
def compute_checkpoint(self, word_index):
if word_index % self.stride == 0 and word_index != 0:
checkpoint_index = word_index - self.expressivity
else:
checkpoint_index = (
math.floor(word_index / self.stride) * self.stride
+ self.stride
- self.expressivity
)
return checkpoint_index
# Computes Ai(2)
def compute_subset_summaries(self, absolute_max):
checkpoint_index = self.compute_checkpoint(0)
subset_two = set()
while checkpoint_index <= absolute_max - 1:
summary = set(
range(
checkpoint_index,
min(checkpoint_index + self.expressivity + 1, absolute_max),
)
)
subset_two = subset_two.union(summary)
checkpoint_index = self.compute_checkpoint(checkpoint_index + self.stride)
return subset_two
# Sparse Transformer Fixed Attention Pattern: https://arxiv.org/pdf/1904.10509.pdf
def compute_fixed_attention_subset(self, word_index, tgt_len):
# +1s account for range function; [min, max) -> [min, max]
if not self.is_bidirectional:
absolute_max = word_index + 1
else:
absolute_max = tgt_len
# Subset 1 - whole window
rounded_index = (
math.floor((word_index + self.stride) / self.stride) * self.stride
)
if word_index % self.stride == 0 and word_index != 0:
subset_one = set(
range(word_index - self.stride, min(absolute_max, word_index + 1))
)
else:
subset_one = set(
range(
max(0, rounded_index - self.stride),
min(absolute_max, rounded_index + 1),
)
)
# Subset 2 - summary per window
# If bidirectional, subset 2 is the same for every index
subset_two = set()
if not self.is_bidirectional:
subset_two = self.compute_subset_summaries(absolute_max)
return subset_one.union(subset_two)
# Compute sparse mask - if bidirectional, can pre-compute and store
def buffered_sparse_mask(self, tensor, tgt_len, src_len):
assert tgt_len > self.stride
sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float("-inf"))
# If bidirectional, subset 2 is the same for every index
subset_summaries = set()
if self.is_bidirectional:
subset_summaries = self.compute_subset_summaries(tgt_len)
for i in range(tgt_len):
fixed_attention_subset = self.compute_fixed_attention_subset(i, tgt_len)
fixed_attention_subset = fixed_attention_subset.union(subset_summaries)
included_word_indices = torch.LongTensor(list(fixed_attention_subset))
sparse_mask[i].index_fill_(0, included_word_indices, 0)
return sparse_mask.type_as(tensor)
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
sparse_mask = self.buffered_sparse_mask(attn_weights, tgt_len, src_len)
sparse_mask = sparse_mask.unsqueeze(0).expand(
bsz * self.num_heads, tgt_len, src_len
)
attn_weights += sparse_mask