|
from dataclasses import dataclass |
|
|
|
from transformers.models.t5.modeling_t5 import ( |
|
T5Stack, T5Block, T5LayerNorm, T5LayerSelfAttention, T5LayerFF, T5LayerCrossAttention, |
|
T5PreTrainedModel, T5ForConditionalGeneration |
|
) |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import CrossEntropyLoss |
|
|
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple |
|
import copy |
|
|
|
from transformers.modeling_outputs import ModelOutput, BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput |
|
from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer |
|
from transformers.utils import logging |
|
from transformers import BeamScorer, BeamSearchScorer |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class JointEncoder(T5Stack): |
|
def __init__(self, config, embed_tokens=None): |
|
super(T5Stack, self).__init__(config) |
|
self.config = config |
|
|
|
self.embed_tokens = embed_tokens |
|
self.is_decoder = self.config.is_decoder |
|
assert self.config.is_decoder is False |
|
|
|
self.block = nn.ModuleList( |
|
[T5Block(config, has_relative_attention_bias=(i == 0)) |
|
for i in range(config.num_layers)] |
|
) |
|
self.final_layer_norm = T5LayerNorm( |
|
config.d_model, eps=config.layer_norm_epsilon) |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
|
|
self.whole_word_embeddings = nn.Embedding( |
|
512, config.d_model |
|
) |
|
self.init_weights() |
|
self.model_parallel = False |
|
self.device_map = None |
|
|
|
def set_input_embeddings(self, new_embeddings): |
|
self.embed_tokens = new_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
whole_word_ids=None, |
|
attention_mask=None, |
|
inputs_embeds=None, |
|
head_mask=None, |
|
past_key_values=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
|
|
if inputs_embeds is None: |
|
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
if whole_word_ids is not None: |
|
whole_word_embeds = self.whole_word_embeddings(whole_word_ids) |
|
assert whole_word_embeds.shape[-1] == inputs_embeds.shape[-1] |
|
inputs_embeds = inputs_embeds + whole_word_embeds |
|
|
|
B, L = inputs_embeds.size()[:-1] |
|
|
|
if attention_mask is None: |
|
attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) |
|
|
|
|
|
extended_attention_mask = self.get_extended_attention_mask( |
|
attention_mask, |
|
(B, L), |
|
inputs_embeds.device) |
|
|
|
|
|
if past_key_values is None: |
|
past_key_values = [None] * len(self.block) |
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_layers) |
|
present_key_value_states = () if use_cache else None |
|
all_hidden_states = () if output_hidden_states else None |
|
all_attentions = () if output_attentions else None |
|
all_cross_attentions = () if (output_attentions and self.is_decoder) else None |
|
|
|
hidden_states = self.dropout(inputs_embeds) |
|
|
|
if self.config.num_layers > 0: |
|
|
|
assert self.block[0].layer[0].SelfAttention.has_relative_attention_bias |
|
|
|
seq_length = L |
|
q_len = seq_length |
|
k_len = seq_length |
|
|
|
|
|
text_position_bias = self.block[0].layer[0].SelfAttention.compute_bias( |
|
L, L) |
|
num_heads = text_position_bias.size(1) |
|
position_bias = text_position_bias.new_zeros( |
|
1, num_heads, seq_length, seq_length) |
|
position_bias[:, :, :L, :L] = text_position_bias |
|
|
|
position_bias = position_bias + extended_attention_mask |
|
|
|
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): |
|
layer_head_mask = head_mask[i] |
|
layer_outputs = layer_module( |
|
hidden_states, |
|
attention_mask=extended_attention_mask, |
|
position_bias=position_bias, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
encoder_decoder_position_bias=None, |
|
|
|
layer_head_mask=layer_head_mask, |
|
past_key_value=past_key_value, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
|
|
|
|
hidden_states, present_key_value_state = layer_outputs[:2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_cache: |
|
present_key_value_states = present_key_value_states + \ |
|
(present_key_value_state,) |
|
|
|
hidden_states = self.final_layer_norm(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [ |
|
hidden_states, |
|
present_key_value_states, |
|
all_hidden_states, |
|
all_attentions, |
|
all_cross_attentions, |
|
] |
|
if v is not None |
|
) |
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
past_key_values=present_key_value_states, |
|
hidden_states=all_hidden_states, |
|
attentions=all_attentions, |
|
cross_attentions=all_cross_attentions, |
|
) |
|
|
|
|
|
class P5(T5ForConditionalGeneration): |
|
_keys_to_ignore_on_load_missing = [ |
|
r"encoder\.embed_tokens\.weight", |
|
r"decoder\.embed_tokens\.weight", |
|
r"lm_head\.weight", |
|
] |
|
_keys_to_ignore_on_load_unexpected = [ |
|
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", |
|
] |
|
|
|
def __init__(self, config): |
|
super(T5ForConditionalGeneration, self).__init__(config) |
|
|
|
self.config = config |
|
|
|
self.model_dim = config.d_model |
|
|
|
self.shared = nn.Embedding(config.vocab_size, config.d_model) |
|
|
|
encoder_config = copy.deepcopy(config) |
|
encoder_config.is_decoder = False |
|
encoder_config.use_cache = False |
|
encoder_config.is_encoder_decoder = False |
|
|
|
self.encoder = JointEncoder(encoder_config, self.shared) |
|
|
|
decoder_config = copy.deepcopy(config) |
|
decoder_config.is_decoder = True |
|
decoder_config.is_encoder_decoder = False |
|
|
|
self.decoder = T5Stack(decoder_config, self.shared) |
|
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
|
self.init_weights() |
|
|
|
self.model_parallel = False |
|
self.device_map = None |
|
|
|
def set_input_embeddings(self, new_embeddings): |
|
self.shared = new_embeddings |
|
self.encoder.set_input_embeddings(new_embeddings) |
|
self.decoder.set_input_embeddings(new_embeddings) |
|
|
|
def extend_vocab(self, vocab_size): |
|
|
|
new_shared = nn.Embedding(vocab_size, self.config.d_model) |
|
old_weight = self.shared.weight.data.detach().clone() |
|
old_vocab_size = old_weight.size(0) |
|
new_shared.weight.data[:old_vocab_size, :] = old_weight |
|
self.shared = new_shared |
|
|
|
new_lm_head = nn.Linear(self.config.d_model, vocab_size, bias=False) |
|
old_weight = self.lm_head.weight.data.detach().clone() |
|
old_vocab_size = old_weight.size(0) |
|
new_lm_head.weight.data[:old_vocab_size, :] = old_weight |
|
self.lm_head = new_lm_head |
|
|
|
self.encoder.embed_tokens = self.shared |
|
self.decoder.embed_tokens = self.shared |
|
|
|
self.lm_head.weight = self.shared.weight |
|
|
|
self.config.vocab_size = vocab_size |
|
self.encoder.config.vocab_size = vocab_size |
|
self.decoder.config.vocab_size = vocab_size |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
whole_word_ids=None, |
|
attention_mask=None, |
|
encoder_outputs=None, |
|
decoder_input_ids=None, |
|
decoder_attention_mask=None, |
|
past_key_values=None, |
|
use_cache=None, |
|
labels=None, |
|
inputs_embeds=None, |
|
decoder_inputs_embeds=None, |
|
head_mask=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
reduce_loss=False, |
|
|
|
return_hidden_state=False, |
|
|
|
**kwargs, |
|
): |
|
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if encoder_outputs is None: |
|
encoder_outputs = self.encoder( |
|
input_ids=input_ids, |
|
whole_word_ids=whole_word_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): |
|
encoder_outputs = BaseModelOutput( |
|
last_hidden_state=encoder_outputs[0], |
|
hidden_states=encoder_outputs[1] if len( |
|
encoder_outputs) > 1 else None, |
|
attentions=encoder_outputs[2] if len( |
|
encoder_outputs) > 2 else None, |
|
) |
|
|
|
hidden_states = encoder_outputs[0] |
|
|
|
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: |
|
|
|
decoder_input_ids = self._shift_right(labels) |
|
|
|
|
|
|
|
if past_key_values is not None: |
|
assert labels is None, "Decoder should not use cached key value states when training." |
|
if decoder_input_ids is not None: |
|
decoder_input_ids = decoder_input_ids[:, -1:] |
|
if decoder_inputs_embeds is not None: |
|
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] |
|
|
|
if attention_mask is None: |
|
attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=hidden_states.dtype, device=hidden_states.device) |
|
encoder_attention_mask = attention_mask |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
inputs_embeds=decoder_inputs_embeds, |
|
past_key_values=past_key_values, |
|
|
|
encoder_hidden_states=hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
|
|
head_mask=head_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = decoder_outputs[0] |
|
|
|
assert self.config.tie_word_embeddings is True |
|
|
|
if self.config.tie_word_embeddings: |
|
sequence_output = sequence_output * (self.model_dim ** -0.5) |
|
|
|
if return_hidden_state: |
|
return sequence_output |
|
|
|
lm_logits = self.lm_head(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
if reduce_loss: |
|
loss_fct = CrossEntropyLoss(ignore_index=-100) |
|
else: |
|
loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none') |
|
loss = loss_fct( |
|
lm_logits.view(-1, lm_logits.size(-1)), |
|
labels.view(-1)) |
|
|
|
return P5Seq2SeqLMOutput( |
|
loss=loss, |
|
logits=lm_logits, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_last_hidden_state=decoder_outputs.last_hidden_state, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids, past=None, attention_mask=None, use_cache=None, |
|
encoder_outputs=None, |
|
**kwargs): |
|
|
|
if past is not None: |
|
input_ids = input_ids[:, -1:] |
|
|
|
output = { |
|
"decoder_input_ids": input_ids, |
|
"past_key_values": past, |
|
"encoder_outputs": encoder_outputs, |
|
"attention_mask": attention_mask, |
|
"use_cache": use_cache, |
|
} |
|
|
|
return output |
|
|
|
@staticmethod |
|
def _expand_inputs_for_generation( |
|
input_ids: torch.LongTensor, |
|
expand_size: int = 1, |
|
is_encoder_decoder: bool = False, |
|
attention_mask: torch.LongTensor = None, |
|
encoder_outputs: ModelOutput = None, |
|
**model_kwargs |
|
) -> Tuple[torch.LongTensor, Dict[str, Any]]: |
|
expanded_return_idx = ( |
|
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, |
|
expand_size).view(-1).to(input_ids.device) |
|
) |
|
input_ids = input_ids.index_select(0, expanded_return_idx) |
|
|
|
if "token_type_ids" in model_kwargs: |
|
token_type_ids = model_kwargs["token_type_ids"] |
|
model_kwargs["token_type_ids"] = token_type_ids.index_select( |
|
0, expanded_return_idx) |
|
|
|
if attention_mask is not None: |
|
model_kwargs["attention_mask"] = attention_mask.index_select( |
|
0, expanded_return_idx) |
|
|
|
if is_encoder_decoder: |
|
assert encoder_outputs is not None |
|
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( |
|
0, expanded_return_idx |
|
) |
|
model_kwargs["encoder_outputs"] = encoder_outputs |
|
|
|
return input_ids, model_kwargs |
|
|
|
|
|
@dataclass |
|
class P5Seq2SeqLMOutput(ModelOutput): |
|
""" |
|
Base class for sequence-to-sequence language models outputs. |
|
|
|
Args: |
|
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): |
|
Languaged modeling loss. |
|
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): |
|
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape |
|
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). |
|
|
|
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be |
|
used (see ``past_key_values`` input) to speed up sequential decoding. |
|
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): |
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) |
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. |
|
decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): |
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape |
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. |
|
|
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the |
|
self-attention heads. |
|
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): |
|
Sequence of hidden-states at the output of the last layer of the encoder of the model. |
|
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): |
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) |
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. |
|
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): |
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape |
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. |
|
|
|
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the |
|
self-attention heads. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
past_key_values: Optional[List[torch.FloatTensor]] = None |
|
decoder_last_hidden_state: Optional[Tuple[torch.FloatTensor]] = None |
|
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None |
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|