|
import shutil |
|
|
|
from haystack.document_stores import FAISSDocumentStore |
|
from haystack.nodes import EmbeddingRetriever |
|
from haystack.pipelines import Pipeline |
|
import streamlit as st |
|
|
|
from app_utils.entailment_checker import EntailmentChecker |
|
from app_utils.config import ( |
|
STATEMENTS_PATH, |
|
INDEX_DIR, |
|
RETRIEVER_MODEL, |
|
RETRIEVER_MODEL_FORMAT, |
|
NLI_MODEL, |
|
) |
|
|
|
|
|
@st.cache() |
|
def load_statements(): |
|
"""Load statements from file""" |
|
with open(STATEMENTS_PATH) as fin: |
|
statements = [ |
|
line.strip() for line in fin.readlines() if not line.startswith("#") |
|
] |
|
return statements |
|
|
|
|
|
|
|
@st.cache( |
|
hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True |
|
) |
|
def start_haystack(): |
|
""" |
|
load document store, retriever, entailment checker and create pipeline |
|
""" |
|
shutil.copy(f"{INDEX_DIR}/faiss_document_store.db", ".") |
|
document_store = FAISSDocumentStore( |
|
faiss_index_path=f"{INDEX_DIR}/my_faiss_index.faiss", |
|
faiss_config_path=f"{INDEX_DIR}/my_faiss_index.json", |
|
) |
|
print(f"Index size: {document_store.get_document_count()}") |
|
retriever = EmbeddingRetriever( |
|
document_store=document_store, |
|
embedding_model=RETRIEVER_MODEL, |
|
model_format=RETRIEVER_MODEL_FORMAT, |
|
) |
|
entailment_checker = EntailmentChecker( |
|
model_name_or_path=NLI_MODEL, |
|
use_gpu=False, |
|
entailment_contradiction_threshold=0.5, |
|
) |
|
|
|
pipe = Pipeline() |
|
pipe.add_node(component=retriever, name="retriever", inputs=["Query"]) |
|
pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"]) |
|
return pipe |
|
|
|
|
|
pipe = start_haystack() |
|
|
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def query(statement: str, retriever_top_k: int = 5): |
|
"""Run query and verify statement""" |
|
params = {"retriever": {"top_k": retriever_top_k}} |
|
return pipe.run(statement, params=params) |
|
|