import torch from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration from datasets import load_dataset # Step 1: Load the dataset with the trust_remote_code flag enabled and a valid config name dataset = load_dataset("wiki_dpr", "psgs_w100.nq.exact", trust_remote_code=True) # Step 2: Load the retriever using the pre-trained model, with use_dummy_dataset=True and trust_remote_code=True retriever = RagRetriever.from_pretrained( "facebook/rag-token-base", use_dummy_dataset=True, trust_remote_code=True ) # Step 3: Load the tokenizer for the RAG model tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") # Step 4: Initialize the RAG model model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base") # Step 5: Define a function to generate an answer using the retriever and model def generate_answer(question): # Tokenize the question inputs = tokenizer(question, return_tensors="pt") # Retrieve relevant documents using the retriever input_ids = inputs["input_ids"] retrieved_doc_ids = retriever.retrieve(input_ids) # Use the model to generate an answer based on the retrieved documents generated_ids = model.generate(input_ids, context_input_ids=retrieved_doc_ids["context_input_ids"]) # Decode the generated answer back to text answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True) return answer # Step 6: Example usage if __name__ == "__main__": question = "Who was the first president of the United States?" print(f"Question: {question}") # Generate and print the answer answer = generate_answer(question) print(f"Answer: {answer}")