Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from transformers import RealmForOpenQA, RealmRetriever | |
| model_name = "google/realm-orqa-nq-openqa" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| retriever = RealmRetriever.from_pretrained(model_name) | |
| tokenizer = retriever.tokenizer | |
| openqa = RealmForOpenQA.from_pretrained(model_name, retriever=retriever) | |
| openqa.to(device) | |
| default_num_block_records = openqa.config.num_block_records | |
| def add_additional_documents(openqa, additional_documents): | |
| documents = additional_documents.split("\n") | |
| np_documents = np.array([doc.encode() for doc in documents], dtype=object) | |
| total_documents = np_documents.shape[0] | |
| retriever = openqa.retriever | |
| tokenizer = openqa.retriever.tokenizer | |
| # docs | |
| retriever.block_records = np.concatenate((retriever.block_records[:default_num_block_records], np_documents), axis=0) | |
| # embeds | |
| inputs = tokenizer(documents, padding=True, truncation=True, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| projected_score = openqa.embedder(**inputs, return_dict=True).projected_score | |
| openqa.block_emb = torch.cat((openqa.block_emb[:default_num_block_records], projected_score), dim=0) | |
| openqa.config.num_block_records = default_num_block_records + total_documents | |
| def question_answer(question, additional_documents): | |
| question_ids = tokenizer(question, return_tensors="pt").input_ids | |
| if additional_documents != "": | |
| add_additional_documents(openqa, additional_documents) | |
| with torch.no_grad(): | |
| outputs = openqa(input_ids=question_ids.to(device), return_dict=True) | |
| return tokenizer.decode(outputs.predicted_answer_ids) | |
| additional_documents_input = gr.inputs.Textbox(lines=5, placeholder="Each line represents a document entry. Leave blank to use default wiki documents.") | |
| iface = gr.Interface( | |
| fn=question_answer, | |
| inputs=["text", additional_documents_input], | |
| outputs=["textbox"], | |
| allow_flagging="never" | |
| ) | |
| iface.launch(enable_queue=True) |