RetriBERT¶
Overview¶
The RetriBERT model was proposed in the blog post Explain Anything Like I’m Five: A Model for Open Domain Long Form Question Answering, RetriBERT is a small model that uses either a single or pair of Bert encoders with lower-dimension projection for dense semantic indexing of text.
Code to train and use the model can be found here.
RetriBertConfig¶
-
class
transformers.RetriBertConfig(vocab_size=30522, hidden_size=768, num_hidden_layers=8, num_attention_heads=12, intermediate_size=3072, hidden_act='gelu', hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, share_encoders=True, projection_dim=128, pad_token_id=0, **kwargs)[source]¶ This is the configuration class to store the configuration of a
RetriBertModel. It is used to instantiate a RetriBertModel model according to the specified arguments, defining the model architecture.Configuration objects inherit from
PretrainedConfigand can be used to control the model outputs. Read the documentation fromPretrainedConfigfor more information.- Parameters
vocab_size (
int, optional, defaults to 30522) – Vocabulary size of the BERT model. Defines the different tokens that can be represented by the inputs_ids passed to the forward method ofBertModel.hidden_size (
int, optional, defaults to 768) – Dimensionality of the encoder layers and the pooler layer.num_hidden_layers (
int, optional, defaults to 12) – Number of hidden layers in the Transformer encoder.num_attention_heads (
int, optional, defaults to 12) – Number of attention heads for each attention layer in the Transformer encoder.intermediate_size (
int, optional, defaults to 3072) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.hidden_act (
strorfunction, optional, defaults to “gelu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.hidden_dropout_prob (
float, optional, defaults to 0.1) – The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.attention_probs_dropout_prob (
float, optional, defaults to 0.1) – The dropout ratio for the attention probabilities.max_position_embeddings (
int, optional, defaults to 512) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).type_vocab_size (
int, optional, defaults to 2) – The vocabulary size of the token_type_ids passed intoBertModel.initializer_range (
float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.layer_norm_eps (
float, optional, defaults to 1e-12) – The epsilon used by the layer normalization layers.share_encoders (
bool, optional, defaults toTrue) – Whether to use the same Bert-type encoder for the queries and documentprojection_dim (
int, optional, defaults to 128) – Final dimension of the query and document representation after projection
RetriBertTokenizer¶
-
class
transformers.RetriBertTokenizer(vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, unk_token='[UNK]', sep_token='[SEP]', pad_token='[PAD]', cls_token='[CLS]', mask_token='[MASK]', tokenize_chinese_chars=True, strip_accents=None, **kwargs)[source]¶ Constructs a retribert.
BertTokenizerand runs end-to-end tokenization: punctuation splitting + wordpiece.Refer to superclass
BertTokenizerfor usage examples and documentation concerning parameters.
RetriBertTokenizerFast¶
-
class
transformers.RetriBertTokenizerFast(vocab_file, do_lower_case=True, unk_token='[UNK]', sep_token='[SEP]', pad_token='[PAD]', cls_token='[CLS]', mask_token='[MASK]', clean_text=True, tokenize_chinese_chars=True, strip_accents=None, wordpieces_prefix='##', **kwargs)[source]¶ Constructs a “Fast” RetriBertTokenizerFast (backed by HuggingFace’s tokenizers library).
RetriBertTokenizerFastis identical toBertTokenizerFastand runs end-to-end tokenization: punctuation splitting + wordpiece.Refer to superclass
BertTokenizerFastfor usage examples and documentation concerning parameters.
RetriBertModel¶
-
class
transformers.RetriBertModel(config)[source]¶ Bert Based model to embed queries or document for document retreival.
This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
- Parameters
config (
RetriBertConfig) – Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out thefrom_pretrained()method to load the model weights.
-
forward(input_ids_query, attention_mask_query, input_ids_doc, attention_mask_doc, checkpoint_batch_size=- 1)[source]¶ - Parameters
input_ids_query (
torch.LongTensorof shape(batch_size, sequence_length)) –Indices of input sequence tokens in the vocabulary for the queries in a batch.
Indices can be obtained using
transformers.RetriBertTokenizer. Seetransformers.PreTrainedTokenizer.encode()andtransformers.PreTrainedTokenizer.__call__()for details.attention_mask_query (
torch.FloatTensorof shape(batch_size, sequence_length), optional) –Mask to avoid performing attention on queries padding token indices. Mask values selected in
[0, 1]:1for tokens that are NOT MASKED,0for MASKED tokens.input_ids_doc (
torch.LongTensorof shape(batch_size, sequence_length)) – Indices of input sequence tokens in the vocabulary for the documents in a batch.attention_mask_doc (
torch.FloatTensorof shape(batch_size, sequence_length), optional) – Mask to avoid performing attention on documents padding token indices.checkpoint_batch_size (
int, optional, defaults to :obj:-1`) – If greater than 0, uses gradient checkpointing to only compute sequence representation on checkpoint_batch_size examples at a time on the GPU. All query representations are still compared to all document representations in the batch.
- Returns
torch.FloatTensorthe bi-directional cross-entropy loss obtained while trying to match each query to its corresponding document and each cocument to its corresponding query in the batch