# coding=utf-8
# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Longformer model. """
import math
import warnings
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from .activations import ACT2FN, gelu
from .configuration_longformer import LongformerConfig
from .file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from .utils import logging
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LongformerConfig"
_TOKENIZER_FOR_DOC = "LongformerTokenizer"
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"allenai/longformer-base-4096",
"allenai/longformer-large-4096",
"allenai/longformer-large-4096-finetuned-triviaqa",
"allenai/longformer-base-4096-extra.pos.embd.only",
"allenai/longformer-large-4096-extra.pos.embd.only",
# See all Longformer models at https://huggingface.co/models?filter=longformer
]
def _get_question_end_index(input_ids, sep_token_id):
"""
Computes the index of the first occurance of `sep_token_id`.
"""
sep_token_indices = (input_ids == sep_token_id).nonzero()
batch_size = input_ids.shape[0]
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
assert (
sep_token_indices.shape[0] == 3 * batch_size
), f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this error."
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True):
"""
Computes global attention mask by putting attention on all tokens
before `sep_token_id` if `before_sep_token is True` else after
`sep_token_id`.
"""
question_end_index = _get_question_end_index(input_ids, sep_token_id)
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
# bool attention mask with True in locations of global attention
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
if before_sep_token is True:
attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8)
else:
# last token is separation token and should not be counted and in the middle are two separation tokens
attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * (
attention_mask.expand_as(input_ids) < input_ids.shape[-1]
).to(torch.uint8)
return attention_mask
# Copied from transformers.modeling_roberta.create_position_ids_from_input_ids
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
return incremental_indices.long() + padding_idx
class LongformerEmbeddings(nn.Module):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
# Copied from transformers.modeling_bert.BertEmbeddings.__init__
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
# End copy
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if position_ids is None:
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
# Copied from transformers.modeling_bert.BertEmbeddings.forward
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""We are provided embeddings directly. We cannot infer which are padded so just generate
sequential position ids.
:param torch.Tensor inputs_embeds:
:return torch.Tensor:
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape)
class LongformerSelfAttention(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_heads = config.num_attention_heads
self.head_dim = int(config.hidden_size / config.num_attention_heads)
self.embed_dim = config.hidden_size
self.query = nn.Linear(config.hidden_size, self.embed_dim)
self.key = nn.Linear(config.hidden_size, self.embed_dim)
self.value = nn.Linear(config.hidden_size, self.embed_dim)
# separate projection layers for tokens with global attention
self.query_global = nn.Linear(config.hidden_size, self.embed_dim)
self.key_global = nn.Linear(config.hidden_size, self.embed_dim)
self.value_global = nn.Linear(config.hidden_size, self.embed_dim)
self.dropout = config.attention_probs_dropout_prob
self.layer_id = layer_id
attention_window = config.attention_window[self.layer_id]
assert (
attention_window % 2 == 0
), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}"
assert (
attention_window > 0
), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}"
self.one_sided_attn_window_size = attention_window // 2
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
):
"""
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`.
Padding to `attention_window` happens in LongformerModel.forward to avoid redoing the padding on each layer.
The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to
-ve: no attention
0: local attention
+ve: global attention
"""
attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
# is index masked or global attention
is_index_masked = attention_mask < 0
is_index_global_attn = attention_mask > 0
is_global_attn = is_index_global_attn.flatten().any().item()
hidden_states = hidden_states.transpose(0, 1)
# project hidden states
query_vectors = self.query(hidden_states)
key_vectors = self.key(hidden_states)
value_vectors = self.value(hidden_states)
seq_len, batch_size, embed_dim = hidden_states.size()
assert (
embed_dim == self.embed_dim
), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"
# normalize query
query_vectors /= math.sqrt(self.head_dim)
query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul(
query_vectors, key_vectors, self.one_sided_attn_window_size
)
# values to pad for attention probs
remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
# cast to fp32/fp16 then replace 1's with -inf
float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
remove_from_windowed_attention_mask, -10000.0
)
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
)
# pad local attention probs
attn_scores += diagonal_mask
assert list(attn_scores.size()) == [
batch_size,
seq_len,
self.num_heads,
self.one_sided_attn_window_size * 2 + 1,
], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
# compute local attention probs from global attention keys and contact over window dim
if is_global_attn:
# compute global attn indices required through out forward fn
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn)
# calculate global attn probs from global key
global_key_attn_scores = self._concat_with_global_key_attn_probs(
query_vectors=query_vectors,
key_vectors=key_vectors,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
)
# concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1)
# free memory
del global_key_attn_scores
attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
attn_probs = attn_probs_fp32.type_as(attn_scores)
# free memory
del attn_probs_fp32
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
# apply dropout
attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training)
value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# compute local attention output with global attention value and add
if is_global_attn:
# compute sum of global and local attn
attn_output = self._compute_attn_output_with_global_indices(
value_vectors=value_vectors,
attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
)
else:
# compute local attn only
attn_output = self._sliding_chunks_matmul_attn_probs_value(
attn_probs, value_vectors, self.one_sided_attn_window_size
)
assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
# compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation
if is_global_attn:
global_attn_output = self._compute_global_attn_output_from_hidden(
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked,
)
# get only non zero global attn output
nonzero_global_attn_output = global_attn_output[
is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]
]
# overwrite values with global attention
attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view(
len(is_local_index_global_attn_nonzero[0]), -1
)
attn_output = attn_output.transpose(0, 1)
if output_attentions:
if is_global_attn:
# With global attention, return global attention probabilities only
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length
# which is the attention weights from tokens with global attention to all tokens
# It doesn't not return local attention
# In case of variable number of global attantion in the rows of a batch,
# attn_probs are padded with -10000.0 attention scores
attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
else:
# without global attention, return local attention probabilities
# batch_size x num_heads x sequence_length x window_size
# which is the attention weights of every token attending to its neighbours
attn_probs = attn_probs.permute(0, 2, 1, 3)
outputs = (attn_output, attn_probs) if output_attentions else (attn_output,)
return outputs
@staticmethod
def _pad_and_transpose_last_two_dims(hidden_states_padded, padding):
"""pads rows and then flips rows and columns"""
hidden_states_padded = F.pad(
hidden_states_padded, padding
) # padding value is not important because it will be overwritten
hidden_states_padded = hidden_states_padded.view(
*hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2)
)
return hidden_states_padded
@staticmethod
def _pad_and_diagonalize(chunked_hidden_states):
"""shift every row 1 step right, converting columns into diagonals.
Example:
chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492,
-1.8348, 0.7672, 0.2986, 0.0285,
-0.7584, 0.4206, -0.0405, 0.1599,
2.0514, -1.1600, 0.5372, 0.2629 ]
window_overlap = num_rows = 4
(pad & diagonilize) =>
[ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000
0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
"""
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
chunked_hidden_states = F.pad(
chunked_hidden_states, (0, window_overlap + 1)
) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
chunked_hidden_states = chunked_hidden_states.view(
total_num_heads, num_chunks, -1
) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap
chunked_hidden_states = chunked_hidden_states[
:, :, :-window_overlap
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
chunked_hidden_states = chunked_hidden_states.view(
total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states
@staticmethod
def _chunk(hidden_states, window_overlap):
"""convert into overlapping chunkings. Chunk size = 2w, overlap size = w"""
# non-overlapping chunks of size = 2w
hidden_states = hidden_states.view(
hidden_states.size(0),
hidden_states.size(1) // (window_overlap * 2),
window_overlap * 2,
hidden_states.size(2),
)
# use `as_strided` to make the chunks overlap with an overlap size = window_overlap
chunk_size = list(hidden_states.size())
chunk_size[1] = chunk_size[1] * 2 - 1
chunk_stride = list(hidden_states.stride())
chunk_stride[1] = chunk_stride[1] // 2
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
@staticmethod
def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
beginning_mask = beginning_mask_2d[None, :, None, :]
ending_mask = beginning_mask.flip(dims=(1, 3))
beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
beginning_mask = beginning_mask.expand(beginning_input.size())
beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
ending_mask = ending_mask.expand(ending_input.size())
ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int):
"""Matrix multiplication of query and key tensors using with a sliding window attention pattern.
This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer)
with an overlap of size window_overlap"""
batch_size, seq_len, num_heads, head_dim = query.size()
assert (
seq_len % (window_overlap * 2) == 0
), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
assert query.size() == key.size()
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
chunked_query = self._chunk(query, window_overlap)
chunked_key = self._chunk(key, window_overlap)
# matrix multipication
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap
chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (chunked_query, chunked_key)) # multiply
# convert diagonals into columns
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(
chunked_attention_scores, padding=(0, 0, 0, 1)
)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
# has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to
# window_overlap previous words). The following column is attention score from each word to itself, then
# followed by window_overlap columns for the upper triangle.
diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1)
)
# copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions
# - copying the main diagonal and the upper triangle
diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
:, :, :window_overlap, : window_overlap + 1
]
diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
:, -1, window_overlap:, : window_overlap + 1
]
# - copying the lower triangle
diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
:, :, -(window_overlap + 1) : -1, window_overlap + 1 :
]
diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
:, 0, : window_overlap - 1, 1 - window_overlap :
]
# separate batch_size and num_heads dimensions again
diagonal_attention_scores = diagonal_attention_scores.view(
batch_size, num_heads, seq_len, 2 * window_overlap + 1
).transpose(2, 1)
self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
return diagonal_attention_scores
def _sliding_chunks_matmul_attn_probs_value(
self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int
):
"""Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors.
Returned tensor will be of the same shape as `attn_probs`"""
batch_size, seq_len, num_heads, head_dim = value.size()
assert seq_len % (window_overlap * 2) == 0
assert attn_probs.size()[:3] == value.size()[:3]
assert attn_probs.size(3) == 2 * window_overlap + 1
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1
)
# group batch_size and num_heads dimensions into one
value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
padded_value = F.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim)
chunked_value_stride = padded_value.stride()
chunked_value_stride = (
chunked_value_stride[0],
window_overlap * chunked_value_stride[1],
chunked_value_stride[1],
chunked_value_stride[2],
)
chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
@staticmethod
def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """
# helper variable
num_global_attn_indices = is_index_global_attn.long().sum(dim=1)
# max number of global attn indices in batch
max_num_global_attn_indices = num_global_attn_indices.max()
# indices of global attn
is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True)
# helper variable
is_local_index_global_attn = torch.arange(
max_num_global_attn_indices, device=is_index_global_attn.device
) < num_global_attn_indices.unsqueeze(dim=-1)
# location of the non-padding values within global attention indices
is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True)
# location of the padding values within global attention indices
is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True)
return (
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
)
def _concat_with_global_key_attn_probs(
self,
key_vectors,
query_vectors,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
):
batch_size = key_vectors.shape[0]
# create only global key vectors
key_vectors_only_global = key_vectors.new_zeros(
batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
)
key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero]
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global))
attn_probs_from_global_key[
is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1]
] = -10000.0
return attn_probs_from_global_key
def _compute_attn_output_with_global_indices(
self,
value_vectors,
attn_probs,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
):
batch_size = attn_probs.shape[0]
# cut local attn probs to global only
attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices)
# get value vectors for global only
value_vectors_only_global = value_vectors.new_zeros(
batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
)
value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero]
# use `matmul` because `einsum` crashes sometimes with fp16
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
# compute attn output only global
attn_output_only_global = torch.matmul(
attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2)
).transpose(1, 2)
# reshape attn probs
attn_probs_without_global = attn_probs.narrow(
-1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices
).contiguous()
# compute attn output with global
attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
attn_probs_without_global, value_vectors, self.one_sided_attn_window_size
)
return attn_output_only_global + attn_output_without_global
def _compute_global_attn_output_from_hidden(
self,
hidden_states,
max_num_global_attn_indices,
is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
is_index_masked,
):
seq_len, batch_size = hidden_states.shape[:2]
# prepare global hidden states
global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim)
global_attn_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[
is_index_global_attn_nonzero[::-1]
]
# global key, query, value
global_query_vectors_only_global = self.query_global(global_attn_hidden_states)
global_key_vectors = self.key_global(hidden_states)
global_value_vectors = self.value_global(hidden_states)
# normalize
global_query_vectors_only_global /= math.sqrt(self.head_dim)
# reshape
global_query_vectors_only_global = (
global_query_vectors_only_global.contiguous()
.view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim)
.transpose(0, 1)
) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim)
global_key_vectors = (
global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
) # batch_size * self.num_heads, seq_len, head_dim)
global_value_vectors = (
global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
) # batch_size * self.num_heads, seq_len, head_dim)
# compute attn scores
global_attn_scores = torch.bmm(global_query_vectors_only_global, global_key_vectors.transpose(1, 2))
assert list(global_attn_scores.size()) == [
batch_size * self.num_heads,
max_num_global_attn_indices,
seq_len,
], f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {global_attn_scores.size()}."
global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
global_attn_scores[
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], :
] = -10000.0
global_attn_scores = global_attn_scores.masked_fill(
is_index_masked[:, None, None, :],
-10000.0,
)
global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
# compute global attn probs
global_attn_probs_float = F.softmax(
global_attn_scores, dim=-1, dtype=torch.float32
) # use fp32 for numerical stability
global_attn_probs = F.dropout(
global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
)
# global attn output
global_attn_output = torch.bmm(global_attn_probs, global_value_vectors)
assert list(global_attn_output.size()) == [
batch_size * self.num_heads,
max_num_global_attn_indices,
self.head_dim,
], f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {global_attn_output.size()}."
global_attn_output = global_attn_output.view(
batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim
)
return global_attn_output
# Copied from transformers.modeling_bert.BertSelfOutput
class LongformerSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class LongformerAttention(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.self = LongformerSelfAttention(config, layer_id)
self.output = LongformerSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
output_attentions,
)
attn_output = self.output(self_outputs[0], hidden_states)
outputs = (attn_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.modeling_bert.BertIntermediate
class LongformerIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.modeling_bert.BertOutput
class LongformerOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class LongformerLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.attention = LongformerAttention(config, layer_id)
self.intermediate = LongformerIntermediate(config)
self.output = LongformerOutput(config)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
):
self_attn_outputs = self.attention(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
attn_output = self_attn_outputs[0]
outputs = self_attn_outputs[1:] # add self attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attn_output
)
outputs = (layer_output,) + outputs
return outputs
def ff_chunk(self, attn_output):
intermediate_output = self.intermediate(attn_output)
layer_output = self.output(intermediate_output, attn_output)
return layer_output
class LongformerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)])
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
# Copied from transformers.modeling_bert.BertPooler
class LongformerPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
# Copied from transformers.modeling_roberta.RobertaLMHead with Roberta->Longformer
class LongformerLMHead(nn.Module):
"""Longformer Head for masked language modeling."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = self.decoder(x)
return x
class LongformerPreTrainedModel(PreTrainedModel):
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained
models.
"""
config_class = LongformerConfig
base_model_prefix = "longformer"
authorized_missing_keys = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
LONGFORMER_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
pruning heads etc.)
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
Parameters:
config (:class:`~transformers.LongformerConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
"""
LONGFORMER_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.LongformerTokenizer`.
See :meth:`transformers.PreTrainedTokenizer.encode` and
:meth:`transformers.PreTrainedTokenizer.__call__` for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **maked**.
`What are attention masks? <../glossary.html#attention-mask>`__
global_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Mask to decide the attention given on each token, local attention or global attenion.
Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for
task-specific finetuning because it makes the model more flexible at representing the task. For example,
for classification, the <s> token should be given global attention. For QA, all question tokens should also have
global attention. Please refer to the `Longformer paper <https://arxiv.org/abs/2004.05150>`__ for more details.
Mask values selected in ``[0, 1]``:
- 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``:
- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`_
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
class LongformerClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, hidden_states, **kwargs):
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
output = self.out_proj(hidden_states)
return output