# coding=utf-8
# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow Longformer model. """
from dataclasses import dataclass
from typing import Optional, Tuple
import tensorflow as tf
from transformers.activations_tf import get_tf_activation
from .configuration_longformer import LongformerConfig
from .file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from .modeling_tf_outputs import TFMaskedLMOutput, TFQuestionAnsweringModelOutput
from .modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
get_initializer,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LongformerConfig"
_TOKENIZER_FOR_DOC = "LongformerTokenizer"
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"allenai/longformer-base-4096",
"allenai/longformer-large-4096",
"allenai/longformer-large-4096-finetuned-triviaqa",
"allenai/longformer-base-4096-extra.pos.embd.only",
"allenai/longformer-large-4096-extra.pos.embd.only",
# See all Longformer models at https://huggingface.co/models?filter=longformer
]
def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_sep_token=True):
"""
Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is
True` else after `sep_token_id`.
"""
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1]
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1
# bool attention mask with True in locations of global attention
attention_mask = tf.range(input_ids_shape[1])
if before_sep_token is True:
attention_mask = tf.cast(
tf.broadcast_to(attention_mask, input_ids_shape) < tf.broadcast_to(question_end_index, input_ids_shape),
tf.dtypes.int32,
)
else:
# last token is separation token and should not be counted and in the middle are two separation tokens
attention_mask = (
tf.cast(
tf.broadcast_to(attention_mask, input_ids_shape)
> tf.broadcast_to(question_end_index + 1, input_ids_shape),
tf.dtypes.int32,
)
* tf.cast(tf.broadcast_to(attention_mask, input_ids_shape) < input_ids_shape[-1], tf.dtypes.int32)
)
return attention_mask
# Copied from transformers.modeling_tf_roberta.TFRobertaLMHead
class TFLongformerLMHead(tf.keras.layers.Layer):
"""Roberta Head for masked language modeling."""
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.act = get_tf_activation("gelu")
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = input_embeddings
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def call(self, features):
x = self.dense(features)
x = self.act(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = self.decoder(x, mode="linear") + self.bias
return x
# Copied from transformers.modeling_tf_roberta.TFRobertaEmbeddings
class TFLongformerEmbeddings(tf.keras.layers.Layer):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.padding_idx = 1
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
self.initializer_range = config.initializer_range
self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings,
config.hidden_size,
embeddings_initializer=get_initializer(self.initializer_range),
name="position_embeddings",
)
self.token_type_embeddings = tf.keras.layers.Embedding(
config.type_vocab_size,
config.hidden_size,
embeddings_initializer=get_initializer(self.initializer_range),
name="token_type_embeddings",
)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def build(self, input_shape):
"""Build shared word embedding layer """
with tf.name_scope("word_embeddings"):
# Create and initialize weights. The random normal initializer was chosen
# arbitrarily, and works well.
self.word_embeddings = self.add_weight(
"weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def create_position_ids_from_input_ids(self, x):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
symbols are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: tf.Tensor
Returns: tf.Tensor
"""
mask = tf.cast(tf.math.not_equal(x, self.padding_idx), dtype=tf.int32)
incremental_indices = tf.math.cumsum(mask, axis=1) * mask
return incremental_indices + self.padding_idx
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
Args:
inputs_embeds: tf.Tensor
Returns: tf.Tensor
"""
seq_length = shape_list(inputs_embeds)[1]
position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :]
return position_ids
def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
mode="embedding",
training=False,
):
"""
Get token embeddings of inputs.
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: If mode == "embedding", output embedding tensor, float32 with shape [batch_size, length,
embedding_size]; if mode == "linear", output linear tensor, float32 with shape [batch_size, length,
vocab_size].
Raises:
ValueError: if mode is not valid.
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if mode == "embedding":
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
elif mode == "linear":
return self._linear(input_ids)
else:
raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
"""Applies embedding based on inputs tensor."""
assert not (input_ids is None and inputs_embeds is None)
if position_ids is None:
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
if input_ids is not None:
input_shape = shape_list(input_ids)
else:
input_shape = shape_list(inputs_embeds)[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings, training=training)
return embeddings
def _linear(self, inputs):
"""
Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
Returns:
float32 tensor with shape [batch_size, length, vocab_size].
"""
batch_size = shape_list(inputs)[0]
length = shape_list(inputs)[1]
x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
return tf.reshape(logits, [batch_size, length, self.vocab_size])
# Copied from transformers.modeling_tf_bert.TFBertIntermediate
class TFLongformerIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertOutput
class TFLongformerOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.modeling_tf_bert.TFBertPooler
class TFLongformerPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
def call(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
return pooled_output
# Copied from transformers.modeling_tf_bert.TFBertSelfOutput
class TFLongformerSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class TFLongformerSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, layer_id, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_heads = config.num_attention_heads
self.head_dim = int(config.hidden_size / config.num_attention_heads)
self.embed_dim = config.hidden_size
self.query = tf.keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="query",
)
self.key = tf.keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="key",
)
self.value = tf.keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="value",
)
# separate projection layers for tokens with global attention
self.query_global = tf.keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="query_global",
)
self.key_global = tf.keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="key_global",
)
self.value_global = tf.keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="value_global",
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.global_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.layer_id = layer_id
attention_window = config.attention_window[self.layer_id]
assert (
attention_window % 2 == 0
), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}"
assert (
attention_window > 0
), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}"
self.one_sided_attn_window_size = attention_window // 2
def call(
self,
inputs,
training=False,
):
"""
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. Padding to
`attention_window` happens in LongformerModel.forward to avoid redoing the padding on each layer.
The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to -ve: no attention
0: local attention
+ve: global attention
"""
# retrieve input args
(
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
) = inputs
# project hidden states
query_vectors = self.query(hidden_states)
key_vectors = self.key(hidden_states)
value_vectors = self.value(hidden_states)
batch_size, seq_len, embed_dim = shape_list(hidden_states)
tf.debugging.assert_equal(
embed_dim,
self.embed_dim,
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
)
# normalize query
query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul(
query_vectors, key_vectors, self.one_sided_attn_window_size
)
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
tf.ones(shape_list(attention_mask), dtype=tf.float32),
attention_mask,
self.one_sided_attn_window_size,
)
# pad local attention probs
attn_scores += diagonal_mask
tf.debugging.assert_equal(
shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}",
)
# compute global attn indices required through out forward fn
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn)
# this function is only relevant for global attention
attn_scores = tf.cond(
is_global_attn,
lambda: self._concat_with_global_key_attn_probs(
attn_scores=attn_scores,
query_vectors=query_vectors,
key_vectors=key_vectors,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
),
lambda: attn_scores,
)
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = tf.where(
tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)),
0.0,
attn_probs,
)
# apply dropout
attn_probs = self.dropout(attn_probs, training=training)
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
# if global attention, compute sum of global and local attn
attn_output = tf.cond(
is_global_attn,
lambda: self._compute_attn_output_with_global_indices(
value_vectors=value_vectors,
attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
),
lambda: self._sliding_chunks_matmul_attn_probs_value(
attn_probs, value_vectors, self.one_sided_attn_window_size
),
)
tf.debugging.assert_equal(
shape_list(attn_output),
[batch_size, seq_len, self.num_heads, self.head_dim],
message="Unexpected size",
)
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
# compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation
attn_output, global_attn_probs = tf.cond(
is_global_attn,
lambda: self._compute_global_attn_output_from_hidden(
attn_output=attn_output,
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked,
training=training,
),
lambda: (attn_output, tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))),
)
# make sure that local attention probabilities are set to 0 for indices of global attn
attn_probs = tf.where(
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)),
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32),
attn_probs,
)
outputs = (attn_output, attn_probs, global_attn_probs)
return outputs
def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
"""
Matrix multiplication of query and key tensors using with a sliding window attention pattern. This
implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an
overlap of size window_overlap
"""
batch_size, seq_len, num_heads, head_dim = shape_list(query)
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
)
tf.debugging.assert_equal(
shape_list(query),
shape_list(key),
message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}",
)
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
query = tf.reshape(
tf.transpose(query, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
chunked_query = self._chunk(query, window_overlap)
chunked_key = self._chunk(key, window_overlap)
# matrix multiplication
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply
# convert diagonals into columns
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
# has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to
# window_overlap previous words). The following column is attention score from each word to itself, then
# followed by window_overlap columns for the upper triangle.
# copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions
# - copying the main diagonal and the upper triangle
# TODO: This code is most likely not very efficient and should be improved
diagonal_attn_scores_up_triang = tf.concat(
[
diagonal_chunked_attention_scores[:, :, :window_overlap, : window_overlap + 1],
diagonal_chunked_attention_scores[:, -1:, window_overlap:, : window_overlap + 1],
],
axis=1,
)
# - copying the lower triangle
diagonal_attn_scores_low_triang = tf.concat(
[
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)),
diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :],
],
axis=1,
)
diagonal_attn_scores_first_chunk = tf.concat(
[
tf.roll(
diagonal_chunked_attention_scores,
shift=[1, window_overlap],
axis=[2, 3],
)[:, :, :window_overlap, :window_overlap],
tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)),
],
axis=1,
)
first_chunk_mask = (
tf.broadcast_to(
tf.range(chunks_count + 1)[None, :, None, None],
shape=(
batch_size * num_heads,
chunks_count + 1,
window_overlap,
window_overlap,
),
)
< 1
)
diagonal_attn_scores_low_triang = tf.where(
first_chunk_mask,
diagonal_attn_scores_first_chunk,
diagonal_attn_scores_low_triang,
)
# merging upper and lower triangle
diagonal_attention_scores = tf.concat(
[diagonal_attn_scores_low_triang, diagonal_attn_scores_up_triang], axis=-1
)
# separate batch_size and num_heads dimensions again
diagonal_attention_scores = tf.transpose(
tf.reshape(
diagonal_attention_scores,
(batch_size, num_heads, seq_len, 2 * window_overlap + 1),
),
(0, 2, 1, 3),
)
diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
return diagonal_attention_scores
@staticmethod
def _mask_invalid_locations(input_tensor, window_overlap):
# create correct upper triangle bool mask
mask_2d_upper = tf.reverse(
tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
axis=[0],
)
# pad to full matrix
padding = tf.constant(
[[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]]
)
# create lower mask
mask_2d = tf.pad(mask_2d_upper, padding)
# combine with upper mask
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
# broadcast to full matrix
mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor))
# inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32)
# mask
input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor)
return input_tensor
def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap):
"""
Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the
same shape as `attn_probs`
"""
batch_size, seq_len, num_heads, head_dim = shape_list(value)
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message="Seq_len has to be multiple of 2 * window_overlap",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[:3],
shape_list(value)[:3],
message="value and attn_probs must have same dims (except head_dim)",
)
tf.debugging.assert_equal(
shape_list(attn_probs)[3],
2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1",
)
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
chunked_attn_probs = tf.reshape(
tf.transpose(attn_probs, (0, 2, 1, 3)),
(
batch_size * num_heads,
seq_len // window_overlap,
window_overlap,
2 * window_overlap + 1,
),
)
# group batch_size and num_heads dimensions into one
value = tf.reshape(
tf.transpose(value, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
paddings = tf.constant([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32)
padded_value = tf.pad(value, paddings, constant_values=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
frame_size = 3 * window_overlap * head_dim
frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
chunked_value = tf.signal.frame(
tf.reshape(padded_value, (batch_size * num_heads, -1)),
frame_size,
frame_hop_size,
)
chunked_value = tf.reshape(
chunked_value,
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
)
tf.debugging.assert_equal(
shape_list(chunked_value),
[batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
message="Chunked value has the wrong shape",
)
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
context = tf.transpose(
tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
(0, 2, 1, 3),
)
return context
@staticmethod
def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings):
"""pads rows and then flips rows and columns"""
hidden_states_padded = tf.pad(
hidden_states_padded, paddings
) # padding value is not important because it will be overwritten
batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
return hidden_states_padded
@staticmethod
def _pad_and_diagonalize(chunked_hidden_states):
"""
shift every row 1 step right, converting columns into diagonals.
Example::
chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492,
-1.8348, 0.7672, 0.2986, 0.0285,
-0.7584, 0.4206, -0.0405, 0.1599,
2.0514, -1.1600, 0.5372, 0.2629 ]
window_overlap = num_rows = 4
(pad & diagonalize) =>
[ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000
0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
"""
total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states)
paddings = tf.constant([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
chunked_hidden_states = tf.pad(
chunked_hidden_states, paddings
) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
chunked_hidden_states = tf.reshape(
chunked_hidden_states, (total_num_heads, num_chunks, -1)
) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap
chunked_hidden_states = chunked_hidden_states[
:, :, :-window_overlap
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
chunked_hidden_states = tf.reshape(
chunked_hidden_states,
(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states
@staticmethod
def _chunk(hidden_states, window_overlap):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
batch_size, seq_length, hidden_dim = shape_list(hidden_states)
num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1
# define frame size and frame stride (similar to convolution)
frame_hop_size = window_overlap * hidden_dim
frame_size = 2 * frame_hop_size
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim))
# chunk with overlap
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
tf.debugging.assert_equal(
shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.",
)
chunked_hidden_states = tf.reshape(
chunked_hidden_states,
(batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
)
return chunked_hidden_states
@staticmethod
def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """
# helper variable
num_global_attn_indices = tf.reduce_sum(tf.cast(is_index_global_attn, dtype=tf.dtypes.int32), axis=1)
# max number of global attn indices in batch
max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices)
# indices of global attn
is_index_global_attn_nonzero = tf.where(is_index_global_attn)
# helper variable
is_local_index_global_attn = tf.range(max_num_global_attn_indices) < tf.expand_dims(
num_global_attn_indices, axis=-1
)
# location of the non-padding values within global attention indices
is_local_index_global_attn_nonzero = tf.where(is_local_index_global_attn)
# location of the padding values within global attention indices
is_local_index_no_global_attn_nonzero = tf.where(tf.math.logical_not(is_local_index_global_attn))
return (
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
)
def _concat_with_global_key_attn_probs(
self,
attn_scores,
key_vectors,
query_vectors,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
):
batch_size = shape_list(key_vectors)[0]
# select global key vectors
global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero)
# create only global key vectors
key_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_key_vectors,
shape=(
batch_size,
max_num_global_attn_indices,
self.num_heads,
self.head_dim,
),
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global)
# (batch_size, max_num_global_attn_indices, seq_len, num_heads)
attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2))
mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
shape_list(attn_probs_from_global_key_trans)[-2:]
)
mask = tf.ones(mask_shape) * -10000.0
# scatter mask
attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
attn_probs_from_global_key_trans,
is_local_index_no_global_attn_nonzero,
mask,
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1))
# concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1)
return attn_scores
def _compute_attn_output_with_global_indices(
self,
value_vectors,
attn_probs,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
):
batch_size = shape_list(attn_probs)[0]
# cut local attn probs to global only
attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices]
# select global value vectors
global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero)
# create only global value vectors
value_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_value_vectors,
shape=(
batch_size,
max_num_global_attn_indices,
self.num_heads,
self.head_dim,
),
)
# compute attn output only global
attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global)
# reshape attn probs
attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:]
# compute attn output with global
attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
attn_probs_without_global, value_vectors, self.one_sided_attn_window_size
)
return attn_output_only_global + attn_output_without_global
def _compute_global_attn_output_from_hidden(
self,
attn_output,
hidden_states,
max_num_global_attn_indices,
is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
is_index_masked,
training,
):
batch_size, seq_len = shape_list(hidden_states)[:2]
# prepare global hidden states
global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero)
global_attn_hidden_states = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_attn_hidden_states,
shape=(batch_size, max_num_global_attn_indices, self.embed_dim),
)
# global key, query, value
global_query_vectors_only_global = self.query_global(global_attn_hidden_states)
global_key_vectors = self.key_global(hidden_states)
global_value_vectors = self.value_global(hidden_states)
# normalize
global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
# compute attn scores
global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
tf.debugging.assert_equal(
shape_list(global_attn_scores),
[batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.",
)
global_attn_scores = tf.reshape(
global_attn_scores,
(batch_size, self.num_heads, max_num_global_attn_indices, seq_len),
)
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
shape_list(global_attn_scores_trans)[-2:]
)
global_attn_mask = tf.ones(mask_shape) * -10000.0
# scatter mask
global_attn_scores_trans = tf.tensor_scatter_nd_update(
global_attn_scores_trans,
is_local_index_no_global_attn_nonzero,
global_attn_mask,
)
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
# mask global attn scores
attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape(
global_attn_scores,
(batch_size * self.num_heads, max_num_global_attn_indices, seq_len),
)
# compute global attn probs
global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1)
# dropout
global_attn_probs = self.global_dropout(global_attn_probs_float, training=training)
# global attn output
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
tf.debugging.assert_equal(
shape_list(global_attn_output),
[batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {shape_list(global_attn_output)}.",
)
global_attn_output = tf.reshape(
global_attn_output,
(batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim),
)
# get only non zero global attn output
nonzero_global_attn_output = tf.gather_nd(
tf.transpose(global_attn_output, (0, 2, 1, 3)),
is_local_index_global_attn_nonzero,
)
nonzero_global_attn_output = tf.reshape(
nonzero_global_attn_output,
(shape_list(is_local_index_global_attn_nonzero)[0], -1),
)
# overwrite values with global attention
attn_output = tf.tensor_scatter_nd_update(
attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output
)
global_attn_probs = tf.reshape(
global_attn_probs, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
)
return attn_output, global_attn_probs
def reshape_and_transpose(self, vector, batch_size):
return tf.reshape(
tf.transpose(
tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)),
(0, 2, 1, 3),
),
(batch_size * self.num_heads, -1, self.head_dim),
)
class TFLongformerAttention(tf.keras.layers.Layer):
def __init__(self, config, layer_id=0, **kwargs):
super().__init__(**kwargs)
self.self_attention = TFLongformerSelfAttention(config, layer_id, name="self")
self.dense_output = TFLongformerSelfOutput(config, name="output")
def prune_heads(self, heads):
raise NotImplementedError
def call(self, inputs, training=False):
(
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
) = inputs
self_outputs = self.self_attention(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn],
training=training,
)
attention_output = self.dense_output(self_outputs[0], hidden_states, training=training)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class TFLongformerLayer(tf.keras.layers.Layer):
def __init__(self, config, layer_id=0, **kwargs):
super().__init__(**kwargs)
self.attention = TFLongformerAttention(config, layer_id, name="attention")
self.intermediate = TFLongformerIntermediate(config, name="intermediate")
self.longformer_output = TFLongformerOutput(config, name="output")
def call(self, inputs, training=False):
(
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
) = inputs
attention_outputs = self.attention(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn],
training=training,
)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.longformer_output(intermediate_output, attention_output, training=training)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
class TFLongformerEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.layer = [
TFLongformerLayer(config, i, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)
]
def call(
self,
hidden_states,
attention_mask=None,
head_mask=None,
padding_len=0,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_global_attentions = () if (output_attentions and is_global_attn) else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
layer_outputs = layer_module(
[
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
],
training=training,
)
hidden_states = layer_outputs[0]
if output_attentions:
# bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
if is_global_attn:
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)))
# Add last layer
if output_hidden_states:
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
if not return_dict:
return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None
)
return TFLongformerBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
global_attentions=all_global_attentions,
)
@keras_serializable
class TFLongformerMainLayer(tf.keras.layers.Layer):
config_class = LongformerConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
if isinstance(config.attention_window, int):
assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value"
assert config.attention_window > 0, "`config.attention_window` has to be positive"
config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer
else:
assert len(config.attention_window) == config.num_hidden_layers, (
"`len(config.attention_window)` should equal `config.num_hidden_layers`. "
f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
)
self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.return_dict = config.use_return_dict
self.pad_token_id = config.pad_token_id
self.attention_window = config.attention_window
self.embeddings = TFLongformerEmbeddings(config, name="embeddings")
self.encoder = TFLongformerEncoder(config, name="encoder")
self.pooler = TFLongformerPooler(config, name="pooler")
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
raise NotImplementedError
def call(
self,
inputs,
attention_mask=None,
global_attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
global_attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)
# merge `global_attention_mask` and `attention_mask`
if global_attention_mask is not None:
attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
(
padding_len,
input_ids,
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
) = self._pad_to_window_size(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
pad_token_id=self.pad_token_id,
)
# is index masked or global attention
is_index_masked = tf.math.less(attention_mask, 1)
is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, to_seq_length, 1, 1]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask[:, :, tf.newaxis, tf.newaxis]
# Since attention_mask is 1.0 for positions we want to locall attend locally and 0.0 for
# masked and global attn positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
padding_len=padding_len,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
# undo padding
if padding_len > 0:
# unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1)
sequence_output = sequence_output[:, :-padding_len]
if not return_dict:
return (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
return TFLongformerBaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
global_attentions=encoder_outputs.global_attentions,
)
def _pad_to_window_size(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
pad_token_id,
):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
# padding
attention_window = (
self.attention_window if isinstance(self.attention_window, int) else max(self.attention_window)
)
assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)
batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window
if padding_len > 0:
logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window
)
)
paddings = tf.constant([[0, 0], [0, padding_len]])
if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
if position_ids is not None:
# pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)
if inputs_embeds is not None:
input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id)
inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
attention_mask = tf.pad(
attention_mask, paddings, constant_values=False
) # no attention on the padding tokens
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
return (
padding_len,
input_ids,
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
)
@staticmethod
def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor):
# longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
# (global_attention_mask + 1) => 1 for local attention, 2 for global attention
# => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
if attention_mask is not None:
attention_mask = attention_mask * (global_attention_mask + 1)
else:
# simply use `global_attention_mask` as `attention_mask`
# if no `attention_mask` is given
attention_mask = global_attention_mask + 1
return attention_mask
class TFLongformerPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = LongformerConfig
base_model_prefix = "longformer"
@property
def dummy_inputs(self):
input_ids = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
# make sure global layers are initialized
attention_mask = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
global_attention_mask = tf.constant([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]])
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"global_attention_mask": global_attention_mask,
}
LONGFORMER_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ subclass. Use
it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage
and behavior.
.. note::
TF 2.0 models accepts two formats as inputs:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional arguments.
This second option is useful when using :meth:`tf.keras.Model.fit` method which currently requires having all
the tensors in the first argument of the model call function: :obj:`model(inputs)`.
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in
the first positional argument :
- a single Tensor with :obj:`input_ids` only and nothing else: :obj:`model(inputs_ids)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
:obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
:obj:`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
Parameters:
config (:class:`~transformers.LongformerConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
"""
LONGFORMER_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`tf.Tensor` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.LongformerTokenizer`. See
:func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
global_attention_mask (:obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Mask to decide the attention given on each token, local attention or global attention. Tokens with global
attention attends to all other tokens, and all other tokens attend to them. This is important for
task-specific finetuning because it makes the model more flexible at representing the task. For example,
for classification, the <s> token should be given global attention. For QA, all question tokens should also
have global attention. Please refer to the `Longformer paper <https://arxiv.org/abs/2004.05150>`__ for more
details. Mask values selected in ``[0, 1]``:
- 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
token_type_ids (:obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``:
- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.
`What are token type IDs? <../glossary.html#token-type-ids>`__
position_ids (:obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`__
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""