|
|
|
|
|
|
|
|
|
|
|
from typing import ( |
|
Optional, |
|
Tuple, |
|
Union, |
|
List, |
|
) |
|
import math |
|
import random |
|
|
|
|
|
import torch |
|
from torch import nn |
|
from transformers import ( |
|
BartConfig, |
|
BartPretrainedModel, |
|
) |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutput, |
|
BaseModelOutputWithPastAndCrossAttentions |
|
) |
|
from transformers.models.bart.modeling_bart import ( |
|
BartLearnedPositionalEmbedding, |
|
_expand_mask, |
|
_make_causal_mask |
|
) |
|
from transformers.utils import ( |
|
logging, |
|
) |
|
|
|
|
|
from .config import BartCustomConfig |
|
from .decoder_layer import BartCustomDecoderLayer |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class BartCustomDecoder(BartPretrainedModel): |
|
""" |
|
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] |
|
|
|
Args: |
|
config: BartConfig |
|
embed_tokens (nn.Embedding): output embedding |
|
""" |
|
|
|
def __init__(self, config: BartCustomConfig, embed_tokens: Optional[nn.Embedding] = None): |
|
super().__init__(config) |
|
self.dropout = config.dropout |
|
self.layerdrop = config.decoder_layerdrop |
|
self.padding_idx = config.pad_token_id |
|
self.max_target_positions = config.max_position_embeddings |
|
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 |
|
|
|
if embed_tokens is not None: |
|
self.embed_tokens = embed_tokens |
|
else: |
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) |
|
|
|
self.embed_positions = BartLearnedPositionalEmbedding( |
|
config.max_position_embeddings, |
|
config.d_model, |
|
) |
|
self.layers = nn.ModuleList([BartCustomDecoderLayer(config) for _ in range(config.decoder_layers)]) |
|
self.layernorm_embedding = nn.LayerNorm(config.d_model) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.embed_tokens = value |
|
|
|
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): |
|
|
|
|
|
combined_attention_mask = None |
|
if input_shape[-1] > 1: |
|
combined_attention_mask = _make_causal_mask( |
|
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length |
|
).to(self.device) |
|
|
|
if attention_mask is not None: |
|
|
|
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) |
|
combined_attention_mask = ( |
|
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask |
|
) |
|
|
|
return combined_attention_mask |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: |
|
r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you |
|
provide it. |
|
|
|
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): |
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention |
|
of the decoder. |
|
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): |
|
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values |
|
selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): |
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
|
|
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): |
|
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing |
|
cross-attention on hidden heads. Mask values selected in `[0, 1]`: |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of |
|
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the |
|
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
|
|
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those |
|
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of |
|
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of |
|
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing |
|
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more |
|
control over how to convert `input_ids` indices into associated vectors than the model's internal |
|
embedding lookup matrix. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
for more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
input_shape = input_ids.size() |
|
input_ids = input_ids.view(-1, input_shape[-1]) |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
else: |
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
|
|
|
|
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale |
|
|
|
attention_mask = self._prepare_decoder_attention_mask( |
|
attention_mask, input_shape, inputs_embeds, past_key_values_length |
|
) |
|
|
|
|
|
if encoder_hidden_states is not None and encoder_attention_mask is not None: |
|
|
|
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) |
|
|
|
|
|
positions = self.embed_positions(input_shape, past_key_values_length) |
|
|
|
hidden_states = inputs_embeds + positions |
|
hidden_states = self.layernorm_embedding(hidden_states) |
|
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None |
|
next_decoder_cache = () if use_cache else None |
|
|
|
|
|
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): |
|
if attn_mask is not None: |
|
if attn_mask.size()[0] != (len(self.layers)): |
|
raise ValueError( |
|
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." |
|
) |
|
|
|
for idx, decoder_layer in enumerate(self.layers): |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
dropout_probability = random.uniform(0, 1) |
|
if self.training and (dropout_probability < self.layerdrop): |
|
continue |
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
|
if use_cache: |
|
logger.warning( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
|
|
return module(*inputs, output_attentions, use_cache) |
|
|
|
return custom_forward |
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(decoder_layer), |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
head_mask[idx] if head_mask is not None else None, |
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, |
|
None, |
|
) |
|
else: |
|
|
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
|
cross_attn_layer_head_mask=( |
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None |
|
), |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
if encoder_hidden_states is not None: |
|
all_cross_attentions += (layer_outputs[2],) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] |
|
if v is not None |
|
) |
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
cross_attentions=all_cross_attentions, |
|
) |
|
|
|
|