Spaces:
Running
Running
File size: 2,994 Bytes
4c2a969 5b26a96 1434337 5b26a96 4c2a969 5b26a96 4c2a969 1434337 5b26a96 1434337 4c2a969 1434337 4c2a969 a147158 4c2a969 1434337 4c2a969 1434337 4c2a969 1434337 4c2a969 4c41de2 4c2a969 5b26a96 1434337 5b26a96 1434337 4c2a969 4c5a7dc 5b26a96 35f0167 4c2a969 4c41de2 5b26a96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import shutil
from typing import List
from haystack import Document
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes import EmbeddingRetriever, PromptNode
from haystack.pipelines import Pipeline
import streamlit as st
from app_utils.entailment_checker import EntailmentChecker
from app_utils.config import (
STATEMENTS_PATH,
INDEX_DIR,
RETRIEVER_MODEL,
RETRIEVER_MODEL_FORMAT,
NLI_MODEL,
PROMPT_MODEL,
)
@st.cache()
def load_statements():
"""Load statements from file"""
with open(STATEMENTS_PATH) as fin:
statements = [
line.strip() for line in fin.readlines() if not line.startswith("#")
]
return statements
# cached to make index and models load only at start
@st.cache(
hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True
)
def start_haystack():
"""
load document store, retriever, entailment checker and create pipeline
"""
shutil.copy(f"{INDEX_DIR}/faiss_document_store.db", ".")
document_store = FAISSDocumentStore(
faiss_index_path=f"{INDEX_DIR}/my_faiss_index.faiss",
faiss_config_path=f"{INDEX_DIR}/my_faiss_index.json",
)
print(f"Index size: {document_store.get_document_count()}")
retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model=RETRIEVER_MODEL,
model_format=RETRIEVER_MODEL_FORMAT,
)
entailment_checker = EntailmentChecker(
model_name_or_path=NLI_MODEL,
use_gpu=False,
entailment_contradiction_threshold=0.5,
)
pipe = Pipeline()
pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"])
prompt_node = PromptNode(model_name_or_path=PROMPT_MODEL, max_length=150)
return pipe, prompt_node
pipe, prompt_node = start_haystack()
# the pipeline is not included as parameter of the following function,
# because it is difficult to cache
@st.cache(allow_output_mutation=True)
def check_statement(statement: str, retriever_top_k: int = 5):
"""Run query and verify statement"""
params = {"retriever": {"top_k": retriever_top_k}}
return pipe.run(statement, params=params)
@st.cache(
hash_funcs={"tokenizers.Tokenizer": lambda _: None}, allow_output_mutation=True
)
def explain_using_llm(
statement: str, documents: List[Document], entailment_or_contradiction: str
) -> str:
"""Explain entailment/contradiction, by prompting a LLM"""
premise = " \n".join([doc.content.replace("\n", ". ") for doc in documents])
if entailment_or_contradiction == "entailment":
verb = "entails"
elif entailment_or_contradiction == "contradiction":
verb = "contradicts"
prompt = f"Premise: {premise}; Hypothesis: {statement}; Please explain in detail why the Premise {verb} the Hypothesis. Step by step Explanation:"
print(prompt)
return prompt_node(prompt)[0]
|