Back to all models

Unable to determine this model’s pipeline type. Check the docs .

Monthly model downloads

facebook/rag-sequence-base facebook/rag-sequence-base
1,449 downloads
last 30 days

pytorch

tf

Contributed by

Facebook AI company
2 team members · 23 models

How to use this model directly from the 🤗/transformers library:

			
Copy to clipboard
from transformers import AutoTokenizer, RagSequenceForGeneration tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-base") model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base")

RAG

This is a non-finetuned version of the RAG-Sequence model of the the paper Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks by Patrick Lewis, Ethan Perez, Aleksandara Piktus et al.

Rag consits of a question encoder, retriever and a generator. The retriever should be a RagRetriever instance. The question encoder can be any model that can be loaded with AutoModel and the generator can be any model that can be loaded with AutoModelForSeq2SeqLM.

This model is a non-finetuned RAG-Sequence model and was created as follows:

from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, AutoTokenizer

model = RagSequenceForGeneration.from_pretrained_question_encoder_generator("facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large")

question_encoder_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
model.config.use_dummy_dataset = True
model.config.index_name = "exact"
retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer)

model.save_pretrained("./")
tokenizer.save_pretrained("./")
retriever.save_pretrained("./")

Note that the model is uncased so that all capital input letters are converted to lower-case.

Usage:

Note: the model uses the dummy retriever as a default. Better results are obtained by using the full retriever, by setting config.index_name="legacy" and config.use_dummy_dataset=False. The model can be fine-tuned as follows:

from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration

tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-base")
model = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)

input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", "michael phelps", return_tensors="pt") 

outputs = model(input_dict["input_ids"], labels=input_dict["labels"])

loss = outputs.loss

# train on loss