|
import torch |
|
from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration |
|
from datasets import load_dataset |
|
|
|
|
|
dataset = load_dataset("wiki_dpr", "psgs_w100.nq.exact", trust_remote_code=True) |
|
|
|
|
|
retriever = RagRetriever.from_pretrained( |
|
"facebook/rag-token-base", |
|
use_dummy_dataset=True, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") |
|
|
|
|
|
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base") |
|
|
|
|
|
def generate_answer(question): |
|
|
|
inputs = tokenizer(question, return_tensors="pt") |
|
|
|
|
|
input_ids = inputs["input_ids"] |
|
retrieved_doc_ids = retriever.retrieve(input_ids) |
|
|
|
|
|
generated_ids = model.generate(input_ids, context_input_ids=retrieved_doc_ids["context_input_ids"]) |
|
|
|
|
|
answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
return answer |
|
|
|
|
|
if __name__ == "__main__": |
|
question = "Who was the first president of the United States?" |
|
print(f"Question: {question}") |
|
|
|
|
|
answer = generate_answer(question) |
|
print(f"Answer: {answer}") |
|
|