Source code for transformers.modeling_xlm

# coding=utf-8
# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch XLM model.
"""


import itertools
import logging
import math

import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import functional as F

from .activations import gelu
from .configuration_xlm import XLMConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import (
    PreTrainedModel,
    SequenceSummary,
    SQuADHead,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)


logger = logging.getLogger(__name__)

_TOKENIZER_FOR_DOC = "XLMTokenizer"

XLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "xlm-mlm-en-2048",
    "xlm-mlm-ende-1024",
    "xlm-mlm-enfr-1024",
    "xlm-mlm-enro-1024",
    "xlm-mlm-tlm-xnli15-1024",
    "xlm-mlm-xnli15-1024",
    "xlm-clm-enfr-1024",
    "xlm-clm-ende-1024",
    "xlm-mlm-17-1280",
    "xlm-mlm-100-1280",
    # See all XLM models at https://huggingface.co/models?filter=xlm
]


def create_sinusoidal_embeddings(n_pos, dim, out):
    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()
    out.requires_grad = False


def get_masks(slen, lengths, causal, padding_mask=None):
    """
    Generate hidden states mask, and optionally an attention mask.
    """
    alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
    if padding_mask is not None:
        mask = padding_mask
    else:
        assert lengths.max().item() <= slen
        mask = alen < lengths[:, None]

    # attention mask is the same as mask, or triangular inferior attention (causal)
    bs = lengths.size(0)
    if causal:
        attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
    else:
        attn_mask = mask

    # sanity check
    assert mask.size() == (bs, slen)
    assert causal is False or attn_mask.size() == (bs, slen, slen)

    return mask, attn_mask


class MultiHeadAttention(nn.Module):

    NEW_ID = itertools.count()

    def __init__(self, n_heads, dim, config):
        super().__init__()
        self.layer_id = next(MultiHeadAttention.NEW_ID)
        self.dim = dim
        self.n_heads = n_heads
        self.dropout = config.attention_dropout
        assert self.dim % self.n_heads == 0

        self.q_lin = nn.Linear(dim, dim)
        self.k_lin = nn.Linear(dim, dim)
        self.v_lin = nn.Linear(dim, dim)
        self.out_lin = nn.Linear(dim, dim)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        attention_head_size = self.dim // self.n_heads
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
        # Prune linear layers
        self.q_lin = prune_linear_layer(self.q_lin, index)
        self.k_lin = prune_linear_layer(self.k_lin, index)
        self.v_lin = prune_linear_layer(self.v_lin, index)
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.dim = attention_head_size * self.n_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(self, input, mask, kv=None, cache=None, head_mask=None, output_attentions=False):
        """
        Self-attention (if kv is None) or attention over source sentence (provided by kv).
        """
        # Input is (bs, qlen, dim)
        # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
        bs, qlen, dim = input.size()
        if kv is None:
            klen = qlen if cache is None else cache["slen"] + qlen
        else:
            klen = kv.size(1)
        # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
        n_heads = self.n_heads
        dim_per_head = self.dim // n_heads
        mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)

        def shape(x):
            """  projection """
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """  compute context """
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)

        q = shape(self.q_lin(input))  # (bs, n_heads, qlen, dim_per_head)
        if kv is None:
            k = shape(self.k_lin(input))  # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(input))  # (bs, n_heads, qlen, dim_per_head)
        elif cache is None or self.layer_id not in cache:
            k = v = kv
            k = shape(self.k_lin(k))  # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(v))  # (bs, n_heads, qlen, dim_per_head)

        if cache is not None:
            if self.layer_id in cache:
                if kv is None:
                    k_, v_ = cache[self.layer_id]
                    k = torch.cat([k_, k], dim=2)  # (bs, n_heads, klen, dim_per_head)
                    v = torch.cat([v_, v], dim=2)  # (bs, n_heads, klen, dim_per_head)
                else:
                    k, v = cache[self.layer_id]
            cache[self.layer_id] = (k, v)

        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, qlen, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, qlen, klen)
        mask = (mask == 0).view(mask_reshape).expand_as(scores)  # (bs, n_heads, qlen, klen)
        scores.masked_fill_(mask, -float("inf"))  # (bs, n_heads, qlen, klen)

        weights = F.softmax(scores.float(), dim=-1).type_as(scores)  # (bs, n_heads, qlen, klen)
        weights = F.dropout(weights, p=self.dropout, training=self.training)  # (bs, n_heads, qlen, klen)

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

        context = torch.matmul(weights, v)  # (bs, n_heads, qlen, dim_per_head)
        context = unshape(context)  # (bs, qlen, dim)

        outputs = (self.out_lin(context),)
        if output_attentions:
            outputs = outputs + (weights,)
        return outputs


class TransformerFFN(nn.Module):
    def __init__(self, in_dim, dim_hidden, out_dim, config):
        super().__init__()
        self.dropout = config.dropout
        self.lin1 = nn.Linear(in_dim, dim_hidden)
        self.lin2 = nn.Linear(dim_hidden, out_dim)
        self.act = gelu if config.gelu_activation else F.relu

    def forward(self, input):
        x = self.lin1(input)
        x = self.act(x)
        x = self.lin2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x


class XLMPreTrainedModel(PreTrainedModel):
    """ An abstract class to handle weights initialization and
        a simple interface for downloading and loading pretrained models.
    """

    config_class = XLMConfig
    load_tf_weights = None
    base_model_prefix = "transformer"

    def __init__(self, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)

    @property
    def dummy_inputs(self):
        inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
        attns_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
        if self.config.use_lang_emb and self.config.n_langs > 1:
            langs_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
        else:
            langs_list = None
        return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}

    def _init_weights(self, module):
        """ Initialize the weights. """
        if isinstance(module, nn.Embedding):
            if self.config is not None and self.config.embed_init_std is not None:
                nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
        if isinstance(module, nn.Linear):
            if self.config is not None and self.config.init_std is not None:
                nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
                if hasattr(module, "bias") and module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)
        if isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


XLM_START_DOCSTRING = r"""

    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
    usage and behavior.

    Parameters:
        config (:class:`~transformers.XLMConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""

XLM_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using :class:`transformers.BertTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
            :func:`transformers.PreTrainedTokenizer.__call__` for details.

            `What are input IDs? <../glossary.html#input-ids>`__
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.

            `What are attention masks? <../glossary.html#attention-mask>`__
        langs (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            A parallel sequence of tokens to be used to indicate the language of each token in the input.
            Indices are languages ids which can be obtained from the language names by using two conversion mappings
            provided in the configuration of the model (only provided for multilingual models).
            More precisely, the `language name -> language id` mapping is in `model.config.lang2id` (dict str -> int) and
            the `language id -> language name` mapping is `model.config.id2lang` (dict int -> str).

            See usage examples detailed in the `multilingual documentation <https://huggingface.co/transformers/multilingual.html>`__.
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token

            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.

            `What are position IDs? <../glossary.html#position-ids>`_
        lengths (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Length of each sentence that can be used to avoid performing attention on padding token indices.
            You can also use `attention_mask` for the same result (see above), kept here for compatbility.
            Indices selected in ``[0, ..., input_ids.size(-1)]``:
        cache (:obj:`Dict[str, torch.FloatTensor]`, `optional`, defaults to :obj:`None`):
            dictionary with ``torch.FloatTensor`` that contains pre-computed
            hidden-states (key and values in the attention blocks) as computed by the model
            (see `cache` output below). Can be used to speed up sequential decoding.
            The dictionary object will be modified in-place during the forward pass to add newly computed hidden-states.
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
"""


[docs]@add_start_docstrings( "The bare XLM Model transformer outputting raw hidden-states without any specific head on top.", XLM_START_DOCSTRING, ) class XLMModel(XLMPreTrainedModel): def __init__(self, config): # , dico, is_encoder, with_output): super().__init__(config) # encoder / decoder, output layer self.is_encoder = config.is_encoder self.is_decoder = not config.is_encoder if self.is_decoder: raise NotImplementedError("Currently XLM can only be used as an encoder") # self.with_output = with_output self.causal = config.causal # dictionary / languages self.n_langs = config.n_langs self.use_lang_emb = config.use_lang_emb self.n_words = config.n_words self.eos_index = config.eos_index self.pad_index = config.pad_index # self.dico = dico # self.id2lang = config.id2lang # self.lang2id = config.lang2id # assert len(self.dico) == self.n_words # assert len(self.id2lang) == len(self.lang2id) == self.n_langs # model parameters self.dim = config.emb_dim # 512 by default self.hidden_dim = self.dim * 4 # 2048 by default self.n_heads = config.n_heads # 8 by default self.n_layers = config.n_layers self.dropout = config.dropout self.attention_dropout = config.attention_dropout assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads" # embeddings self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim) if config.sinusoidal_embeddings: create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) if config.n_langs > 1 and config.use_lang_emb: self.lang_embeddings = nn.Embedding(self.n_langs, self.dim) self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index) self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps) # transformer layers self.attentions = nn.ModuleList() self.layer_norm1 = nn.ModuleList() self.ffns = nn.ModuleList() self.layer_norm2 = nn.ModuleList() # if self.is_decoder: # self.layer_norm15 = nn.ModuleList() # self.encoder_attn = nn.ModuleList() for _ in range(self.n_layers): self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config)) self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) # if self.is_decoder: # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout)) self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config)) self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) if hasattr(config, "pruned_heads"): pruned_heads = config.pruned_heads.copy().items() config.pruned_heads = {} for layer, heads in pruned_heads: if self.attentions[int(layer)].n_heads == config.n_heads: self.prune_heads({int(layer): list(map(int, heads))}) self.init_weights()
[docs] def get_input_embeddings(self): return self.embeddings
[docs] def set_input_embeddings(self, new_embeddings): self.embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.attentions[layer].prune_heads(heads)
[docs] @add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING) @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlm-mlm-en-2048") def forward( self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, lengths=None, cache=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, ): r""" Return: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs: last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) if input_ids is not None: bs, slen = input_ids.size() else: bs, slen = inputs_embeds.size()[:-1] if lengths is None: if input_ids is not None: lengths = (input_ids != self.pad_index).sum(dim=1).long() else: lengths = torch.LongTensor([slen] * bs) # mask = input_ids != self.pad_index # check inputs assert lengths.size(0) == bs assert lengths.max().item() <= slen # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # assert (src_enc is None) == (src_len is None) # if src_enc is not None: # assert self.is_decoder # assert src_enc.size(0) == bs # generate masks mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask) # if self.is_decoder and src_enc is not None: # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] device = input_ids.device if input_ids is not None else inputs_embeds.device # position_ids if position_ids is None: position_ids = torch.arange(slen, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).expand((bs, slen)) else: assert position_ids.size() == (bs, slen) # (slen, bs) # position_ids = position_ids.transpose(0, 1) # langs if langs is not None: assert langs.size() == (bs, slen) # (slen, bs) # langs = langs.transpose(0, 1) # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.n_layers) # do not recompute cached elements if cache is not None and input_ids is not None: _slen = slen - cache["slen"] input_ids = input_ids[:, -_slen:] position_ids = position_ids[:, -_slen:] if langs is not None: langs = langs[:, -_slen:] mask = mask[:, -_slen:] attn_mask = attn_mask[:, -_slen:] # embeddings if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds) if langs is not None and self.use_lang_emb and self.n_langs > 1: tensor = tensor + self.lang_embeddings(langs) if token_type_ids is not None: tensor = tensor + self.embeddings(token_type_ids) tensor = self.layer_norm_emb(tensor) tensor = F.dropout(tensor, p=self.dropout, training=self.training) tensor *= mask.unsqueeze(-1).to(tensor.dtype) # transformer layers hidden_states = () attentions = () for i in range(self.n_layers): if output_hidden_states: hidden_states = hidden_states + (tensor,) # self attention attn_outputs = self.attentions[i]( tensor, attn_mask, cache=cache, head_mask=head_mask[i], output_attentions=output_attentions, ) attn = attn_outputs[0] if output_attentions: attentions = attentions + (attn_outputs[1],) attn = F.dropout(attn, p=self.dropout, training=self.training) tensor = tensor + attn tensor = self.layer_norm1[i](tensor) # encoder attention (for decoder only) # if self.is_decoder and src_enc is not None: # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache) # attn = F.dropout(attn, p=self.dropout, training=self.training) # tensor = tensor + attn # tensor = self.layer_norm15[i](tensor) # FFN tensor = tensor + self.ffns[i](tensor) tensor = self.layer_norm2[i](tensor) tensor *= mask.unsqueeze(-1).to(tensor.dtype) # Add last hidden state if output_hidden_states: hidden_states = hidden_states + (tensor,) # update cache length if cache is not None: cache["slen"] += tensor.size(1) # move back sequence length to dimension 0 # tensor = tensor.transpose(0, 1) outputs = (tensor,) if output_hidden_states: outputs = outputs + (hidden_states,) if output_attentions: outputs = outputs + (attentions,) return outputs # outputs, (hidden_states), (attentions)
class XLMPredLayer(nn.Module): """ Prediction layer (cross_entropy or adaptive_softmax). """ def __init__(self, config): super().__init__() self.asm = config.asm self.n_words = config.n_words self.pad_index = config.pad_index dim = config.emb_dim if config.asm is False: self.proj = nn.Linear(dim, config.n_words, bias=True) else: self.proj = nn.AdaptiveLogSoftmaxWithLoss( in_features=dim, n_classes=config.n_words, cutoffs=config.asm_cutoffs, div_value=config.asm_div_value, head_bias=True, # default is False ) def forward(self, x, y=None): """ Compute the loss, and optionally the scores. """ outputs = () if self.asm is False: scores = self.proj(x) outputs = (scores,) + outputs if y is not None: loss = F.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction="elementwise_mean") outputs = (loss,) + outputs else: scores = self.proj.log_prob(x) outputs = (scores,) + outputs if y is not None: _, loss = self.proj(x, y) outputs = (loss,) + outputs return outputs
[docs]@add_start_docstrings( """The XLM Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). """, XLM_START_DOCSTRING, ) class XLMWithLMHeadModel(XLMPreTrainedModel): def __init__(self, config): super().__init__(config) self.transformer = XLMModel(config) self.pred_layer = XLMPredLayer(config) self.init_weights()
[docs] def get_output_embeddings(self): return self.pred_layer.proj
def prepare_inputs_for_generation(self, input_ids, **kwargs): mask_token_id = self.config.mask_token_id lang_id = self.config.lang_id effective_batch_size = input_ids.shape[0] mask_token = torch.full((effective_batch_size, 1), mask_token_id, dtype=torch.long, device=input_ids.device) input_ids = torch.cat([input_ids, mask_token], dim=1) if lang_id is not None: langs = torch.full_like(input_ids, lang_id) else: langs = None return {"input_ids": input_ids, "langs": langs}
[docs] @add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING) @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlm-mlm-en-2048") def forward( self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` Return: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs: loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided) Language modeling loss. prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ transformer_outputs = self.transformer( input_ids, attention_mask=attention_mask, langs=langs, token_type_ids=token_type_ids, position_ids=position_ids, lengths=lengths, cache=cache, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) output = transformer_outputs[0] outputs = self.pred_layer(output, labels) outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here return outputs
[docs]@add_start_docstrings( """XLM Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, XLM_START_DOCSTRING, ) class XLMForSequenceClassification(XLMPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.transformer = XLMModel(config) self.sequence_summary = SequenceSummary(config) self.init_weights()
[docs] @add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING) @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlm-mlm-en-2048") def forward( self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). Returns: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): Classification (or regression if config.num_labels==1) loss. logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ transformer_outputs = self.transformer( input_ids, attention_mask=attention_mask, langs=langs, token_type_ids=token_type_ids, position_ids=position_ids, lengths=lengths, cache=cache, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) output = transformer_outputs[0] logits = self.sequence_summary(output) outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here if labels is not None: if self.num_labels == 1: # We are doing regression loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs return outputs
[docs]@add_start_docstrings( """XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, XLM_START_DOCSTRING, ) class XLMForQuestionAnsweringSimple(XLMPreTrainedModel): def __init__(self, config): super().__init__(config) self.transformer = XLMModel(config) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.init_weights()
[docs] @add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING) @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlm-mlm-en-2048") def forward( self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, lengths=None, cache=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, ): r""" start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. Returns: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): Span-start scores (before SoftMax). end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): Span-end scores (before SoftMax). hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ transformer_outputs = self.transformer( input_ids, attention_mask=attention_mask, langs=langs, token_type_ids=token_type_ids, position_ids=position_ids, lengths=lengths, cache=cache, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) sequence_output = transformer_outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) outputs = ( start_logits, end_logits, ) if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 outputs = (total_loss,) + outputs outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here return outputs
[docs]@add_start_docstrings( """XLM Model with a beam-search span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, XLM_START_DOCSTRING, ) class XLMForQuestionAnswering(XLMPreTrainedModel): def __init__(self, config): super().__init__(config) self.transformer = XLMModel(config) self.qa_outputs = SQuADHead(config) self.init_weights()
[docs] @add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING) def forward( self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, lengths=None, cache=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, is_impossible=None, cls_index=None, p_mask=None, output_attentions=None, output_hidden_states=None, ): r""" start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. is_impossible (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`): Labels whether a question has an answer or no answer (SQuAD 2.0) cls_index (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`): Labels for position (index) of the classification token to use as input for computing plausibility of the answer. p_mask (``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`): Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...). 1.0 means token should be masked. 0.0 mean token is not masked. Returns: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided): Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): Log probabilities for the top config.start_n_top start token possibilities (beam-search). start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): Indices for the top config.start_n_top start token possibilities (beam-search). end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search). cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided): Log probabilities for the ``is_impossible`` label of the answers. hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Example:: >>> from transformers import XLMTokenizer, XLMForQuestionAnswering >>> import torch >>> tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048') >>> model = XLMForQuestionAnswering.from_pretrained('xlm-mlm-en-2048') >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 >>> start_positions = torch.tensor([1]) >>> end_positions = torch.tensor([3]) >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions) >>> loss = outputs[0] """ transformer_outputs = self.transformer( input_ids, attention_mask=attention_mask, langs=langs, token_type_ids=token_type_ids, position_ids=position_ids, lengths=lengths, cache=cache, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) output = transformer_outputs[0] outputs = self.qa_outputs( output, start_positions=start_positions, end_positions=end_positions, cls_index=cls_index, is_impossible=is_impossible, p_mask=p_mask, ) outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here return outputs
@add_start_docstrings( """XLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, XLM_START_DOCSTRING, ) class XLMForTokenClassification(XLMPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.transformer = XLMModel(config) self.dropout = nn.Dropout(config.dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() @add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING) @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlm-mlm-en-2048") def forward( self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, lengths=None, cache=None, head_mask=None, labels=None, output_attentions=None, output_hidden_states=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - 1]``. Returns: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) : Classification loss. scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`) Classification scores (before SoftMax). hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ outputs = self.transformer( input_ids, attention_mask=attention_mask, langs=langs, token_type_ids=token_type_ids, position_ids=position_ids, lengths=lengths, cache=cache, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here if labels is not None: loss_fct = CrossEntropyLoss() # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) ) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs return outputs # (loss), scores, (hidden_states), (attentions)