Spaces:
Runtime error
Runtime error
import datasets | |
import faiss | |
import numpy as np | |
import streamlit as st | |
import torch | |
from elasticsearch import Elasticsearch | |
from eli5_utils import ( | |
embed_questions_for_retrieval, | |
make_qa_s2s_model, | |
qa_s2s_generate, | |
query_es_index, | |
query_qa_dense_index, | |
) | |
import transformers | |
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer | |
MODEL_TYPE = "bart" | |
LOAD_DENSE_INDEX = True | |
def load_models(): | |
if LOAD_DENSE_INDEX: | |
qar_tokenizer = AutoTokenizer.from_pretrained("yjernite/retribert-base-uncased") | |
qar_model = AutoModel.from_pretrained("yjernite/retribert-base-uncased").to("cuda:0") | |
_ = qar_model.eval() | |
else: | |
qar_tokenizer, qar_model = (None, None) | |
if MODEL_TYPE == "bart": | |
s2s_tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5") | |
s2s_model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to("cuda:0") | |
save_dict = torch.load("seq2seq_models/eli5_bart_model_blm_2.pth") | |
s2s_model.load_state_dict(save_dict["model"]) | |
_ = s2s_model.eval() | |
else: | |
s2s_tokenizer, s2s_model = make_qa_s2s_model( | |
model_name="t5-small", from_file="seq2seq_models/eli5_t5_model_1024_4.pth", device="cuda:0" | |
) | |
return (qar_tokenizer, qar_model, s2s_tokenizer, s2s_model) | |
def load_indexes(): | |
if LOAD_DENSE_INDEX: | |
faiss_res = faiss.StandardGpuResources() | |
wiki40b_passages = datasets.load_dataset(path="wiki_snippets", name="wiki40b_en_100_0")["train"] | |
wiki40b_passage_reps = np.memmap( | |
"wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat", | |
dtype="float32", | |
mode="r", | |
shape=(wiki40b_passages.num_rows, 128), | |
) | |
wiki40b_index_flat = faiss.IndexFlatIP(128) | |
wiki40b_gpu_index_flat = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat) | |
wiki40b_gpu_index_flat.add(wiki40b_passage_reps) # TODO fix for larger GPU | |
else: | |
wiki40b_passages, wiki40b_gpu_index_flat = (None, None) | |
es_client = Elasticsearch([{"host": "localhost", "port": "9200"}]) | |
return (wiki40b_passages, wiki40b_gpu_index_flat, es_client) | |
def load_train_data(): | |
eli5 = datasets.load_dataset("eli5", name="LFQA_reddit") | |
eli5_train = eli5["train_eli5"] | |
eli5_train_q_reps = np.memmap( | |
"eli5_questions_reps.dat", dtype="float32", mode="r", shape=(eli5_train.num_rows, 128) | |
) | |
eli5_train_q_index = faiss.IndexFlatIP(128) | |
eli5_train_q_index.add(eli5_train_q_reps) | |
return (eli5_train, eli5_train_q_index) | |
passages, gpu_dense_index, es_client = load_indexes() | |
qar_tokenizer, qar_model, s2s_tokenizer, s2s_model = load_models() | |
eli5_train, eli5_train_q_index = load_train_data() | |
def find_nearest_training(question, n_results=10): | |
q_rep = embed_questions_for_retrieval([question], qar_tokenizer, qar_model) | |
D, I = eli5_train_q_index.search(q_rep, n_results) | |
nn_examples = [eli5_train[int(i)] for i in I[0]] | |
return nn_examples | |
def make_support(question, source="wiki40b", method="dense", n_results=10): | |
if source == "none": | |
support_doc, hit_lst = (" <P> ".join(["" for _ in range(11)]).strip(), []) | |
else: | |
if method == "dense": | |
support_doc, hit_lst = query_qa_dense_index( | |
question, qar_model, qar_tokenizer, passages, gpu_dense_index, n_results | |
) | |
else: | |
support_doc, hit_lst = query_es_index( | |
question, | |
es_client, | |
index_name="english_wiki40b_snippets_100w", | |
n_results=n_results, | |
) | |
support_list = [ | |
(res["article_title"], res["section_title"].strip(), res["score"], res["passage_text"]) for res in hit_lst | |
] | |
question_doc = "question: {} context: {}".format(question, support_doc) | |
return question_doc, support_list | |
def answer_question( | |
question_doc, s2s_model, s2s_tokenizer, min_len=64, max_len=256, sampling=False, n_beams=2, top_p=0.95, temp=0.8 | |
): | |
with torch.no_grad(): | |
answer = qa_s2s_generate( | |
question_doc, | |
s2s_model, | |
s2s_tokenizer, | |
num_answers=1, | |
num_beams=n_beams, | |
min_len=min_len, | |
max_len=max_len, | |
do_sample=sampling, | |
temp=temp, | |
top_p=top_p, | |
top_k=None, | |
max_input_length=1024, | |
device="cuda:0", | |
)[0] | |
return (answer, support_list) | |
st.title("Long Form Question Answering with ELI5") | |
# Start sidebar | |
header_html = "<img src='https://huggingface.co/front/assets/huggingface_logo.svg'>" | |
header_full = """ | |
<html> | |
<head> | |
<style> | |
.img-container { | |
padding-left: 90px; | |
padding-right: 90px; | |
padding-top: 50px; | |
padding-bottom: 50px; | |
background-color: #f0f3f9; | |
} | |
</style> | |
</head> | |
<body> | |
<span class="img-container"> <!-- Inline parent element --> | |
%s | |
</span> | |
</body> | |
</html> | |
""" % ( | |
header_html, | |
) | |
st.sidebar.markdown( | |
header_full, | |
unsafe_allow_html=True, | |
) | |
# Long Form QA with ELI5 and Wikipedia | |
description = """ | |
This demo presents a model trained to [provide long-form answers to open-domain questions](https://yjernite.github.io/lfqa.html). | |
First, a document retriever fetches a set of relevant Wikipedia passages given the question from the [Wiki40b](https://research.google/pubs/pub49029/) dataset, | |
a pre-processed fixed snapshot of Wikipedia. | |
""" | |
st.sidebar.markdown(description, unsafe_allow_html=True) | |
action_list = [ | |
"Answer the question", | |
"View the retrieved document only", | |
"View the most similar ELI5 question and answer", | |
"Show me everything, please!", | |
] | |
demo_options = st.sidebar.checkbox("Demo options") | |
if demo_options: | |
action_st = st.sidebar.selectbox( | |
"", | |
action_list, | |
index=3, | |
) | |
action = action_list.index(action_st) | |
show_type = st.sidebar.selectbox( | |
"", | |
["Show full text of passages", "Show passage section titles"], | |
index=0, | |
) | |
show_passages = show_type == "Show full text of passages" | |
else: | |
action = 3 | |
show_passages = True | |
retrieval_options = st.sidebar.checkbox("Retrieval options") | |
if retrieval_options: | |
retriever_info = """ | |
### Information retriever options | |
The **sparse** retriever uses ElasticSearch, while the **dense** retriever uses max-inner-product search between a question and passage embedding | |
trained using the [ELI5](https://arxiv.org/abs/1907.09190) questions-answer pairs. | |
The answer is then generated by sequence to sequence model which takes the question and retrieved document as input. | |
""" | |
st.sidebar.markdown(retriever_info) | |
wiki_source = st.sidebar.selectbox("Which Wikipedia format should the model use?", ["wiki40b", "none"]) | |
index_type = st.sidebar.selectbox("Which Wikipedia indexer should the model use?", ["dense", "sparse", "mixed"]) | |
else: | |
wiki_source = "wiki40b" | |
index_type = "dense" | |
sampled = "beam" | |
n_beams = 2 | |
min_len = 64 | |
max_len = 256 | |
top_p = None | |
temp = None | |
generate_options = st.sidebar.checkbox("Generation options") | |
if generate_options: | |
generate_info = """ | |
### Answer generation options | |
The sequence-to-sequence model was initialized with [BART](https://huggingface.co/facebook/bart-large) | |
weights and fine-tuned on the ELI5 QA pairs and retrieved documents. You can use the model for greedy decoding with | |
**beam** search, or **sample** from the decoder's output probabilities. | |
""" | |
st.sidebar.markdown(generate_info) | |
sampled = st.sidebar.selectbox("Would you like to use beam search or sample an answer?", ["beam", "sampled"]) | |
min_len = st.sidebar.slider( | |
"Minimum generation length", min_value=8, max_value=256, value=64, step=8, format=None, key=None | |
) | |
max_len = st.sidebar.slider( | |
"Maximum generation length", min_value=64, max_value=512, value=256, step=16, format=None, key=None | |
) | |
if sampled == "beam": | |
n_beams = st.sidebar.slider("Beam size", min_value=1, max_value=8, value=2, step=None, format=None, key=None) | |
else: | |
top_p = st.sidebar.slider( | |
"Nucleus sampling p", min_value=0.1, max_value=1.0, value=0.95, step=0.01, format=None, key=None | |
) | |
temp = st.sidebar.slider( | |
"Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.01, format=None, key=None | |
) | |
n_beams = None | |
# start main text | |
questions_list = [ | |
"<MY QUESTION>", | |
"How do people make chocolate?", | |
"Why do we get a fever when we are sick?", | |
"How can different animals perceive different colors?", | |
"What is natural language processing?", | |
"What's the best way to treat a sunburn?", | |
"What exactly are vitamins ?", | |
"How does nuclear energy provide electricity?", | |
"What's the difference between viruses and bacteria?", | |
"Why are flutes classified as woodwinds when most of them are made out of metal ?", | |
"Why do people like drinking coffee even though it tastes so bad?", | |
"What happens when wine ages? How does it make the wine taste better?", | |
"If an animal is an herbivore, where does it get the protein that it needs to survive if it only eats grass?", | |
"How can we set a date to the beginning or end of an artistic period? Doesn't the change happen gradually?", | |
"How does New Zealand have so many large bird predators?", | |
] | |
question_s = st.selectbox( | |
"What would you like to ask? ---- select <MY QUESTION> to enter a new query", | |
questions_list, | |
index=1, | |
) | |
if question_s == "<MY QUESTION>": | |
question = st.text_input("Enter your question here:", "") | |
else: | |
question = question_s | |
if st.button("Show me!"): | |
if action in [0, 1, 3]: | |
if index_type == "mixed": | |
_, support_list_dense = make_support(question, source=wiki_source, method="dense", n_results=10) | |
_, support_list_sparse = make_support(question, source=wiki_source, method="sparse", n_results=10) | |
support_list = [] | |
for res_d, res_s in zip(support_list_dense, support_list_sparse): | |
if tuple(res_d) not in support_list: | |
support_list += [tuple(res_d)] | |
if tuple(res_s) not in support_list: | |
support_list += [tuple(res_s)] | |
support_list = support_list[:10] | |
question_doc = "<P> " + " <P> ".join([res[-1] for res in support_list]) | |
else: | |
question_doc, support_list = make_support(question, source=wiki_source, method=index_type, n_results=10) | |
if action in [0, 3]: | |
answer, support_list = answer_question( | |
question_doc, | |
s2s_model, | |
s2s_tokenizer, | |
min_len=min_len, | |
max_len=int(max_len), | |
sampling=(sampled == "sampled"), | |
n_beams=n_beams, | |
top_p=top_p, | |
temp=temp, | |
) | |
st.markdown("### The model generated answer is:") | |
st.write(answer) | |
if action in [0, 1, 3] and wiki_source != "none": | |
st.markdown("--- \n ### The model is drawing information from the following Wikipedia passages:") | |
for i, res in enumerate(support_list): | |
wiki_url = "https://en.wikipedia.org/wiki/{}".format(res[0].replace(" ", "_")) | |
sec_titles = res[1].strip() | |
if sec_titles == "": | |
sections = "[{}]({})".format(res[0], wiki_url) | |
else: | |
sec_list = sec_titles.split(" & ") | |
sections = " & ".join( | |
["[{}]({}#{})".format(sec.strip(), wiki_url, sec.strip().replace(" ", "_")) for sec in sec_list] | |
) | |
st.markdown( | |
"{0:02d} - **Article**: {1:<18} <br> _Section_: {2}".format(i + 1, res[0], sections), | |
unsafe_allow_html=True, | |
) | |
if show_passages: | |
st.write( | |
'> <span style="font-family:arial; font-size:10pt;">' + res[-1] + "</span>", unsafe_allow_html=True | |
) | |
if action in [2, 3]: | |
nn_train_list = find_nearest_training(question) | |
train_exple = nn_train_list[0] | |
st.markdown( | |
"--- \n ### The most similar question in the ELI5 training set was: \n\n {}".format(train_exple["title"]) | |
) | |
answers_st = [ | |
"{}. {}".format(i + 1, " \n".join([line.strip() for line in ans.split("\n") if line.strip() != ""])) | |
for i, (ans, sc) in enumerate(zip(train_exple["answers"]["text"], train_exple["answers"]["score"])) | |
if i == 0 or sc > 2 | |
] | |
st.markdown("##### Its answers were: \n\n {}".format("\n".join(answers_st))) | |
disclaimer = """ | |
--- | |
**Disclaimer** | |
*The intent of this app is to provide some (hopefully entertaining) insights into the behavior of a current LFQA system. | |
Evaluating biases of such a model and ensuring factual generations are still very much open research problems. | |
Therefore, until some significant progress is achieved, we caution against using the generated answers for practical purposes.* | |
""" | |
st.sidebar.markdown(disclaimer, unsafe_allow_html=True) | |