# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch RoBERTa model. """
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from pytorch_transformers.modeling_bert import (BertConfig, BertEmbeddings,
BertLayerNorm, BertModel,
BertPreTrainedModel, gelu)
from pytorch_transformers.modeling_utils import add_start_docstrings
logger = logging.getLogger(__name__)
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-pytorch_model.bin",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-pytorch_model.bin",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin",
}
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
}
class RobertaEmbeddings(BertEmbeddings):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
def __init__(self, config):
super(RobertaEmbeddings, self).__init__(config)
self.padding_idx = 1
def forward(self, input_ids, token_type_ids=None, position_ids=None):
seq_length = input_ids.size(1)
if position_ids is None:
# Position numbers begin at padding_idx+1. Padding symbols are ignored.
# cf. fairseq's `utils.make_positions`
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
return super(RobertaEmbeddings, self).forward(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
[docs]class RobertaConfig(BertConfig):
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in
`RoBERTa: A Robustly Optimized BERT Pretraining Approach`_
by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer,
Veselin Stoyanov. It is based on Google's BERT model released in 2018.
It builds on BERT and modifies key hyperparameters, removing the next-sentence pretraining
objective and training with much larger mini-batches and learning rates.
This implementation is the same as BertModel with a tiny embeddings tweak as well as a setup for Roberta pretrained
models.
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`RoBERTa: A Robustly Optimized BERT Pretraining Approach`:
https://arxiv.org/abs/1907.11692
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~pytorch_transformers.RobertaConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
ROBERTA_INPUTS_DOCSTRING = r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
To match pre-training, RoBERTa input sequence should be formatted with <s> and </s> tokens as follows:
(a) For sequence pairs:
``tokens: <s> Is this Jacksonville ? </s> </s> No it is not . </s>``
(b) For single sequences:
``tokens: <s> the dog is hairy . </s>``
Fully encoded sequences or sequence pairs can be obtained using the RobertaTokenizer.encode function with
the ``add_special_tokens`` parameter set to ``True``.
RoBERTa is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1[``.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
[docs]@add_start_docstrings("The bare RoBERTa Model transformer outputing raw hidden-states without any specific head on top.",
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
class RobertaModel(BertModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the output of the last layer of the model.
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
def __init__(self, config):
super(RobertaModel, self).__init__(config)
self.embeddings = RobertaEmbeddings(config)
self.init_weights()
[docs] def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
if input_ids[:, 0].sum().item() != 0:
logger.warning("A sequence with no special tokens has been passed to the RoBERTa model. "
"This model requires special tokens in order to work. "
"Please specify add_special_tokens=True in your encoding.")
return super(RobertaModel, self).forward(input_ids, token_type_ids, attention_mask, position_ids, head_mask)
class RobertaLMHead(nn.Module):
"""Roberta Head for masked language modeling."""
def __init__(self, config):
super(RobertaLMHead, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = self.decoder(x) + self.bias
return x
[docs]@add_start_docstrings("""RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer
on top of the pooled output) e.g. for GLUE tasks. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
class RobertaForSequenceClassification(BertPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RoertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
def __init__(self, config):
super(RobertaForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.roberta = RobertaModel(config)
self.classifier = RobertaClassificationHead(config)
[docs] def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
outputs = self.roberta(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:]
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super(RobertaClassificationHead, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x