|
import datasets |
|
import faiss |
|
import numpy as np |
|
import streamlit as st |
|
import torch |
|
from elasticsearch import Elasticsearch |
|
from eli5_utils import ( |
|
embed_questions_for_retrieval, |
|
make_qa_s2s_model, |
|
qa_s2s_generate, |
|
query_es_index, |
|
query_qa_dense_index, |
|
) |
|
|
|
import transformers |
|
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
|
|
MODEL_TYPE = "bart" |
|
LOAD_DENSE_INDEX = True |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_models(): |
|
if LOAD_DENSE_INDEX: |
|
qar_tokenizer = AutoTokenizer.from_pretrained("yjernite/retribert-base-uncased") |
|
qar_model = AutoModel.from_pretrained("yjernite/retribert-base-uncased").to("cuda:0") |
|
_ = qar_model.eval() |
|
else: |
|
qar_tokenizer, qar_model = (None, None) |
|
if MODEL_TYPE == "bart": |
|
s2s_tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5") |
|
s2s_model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to("cuda:0") |
|
save_dict = torch.load("seq2seq_models/eli5_bart_model_blm_2.pth") |
|
s2s_model.load_state_dict(save_dict["model"]) |
|
_ = s2s_model.eval() |
|
else: |
|
s2s_tokenizer, s2s_model = make_qa_s2s_model( |
|
model_name="t5-small", from_file="seq2seq_models/eli5_t5_model_1024_4.pth", device="cuda:0" |
|
) |
|
return (qar_tokenizer, qar_model, s2s_tokenizer, s2s_model) |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_indexes(): |
|
if LOAD_DENSE_INDEX: |
|
faiss_res = faiss.StandardGpuResources() |
|
wiki40b_passages = datasets.load_dataset(path="wiki_snippets", name="wiki40b_en_100_0")["train"] |
|
wiki40b_passage_reps = np.memmap( |
|
"wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat", |
|
dtype="float32", |
|
mode="r", |
|
shape=(wiki40b_passages.num_rows, 128), |
|
) |
|
wiki40b_index_flat = faiss.IndexFlatIP(128) |
|
wiki40b_gpu_index_flat = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat) |
|
wiki40b_gpu_index_flat.add(wiki40b_passage_reps) |
|
else: |
|
wiki40b_passages, wiki40b_gpu_index_flat = (None, None) |
|
es_client = Elasticsearch([{"host": "localhost", "port": "9200"}]) |
|
return (wiki40b_passages, wiki40b_gpu_index_flat, es_client) |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_train_data(): |
|
eli5 = datasets.load_dataset("eli5", name="LFQA_reddit") |
|
eli5_train = eli5["train_eli5"] |
|
eli5_train_q_reps = np.memmap( |
|
"eli5_questions_reps.dat", dtype="float32", mode="r", shape=(eli5_train.num_rows, 128) |
|
) |
|
eli5_train_q_index = faiss.IndexFlatIP(128) |
|
eli5_train_q_index.add(eli5_train_q_reps) |
|
return (eli5_train, eli5_train_q_index) |
|
|
|
|
|
passages, gpu_dense_index, es_client = load_indexes() |
|
qar_tokenizer, qar_model, s2s_tokenizer, s2s_model = load_models() |
|
eli5_train, eli5_train_q_index = load_train_data() |
|
|
|
|
|
def find_nearest_training(question, n_results=10): |
|
q_rep = embed_questions_for_retrieval([question], qar_tokenizer, qar_model) |
|
D, I = eli5_train_q_index.search(q_rep, n_results) |
|
nn_examples = [eli5_train[int(i)] for i in I[0]] |
|
return nn_examples |
|
|
|
|
|
def make_support(question, source="wiki40b", method="dense", n_results=10): |
|
if source == "none": |
|
support_doc, hit_lst = (" <P> ".join(["" for _ in range(11)]).strip(), []) |
|
else: |
|
if method == "dense": |
|
support_doc, hit_lst = query_qa_dense_index( |
|
question, qar_model, qar_tokenizer, passages, gpu_dense_index, n_results |
|
) |
|
else: |
|
support_doc, hit_lst = query_es_index( |
|
question, |
|
es_client, |
|
index_name="english_wiki40b_snippets_100w", |
|
n_results=n_results, |
|
) |
|
support_list = [ |
|
(res["article_title"], res["section_title"].strip(), res["score"], res["passage_text"]) for res in hit_lst |
|
] |
|
question_doc = "question: {} context: {}".format(question, support_doc) |
|
return question_doc, support_list |
|
|
|
|
|
@st.cache( |
|
hash_funcs={ |
|
torch.Tensor: (lambda _: None), |
|
transformers.models.bart.tokenization_bart.BartTokenizer: (lambda _: None), |
|
} |
|
) |
|
def answer_question( |
|
question_doc, s2s_model, s2s_tokenizer, min_len=64, max_len=256, sampling=False, n_beams=2, top_p=0.95, temp=0.8 |
|
): |
|
with torch.no_grad(): |
|
answer = qa_s2s_generate( |
|
question_doc, |
|
s2s_model, |
|
s2s_tokenizer, |
|
num_answers=1, |
|
num_beams=n_beams, |
|
min_len=min_len, |
|
max_len=max_len, |
|
do_sample=sampling, |
|
temp=temp, |
|
top_p=top_p, |
|
top_k=None, |
|
max_input_length=1024, |
|
device="cuda:0", |
|
)[0] |
|
return (answer, support_list) |
|
|
|
|
|
st.title("Long Form Question Answering with ELI5") |
|
|
|
|
|
header_html = "<img src='https://huggingface.co/front/assets/huggingface_logo.svg'>" |
|
header_full = """ |
|
<html> |
|
<head> |
|
<style> |
|
.img-container { |
|
padding-left: 90px; |
|
padding-right: 90px; |
|
padding-top: 50px; |
|
padding-bottom: 50px; |
|
background-color: #f0f3f9; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<span class="img-container"> <!-- Inline parent element --> |
|
%s |
|
</span> |
|
</body> |
|
</html> |
|
""" % ( |
|
header_html, |
|
) |
|
st.sidebar.markdown( |
|
header_full, |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
description = """ |
|
This demo presents a model trained to [provide long-form answers to open-domain questions](https://yjernite.github.io/lfqa.html). |
|
First, a document retriever fetches a set of relevant Wikipedia passages given the question from the [Wiki40b](https://research.google/pubs/pub49029/) dataset, |
|
a pre-processed fixed snapshot of Wikipedia. |
|
""" |
|
st.sidebar.markdown(description, unsafe_allow_html=True) |
|
|
|
action_list = [ |
|
"Answer the question", |
|
"View the retrieved document only", |
|
"View the most similar ELI5 question and answer", |
|
"Show me everything, please!", |
|
] |
|
demo_options = st.sidebar.checkbox("Demo options") |
|
if demo_options: |
|
action_st = st.sidebar.selectbox( |
|
"", |
|
action_list, |
|
index=3, |
|
) |
|
action = action_list.index(action_st) |
|
show_type = st.sidebar.selectbox( |
|
"", |
|
["Show full text of passages", "Show passage section titles"], |
|
index=0, |
|
) |
|
show_passages = show_type == "Show full text of passages" |
|
else: |
|
action = 3 |
|
show_passages = True |
|
|
|
retrieval_options = st.sidebar.checkbox("Retrieval options") |
|
if retrieval_options: |
|
retriever_info = """ |
|
### Information retriever options |
|
|
|
The **sparse** retriever uses ElasticSearch, while the **dense** retriever uses max-inner-product search between a question and passage embedding |
|
trained using the [ELI5](https://arxiv.org/abs/1907.09190) questions-answer pairs. |
|
The answer is then generated by sequence to sequence model which takes the question and retrieved document as input. |
|
""" |
|
st.sidebar.markdown(retriever_info) |
|
wiki_source = st.sidebar.selectbox("Which Wikipedia format should the model use?", ["wiki40b", "none"]) |
|
index_type = st.sidebar.selectbox("Which Wikipedia indexer should the model use?", ["dense", "sparse", "mixed"]) |
|
else: |
|
wiki_source = "wiki40b" |
|
index_type = "dense" |
|
|
|
sampled = "beam" |
|
n_beams = 2 |
|
min_len = 64 |
|
max_len = 256 |
|
top_p = None |
|
temp = None |
|
generate_options = st.sidebar.checkbox("Generation options") |
|
if generate_options: |
|
generate_info = """ |
|
### Answer generation options |
|
|
|
The sequence-to-sequence model was initialized with [BART](https://huggingface.co/facebook/bart-large) |
|
weights and fine-tuned on the ELI5 QA pairs and retrieved documents. You can use the model for greedy decoding with |
|
**beam** search, or **sample** from the decoder's output probabilities. |
|
""" |
|
st.sidebar.markdown(generate_info) |
|
sampled = st.sidebar.selectbox("Would you like to use beam search or sample an answer?", ["beam", "sampled"]) |
|
min_len = st.sidebar.slider( |
|
"Minimum generation length", min_value=8, max_value=256, value=64, step=8, format=None, key=None |
|
) |
|
max_len = st.sidebar.slider( |
|
"Maximum generation length", min_value=64, max_value=512, value=256, step=16, format=None, key=None |
|
) |
|
if sampled == "beam": |
|
n_beams = st.sidebar.slider("Beam size", min_value=1, max_value=8, value=2, step=None, format=None, key=None) |
|
else: |
|
top_p = st.sidebar.slider( |
|
"Nucleus sampling p", min_value=0.1, max_value=1.0, value=0.95, step=0.01, format=None, key=None |
|
) |
|
temp = st.sidebar.slider( |
|
"Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.01, format=None, key=None |
|
) |
|
n_beams = None |
|
|
|
|
|
questions_list = [ |
|
"<MY QUESTION>", |
|
"How do people make chocolate?", |
|
"Why do we get a fever when we are sick?", |
|
"How can different animals perceive different colors?", |
|
"What is natural language processing?", |
|
"What's the best way to treat a sunburn?", |
|
"What exactly are vitamins ?", |
|
"How does nuclear energy provide electricity?", |
|
"What's the difference between viruses and bacteria?", |
|
"Why are flutes classified as woodwinds when most of them are made out of metal ?", |
|
"Why do people like drinking coffee even though it tastes so bad?", |
|
"What happens when wine ages? How does it make the wine taste better?", |
|
"If an animal is an herbivore, where does it get the protein that it needs to survive if it only eats grass?", |
|
"How can we set a date to the beginning or end of an artistic period? Doesn't the change happen gradually?", |
|
"How does New Zealand have so many large bird predators?", |
|
] |
|
question_s = st.selectbox( |
|
"What would you like to ask? ---- select <MY QUESTION> to enter a new query", |
|
questions_list, |
|
index=1, |
|
) |
|
if question_s == "<MY QUESTION>": |
|
question = st.text_input("Enter your question here:", "") |
|
else: |
|
question = question_s |
|
|
|
if st.button("Show me!"): |
|
if action in [0, 1, 3]: |
|
if index_type == "mixed": |
|
_, support_list_dense = make_support(question, source=wiki_source, method="dense", n_results=10) |
|
_, support_list_sparse = make_support(question, source=wiki_source, method="sparse", n_results=10) |
|
support_list = [] |
|
for res_d, res_s in zip(support_list_dense, support_list_sparse): |
|
if tuple(res_d) not in support_list: |
|
support_list += [tuple(res_d)] |
|
if tuple(res_s) not in support_list: |
|
support_list += [tuple(res_s)] |
|
support_list = support_list[:10] |
|
question_doc = "<P> " + " <P> ".join([res[-1] for res in support_list]) |
|
else: |
|
question_doc, support_list = make_support(question, source=wiki_source, method=index_type, n_results=10) |
|
if action in [0, 3]: |
|
answer, support_list = answer_question( |
|
question_doc, |
|
s2s_model, |
|
s2s_tokenizer, |
|
min_len=min_len, |
|
max_len=int(max_len), |
|
sampling=(sampled == "sampled"), |
|
n_beams=n_beams, |
|
top_p=top_p, |
|
temp=temp, |
|
) |
|
st.markdown("### The model generated answer is:") |
|
st.write(answer) |
|
if action in [0, 1, 3] and wiki_source != "none": |
|
st.markdown("--- \n ### The model is drawing information from the following Wikipedia passages:") |
|
for i, res in enumerate(support_list): |
|
wiki_url = "https://en.wikipedia.org/wiki/{}".format(res[0].replace(" ", "_")) |
|
sec_titles = res[1].strip() |
|
if sec_titles == "": |
|
sections = "[{}]({})".format(res[0], wiki_url) |
|
else: |
|
sec_list = sec_titles.split(" & ") |
|
sections = " & ".join( |
|
["[{}]({}#{})".format(sec.strip(), wiki_url, sec.strip().replace(" ", "_")) for sec in sec_list] |
|
) |
|
st.markdown( |
|
"{0:02d} - **Article**: {1:<18} <br> _Section_: {2}".format(i + 1, res[0], sections), |
|
unsafe_allow_html=True, |
|
) |
|
if show_passages: |
|
st.write( |
|
'> <span style="font-family:arial; font-size:10pt;">' + res[-1] + "</span>", unsafe_allow_html=True |
|
) |
|
if action in [2, 3]: |
|
nn_train_list = find_nearest_training(question) |
|
train_exple = nn_train_list[0] |
|
st.markdown( |
|
"--- \n ### The most similar question in the ELI5 training set was: \n\n {}".format(train_exple["title"]) |
|
) |
|
answers_st = [ |
|
"{}. {}".format(i + 1, " \n".join([line.strip() for line in ans.split("\n") if line.strip() != ""])) |
|
for i, (ans, sc) in enumerate(zip(train_exple["answers"]["text"], train_exple["answers"]["score"])) |
|
if i == 0 or sc > 2 |
|
] |
|
st.markdown("##### Its answers were: \n\n {}".format("\n".join(answers_st))) |
|
|
|
|
|
disclaimer = """ |
|
--- |
|
|
|
**Disclaimer** |
|
|
|
*The intent of this app is to provide some (hopefully entertaining) insights into the behavior of a current LFQA system. |
|
Evaluating biases of such a model and ensuring factual generations are still very much open research problems. |
|
Therefore, until some significant progress is achieved, we caution against using the generated answers for practical purposes.* |
|
""" |
|
st.sidebar.markdown(disclaimer, unsafe_allow_html=True) |
|
|