Source code for transformers.modeling_transfo_xl

# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model.
    Adapted from https://github.com/kimiyoung/transformer-xl.
    In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
"""

from __future__ import absolute_import, division, print_function, unicode_literals

import os
import json
import math
import logging
import collections
import sys
from io import open

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter

from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
from .configuration_transfo_xl import TransfoXLConfig
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import add_start_docstrings

logger = logging.getLogger(__name__)

TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
    'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
}

def build_tf_to_pytorch_map(model, config):
    """ A map of modules from TF to PyTorch.
        This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
    """
    tf_to_pt_map = {}

    if hasattr(model, 'transformer'):
        # We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax
        tf_to_pt_map.update({
            "transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
            "transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
        for i, (out_l, proj_l, tie_proj) in enumerate(zip(
                                model.crit.out_layers,
                                model.crit.out_projs,
                                config.tie_projs)):
            layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
            if config.tie_weight:
                tf_to_pt_map.update({
                    layer_str + 'b': out_l.bias})
            else:
                raise NotImplementedError
                # I don't think this is implemented in the TF code
                tf_to_pt_map.update({
                    layer_str + 'lookup_table': out_l.weight,
                    layer_str + 'b': out_l.bias})
            if not tie_proj:
                tf_to_pt_map.update({
                    layer_str + 'proj': proj_l
                    })
        # Now load the rest of the transformer
        model = model.transformer

    # Embeddings
    for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
        layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
        tf_to_pt_map.update({
            layer_str + 'lookup_table': embed_l.weight,
            layer_str + 'proj_W': proj_l
            })

    # Transformer blocks
    for i, b in enumerate(model.layers):
        layer_str = "transformer/layer_%d/" % i
        tf_to_pt_map.update({
            layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
            layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
            layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
            layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
            layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
            layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
            layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
            layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
            layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
            layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
            layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
        })

    # Relative positioning biases
    if config.untie_r:
        r_r_list = []
        r_w_list = []
        for b in model.layers:
            r_r_list.append(b.dec_attn.r_r_bias)
            r_w_list.append(b.dec_attn.r_w_bias)
    else:
        r_r_list = [model.r_r_bias]
        r_w_list = [model.r_w_bias]
    tf_to_pt_map.update({
        'transformer/r_r_bias': r_r_list,
        'transformer/r_w_bias': r_w_list})
    return tf_to_pt_map

def load_tf_weights_in_transfo_xl(model, config, tf_path):
    """ Load tf checkpoints in a pytorch model
    """
    try:
        import numpy as np
        import tensorflow as tf
    except ImportError:
        logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions.")
        raise
    # Build TF to PyTorch weights loading map
    tf_to_pt_map = build_tf_to_pytorch_map(model, config)

    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    tf_weights = {}
    for name, shape in init_vars:
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        tf_weights[name] = array

    for name, pointer in tf_to_pt_map.items():
        assert name in tf_weights
        array = tf_weights[name]
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
        if 'kernel' in name or 'proj' in name:
            array = np.transpose(array)
        if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
            # Here we will split the TF weigths
            assert len(pointer) == array.shape[0]
            for i, p_i in enumerate(pointer):
                arr_i = array[i, ...]
                try:
                    assert p_i.shape == arr_i.shape
                except AssertionError as e:
                    e.args += (p_i.shape, arr_i.shape)
                    raise
                logger.info("Initialize PyTorch weight {} for layer {}".format(name, i))
                p_i.data = torch.from_numpy(arr_i)
        else:
            try:
                assert pointer.shape == array.shape
            except AssertionError as e:
                e.args += (pointer.shape, array.shape)
                raise
            logger.info("Initialize PyTorch weight {}".format(name))
            pointer.data = torch.from_numpy(array)
        tf_weights.pop(name, None)
        tf_weights.pop(name + '/Adam', None)
        tf_weights.pop(name + '/Adam_1', None)

    logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
    return model


class PositionalEmbedding(nn.Module):
    def __init__(self, demb):
        super(PositionalEmbedding, self).__init__()

        self.demb = demb

        inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, pos_seq, bsz=None):
        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)

        if bsz is not None:
            return pos_emb[:,None,:].expand(-1, bsz, -1)
        else:
            return pos_emb[:,None,:]



class PositionwiseFF(nn.Module):
    def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5):
        super(PositionwiseFF, self).__init__()

        self.d_model = d_model
        self.d_inner = d_inner
        self.dropout = dropout

        self.CoreNet = nn.Sequential(
            nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_inner, d_model),
            nn.Dropout(dropout),
        )

        self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)

        self.pre_lnorm = pre_lnorm

    def forward(self, inp):
        if self.pre_lnorm:
            ##### layer normalization + positionwise feed-forward
            core_out = self.CoreNet(self.layer_norm(inp))

            ##### residual connection
            output = core_out + inp
        else:
            ##### positionwise feed-forward
            core_out = self.CoreNet(inp)

            ##### residual connection + layer normalization
            output = self.layer_norm(inp + core_out)

        return output


class RelPartialLearnableMultiHeadAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
                 tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
                 r_r_bias=None, r_w_bias=None, output_attentions=False,
                 layer_norm_epsilon=1e-5):
        super(RelPartialLearnableMultiHeadAttn, self).__init__()

        self.output_attentions = output_attentions
        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout

        self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)

        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)

        self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)

        self.scale = 1 / (d_head ** 0.5)

        self.pre_lnorm = pre_lnorm

        if r_r_bias is None or r_w_bias is None: # Biases are not shared
            self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
            self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
        else:
            self.r_r_bias = r_r_bias
            self.r_w_bias = r_w_bias

        self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)

    def _rel_shift(self, x):
        zero_pad_shape = (x.size(0), 1) + x.size()[2:]
        zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
        x_padded = torch.cat([zero_pad, x], dim=1)

        x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:]
        x_padded = x_padded.view(*x_padded_shape)

        x = x_padded[1:].view_as(x)

        return x

    def forward(self, w, r, attn_mask=None, mems=None, head_mask=None):
        qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)

        if mems is not None:
            cat = torch.cat([mems, w], 0)
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(cat))
            else:
                w_heads = self.qkv_net(cat)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
            w_head_q = w_head_q[-qlen:]
        else:
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(w))
            else:
                w_heads = self.qkv_net(w)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

        klen = w_head_k.size(0)

        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head

        r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)                # qlen x n_head x d_head

        #### compute attention score
        rw_head_q = w_head_q + self.r_w_bias                                    # qlen x bsz x n_head x d_head
        AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head

        rr_head_q = w_head_q + self.r_r_bias
        BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k))              # qlen x klen x bsz x n_head
        BD = self._rel_shift(BD)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score.mul_(self.scale)

        #### compute attention probability
        if attn_mask is not None and torch.sum(attn_mask).item():
            attn_mask = (attn_mask == 1)  # Switch to bool
            if attn_mask.dim() == 2:
                if next(self.parameters()).dtype == torch.float16:
                    attn_score = attn_score.float().masked_fill(
                        attn_mask[None,:,:,None], -65000).type_as(attn_score)
                else:
                    attn_score = attn_score.float().masked_fill(
                        attn_mask[None,:,:,None], -1e30).type_as(attn_score)
            elif attn_mask.dim() == 3:
                if next(self.parameters()).dtype == torch.float16:
                    attn_score = attn_score.float().masked_fill(
                        attn_mask[:,:,:,None], -65000).type_as(attn_score)
                else:
                    attn_score = attn_score.float().masked_fill(
                        attn_mask[:,:,:,None], -1e30).type_as(attn_score)

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)

        # Mask heads if we want to
        if head_mask is not None:
            attn_prob = attn_prob * head_mask

        #### compute attention vector
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))

        # [qlen x bsz x n_head x d_head]
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
            outputs = [w + attn_out]
        else:
            ##### residual connection + layer normalization
            outputs = [self.layer_norm(w + attn_out)]

        if self.output_attentions:
            outputs.append(attn_prob)

        return outputs


class RelPartialLearnableDecoderLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout, layer_norm_epsilon=1e-5,
                 **kwargs):
        super(RelPartialLearnableDecoderLayer, self).__init__()

        self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
                            d_head, dropout, layer_norm_epsilon=layer_norm_epsilon, **kwargs)
        self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 
                                     pre_lnorm=kwargs.get('pre_lnorm'),
                                     layer_norm_epsilon=layer_norm_epsilon)

    def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):

        attn_outputs = self.dec_attn(dec_inp, r,
                               attn_mask=dec_attn_mask,
                               mems=mems, head_mask=head_mask)
        ff_output = self.pos_ff(attn_outputs[0])

        outputs = [ff_output] + attn_outputs[1:]

        return outputs


class AdaptiveEmbedding(nn.Module):
    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
                 sample_softmax=False):
        super(AdaptiveEmbedding, self).__init__()

        self.n_token = n_token
        self.d_embed = d_embed

        self.cutoffs = cutoffs + [n_token]
        self.div_val = div_val
        self.d_proj = d_proj

        self.emb_scale = d_proj ** 0.5

        self.cutoff_ends = [0] + self.cutoffs

        self.emb_layers = nn.ModuleList()
        self.emb_projs = nn.ParameterList()
        if div_val == 1:
            self.emb_layers.append(
                nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
            )
            if d_proj != d_embed:
                self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
        else:
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
                d_emb_i = d_embed // (div_val ** i)
                self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
                self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))

    def forward(self, inp):
        if self.div_val == 1:
            embed = self.emb_layers[0](inp)
            if self.d_proj != self.d_embed:
                embed  = F.linear(embed, self.emb_projs[0])
        else:
            param = next(self.parameters())
            inp_flat = inp.view(-1)
            emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
                dtype=param.dtype, device=param.device)
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]

                mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
                indices_i = mask_i.nonzero().squeeze()

                if indices_i.numel() == 0:
                    continue

                inp_i = inp_flat.index_select(0, indices_i) - l_idx
                emb_i = self.emb_layers[i](inp_i)
                emb_i = F.linear(emb_i, self.emb_projs[i])

                emb_flat.index_copy_(0, indices_i, emb_i)

            embed_shape = inp.size() + (self.d_proj,)
            embed = emb_flat.view(embed_shape)

        embed.mul_(self.emb_scale)

        return embed


class TransfoXLPreTrainedModel(PreTrainedModel):
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
    config_class = TransfoXLConfig
    pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
    load_tf_weights = load_tf_weights_in_transfo_xl
    base_model_prefix = "transformer"

    def _init_weight(self, weight):
        if self.config.init == 'uniform':
            nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
        elif self.config.init == 'normal':
            nn.init.normal_(weight, 0.0, self.config.init_std)

    def _init_bias(self, bias):
        nn.init.constant_(bias, 0.0)

    def _init_weights(self, m):
        """ Initialize the weights.
        """
        classname = m.__class__.__name__
        if classname.find('Linear') != -1:
            if hasattr(m, 'weight') and m.weight is not None:
                self._init_weight(m.weight)
            if hasattr(m, 'bias') and m.bias is not None:
                self._init_bias(m.bias)
        elif classname.find('AdaptiveEmbedding') != -1:
            if hasattr(m, 'emb_projs'):
                for i in range(len(m.emb_projs)):
                    if m.emb_projs[i] is not None:
                        nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
        elif classname.find('Embedding') != -1:
            if hasattr(m, 'weight'):
                self._init_weight(m.weight)
        elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
            if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
                self._init_weight(m.cluster_weight)
            if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
                self._init_bias(m.cluster_bias)
            if hasattr(m, 'out_projs'):
                for i in range(len(m.out_projs)):
                    if m.out_projs[i] is not None:
                        nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
        elif classname.find('LayerNorm') != -1:
            if hasattr(m, 'weight'):
                nn.init.normal_(m.weight, 1.0, self.config.init_std)
            if hasattr(m, 'bias') and m.bias is not None:
                self._init_bias(m.bias)
        else:
            if hasattr(m, 'r_emb'):
                self._init_weight(m.r_emb)
            if hasattr(m, 'r_w_bias'):
                self._init_weight(m.r_w_bias)
            if hasattr(m, 'r_r_bias'):
                self._init_weight(m.r_r_bias)
            if hasattr(m, 'r_bias'):
                self._init_bias(m.r_bias)


TRANSFO_XL_START_DOCSTRING = r"""    The Transformer-XL model was proposed in
    `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
    by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
    It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
    previously computed hidden-states to attend to longer context (memory).
    This model also uses adaptive softmax inputs and outputs (tied).

    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.

    .. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`:
        https://arxiv.org/abs/1901.02860

    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module

    Parameters:
        config (:class:`~transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""

TRANSFO_XL_INPUTS_DOCSTRING = r"""
    Inputs:
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
            Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
            the right or on the left.
            Indices can be obtained using :class:`transformers.TransfoXLTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
            :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
        **mems**: (`optional`)
            list of ``torch.FloatTensor`` (one for each layer):
            that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
            (see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
        **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
            Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
"""

[docs]@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING) class TransfoXLModel(TransfoXLPreTrainedModel): r""" Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` Sequence of hidden-states at the last layer of the model. **mems**: list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) of shape ``(batch_size, sequence_length, hidden_size)``: Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**: (`optional`, returned when ``config.output_attentions=True``) list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Examples:: tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103') model = TransfoXLModel.from_pretrained('transfo-xl-wt103') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 outputs = model(input_ids) last_hidden_states, mems = outputs[:2] """ def __init__(self, config): super(TransfoXLModel, self).__init__(config) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.n_token = config.n_token self.d_embed = config.d_embed self.d_model = config.d_model self.n_head = config.n_head self.d_head = config.d_head self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val) self.drop = nn.Dropout(config.dropout) self.n_layer = config.n_layer self.tgt_len = config.tgt_len self.mem_len = config.mem_len self.ext_len = config.ext_len self.max_klen = config.tgt_len + config.ext_len + config.mem_len self.attn_type = config.attn_type if not config.untie_r: self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) self.layers = nn.ModuleList() if config.attn_type == 0: # the default attention for i in range(config.n_layer): self.layers.append( RelPartialLearnableDecoderLayer( config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout, tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len, dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, r_w_bias=None if config.untie_r else self.r_w_bias, r_r_bias=None if config.untie_r else self.r_r_bias, output_attentions=self.output_attentions, layer_norm_epsilon=config.layer_norm_epsilon) ) else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints raise NotImplementedError # Removed them to avoid maintaining dead code self.same_length = config.same_length self.clamp_len = config.clamp_len if self.attn_type == 0: # default attention self.pos_emb = PositionalEmbedding(self.d_model) else: # learnable embeddings and absolute embeddings raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint self.init_weights()
[docs] def get_input_embeddings(self): return self.word_emb
[docs] def set_input_embeddings(self, new_embeddings): self.word_emb = new_embeddings
def backward_compatible(self): self.sample_softmax = -1 def reset_length(self, tgt_len, ext_len, mem_len): self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len def _prune_heads(self, heads): logger.info("Head pruning is not implemented for Transformer-XL model") pass def init_mems(self, bsz): if self.mem_len > 0: mems = [] param = next(self.parameters()) for i in range(self.n_layer): empty = torch.zeros(self.mem_len, bsz, self.config.d_model, dtype=param.dtype, device=param.device) mems.append(empty) return mems else: return None def _update_mems(self, hids, mems, qlen, mlen): # does not deal with None if mems is None: return None # mems is not None assert len(hids) == len(mems), 'len(hids) != len(mems)' # There are `mlen + qlen` steps that can be cached into mems # For the next step, the last `ext_len` of the `qlen` tokens # will be used as the extended context. Hence, we only cache # the tokens from `mlen + qlen - self.ext_len - self.mem_len` # to `mlen + qlen - self.ext_len`. with torch.no_grad(): new_mems = [] end_idx = mlen + max(0, qlen - 0 - self.ext_len) beg_idx = max(0, end_idx - self.mem_len) for i in range(len(hids)): cat = torch.cat([mems[i], hids[i]], dim=0) new_mems.append(cat[beg_idx:end_idx].detach()) return new_mems
[docs] def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None): # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library # so we transpose here from shape [bsz, len] to shape [len, bsz] if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_ids = input_ids.transpose(0, 1).contiguous() qlen, bsz = input_ids.size() elif inputs_embeds is not None: inputs_embeds = inputs_embeds.transpose(0, 1).contiguous() qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if mems is None: mems = self.init_mems(bsz) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] if head_mask is not None: if head_mask.dim() == 1: head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0) head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1) elif head_mask.dim() == 2: head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1) head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility else: head_mask = [None] * self.n_layer if inputs_embeds is not None: word_emb = inputs_embeds else: word_emb = self.word_emb(input_ids) mlen = mems[0].size(0) if mems is not None else 0 klen = mlen + qlen if self.same_length: all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8) mask_len = klen - self.mem_len if mask_len > 0: mask_shift_len = qlen - mask_len else: mask_shift_len = qlen dec_attn_mask = (torch.triu(all_ones, 1+mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1 else: dec_attn_mask = torch.triu( word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None] hids = [] attentions = [] if self.attn_type == 0: # default pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) pos_emb = self.pos_emb(pos_seq) core_out = self.drop(word_emb) pos_emb = self.drop(pos_emb) for i, layer in enumerate(self.layers): hids.append(core_out) mems_i = None if mems is None else mems[i] layer_outputs = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask, mems=mems_i, head_mask=head_mask[i]) core_out = layer_outputs[0] if self.output_attentions: attentions.append(layer_outputs[1]) else: # learnable embeddings and absolute embeddings raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint core_out = self.drop(core_out) new_mems = self._update_mems(hids, mems, mlen, qlen) # We transpose back here to shape [bsz, len, hidden_dim] outputs = [core_out.transpose(0, 1).contiguous(), new_mems] if self.output_hidden_states: # Add last layer and transpose to library standard shape [bsz, len, hidden_dim] hids.append(core_out) hids = list(t.transpose(0, 1).contiguous() for t in hids) outputs.append(hids) if self.output_attentions: # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len] attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions) outputs.append(attentions) return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
[docs]@add_start_docstrings("""The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive input embeddings)""", TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING) class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): r""" **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` Indices are selected in ``[-1, 0, ..., config.vocab_size]`` All labels set to ``-1`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: Language modeling loss. **prediction_scores**: ``None`` if ``lm_labels`` is provided else ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). We don't output them when the loss is computed to speedup adaptive softmax decoding. **mems**: list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) of shape ``(batch_size, sequence_length, hidden_size)``: Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**: (`optional`, returned when ``config.output_attentions=True``) list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Examples:: tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103') model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 outputs = model(input_ids) prediction_scores, mems = outputs[:2] """ def __init__(self, config): super(TransfoXLLMHeadModel, self).__init__(config) self.transformer = TransfoXLModel(config) self.sample_softmax = config.sample_softmax # use sampled softmax if config.sample_softmax > 0: self.out_layer = nn.Linear(config.d_model, config.n_token) self.sampler = LogUniformSampler(config.n_token, config.sample_softmax) # use adaptive softmax (including standard softmax) else: self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val) self.init_weights()
[docs] def tie_weights(self): """ Run this to be sure output and input (adaptive) softmax weights are tied """ # sampled softmax if self.sample_softmax > 0: if self.config.tie_weight: self.out_layer.weight = self.transformer.word_emb.weight # adaptive softmax (including standard softmax) else: if self.config.tie_weight: for i in range(len(self.crit.out_layers)): self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i]) if self.config.tie_projs: for i, tie_proj in enumerate(self.config.tie_projs): if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed: if self.config.torchscript: self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone()) else: self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0] elif tie_proj and self.config.div_val != 1: if self.config.torchscript: self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone()) else: self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
def reset_length(self, tgt_len, ext_len, mem_len): self.transformer.reset_length(tgt_len, ext_len, mem_len) def init_mems(self, bsz): return self.transformer.init_mems(bsz)
[docs] def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None, labels=None): if input_ids is not None: bsz, tgt_len = input_ids.size(0), input_ids.size(1) elif inputs_embeds is not None: bsz, tgt_len = inputs_embeds.size(0), inputs_embeds.size(1) else: raise ValueError("You have to specify either input_ids or inputs_embeds") transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask, inputs_embeds=inputs_embeds) last_hidden = transformer_outputs[0] pred_hid = last_hidden[:, -tgt_len:] outputs = transformer_outputs[1:] if self.sample_softmax > 0 and self.training: assert self.config.tie_weight logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler) softmax_output = -F.log_softmax(logit, -1)[:, :, 0] outputs = [softmax_output] + outputs if labels is not None: # TODO: This is not implemented raise NotImplementedError else: softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels) if labels is None: softmax_output = softmax_output.view(bsz, tgt_len, -1) outputs = [softmax_output] + outputs else: softmax_output = softmax_output.view(bsz, tgt_len) outputs = [softmax_output, None] + outputs return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)