#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ This code is in part adapted from AllenAI's Longformer: https://github.com/allenai/longformer/ and in part adapted from: https://github.com/huggingface/transformers Author: Annette Rios (rios@cl.uzh.ch) """ from typing import List, Optional, Tuple, Dict, Union from torch import nn, Tensor, zeros import torch import math import random from .longformer import LongformerSelfAttention from transformers.models.mbart.modeling_mbart import MBartConfig, MBartForConditionalGeneration, MBartEncoder, MBartLearnedPositionalEmbedding, MBartEncoderLayer, MBartDecoder, MBartModel, _expand_mask from transformers.modeling_outputs import BaseModelOutput class MLongformerEncoderDecoderForConditionalGeneration(MBartForConditionalGeneration): def __init__(self, config): super(MBartForConditionalGeneration, self).__init__(config) self.model = LongMBartModel(config) self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) #print(self) if config.attention_mode == 'n2': pass # do nothing, use MBartSelfAttention instead else: for i, layer in enumerate(self.model.encoder.layers): layer.self_attn = LongformerSelfAttentionForMBart(config, layer_id=i) # Initialize weights and apply final processing self.post_init() class MLongformerEncoderDecoderConfig(MBartConfig): def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, autoregressive: bool = False, attention_mode: str = 'sliding_chunks', gradient_checkpointing: bool = False, **kwargs): """ Args: attention_window: list of attention window sizes of length = number of layers. window size = number of attention locations on each side. For an affective window size of 512, use `attention_window=[256]*num_layers` which is 256 on each side. attention_dilation: list of attention dilation of length = number of layers. attention dilation of `1` means no dilation. autoregressive: do autoregressive attention or have attention of both sides attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer selfattention, 'sliding_chunks' for another implementation of Longformer selfattention """ super().__init__(**kwargs) self.attention_window = attention_window self.attention_dilation = attention_dilation self.autoregressive = autoregressive self.attention_mode = attention_mode self.gradient_checkpointing = gradient_checkpointing assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2'] class LongformerSelfAttentionForMBart(nn.Module): def __init__(self, config, layer_id): super().__init__() self.embed_dim = config.d_model self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id) self.output = nn.Linear(self.embed_dim, self.embed_dim) def forward( self, hidden_states: Tensor, # shape (batch_size, q_len, model_size) key_value_states: Optional[Tensor] = None, # cross-attention in transformers.models.mbart.modeling_mbart past_key_value: Optional[Tuple[Tensor]] = None, # only for decoder attention_mask: Optional[Tensor] = None, # shape (batch_size, k_len) -> changed in transformers.models.modeling_mbart.MBartEncoder and MBartEncoderLayer (new mask uses bool -> global attention positions are lost, need to use the inverted orignal mask layer_head_mask: Optional[Tensor] = None, # head dropout? output_attentions: bool = False ) -> Tuple[Tensor, Optional[Tensor]]: bsz, tgt_len, embed_dim = hidden_states.size() assert embed_dim == self.embed_dim assert list(hidden_states.size()) == [bsz, tgt_len, embed_dim] outputs = self.longformer_self_attn( hidden_states, attention_mask=attention_mask * -1, # shape (batch_size, 1, 1, key_len) head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=output_attentions, ) ## new: MBart encoder expects shape (seq_len, bsz, embed_dim), no transpose needed attn_output = self.output(outputs[0]) # new return in MBartAttention has attn_output, attn_weights_reshaped, past_key_value (only for decoder), need to return 3 values (None for past_key_value) return (attn_output, outputs[1:] ,None) if len(outputs) == 2 else (attn_output, None, None) class LongMBartEncoder(MBartEncoder): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`MBartEncoderLayer`]. Args: config: MBartConfig embed_tokens (nn.Embedding): output embedding """ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.encoder_layerdrop embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_encoder_position_embeddings self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 if embed_tokens is not None: self.embed_tokens = embed_tokens else: self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_positions = MBartLearnedPositionalEmbedding( self.max_source_positions, embed_dim, ) self.layers = nn.ModuleList([LongMBartEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`MBartTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input = input_ids input_shape = input.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input) hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # expand attention_mask longformer_attention_mask = None if attention_mask is not None: # need to return original, inverted mask for longformer attention, else value for global attention (=2 in given mask, will be -1) is lost longformer_attention_mask = 1 - attention_mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: if head_mask.size()[0] != len(self.layers): raise ValueError( f"The head_mask should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(encoder_layer), hidden_states, attention_mask, longformer_attention_mask, (head_mask[idx] if head_mask is not None else None), ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, longformer_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) class LongMBartModel(MBartModel): def __init__(self, config: MBartConfig): super().__init__(config) padding_idx, vocab_size = config.pad_token_id, config.vocab_size self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) self.encoder = LongMBartEncoder(config, self.shared) self.decoder = MBartDecoder(config, self.shared) # Initialize weights and apply final processing self.post_init() class LongMBartEncoderLayer(MBartEncoderLayer): def __init__(self, config: MBartConfig): super().__init__(config) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, longformer_attention_mask: torch.Tensor, layer_head_mask: torch.Tensor, output_attentions: bool = False, ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)* attention_mask (`torch.FloatTensor`): attention mask of size *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. longformer_attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, src_len)` where 0=local, -1=global, 1=padding. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size *(encoder_attention_heads,)*. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ # if longformer attention instead of mbart self attention: use special mask if isinstance(self.self_attn, LongformerSelfAttentionForMBart): attention_mask = longformer_attention_mask residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs