|
import torch |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss, MSELoss |
|
|
|
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward |
|
from transformers.models.bert.modeling_bert import ( |
|
BERT_INPUTS_DOCSTRING, |
|
BERT_START_DOCSTRING, |
|
BertEmbeddings, |
|
BertLayer, |
|
BertPooler, |
|
BertPreTrainedModel, |
|
) |
|
|
|
|
|
def entropy(x): |
|
"""Calculate entropy of a pre-softmax logit Tensor""" |
|
exp_x = torch.exp(x) |
|
A = torch.sum(exp_x, dim=1) |
|
B = torch.sum(x * exp_x, dim=1) |
|
return torch.log(A) - B / A |
|
|
|
|
|
class DeeBertEncoder(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.output_attentions = config.output_attentions |
|
self.output_hidden_states = config.output_hidden_states |
|
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) |
|
self.highway = nn.ModuleList([BertHighway(config) for _ in range(config.num_hidden_layers)]) |
|
|
|
self.early_exit_entropy = [-1 for _ in range(config.num_hidden_layers)] |
|
|
|
def set_early_exit_entropy(self, x): |
|
if (type(x) is float) or (type(x) is int): |
|
for i in range(len(self.early_exit_entropy)): |
|
self.early_exit_entropy[i] = x |
|
else: |
|
self.early_exit_entropy = x |
|
|
|
def init_highway_pooler(self, pooler): |
|
loaded_model = pooler.state_dict() |
|
for highway in self.highway: |
|
for name, param in highway.pooler.state_dict().items(): |
|
param.copy_(loaded_model[name]) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
head_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
): |
|
all_hidden_states = () |
|
all_attentions = () |
|
all_highway_exits = () |
|
for i, layer_module in enumerate(self.layer): |
|
if self.output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
layer_outputs = layer_module( |
|
hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
if self.output_attentions: |
|
all_attentions = all_attentions + (layer_outputs[1],) |
|
|
|
current_outputs = (hidden_states,) |
|
if self.output_hidden_states: |
|
current_outputs = current_outputs + (all_hidden_states,) |
|
if self.output_attentions: |
|
current_outputs = current_outputs + (all_attentions,) |
|
|
|
highway_exit = self.highway[i](current_outputs) |
|
|
|
|
|
if not self.training: |
|
highway_logits = highway_exit[0] |
|
highway_entropy = entropy(highway_logits) |
|
highway_exit = highway_exit + (highway_entropy,) |
|
all_highway_exits = all_highway_exits + (highway_exit,) |
|
|
|
if highway_entropy < self.early_exit_entropy[i]: |
|
new_output = (highway_logits,) + current_outputs[1:] + (all_highway_exits,) |
|
raise HighwayException(new_output, i + 1) |
|
else: |
|
all_highway_exits = all_highway_exits + (highway_exit,) |
|
|
|
|
|
if self.output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
outputs = (hidden_states,) |
|
if self.output_hidden_states: |
|
outputs = outputs + (all_hidden_states,) |
|
if self.output_attentions: |
|
outputs = outputs + (all_attentions,) |
|
|
|
outputs = outputs + (all_highway_exits,) |
|
return outputs |
|
|
|
|
|
@add_start_docstrings( |
|
"The Bert Model transformer with early exiting (DeeBERT). ", |
|
BERT_START_DOCSTRING, |
|
) |
|
class DeeBertModel(BertPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.embeddings = BertEmbeddings(config) |
|
self.encoder = DeeBertEncoder(config) |
|
self.pooler = BertPooler(config) |
|
|
|
self.init_weights() |
|
|
|
def init_highway_pooler(self): |
|
self.encoder.init_highway_pooler(self.pooler) |
|
|
|
def get_input_embeddings(self): |
|
return self.embeddings.word_embeddings |
|
|
|
def set_input_embeddings(self, value): |
|
self.embeddings.word_embeddings = value |
|
|
|
def _prune_heads(self, heads_to_prune): |
|
"""Prunes heads of the model. |
|
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} |
|
See base class PreTrainedModel |
|
""" |
|
for layer, heads in heads_to_prune.items(): |
|
self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING) |
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
): |
|
r""" |
|
Return: |
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: |
|
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the model. |
|
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`): |
|
Last layer hidden-state of the first token of the sequence (classification token) |
|
further processed by a Linear layer and a Tanh activation function. The Linear |
|
layer weights are trained from the next sentence prediction (classification) |
|
objective during pre-training. |
|
|
|
This output is usually *not* a good summary |
|
of the semantic content of the input, you're often better with averaging or pooling |
|
the sequence of hidden-states for the whole input sequence. |
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): |
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) |
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): |
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape |
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
highway_exits (:obj:`tuple(tuple(torch.Tensor))`: |
|
Tuple of each early exit's results (total length: number of layers) |
|
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states. |
|
""" |
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
input_shape = input_ids.size() |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones(input_shape, device=device) |
|
if encoder_attention_mask is None: |
|
encoder_attention_mask = torch.ones(input_shape, device=device) |
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
|
|
|
|
|
|
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) |
|
|
|
|
|
|
|
if encoder_attention_mask.dim() == 3: |
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] |
|
if encoder_attention_mask.dim() == 2: |
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] |
|
|
|
encoder_extended_attention_mask = encoder_extended_attention_mask.to( |
|
dtype=next(self.parameters()).dtype |
|
) |
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
|
embedding_output = self.embeddings( |
|
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds |
|
) |
|
encoder_outputs = self.encoder( |
|
embedding_output, |
|
attention_mask=extended_attention_mask, |
|
head_mask=head_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_extended_attention_mask, |
|
) |
|
sequence_output = encoder_outputs[0] |
|
pooled_output = self.pooler(sequence_output) |
|
|
|
outputs = ( |
|
sequence_output, |
|
pooled_output, |
|
) + encoder_outputs[ |
|
1: |
|
] |
|
return outputs |
|
|
|
|
|
class HighwayException(Exception): |
|
def __init__(self, message, exit_layer): |
|
self.message = message |
|
self.exit_layer = exit_layer |
|
|
|
|
|
class BertHighway(nn.Module): |
|
"""A module to provide a shortcut |
|
from (the output of one non-final BertLayer in BertEncoder) to (cross-entropy computation in BertForSequenceClassification) |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.pooler = BertPooler(config) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
def forward(self, encoder_outputs): |
|
|
|
pooler_input = encoder_outputs[0] |
|
pooler_output = self.pooler(pooler_input) |
|
|
|
|
|
|
|
bmodel_output = (pooler_input, pooler_output) + encoder_outputs[1:] |
|
|
|
|
|
|
|
pooled_output = bmodel_output[1] |
|
|
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
|
|
return logits, pooled_output |
|
|
|
|
|
@add_start_docstrings( |
|
"""Bert Model (with early exiting - DeeBERT) with a classifier on top, |
|
also takes care of multi-layer training. """, |
|
BERT_START_DOCSTRING, |
|
) |
|
class DeeBertForSequenceClassification(BertPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.num_layers = config.num_hidden_layers |
|
|
|
self.bert = DeeBertModel(config) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) |
|
|
|
self.init_weights() |
|
|
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING) |
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
output_layer=-1, |
|
train_highway=False, |
|
): |
|
r""" |
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): |
|
Labels for computing the sequence classification/regression loss. |
|
Indices should be in :obj:`[0, ..., config.num_labels - 1]`. |
|
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), |
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
|
|
Returns: |
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: |
|
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): |
|
Classification (or regression if config.num_labels==1) loss. |
|
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): |
|
Classification (or regression if config.num_labels==1) scores (before SoftMax). |
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): |
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) |
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): |
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape |
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
highway_exits (:obj:`tuple(tuple(torch.Tensor))`: |
|
Tuple of each early exit's results (total length: number of layers) |
|
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states. |
|
""" |
|
|
|
exit_layer = self.num_layers |
|
try: |
|
outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
) |
|
|
|
|
|
pooled_output = outputs[1] |
|
|
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
outputs = (logits,) + outputs[2:] |
|
except HighwayException as e: |
|
outputs = e.message |
|
exit_layer = e.exit_layer |
|
logits = outputs[0] |
|
|
|
if not self.training: |
|
original_entropy = entropy(logits) |
|
highway_entropy = [] |
|
highway_logits_all = [] |
|
if labels is not None: |
|
if self.num_labels == 1: |
|
|
|
loss_fct = MSELoss() |
|
loss = loss_fct(logits.view(-1), labels.view(-1)) |
|
else: |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
|
|
highway_losses = [] |
|
for highway_exit in outputs[-1]: |
|
highway_logits = highway_exit[0] |
|
if not self.training: |
|
highway_logits_all.append(highway_logits) |
|
highway_entropy.append(highway_exit[2]) |
|
if self.num_labels == 1: |
|
|
|
loss_fct = MSELoss() |
|
highway_loss = loss_fct(highway_logits.view(-1), labels.view(-1)) |
|
else: |
|
loss_fct = CrossEntropyLoss() |
|
highway_loss = loss_fct(highway_logits.view(-1, self.num_labels), labels.view(-1)) |
|
highway_losses.append(highway_loss) |
|
|
|
if train_highway: |
|
outputs = (sum(highway_losses[:-1]),) + outputs |
|
|
|
else: |
|
outputs = (loss,) + outputs |
|
if not self.training: |
|
outputs = outputs + ((original_entropy, highway_entropy), exit_layer) |
|
if output_layer >= 0: |
|
outputs = ( |
|
(outputs[0],) + (highway_logits_all[output_layer],) + outputs[2:] |
|
) |
|
|
|
return outputs |
|
|