realm-demo / app.py
qqaatw's picture
Update
96a5a56
import gradio as gr
import numpy as np
import torch
from transformers import RealmForOpenQA, RealmRetriever
model_name = "google/realm-orqa-nq-openqa"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
retriever = RealmRetriever.from_pretrained(model_name)
tokenizer = retriever.tokenizer
openqa = RealmForOpenQA.from_pretrained(model_name, retriever=retriever)
openqa.to(device)
default_num_block_records = openqa.config.num_block_records
def add_additional_documents(openqa, additional_documents):
documents = additional_documents.split("\n")
np_documents = np.array([doc.encode() for doc in documents], dtype=object)
total_documents = np_documents.shape[0]
retriever = openqa.retriever
tokenizer = openqa.retriever.tokenizer
# docs
retriever.block_records = np.concatenate((retriever.block_records[:default_num_block_records], np_documents), axis=0)
# embeds
inputs = tokenizer(documents, padding=True, truncation=True, return_tensors="pt").to(device)
with torch.no_grad():
projected_score = openqa.embedder(**inputs, return_dict=True).projected_score
openqa.block_emb = torch.cat((openqa.block_emb[:default_num_block_records], projected_score), dim=0)
openqa.config.num_block_records = default_num_block_records + total_documents
def question_answer(question, additional_documents):
question_ids = tokenizer(question, return_tensors="pt").input_ids
if additional_documents != "":
add_additional_documents(openqa, additional_documents)
with torch.no_grad():
outputs = openqa(input_ids=question_ids.to(device), return_dict=True)
return tokenizer.decode(outputs.predicted_answer_ids)
additional_documents_input = gr.inputs.Textbox(lines=5, placeholder="Each line represents a document entry. Leave blank to use default wiki documents.")
iface = gr.Interface(
fn=question_answer,
inputs=["text", additional_documents_input],
outputs=["textbox"],
allow_flagging="never"
)
iface.launch(enable_queue=True)