|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" EncT5 model (based on HuggingFace T5 Model) """ |
|
|
|
from typing import Optional, List, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Model |
|
from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput |
|
|
|
from .configuration_enct5 import EncT5Config |
|
|
|
|
|
class EncT5ClassificationHead(nn.Module): |
|
"""Head for sentence-level classification tasks.""" |
|
|
|
def __init__(self, config: EncT5Config): |
|
super().__init__() |
|
self.dropout = nn.Dropout(p=config.classifier_dropout) |
|
self.out_proj = nn.Linear(config.d_model, config.num_labels) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.dropout(hidden_states) |
|
hidden_states = self.out_proj(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class EncT5MultiLabelClassificationHead(nn.Module): |
|
"""Head for multi-label sentence-level classification tasks.""" |
|
|
|
def __init__(self, config: EncT5Config): |
|
super().__init__() |
|
self.weights = nn.Parameter(torch.Tensor(config.num_labels, config.d_model)) |
|
self.biases = nn.Parameter(torch.Tensor(config.num_labels)) |
|
self.dropout = nn.Dropout(p=config.classifier_dropout) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
|
hidden_states = self.dropout(hidden_states) |
|
|
|
|
|
|
|
hidden_states = torch.sum(hidden_states * self.weights, dim=-1) + self.biases |
|
return hidden_states |
|
|
|
|
|
class EncT5PreTrainedModel(T5PreTrainedModel): |
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
factor = self.config.initializer_factor |
|
if isinstance(module, EncT5ClassificationHead): |
|
module.out_proj.weight.data.normal_(mean=0.0, std=factor * (self.config.d_model ** -0.5)) |
|
if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: |
|
module.out_proj.bias.data.zero_() |
|
elif isinstance(module, EncT5MultiLabelClassificationHead): |
|
module.weights.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) |
|
module.biases.data.zero_() |
|
super()._init_weights(module) |
|
|
|
|
|
class EncT5ForSequenceClassification(EncT5PreTrainedModel): |
|
r""" |
|
The EncT5 model was proposed in [EncT5: A Framework for Fine-tuning T5 as Non-autoregressive |
|
Models](https://arxiv.org/abs/2110.08426) by Frederick Liu, Terry Huang, Shihang Lyu, Siamak Shakeri, Hongkun Yu, |
|
Jing Li. |
|
|
|
EncT5 is a variant of T5 that uses mainly the encoder for non-autoregressive tasks. There are several special |
|
features to EncT5: 1) there are less decoder layers (defaulting to 1 decoder layer), 2) there is a separate decoder |
|
word embedding, with the decoder input ids being predefined constants, and 3) there is a classification head on top |
|
of the output. Research has shown that this model can be more efficient and usable over T5 and BERT for |
|
non-autoregressive tasks such as classification and regression. |
|
""" |
|
config_class = EncT5Config |
|
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] |
|
|
|
def __init__(self, config: EncT5Config): |
|
super().__init__(config) |
|
|
|
|
|
self.transformer = T5Model(T5Config.from_dict(config.to_dict())) |
|
|
|
|
|
self.decoder_embeddings = nn.Embedding(config.decoder_vocab_size, config.d_model) |
|
self.transformer.get_decoder().set_input_embeddings(self.decoder_embeddings) |
|
|
|
|
|
if config.problem_type == "multi_label_classification": |
|
self.classification_head = EncT5MultiLabelClassificationHead(config) |
|
else: |
|
self.classification_head = EncT5ClassificationHead(config) |
|
|
|
|
|
self.post_init() |
|
|
|
self.model_parallel = False |
|
|
|
def load_weights_from_pretrained_t5(self, model_path: str): |
|
pretrained_t5_model = T5Model.from_pretrained(model_path) |
|
|
|
|
|
pretrained_state_dict = pretrained_t5_model.state_dict() |
|
pretrained_state_dict["decoder.embed_tokens.weight"] = self.decoder_embeddings.state_dict()["weight"] |
|
|
|
self.transformer.load_state_dict(pretrained_state_dict, strict=False) |
|
|
|
def prepare_for_fine_tuning(self): |
|
r""" |
|
Prepares the model for fine-tuning by re-initializing the necessary weights for fine-tuning. This step should be |
|
performed after loading the pre-trained T5 model but before fine-tuning. |
|
""" |
|
self.transformer.get_decoder().apply(self._init_weights) |
|
self._init_weights(self.classification_head) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
decoder_attention_mask: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
decoder_head_mask: Optional[torch.Tensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: |
|
r""" |
|
Arguments: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so |
|
you should be able to pad the inputs on both the right and the left. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for detail. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
|
|
To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). |
|
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): |
|
Indices of decoder input sequence tokens in the vocabulary. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are decoder input IDs?](../glossary#decoder-input-ids) |
|
|
|
T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If |
|
`past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see |
|
`past_key_values`). |
|
|
|
To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 |
|
Training](./t5#training). |
|
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): |
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will |
|
also be used by default. |
|
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): |
|
Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in |
|
`[0, 1]`: |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
|
|
decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): |
|
Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in |
|
`[0, 1]`: |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
|
|
cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): |
|
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected |
|
in `[0, 1]`: |
|
|
|
- 1 indicates the head is **not masked**, |
|
- 0 indicates the head is **masked**. |
|
|
|
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): |
|
Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) |
|
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states |
|
at the output of the last layer of the encoder. Used in the cross-attention of the decoder. |
|
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 |
|
tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): |
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up |
|
decoding. |
|
|
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those |
|
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of |
|
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. |
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
|
than the model's internal embedding lookup matrix. |
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded |
|
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to |
|
be input (see `past_key_values`). This is useful if you want more control over how to convert |
|
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. |
|
|
|
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the |
|
value of `inputs_embeds`. |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
|
(see `past_key_values`). |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
Returns: |
|
""" |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
if labels is not None: |
|
use_cache = False |
|
|
|
if input_ids is None and inputs_embeds is None: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds.") |
|
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] |
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
if decoder_input_ids is None and decoder_inputs_embeds is None: |
|
if self.config.problem_type == "multi_label_classification": |
|
decoder_input_ids = torch.arange(end=self.config.num_labels, device=device, dtype=torch.long) |
|
decoder_input_ids = decoder_input_ids.repeat(batch_size, 1) |
|
|
|
if decoder_attention_mask is None: |
|
decoder_attention_mask = torch.ones( |
|
(batch_size, self.config.num_labels, self.config.num_labels), device=device, dtype=torch.long |
|
) |
|
else: |
|
decoder_input_ids = torch.zeros(batch_size, 1, device=device, dtype=torch.long) |
|
|
|
outputs = self.transformer( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
decoder_input_ids=decoder_input_ids, |
|
decoder_attention_mask=decoder_attention_mask, |
|
head_mask=head_mask, |
|
decoder_head_mask=decoder_head_mask, |
|
cross_attn_head_mask=cross_attn_head_mask, |
|
encoder_outputs=encoder_outputs, |
|
inputs_embeds=inputs_embeds, |
|
decoder_inputs_embeds=decoder_inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
sequence_output = outputs[0] |
|
|
|
logits = self.classification_head(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
labels = labels.to(logits.device) |
|
if self.config.problem_type is None: |
|
if self.config.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
|
|
|
|
raise ValueError( |
|
"For multi-label classification, the config.problem_type must be set to " |
|
"'multi_label_classification' when initializing the model.") |
|
|
|
if self.config.problem_type == "regression": |
|
loss_fct = MSELoss() |
|
if self.config.num_labels == 1: |
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
|
else: |
|
loss_fct = BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return Seq2SeqSequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
decoder_hidden_states=outputs.decoder_hidden_states, |
|
decoder_attentions=outputs.decoder_attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
|
encoder_hidden_states=outputs.encoder_hidden_states, |
|
encoder_attentions=outputs.encoder_attentions, |
|
) |
|
|