Transformer XL

TransfoXLConfig

class transformers.TransfoXLConfig(vocab_size_or_config_json_file=267735, cutoffs=[20000, 40000, 200000], d_model=1024, d_embed=1024, n_head=16, d_head=64, d_inner=4096, div_val=4, pre_lnorm=False, n_layer=18, tgt_len=128, ext_len=0, mem_len=1600, clamp_len=1000, same_length=True, proj_share_all_but_first=True, attn_type=0, sample_softmax=-1, adaptive=True, tie_weight=True, dropout=0.1, dropatt=0.0, untie_r=True, init='normal', init_range=0.01, proj_init_std=0.01, init_std=0.02, layer_norm_epsilon=1e-05, **kwargs)[source]

Configuration class to store the configuration of a TransfoXLModel.

Parameters
  • vocab_size_or_config_json_file – Vocabulary size of inputs_ids in TransfoXLModel or a configuration json file.

  • cutoffs – cutoffs for the adaptive softmax

  • d_model – Dimensionality of the model’s hidden states.

  • d_embed – Dimensionality of the embeddings

  • d_head – Dimensionality of the model’s heads.

  • div_val – divident value for adapative input and softmax

  • pre_lnorm – apply LayerNorm to the input instead of the output

  • d_inner – Inner dimension in FF

  • n_layer – Number of hidden layers in the Transformer encoder.

  • n_head – Number of attention heads for each attention layer in the Transformer encoder.

  • tgt_len – number of tokens to predict

  • ext_len – length of the extended context

  • mem_len – length of the retained previous heads

  • same_length – use the same attn length for all tokens

  • proj_share_all_but_first – True to share all but first projs, False not to share.

  • attn_type – attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.

  • clamp_len – use the same pos embeddings after clamp_len

  • sample_softmax – number of samples in sampled softmax

  • adaptive – use adaptive softmax

  • tie_weight – tie the word embedding and softmax weights

  • dropout – The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.

  • dropatt – The dropout ratio for the attention probabilities.

  • untie_r – untie relative position biases

  • embd_pdrop – The dropout ratio for the embeddings.

  • init – parameter initializer to use

  • init_range – parameters initialized by U(-init_range, init_range).

  • proj_init_std – parameters initialized by N(0, init_std)

  • init_std – parameters initialized by N(0, init_std)

TransfoXLTokenizer

class transformers.TransfoXLTokenizer(special=None, min_freq=0, max_size=None, lower_case=False, delimiter=None, vocab_file=None, pretrained_vocab_file=None, never_split=None, unk_token='<unk>', eos_token='<eos>', additional_special_tokens=['<formula>'], **kwargs)[source]

Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl

convert_tokens_to_string(tokens)[source]

Converts a sequence of tokens (string) in a single string.

count_sents(sents, verbose=False)[source]

sents : a list of sentences, each a list of tokenized symbols

save_vocabulary(vocab_path)[source]

Save the tokenizer vocabulary to a directory or file.

property vocab_size

Size of the base vocabulary (without the added tokens)

TransfoXLModel

class transformers.TransfoXLModel(config)[source]

The bare Bert Model transformer outputting raw hidden-states without any specific head on top. The Transformer-XL model was proposed in Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. It’s a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse previously computed hidden-states to attend to longer context (memory). This model also uses adaptive softmax inputs and outputs (tied).

This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Parameters

config (TransfoXLConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

Inputs:
input_ids: torch.LongTensor of shape (batch_size, sequence_length):

Indices of input sequence tokens in the vocabulary. Transformer-XL is a model with relative position embeddings so you can either pad the inputs on the right or on the left. Indices can be obtained using transformers.TransfoXLTokenizer. See transformers.PreTrainedTokenizer.encode() and transformers.PreTrainedTokenizer.convert_tokens_to_ids() for details.

mems: (optional)

list of torch.FloatTensor (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see mems output below). Can be used to speed up sequential decoding and attend to longer context.

head_mask: (optional) torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads):

Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]: 1 indicates the head is not masked, 0 indicates the head is masked.

inputs_embeds: (optional) torch.FloatTensor of shape (batch_size, sequence_length, embedding_dim):

Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.

Outputs: Tuple comprising various elements depending on the configuration (config) and inputs:
last_hidden_state: torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)

Sequence of hidden-states at the last layer of the model.

mems:

list of torch.FloatTensor (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see mems input above). Can be used to speed up sequential decoding and attend to longer context.

hidden_states: (optional, returned when config.output_hidden_states=True)

list of torch.FloatTensor (one for the output of each layer + the output of the embeddings) of shape (batch_size, sequence_length, hidden_size): Hidden-states of the model at the output of each layer plus the initial embedding outputs.

attentions: (optional, returned when config.output_attentions=True)

list of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

Examples:

tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
outputs = model(input_ids)
last_hidden_states, mems = outputs[:2]
forward(input_ids=None, mems=None, head_mask=None, inputs_embeds=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_input_embeddings()[source]

Get model’s input embeddings

set_input_embeddings(new_embeddings)[source]

Set model’s input embeddings

TransfoXLLMHeadModel

class transformers.TransfoXLLMHeadModel(config)[source]

The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive input embeddings) The Transformer-XL model was proposed in Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. It’s a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse previously computed hidden-states to attend to longer context (memory). This model also uses adaptive softmax inputs and outputs (tied).

This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Parameters

config (TransfoXLConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

Inputs:
input_ids: torch.LongTensor of shape (batch_size, sequence_length):

Indices of input sequence tokens in the vocabulary. Transformer-XL is a model with relative position embeddings so you can either pad the inputs on the right or on the left. Indices can be obtained using transformers.TransfoXLTokenizer. See transformers.PreTrainedTokenizer.encode() and transformers.PreTrainedTokenizer.convert_tokens_to_ids() for details.

mems: (optional)

list of torch.FloatTensor (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see mems output below). Can be used to speed up sequential decoding and attend to longer context.

head_mask: (optional) torch.FloatTensor of shape (num_heads,) or (num_layers, num_heads):

Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]: 1 indicates the head is not masked, 0 indicates the head is masked.

inputs_embeds: (optional) torch.FloatTensor of shape (batch_size, sequence_length, embedding_dim):

Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.

lm_labels: (optional) torch.LongTensor of shape (batch_size, sequence_length):

Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can set lm_labels = input_ids Indices are selected in [-1, 0, ..., config.vocab_size] All labels set to -1 are ignored (masked), the loss is only computed for labels in [0, ..., config.vocab_size]

Outputs: Tuple comprising various elements depending on the configuration (config) and inputs:
loss: (optional, returned when lm_labels is provided) torch.FloatTensor of shape (1,):

Language modeling loss.

prediction_scores: None if lm_labels is provided else torch.FloatTensor of shape (batch_size, sequence_length, config.vocab_size)

Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). We don’t output them when the loss is computed to speedup adaptive softmax decoding.

mems:

list of torch.FloatTensor (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see mems input above). Can be used to speed up sequential decoding and attend to longer context.

hidden_states: (optional, returned when config.output_hidden_states=True)

list of torch.FloatTensor (one for the output of each layer + the output of the embeddings) of shape (batch_size, sequence_length, hidden_size): Hidden-states of the model at the output of each layer plus the initial embedding outputs.

attentions: (optional, returned when config.output_attentions=True)

list of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

Examples:

tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
outputs = model(input_ids)
prediction_scores, mems = outputs[:2]
forward(input_ids=None, mems=None, head_mask=None, inputs_embeds=None, labels=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

tie_weights()[source]

Run this to be sure output and input (adaptive) softmax weights are tied

TFTransfoXLModel

class transformers.TFTransfoXLModel(config, *inputs, **kwargs)[source]

The bare Bert Model transformer outputing raw hidden-states without any specific head on top. The Transformer-XL model was proposed in Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. It’s a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse previously computed hidden-states to attend to longer context (memory). This model also uses adaptive softmax inputs and outputs (tied).

This model is a tf.keras.Model tf.keras.Model sub-class. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and behavior.

Note on the model inputs:

TF 2.0 models accepts two formats as inputs:

  • having all inputs as keyword arguments (like PyTorch models), or

  • having all inputs as a list, tuple or dict in the first positional arguments.

This second option is usefull when using tf.keras.Model.fit() method which currently requires having all the tensors in the first argument of the model call function: model(inputs).

If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument :

  • a single Tensor with input_ids only and nothing else: `model(inputs_ids)

  • a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:

    model([input_ids, attention_mask]) or model([input_ids, attention_mask, token_type_ids])

  • a dictionary with one or several input Tensors associaed to the input names given in the docstring:

    model({‘input_ids’: input_ids, ‘token_type_ids’: token_type_ids})

Parameters

config (TransfoXLConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

Inputs:
input_ids: Numpy array or tf.Tensor of shape (batch_size, sequence_length):

Indices of input sequence tokens in the vocabulary. Transformer-XL is a model with relative position embeddings so you can either pad the inputs on the right or on the left. Indices can be obtained using transformers.TransfoXLTokenizer. See transformers.PreTrainedTokenizer.encode() and transformers.PreTrainedTokenizer.convert_tokens_to_ids() for details.

mems: (optional)

list of Numpy array or tf.Tensor (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see mems output below). Can be used to speed up sequential decoding and attend to longer context.

head_mask: (optional) Numpy array or tf.Tensor of shape (num_heads,) or (num_layers, num_heads):

Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]: 1 indicates the head is not masked, 0 indicates the head is masked.

inputs_embeds: (optional) Numpy array or tf.Tensor of shape (batch_size, sequence_length, embedding_dim):

Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.

Outputs: Tuple comprising various elements depending on the configuration (config) and inputs:
last_hidden_state: tf.Tensor of shape (batch_size, sequence_length, hidden_size)

Sequence of hidden-states at the last layer of the model.

mems:

list of tf.Tensor (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see mems input above). Can be used to speed up sequential decoding and attend to longer context.

hidden_states: (optional, returned when config.output_hidden_states=True)

list of tf.Tensor (one for the output of each layer + the output of the embeddings) of shape (batch_size, sequence_length, hidden_size): Hidden-states of the model at the output of each layer plus the initial embedding outputs.

attentions: (optional, returned when config.output_attentions=True)

list of tf.Tensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

Examples:

import tensorflow as tf
from transformers import TransfoXLTokenizer, TFTransfoXLModel

tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TFTransfoXLModel.from_pretrained('transfo-xl-wt103')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
outputs = model(input_ids)
last_hidden_states, mems = outputs[:2]
call(inputs, **kwargs)[source]

Calls the model on new inputs.

In this case call just reapplies all ops in the graph to the new inputs (e.g. build a new computational graph from the provided inputs).

Parameters
  • inputs – A tensor or list of tensors.

  • training – Boolean or boolean scalar tensor, indicating whether to run the Network in training mode or inference mode.

  • mask – A mask or list of masks. A mask can be either a tensor or None (no mask).

Returns

A tensor if there is a single output, or a list of tensors if there are more than one outputs.

TFTransfoXLLMHeadModel

class transformers.TFTransfoXLLMHeadModel(config)[source]

The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive input embeddings) The Transformer-XL model was proposed in Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. It’s a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse previously computed hidden-states to attend to longer context (memory). This model also uses adaptive softmax inputs and outputs (tied).

This model is a tf.keras.Model tf.keras.Model sub-class. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and behavior.

Note on the model inputs:

TF 2.0 models accepts two formats as inputs:

  • having all inputs as keyword arguments (like PyTorch models), or

  • having all inputs as a list, tuple or dict in the first positional arguments.

This second option is usefull when using tf.keras.Model.fit() method which currently requires having all the tensors in the first argument of the model call function: model(inputs).

If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument :

  • a single Tensor with input_ids only and nothing else: `model(inputs_ids)

  • a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:

    model([input_ids, attention_mask]) or model([input_ids, attention_mask, token_type_ids])

  • a dictionary with one or several input Tensors associaed to the input names given in the docstring:

    model({‘input_ids’: input_ids, ‘token_type_ids’: token_type_ids})

Parameters

config (TransfoXLConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

Inputs:
input_ids: Numpy array or tf.Tensor of shape (batch_size, sequence_length):

Indices of input sequence tokens in the vocabulary. Transformer-XL is a model with relative position embeddings so you can either pad the inputs on the right or on the left. Indices can be obtained using transformers.TransfoXLTokenizer. See transformers.PreTrainedTokenizer.encode() and transformers.PreTrainedTokenizer.convert_tokens_to_ids() for details.

mems: (optional)

list of Numpy array or tf.Tensor (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see mems output below). Can be used to speed up sequential decoding and attend to longer context.

head_mask: (optional) Numpy array or tf.Tensor of shape (num_heads,) or (num_layers, num_heads):

Mask to nullify selected heads of the self-attention modules. Mask values selected in [0, 1]: 1 indicates the head is not masked, 0 indicates the head is masked.

inputs_embeds: (optional) Numpy array or tf.Tensor of shape (batch_size, sequence_length, embedding_dim):

Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.

Outputs: Tuple comprising various elements depending on the configuration (config) and inputs:
prediction_scores: None if lm_labels is provided else tf.Tensor of shape (batch_size, sequence_length, config.vocab_size)

Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). We don’t output them when the loss is computed to speedup adaptive softmax decoding.

mems:

list of tf.Tensor (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see mems input above). Can be used to speed up sequential decoding and attend to longer context.

hidden_states: (optional, returned when config.output_hidden_states=True)

list of tf.Tensor (one for the output of each layer + the output of the embeddings) of shape (batch_size, sequence_length, hidden_size): Hidden-states of the model at the output of each layer plus the initial embedding outputs.

attentions: (optional, returned when config.output_attentions=True)

list of tf.Tensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

Examples:

import tensorflow as tf
from transformers import TransfoXLTokenizer, TFTransfoXLLMHeadModel

tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TFTransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
outputs = model(input_ids)
prediction_scores, mems = outputs[:2]
call(inputs, mems=None, head_mask=None, inputs_embeds=None, labels=None, training=False)[source]

Calls the model on new inputs.

In this case call just reapplies all ops in the graph to the new inputs (e.g. build a new computational graph from the provided inputs).

Parameters
  • inputs – A tensor or list of tensors.

  • training – Boolean or boolean scalar tensor, indicating whether to run the Network in training mode or inference mode.

  • mask – A mask or list of masks. A mask can be either a tensor or None (no mask).

Returns

A tensor if there is a single output, or a list of tensors if there are more than one outputs.