|
|
|
|
|
|
|
""" |
|
|
|
This code is adapted from AllenAI's Longformer: |
|
https://github.com/allenai/longformer/ |
|
|
|
""" |
|
from typing import List |
|
import math |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from .diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations |
|
from .sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv |
|
from .sliding_chunks import sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv |
|
from transformers.models.roberta.modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM |
|
|
|
|
|
class Longformer(RobertaModel): |
|
def __init__(self, config): |
|
super(Longformer, self).__init__(config) |
|
if config.attention_mode == 'n2': |
|
pass |
|
else: |
|
for i, layer in enumerate(self.encoder.layer): |
|
layer.attention.self = LongformerSelfAttention(config, layer_id=i) |
|
|
|
|
|
class LongformerForMaskedLM(RobertaForMaskedLM): |
|
def __init__(self, config): |
|
super(LongformerForMaskedLM, self).__init__(config) |
|
if config.attention_mode == 'n2': |
|
pass |
|
else: |
|
for i, layer in enumerate(self.roberta.encoder.layer): |
|
layer.attention.self = LongformerSelfAttention(config, layer_id=i) |
|
|
|
|
|
class LongformerConfig(RobertaConfig): |
|
def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, |
|
autoregressive: bool = False, attention_mode: str = 'sliding_chunks', **kwargs): |
|
""" |
|
Args: |
|
attention_window: list of attention window sizes of length = number of layers. |
|
window size = number of attention locations on each side. |
|
For an affective window size of 512, use `attention_window=[256]*num_layers` |
|
which is 256 on each side. |
|
attention_dilation: list of attention dilation of length = number of layers. |
|
attention dilation of `1` means no dilation. |
|
autoregressive: do autoregressive attention or have attention of both sides |
|
attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer |
|
selfattention, 'sliding_chunks' for another implementation of Longformer selfattention |
|
""" |
|
super().__init__(**kwargs) |
|
self.attention_window = attention_window |
|
self.attention_dilation = attention_dilation |
|
self.autoregressive = autoregressive |
|
self.attention_mode = attention_mode |
|
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2', 'sliding_chunks_no_overlap'] |
|
|
|
|
|
class LongformerSelfAttention(nn.Module): |
|
def __init__(self, config, layer_id): |
|
super(LongformerSelfAttention, self).__init__() |
|
if config.hidden_size % config.num_attention_heads != 0: |
|
raise ValueError( |
|
"The hidden size (%d) is not a multiple of the number of attention " |
|
"heads (%d)" % (config.hidden_size, config.num_attention_heads)) |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = int(config.hidden_size / config.num_attention_heads) |
|
self.embed_dim = config.hidden_size |
|
|
|
self.query = nn.Linear(config.hidden_size, self.embed_dim) |
|
self.key = nn.Linear(config.hidden_size, self.embed_dim) |
|
self.value = nn.Linear(config.hidden_size, self.embed_dim) |
|
|
|
self.query_global = nn.Linear(config.hidden_size, self.embed_dim) |
|
self.key_global = nn.Linear(config.hidden_size, self.embed_dim) |
|
self.value_global = nn.Linear(config.hidden_size, self.embed_dim) |
|
|
|
self.dropout = config.attention_probs_dropout_prob |
|
|
|
self.layer_id = layer_id |
|
self.attention_window = config.attention_window[self.layer_id] |
|
self.attention_dilation = config.attention_dilation[self.layer_id] |
|
self.attention_mode = config.attention_mode |
|
self.autoregressive = config.autoregressive |
|
assert self.attention_window > 0 |
|
assert self.attention_dilation > 0 |
|
assert self.attention_mode in ['tvm', 'sliding_chunks', 'sliding_chunks_no_overlap'] |
|
if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']: |
|
assert not self.autoregressive |
|
assert self.attention_dilation == 1 |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
head_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
output_attentions=False, |
|
): |
|
''' |
|
The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to |
|
-ve: no attention |
|
0: local attention |
|
+ve: global attention |
|
''' |
|
assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None" |
|
assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and should be None" |
|
|
|
if attention_mask is not None: |
|
key_padding_mask = attention_mask < 0 |
|
extra_attention_mask = attention_mask > 0 |
|
remove_from_windowed_attention_mask = attention_mask != 0 |
|
|
|
num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1) |
|
max_num_extra_indices_per_batch = num_extra_indices_per_batch.max() |
|
if max_num_extra_indices_per_batch <= 0: |
|
extra_attention_mask = None |
|
else: |
|
|
|
|
|
|
|
|
|
extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True) |
|
zero_to_max_range = torch.arange(0, max_num_extra_indices_per_batch, |
|
device=num_extra_indices_per_batch.device) |
|
|
|
selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1) |
|
|
|
selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True) |
|
|
|
selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True) |
|
else: |
|
remove_from_windowed_attention_mask = None |
|
extra_attention_mask = None |
|
key_padding_mask = None |
|
|
|
hidden_states = hidden_states.transpose(0, 1) |
|
seq_len, bsz, embed_dim = hidden_states.size() |
|
assert embed_dim == self.embed_dim |
|
q = self.query(hidden_states) |
|
k = self.key(hidden_states) |
|
v = self.value(hidden_states) |
|
q /= math.sqrt(self.head_dim) |
|
|
|
q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) |
|
k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) |
|
|
|
if self.attention_mode == 'tvm': |
|
q = q.float().contiguous() |
|
k = k.float().contiguous() |
|
attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False) |
|
elif self.attention_mode == "sliding_chunks": |
|
attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0) |
|
elif self.attention_mode == "sliding_chunks_no_overlap": |
|
attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0) |
|
else: |
|
raise False |
|
mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False) |
|
if remove_from_windowed_attention_mask is not None: |
|
|
|
|
|
remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(dim=-1) |
|
|
|
float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(remove_from_windowed_attention_mask, -10000.0) |
|
repeat_size = 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation) |
|
float_mask = float_mask.repeat(1, 1, repeat_size, 1) |
|
ones = float_mask.new_ones(size=float_mask.size()) |
|
|
|
if self.attention_mode == 'tvm': |
|
d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False) |
|
elif self.attention_mode == "sliding_chunks": |
|
d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) |
|
elif self.attention_mode == "sliding_chunks_no_overlap": |
|
d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) |
|
|
|
attn_weights += d_mask |
|
assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads] |
|
assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3] |
|
|
|
|
|
if extra_attention_mask is not None: |
|
selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) |
|
selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros] |
|
|
|
selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k)) |
|
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000 |
|
|
|
|
|
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) |
|
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) |
|
if key_padding_mask is not None: |
|
|
|
attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) |
|
attn_weights = attn_weights_float.type_as(attn_weights) |
|
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) |
|
v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) |
|
attn = 0 |
|
if extra_attention_mask is not None: |
|
selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch) |
|
selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) |
|
selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros] |
|
|
|
|
|
attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2) |
|
attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous() |
|
|
|
if self.attention_mode == 'tvm': |
|
v = v.float().contiguous() |
|
attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False) |
|
elif self.attention_mode == "sliding_chunks": |
|
attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window) |
|
elif self.attention_mode == "sliding_chunks_no_overlap": |
|
attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window) |
|
else: |
|
raise False |
|
|
|
attn = attn.type_as(hidden_states) |
|
assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim] |
|
attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous() |
|
|
|
|
|
|
|
if extra_attention_mask is not None: |
|
selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, bsz, embed_dim) |
|
selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[extra_attention_mask_nonzeros[::-1]] |
|
|
|
q = self.query_global(selected_hidden_states) |
|
k = self.key_global(hidden_states) |
|
v = self.value_global(hidden_states) |
|
q /= math.sqrt(self.head_dim) |
|
|
|
q = q.contiguous().view(max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
|
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
|
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) |
|
attn_weights = torch.bmm(q, k.transpose(1, 2)) |
|
assert list(attn_weights.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len] |
|
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) |
|
attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0 |
|
if key_padding_mask is not None: |
|
attn_weights = attn_weights.masked_fill( |
|
key_padding_mask.unsqueeze(1).unsqueeze(2), |
|
-10000.0, |
|
) |
|
attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len) |
|
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) |
|
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) |
|
selected_attn = torch.bmm(attn_probs, v) |
|
assert list(selected_attn.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim] |
|
|
|
selected_attn_4d = selected_attn.view(bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim) |
|
nonzero_selected_attn = selected_attn_4d[selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]] |
|
attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(len(selection_padding_mask_nonzeros[0]), -1).type_as(hidden_states) |
|
|
|
context_layer = attn.transpose(0, 1) |
|
if output_attentions: |
|
if extra_attention_mask is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) |
|
else: |
|
|
|
|
|
|
|
attn_weights = attn_weights.permute(0, 2, 1, 3) |
|
outputs = (context_layer, attn_weights) if output_attentions else (context_layer,) |
|
return outputs |
|
|