import gradio as gr import numpy as np import torch from transformers import RealmForOpenQA, RealmRetriever model_name = "google/realm-orqa-nq-openqa" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") retriever = RealmRetriever.from_pretrained(model_name) tokenizer = retriever.tokenizer openqa = RealmForOpenQA.from_pretrained(model_name, retriever=retriever) openqa.to(device) default_num_block_records = openqa.config.num_block_records def add_additional_documents(openqa, additional_documents): documents = additional_documents.split("\n") np_documents = np.array([doc.encode() for doc in documents], dtype=object) total_documents = np_documents.shape[0] retriever = openqa.retriever tokenizer = openqa.retriever.tokenizer # docs retriever.block_records = np.concatenate((retriever.block_records[:default_num_block_records], np_documents), axis=0) # embeds inputs = tokenizer(documents, padding=True, truncation=True, return_tensors="pt").to(device) with torch.no_grad(): projected_score = openqa.embedder(**inputs, return_dict=True).projected_score openqa.block_emb = torch.cat((openqa.block_emb[:default_num_block_records], projected_score), dim=0) openqa.config.num_block_records = default_num_block_records + total_documents def question_answer(question, additional_documents): question_ids = tokenizer(question, return_tensors="pt").input_ids if additional_documents != "": add_additional_documents(openqa, additional_documents) with torch.no_grad(): outputs = openqa(input_ids=question_ids.to(device), return_dict=True) return tokenizer.decode(outputs.predicted_answer_ids) additional_documents_input = gr.inputs.Textbox(lines=5, placeholder="Each line represents a document entry. Leave blank to use default wiki documents.") iface = gr.Interface( fn=question_answer, inputs=["text", additional_documents_input], outputs=["textbox"], allow_flagging="never" ) iface.launch(enable_queue=True)