# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch DistilBERT model
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
and in part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
import copy
import logging
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from .activations import gelu
from .configuration_distilbert import DistilBertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, prune_linear_layer
logger = logging.getLogger(__name__)
"distilbert-base-uncased": "https://cdn.huggingface.co/distilbert-base-uncased-pytorch_model.bin",
"distilbert-base-uncased-distilled-squad": "https://cdn.huggingface.co/distilbert-base-uncased-distilled-squad-pytorch_model.bin",
"distilbert-base-cased": "https://cdn.huggingface.co/distilbert-base-cased-pytorch_model.bin",
"distilbert-base-cased-distilled-squad": "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-pytorch_model.bin",
"distilbert-base-german-cased": "https://cdn.huggingface.co/distilbert-base-german-cased-pytorch_model.bin",
"distilbert-base-multilingual-cased": "https://cdn.huggingface.co/distilbert-base-multilingual-cased-pytorch_model.bin",
"distilbert-base-uncased-finetuned-sst-2-english": "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin",
def create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.requires_grad = False
class Embeddings(nn.Module):
def __init__(self, config):
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
if config.sinusoidal_pos_embds:
n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
self.dropout = nn.Dropout(config.dropout)
def forward(self, input_ids):
input_ids: torch.tensor(bs, max_seq_length)
The token ids to embed.
embeddings: torch.tensor(bs, max_seq_length, dim)
The embedded tokens (plus position embeddings, no token_type embeddings)
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim)
return embeddings
class MultiHeadSelfAttention(nn.Module):
def __init__(self, config):
self.n_heads = config.n_heads
self.dim = config.dim
self.dropout = nn.Dropout(p=config.attention_dropout)
self.output_attentions = config.output_attentions
assert self.dim % self.n_heads == 0
self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
self.pruned_heads = set()
def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
mask = torch.ones(self.n_heads, attention_head_size)
heads = set(heads) - self.pruned_heads
for head in heads:
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers
self.q_lin = prune_linear_layer(self.q_lin, index)
self.k_lin = prune_linear_layer(self.k_lin, index)
self.v_lin = prune_linear_layer(self.v_lin, index)
self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
# Update hyper params
self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, query, key, value, mask, head_mask=None):
query: torch.tensor(bs, seq_length, dim)
key: torch.tensor(bs, seq_length, dim)
value: torch.tensor(bs, seq_length, dim)
mask: torch.tensor(bs, seq_length)
weights: torch.tensor(bs, n_heads, seq_length, seq_length)
Attention weights
context: torch.tensor(bs, seq_length, dim)
Contextualized layer. Optional: only if `output_attentions=True`
bs, q_length, dim = query.size()
k_length = key.size(1)
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
# assert key.size() == value.size()
dim_per_head = self.dim // self.n_heads
mask_reshp = (bs, 1, 1, k_length)
def shape(x):
""" separate heads """
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x):
""" group heads """
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, q_length, k_length)
weights = nn.Softmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
# Mask heads if we want to
if head_mask is not None:
weights = weights * head_mask
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
context = unshape(context) # (bs, q_length, dim)
context = self.out_lin(context) # (bs, q_length, dim)
if self.output_attentions:
return (context, weights)
return (context,)
class FFN(nn.Module):
def __init__(self, config):
self.dropout = nn.Dropout(p=config.dropout)
self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
self.activation = gelu if config.activation == "gelu" else nn.ReLU()
def forward(self, input):
x = self.lin1(input)
x = self.activation(x)
x = self.lin2(x)
x = self.dropout(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, config):
self.output_attentions = config.output_attentions
assert config.dim % config.n_heads == 0
self.attention = MultiHeadSelfAttention(config)
self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
self.ffn = FFN(config)
self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
def forward(self, x, attn_mask=None, head_mask=None):
x: torch.tensor(bs, seq_length, dim)
attn_mask: torch.tensor(bs, seq_length)
sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length)
The attention weights
ffn_output: torch.tensor(bs, seq_length, dim)
The output of the transformer block contextualization.
# Self-Attention
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask, head_mask=head_mask)
if self.output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
assert type(sa_output) == tuple
sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
# Feed Forward Network
ffn_output = self.ffn(sa_output) # (bs, seq_length, dim)
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
output = (ffn_output,)
if self.output_attentions:
output = (sa_weights,) + output
return output
class Transformer(nn.Module):
def __init__(self, config):
self.n_layers = config.n_layers
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
layer = TransformerBlock(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
def forward(self, x, attn_mask=None, head_mask=None):
x: torch.tensor(bs, seq_length, dim)
Input sequence embedded.
attn_mask: torch.tensor(bs, seq_length)
Attention mask on the sequence.
hidden_state: torch.tensor(bs, seq_length, dim)
Sequence of hiddens states in the last (top) layer
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
Tuple of length n_layers with the hidden states from each layer.
Optional: only if output_hidden_states=True
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if output_attentions=True
all_hidden_states = ()
all_attentions = ()
hidden_state = x
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
layer_outputs = layer_module(x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i])
hidden_state = layer_outputs[-1]
if self.output_attentions:
assert len(layer_outputs) == 2
attentions = layer_outputs[0]
all_attentions = all_attentions + (attentions,)
assert len(layer_outputs) == 1
# Add last layer
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
outputs = (hidden_state,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class DistilBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
config_class = DistilBertConfig
pretrained_model_archive_map = DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = None
base_model_prefix = "distilbert"
def _init_weights(self, module):
""" Initialize the weights.
if isinstance(module, nn.Embedding):
if module.weight.requires_grad:
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
if isinstance(module, nn.Linear) and module.bias is not None:
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
config (:class:`~transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.DistilBertTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
class DistilBertModel(DistilBertPreTrainedModel):
def __init__(self, config):
self.embeddings = Embeddings(config) # Embeddings
self.transformer = Transformer(config) # Encoder
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
for layer, heads in heads_to_prune.items():
[docs] @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None):
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DistilBertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
from transformers import DistilBertTokenizer, DistilBertModel
import torch
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = DistilBertModel.from_pretrained('distilbert-base-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
tfmr_output = self.transformer(x=inputs_embeds, attn_mask=attention_mask, head_mask=head_mask)
hidden_state = tfmr_output[0]
output = (hidden_state,) + tfmr_output[1:]
return output # last-layer hidden-state, (all hidden_states), (all attentions)
"""DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
def __init__(self, config):
self.num_labels = config.num_labels
self.distilbert = DistilBertModel(config)
self.pre_classifier = nn.Linear(config.dim, config.dim)
self.classifier = nn.Linear(config.dim, config.num_labels)
self.dropout = nn.Dropout(config.seq_classif_dropout)
[docs] @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None):
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DistilBertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
distilbert_output = self.distilbert(
input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim)
outputs = (logits,) + distilbert_output[1:]
if labels is not None:
if self.num_labels == 1:
loss_fct = nn.MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
"""DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
def __init__(self, config):
self.distilbert = DistilBertModel(config)
self.qa_outputs = nn.Linear(config.dim, config.num_labels)
assert config.num_labels == 2
self.dropout = nn.Dropout(config.qa_dropout)
[docs] @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def forward(
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DistilBertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
import torch
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:3]
distilbert_output = self.distilbert(
input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) # (bs, max_query_len)
end_logits = end_logits.squeeze(-1) # (bs, max_query_len)
outputs = (start_logits, end_logits,) + distilbert_output[1:]
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
"""DistilBert Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
class DistilBertForTokenClassification(DistilBertPreTrainedModel):
def __init__(self, config):
self.num_labels = config.num_labels
self.distilbert = DistilBertModel(config)
self.dropout = nn.Dropout(config.dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None):
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DistilBertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
from transformers import DistilBertTokenizer, DistilBertForTokenClassification
import torch
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = DistilBertForTokenClassification.from_pretrained('distilbert-base-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
outputs = self.distilbert(
input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
loss = loss_fct(active_logits, active_labels)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)