PMD / app.py
arif670's picture
Update app.py
a8d7c56 verified
import torch
from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration
from datasets import load_dataset
# Step 1: Load the dataset with the trust_remote_code flag enabled and a valid config name
dataset = load_dataset("wiki_dpr", "psgs_w100.nq.exact", trust_remote_code=True)
# Step 2: Load the retriever using the pre-trained model, with use_dummy_dataset=True and trust_remote_code=True
retriever = RagRetriever.from_pretrained(
"facebook/rag-token-base",
use_dummy_dataset=True,
trust_remote_code=True
)
# Step 3: Load the tokenizer for the RAG model
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
# Step 4: Initialize the RAG model
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base")
# Step 5: Define a function to generate an answer using the retriever and model
def generate_answer(question):
# Tokenize the question
inputs = tokenizer(question, return_tensors="pt")
# Retrieve relevant documents using the retriever
input_ids = inputs["input_ids"]
retrieved_doc_ids = retriever.retrieve(input_ids)
# Use the model to generate an answer based on the retrieved documents
generated_ids = model.generate(input_ids, context_input_ids=retrieved_doc_ids["context_input_ids"])
# Decode the generated answer back to text
answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return answer
# Step 6: Example usage
if __name__ == "__main__":
question = "Who was the first president of the United States?"
print(f"Question: {question}")
# Generate and print the answer
answer = generate_answer(question)
print(f"Answer: {answer}")