|
|
|
|
|
|
|
""" |
|
|
|
This code is in part adapted from AllenAI's Longformer: |
|
https://github.com/allenai/longformer/ |
|
and in part adapted from: |
|
https://github.com/huggingface/transformers |
|
|
|
Author: Annette Rios (rios@cl.uzh.ch) |
|
|
|
""" |
|
from typing import List, Optional, Tuple, Dict, Union |
|
from torch import nn, Tensor, zeros |
|
import torch |
|
import math |
|
import random |
|
from .longformer import LongformerSelfAttention |
|
from transformers.models.mbart.modeling_mbart import MBartConfig, MBartForConditionalGeneration, MBartEncoder, MBartLearnedPositionalEmbedding, MBartEncoderLayer, MBartDecoder, MBartModel, _expand_mask |
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
class MLongformerEncoderDecoderForConditionalGeneration(MBartForConditionalGeneration): |
|
def __init__(self, config): |
|
super(MBartForConditionalGeneration, self).__init__(config) |
|
|
|
self.model = LongMBartModel(config) |
|
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) |
|
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) |
|
|
|
|
|
if config.attention_mode == 'n2': |
|
pass |
|
else: |
|
for i, layer in enumerate(self.model.encoder.layers): |
|
layer.self_attn = LongformerSelfAttentionForMBart(config, layer_id=i) |
|
|
|
self.post_init() |
|
|
|
|
|
class MLongformerEncoderDecoderConfig(MBartConfig): |
|
def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, |
|
autoregressive: bool = False, attention_mode: str = 'sliding_chunks', |
|
gradient_checkpointing: bool = False, **kwargs): |
|
""" |
|
Args: |
|
attention_window: list of attention window sizes of length = number of layers. |
|
window size = number of attention locations on each side. |
|
For an affective window size of 512, use `attention_window=[256]*num_layers` |
|
which is 256 on each side. |
|
attention_dilation: list of attention dilation of length = number of layers. |
|
attention dilation of `1` means no dilation. |
|
autoregressive: do autoregressive attention or have attention of both sides |
|
attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer |
|
selfattention, 'sliding_chunks' for another implementation of Longformer selfattention |
|
""" |
|
super().__init__(**kwargs) |
|
self.attention_window = attention_window |
|
self.attention_dilation = attention_dilation |
|
self.autoregressive = autoregressive |
|
self.attention_mode = attention_mode |
|
self.gradient_checkpointing = gradient_checkpointing |
|
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2'] |
|
|
|
class LongformerSelfAttentionForMBart(nn.Module): |
|
def __init__(self, config, layer_id): |
|
super().__init__() |
|
self.embed_dim = config.d_model |
|
self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id) |
|
self.output = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
def forward( |
|
self, |
|
hidden_states: Tensor, |
|
key_value_states: Optional[Tensor] = None, |
|
past_key_value: Optional[Tuple[Tensor]] = None, |
|
attention_mask: Optional[Tensor] = None, |
|
layer_head_mask: Optional[Tensor] = None, |
|
output_attentions: bool = False |
|
) -> Tuple[Tensor, Optional[Tensor]]: |
|
|
|
bsz, tgt_len, embed_dim = hidden_states.size() |
|
assert embed_dim == self.embed_dim |
|
assert list(hidden_states.size()) == [bsz, tgt_len, embed_dim] |
|
|
|
outputs = self.longformer_self_attn( |
|
hidden_states, |
|
attention_mask=attention_mask * -1, |
|
head_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
|
|
attn_output = self.output(outputs[0]) |
|
|
|
return (attn_output, outputs[1:] ,None) if len(outputs) == 2 else (attn_output, None, None) |
|
|
|
|
|
class LongMBartEncoder(MBartEncoder): |
|
""" |
|
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a |
|
[`MBartEncoderLayer`]. |
|
|
|
Args: |
|
config: MBartConfig |
|
embed_tokens (nn.Embedding): output embedding |
|
""" |
|
|
|
def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None): |
|
super().__init__(config) |
|
|
|
self.dropout = config.dropout |
|
self.layerdrop = config.encoder_layerdrop |
|
|
|
embed_dim = config.d_model |
|
self.padding_idx = config.pad_token_id |
|
self.max_source_positions = config.max_encoder_position_embeddings |
|
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 |
|
|
|
if embed_tokens is not None: |
|
self.embed_tokens = embed_tokens |
|
else: |
|
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) |
|
|
|
self.embed_positions = MBartLearnedPositionalEmbedding( |
|
self.max_source_positions, |
|
embed_dim, |
|
) |
|
self.layers = nn.ModuleList([LongMBartEncoderLayer(config) for _ in range(config.encoder_layers)]) |
|
self.layernorm_embedding = nn.LayerNorm(embed_dim) |
|
self.layer_norm = nn.LayerNorm(config.d_model) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutput]: |
|
r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you |
|
provide it. |
|
|
|
Indices can be obtained using [`MBartTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): |
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
|
than the model's internal embedding lookup matrix. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
for more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
input = input_ids |
|
input_shape = input.shape |
|
input_ids = input_ids.view(-1, input_shape[-1]) |
|
elif inputs_embeds is not None: |
|
input = inputs_embeds[:, :, -1] |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale |
|
|
|
embed_pos = self.embed_positions(input) |
|
|
|
hidden_states = inputs_embeds + embed_pos |
|
hidden_states = self.layernorm_embedding(hidden_states) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
|
|
|
longformer_attention_mask = None |
|
if attention_mask is not None: |
|
|
|
longformer_attention_mask = 1 - attention_mask |
|
|
|
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) |
|
|
|
|
|
encoder_states = () if output_hidden_states else None |
|
all_attentions = () if output_attentions else None |
|
|
|
|
|
if head_mask is not None: |
|
if head_mask.size()[0] != len(self.layers): |
|
raise ValueError( |
|
f"The head_mask should be specified for {len(self.layers)} layers, but it is for" |
|
f" {head_mask.size()[0]}." |
|
) |
|
for idx, encoder_layer in enumerate(self.layers): |
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
dropout_probability = random.uniform(0, 1) |
|
if self.training and (dropout_probability < self.layerdrop): |
|
layer_outputs = (None, None) |
|
else: |
|
if self.gradient_checkpointing and self.training: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
return module(*inputs, output_attentions) |
|
|
|
return custom_forward |
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(encoder_layer), |
|
hidden_states, |
|
attention_mask, |
|
longformer_attention_mask, |
|
(head_mask[idx] if head_mask is not None else None), |
|
) |
|
else: |
|
layer_outputs = encoder_layer( |
|
hidden_states, |
|
attention_mask, |
|
longformer_attention_mask, |
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
|
output_attentions=output_attentions, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_attentions = all_attentions + (layer_outputs[1],) |
|
|
|
hidden_states = self.layer_norm(hidden_states) |
|
|
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) |
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions |
|
) |
|
|
|
|
|
class LongMBartModel(MBartModel): |
|
def __init__(self, config: MBartConfig): |
|
super().__init__(config) |
|
|
|
padding_idx, vocab_size = config.pad_token_id, config.vocab_size |
|
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) |
|
|
|
self.encoder = LongMBartEncoder(config, self.shared) |
|
self.decoder = MBartDecoder(config, self.shared) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
class LongMBartEncoderLayer(MBartEncoderLayer): |
|
def __init__(self, config: MBartConfig): |
|
super().__init__(config) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
longformer_attention_mask: torch.Tensor, |
|
layer_head_mask: torch.Tensor, |
|
output_attentions: bool = False, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)* |
|
attention_mask (`torch.FloatTensor`): attention mask of size |
|
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. |
|
longformer_attention_mask (:obj:`torch.FloatTensor`): attention mask of size |
|
`(batch, src_len)` where 0=local, -1=global, 1=padding. |
|
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size |
|
*(encoder_attention_heads,)*. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
""" |
|
|
|
if isinstance(self.self_attn, LongformerSelfAttentionForMBart): |
|
attention_mask = longformer_attention_mask |
|
residual = hidden_states |
|
hidden_states = self.self_attn_layer_norm(hidden_states) |
|
hidden_states, attn_weights, _ = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
layer_head_mask=layer_head_mask, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.final_layer_norm(hidden_states) |
|
hidden_states = self.activation_fn(self.fc1(hidden_states)) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) |
|
hidden_states = self.fc2(hidden_states) |
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
hidden_states = residual + hidden_states |
|
|
|
if hidden_states.dtype == torch.float16 and ( |
|
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() |
|
): |
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
return outputs |
|
|