import datasets import faiss import numpy as np import streamlit as st import torch from elasticsearch import Elasticsearch from eli5_utils import ( embed_questions_for_retrieval, make_qa_s2s_model, qa_s2s_generate, query_es_index, query_qa_dense_index, ) import transformers from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer MODEL_TYPE = "bart" LOAD_DENSE_INDEX = True @st.cache(allow_output_mutation=True) def load_models(): if LOAD_DENSE_INDEX: qar_tokenizer = AutoTokenizer.from_pretrained("yjernite/retribert-base-uncased") qar_model = AutoModel.from_pretrained("yjernite/retribert-base-uncased").to("cuda:0") _ = qar_model.eval() else: qar_tokenizer, qar_model = (None, None) if MODEL_TYPE == "bart": s2s_tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5") s2s_model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to("cuda:0") save_dict = torch.load("seq2seq_models/eli5_bart_model_blm_2.pth") s2s_model.load_state_dict(save_dict["model"]) _ = s2s_model.eval() else: s2s_tokenizer, s2s_model = make_qa_s2s_model( model_name="t5-small", from_file="seq2seq_models/eli5_t5_model_1024_4.pth", device="cuda:0" ) return (qar_tokenizer, qar_model, s2s_tokenizer, s2s_model) @st.cache(allow_output_mutation=True) def load_indexes(): if LOAD_DENSE_INDEX: faiss_res = faiss.StandardGpuResources() wiki40b_passages = datasets.load_dataset(path="wiki_snippets", name="wiki40b_en_100_0")["train"] wiki40b_passage_reps = np.memmap( "wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat", dtype="float32", mode="r", shape=(wiki40b_passages.num_rows, 128), ) wiki40b_index_flat = faiss.IndexFlatIP(128) wiki40b_gpu_index_flat = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat) wiki40b_gpu_index_flat.add(wiki40b_passage_reps) # TODO fix for larger GPU else: wiki40b_passages, wiki40b_gpu_index_flat = (None, None) es_client = Elasticsearch([{"host": "localhost", "port": "9200"}]) return (wiki40b_passages, wiki40b_gpu_index_flat, es_client) @st.cache(allow_output_mutation=True) def load_train_data(): eli5 = datasets.load_dataset("eli5", name="LFQA_reddit") eli5_train = eli5["train_eli5"] eli5_train_q_reps = np.memmap( "eli5_questions_reps.dat", dtype="float32", mode="r", shape=(eli5_train.num_rows, 128) ) eli5_train_q_index = faiss.IndexFlatIP(128) eli5_train_q_index.add(eli5_train_q_reps) return (eli5_train, eli5_train_q_index) passages, gpu_dense_index, es_client = load_indexes() qar_tokenizer, qar_model, s2s_tokenizer, s2s_model = load_models() eli5_train, eli5_train_q_index = load_train_data() def find_nearest_training(question, n_results=10): q_rep = embed_questions_for_retrieval([question], qar_tokenizer, qar_model) D, I = eli5_train_q_index.search(q_rep, n_results) nn_examples = [eli5_train[int(i)] for i in I[0]] return nn_examples def make_support(question, source="wiki40b", method="dense", n_results=10): if source == "none": support_doc, hit_lst = ("
".join(["" for _ in range(11)]).strip(), []) else: if method == "dense": support_doc, hit_lst = query_qa_dense_index( question, qar_model, qar_tokenizer, passages, gpu_dense_index, n_results ) else: support_doc, hit_lst = query_es_index( question, es_client, index_name="english_wiki40b_snippets_100w", n_results=n_results, ) support_list = [ (res["article_title"], res["section_title"].strip(), res["score"], res["passage_text"]) for res in hit_lst ] question_doc = "question: {} context: {}".format(question, support_doc) return question_doc, support_list @st.cache( hash_funcs={ torch.Tensor: (lambda _: None), transformers.models.bart.tokenization_bart.BartTokenizer: (lambda _: None), } ) def answer_question( question_doc, s2s_model, s2s_tokenizer, min_len=64, max_len=256, sampling=False, n_beams=2, top_p=0.95, temp=0.8 ): with torch.no_grad(): answer = qa_s2s_generate( question_doc, s2s_model, s2s_tokenizer, num_answers=1, num_beams=n_beams, min_len=min_len, max_len=max_len, do_sample=sampling, temp=temp, top_p=top_p, top_k=None, max_input_length=1024, device="cuda:0", )[0] return (answer, support_list) st.title("Long Form Question Answering with ELI5") # Start sidebar header_html = "" header_full = """
%s """ % ( header_html, ) st.sidebar.markdown( header_full, unsafe_allow_html=True, ) # Long Form QA with ELI5 and Wikipedia description = """ This demo presents a model trained to [provide long-form answers to open-domain questions](https://yjernite.github.io/lfqa.html). First, a document retriever fetches a set of relevant Wikipedia passages given the question from the [Wiki40b](https://research.google/pubs/pub49029/) dataset, a pre-processed fixed snapshot of Wikipedia. """ st.sidebar.markdown(description, unsafe_allow_html=True) action_list = [ "Answer the question", "View the retrieved document only", "View the most similar ELI5 question and answer", "Show me everything, please!", ] demo_options = st.sidebar.checkbox("Demo options") if demo_options: action_st = st.sidebar.selectbox( "", action_list, index=3, ) action = action_list.index(action_st) show_type = st.sidebar.selectbox( "", ["Show full text of passages", "Show passage section titles"], index=0, ) show_passages = show_type == "Show full text of passages" else: action = 3 show_passages = True retrieval_options = st.sidebar.checkbox("Retrieval options") if retrieval_options: retriever_info = """ ### Information retriever options The **sparse** retriever uses ElasticSearch, while the **dense** retriever uses max-inner-product search between a question and passage embedding trained using the [ELI5](https://arxiv.org/abs/1907.09190) questions-answer pairs. The answer is then generated by sequence to sequence model which takes the question and retrieved document as input. """ st.sidebar.markdown(retriever_info) wiki_source = st.sidebar.selectbox("Which Wikipedia format should the model use?", ["wiki40b", "none"]) index_type = st.sidebar.selectbox("Which Wikipedia indexer should the model use?", ["dense", "sparse", "mixed"]) else: wiki_source = "wiki40b" index_type = "dense" sampled = "beam" n_beams = 2 min_len = 64 max_len = 256 top_p = None temp = None generate_options = st.sidebar.checkbox("Generation options") if generate_options: generate_info = """ ### Answer generation options The sequence-to-sequence model was initialized with [BART](https://huggingface.co/facebook/bart-large) weights and fine-tuned on the ELI5 QA pairs and retrieved documents. You can use the model for greedy decoding with **beam** search, or **sample** from the decoder's output probabilities. """ st.sidebar.markdown(generate_info) sampled = st.sidebar.selectbox("Would you like to use beam search or sample an answer?", ["beam", "sampled"]) min_len = st.sidebar.slider( "Minimum generation length", min_value=8, max_value=256, value=64, step=8, format=None, key=None ) max_len = st.sidebar.slider( "Maximum generation length", min_value=64, max_value=512, value=256, step=16, format=None, key=None ) if sampled == "beam": n_beams = st.sidebar.slider("Beam size", min_value=1, max_value=8, value=2, step=None, format=None, key=None) else: top_p = st.sidebar.slider( "Nucleus sampling p", min_value=0.1, max_value=1.0, value=0.95, step=0.01, format=None, key=None ) temp = st.sidebar.slider( "Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.01, format=None, key=None ) n_beams = None # start main text questions_list = [ "" + "
".join([res[-1] for res in support_list])
else:
question_doc, support_list = make_support(question, source=wiki_source, method=index_type, n_results=10)
if action in [0, 3]:
answer, support_list = answer_question(
question_doc,
s2s_model,
s2s_tokenizer,
min_len=min_len,
max_len=int(max_len),
sampling=(sampled == "sampled"),
n_beams=n_beams,
top_p=top_p,
temp=temp,
)
st.markdown("### The model generated answer is:")
st.write(answer)
if action in [0, 1, 3] and wiki_source != "none":
st.markdown("--- \n ### The model is drawing information from the following Wikipedia passages:")
for i, res in enumerate(support_list):
wiki_url = "https://en.wikipedia.org/wiki/{}".format(res[0].replace(" ", "_"))
sec_titles = res[1].strip()
if sec_titles == "":
sections = "[{}]({})".format(res[0], wiki_url)
else:
sec_list = sec_titles.split(" & ")
sections = " & ".join(
["[{}]({}#{})".format(sec.strip(), wiki_url, sec.strip().replace(" ", "_")) for sec in sec_list]
)
st.markdown(
"{0:02d} - **Article**: {1:<18}
_Section_: {2}".format(i + 1, res[0], sections),
unsafe_allow_html=True,
)
if show_passages:
st.write(
'> ' + res[-1] + "", unsafe_allow_html=True
)
if action in [2, 3]:
nn_train_list = find_nearest_training(question)
train_exple = nn_train_list[0]
st.markdown(
"--- \n ### The most similar question in the ELI5 training set was: \n\n {}".format(train_exple["title"])
)
answers_st = [
"{}. {}".format(i + 1, " \n".join([line.strip() for line in ans.split("\n") if line.strip() != ""]))
for i, (ans, sc) in enumerate(zip(train_exple["answers"]["text"], train_exple["answers"]["score"]))
if i == 0 or sc > 2
]
st.markdown("##### Its answers were: \n\n {}".format("\n".join(answers_st)))
disclaimer = """
---
**Disclaimer**
*The intent of this app is to provide some (hopefully entertaining) insights into the behavior of a current LFQA system.
Evaluating biases of such a model and ensuring factual generations are still very much open research problems.
Therefore, until some significant progress is achieved, we caution against using the generated answers for practical purposes.*
"""
st.sidebar.markdown(disclaimer, unsafe_allow_html=True)