File size: 1,713 Bytes
2795690
 
 
d25c039
2424894
a8d7c56
2795690
 
d25c039
 
 
 
 
 
2795690
72d38f5
2795690
 
72d38f5
 
2795690
 
 
72d38f5
 
2795690
 
 
72d38f5
2795690
 
 
 
 
72d38f5
 
 
2795690
dbad9b8
2795690
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration
from datasets import load_dataset

# Step 1: Load the dataset with the trust_remote_code flag enabled and a valid config name
dataset = load_dataset("wiki_dpr", "psgs_w100.nq.exact", trust_remote_code=True)

# Step 2: Load the retriever using the pre-trained model, with use_dummy_dataset=True and trust_remote_code=True
retriever = RagRetriever.from_pretrained(
    "facebook/rag-token-base", 
    use_dummy_dataset=True, 
    trust_remote_code=True
)

# Step 3: Load the tokenizer for the RAG model
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")

# Step 4: Initialize the RAG model
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base")

# Step 5: Define a function to generate an answer using the retriever and model
def generate_answer(question):
    # Tokenize the question
    inputs = tokenizer(question, return_tensors="pt")
    
    # Retrieve relevant documents using the retriever
    input_ids = inputs["input_ids"]
    retrieved_doc_ids = retriever.retrieve(input_ids)
    
    # Use the model to generate an answer based on the retrieved documents
    generated_ids = model.generate(input_ids, context_input_ids=retrieved_doc_ids["context_input_ids"])
    
    # Decode the generated answer back to text
    answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    return answer

# Step 6: Example usage
if __name__ == "__main__":
    question = "Who was the first president of the United States?"
    print(f"Question: {question}")
    
    # Generate and print the answer
    answer = generate_answer(question)
    print(f"Answer: {answer}")