idr / lfqa.py
daneshjoy's picture
init
c8b3fc9
from haystack.document_stores import FAISSDocumentStore
from haystack.utils import convert_files_to_docs, fetch_archive_from_http, clean_wiki_text
from haystack.nodes import DensePassageRetriever
from haystack.utils import print_documents, print_answers
from haystack.pipelines import DocumentSearchPipeline
from haystack.nodes import Seq2SeqGenerator
from haystack.pipelines import GenerativeQAPipeline
# %% Save/Load FAISS and embeddings
# Try out this script. Make sure you have deleted any old saves of the document store, including the file called faiss_document_store.db that is saved and loaded by default.
# # Convert files to dicts
# dicts = convert_files_to_dicts(dir_path=doc_dir, clean_func=clean_wiki_text, split_paragraphs=True)[:10]
# document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", vector_dim=128)
# # document_store = FAISSDocumentStore(sql_url= "sqlite:///faiss_document_store.db")
# retriever = EmbeddingRetriever(document_store=document_store,
# embedding_model="yjernite/retribert-base-uncased",
# model_format="retribert",
# use_gpu=False)
# # Now, let's write the dicts containing documents to our DB.
# document_store.write_documents(dicts)
# document_store.update_embeddings(retriever)
# document_store.save("my_faiss_index.faiss")
# new_document_store= FAISSDocumentStore.load("my_faiss_index.faiss")
# # new_document_store = FAISSDocumentStore.load(faiss_file_path="testfile_path", sql_url= "sqlite:///faiss_document_store.db")
# %% ------------------------------------------------------------------------------------------------------------
def prepare():
# %% Document Store
document_store= FAISSDocumentStore.load("faiss_index.faiss")
# %% Initialize Retriever and Reader/Generator
# Retriever (DPR)
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
use_gpu=False
)
# # Test DPR
# p_retrieval = DocumentSearchPipeline(retriever)
# res = p_retrieval.run(query="Tell me something about Arya Stark?", params={"Retriever": {"top_k": 5}})
# print_documents(res, max_text_len=512)
# Reader/Generator
# Here we use a Seq2SeqGenerator with the vblagoje/bart_lfqa model (https://huggingface.co/vblagoje/bart_lfqa)
generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa",
use_gpu=False)
# %% Pipeline
pipe = GenerativeQAPipeline(generator, retriever)
return pipe
def answer(pipe, question, k_retriever=3):
res = pipe.run(question, params={"Retriever": {"top_k": k_retriever}})
# # Question
# pipe.run(
# query="How did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 3}}
# )
# # Answer
# res = pipe.run(query="Why is Arya Stark an unusual character?", params={"Retriever": {"top_k": 3}})
return res
if __name__ == '__main__':
question = 'Tell me something about Arya Stark?'
pipe = prepare()
res = answer(pipe, question)
print_answers(res, details="minimum")