anakin87 commited on
Commit
5b26a96
1 Parent(s): 5fe5c67

add LLM explanation feat

Browse files
Rock_fact_checker.py CHANGED
@@ -5,7 +5,7 @@ from json import JSONDecodeError
5
 
6
  import streamlit as st
7
 
8
- from app_utils.backend_utils import load_statements, query
9
  from app_utils.frontend_utils import (
10
  set_state_if_absent,
11
  reset_results,
@@ -80,7 +80,7 @@ def main():
80
  st.session_state.statement = statement
81
  with st.spinner("🧠    Performing neural search on documents..."):
82
  try:
83
- st.session_state.results = query(statement, RETRIEVER_TOP_K)
84
  print(f"S: {statement}")
85
  time_end = time.time()
86
  print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
@@ -121,5 +121,12 @@ def main():
121
  str_wiki_pages += f"[{doc}]({url}) "
122
  st.markdown(str_wiki_pages)
123
 
 
 
 
 
 
 
 
124
 
125
  main()
 
5
 
6
  import streamlit as st
7
 
8
+ from app_utils.backend_utils import load_statements, check_statement, explain_using_llm
9
  from app_utils.frontend_utils import (
10
  set_state_if_absent,
11
  reset_results,
 
80
  st.session_state.statement = statement
81
  with st.spinner("🧠    Performing neural search on documents..."):
82
  try:
83
+ st.session_state.results = check_statement(statement, RETRIEVER_TOP_K)
84
  print(f"S: {statement}")
85
  time_end = time.time()
86
  print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
 
121
  str_wiki_pages += f"[{doc}]({url}) "
122
  st.markdown(str_wiki_pages)
123
 
124
+ if max_key != "neutral":
125
+ explanation = explain_using_llm(
126
+ statement=statement, documents=docs, entailment_or_contradiction=max_key
127
+ )
128
+ explanation = "#### Explanation 🧠 (experimental):\n" + explanation
129
+ st.markdown(explanation)
130
+
131
 
132
  main()
app_utils/backend_utils.py CHANGED
@@ -1,7 +1,9 @@
1
  import shutil
 
2
 
 
3
  from haystack.document_stores import FAISSDocumentStore
4
- from haystack.nodes import EmbeddingRetriever
5
  from haystack.pipelines import Pipeline
6
  import streamlit as st
7
 
@@ -12,6 +14,7 @@ from app_utils.config import (
12
  RETRIEVER_MODEL,
13
  RETRIEVER_MODEL_FORMAT,
14
  NLI_MODEL,
 
15
  )
16
 
17
 
@@ -53,15 +56,37 @@ def start_haystack():
53
  pipe = Pipeline()
54
  pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
55
  pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"])
56
- return pipe
57
 
 
58
 
59
- pipe = start_haystack()
 
 
 
60
 
61
  # the pipeline is not included as parameter of the following function,
62
  # because it is difficult to cache
63
  @st.cache(allow_output_mutation=True)
64
- def query(statement: str, retriever_top_k: int = 5):
65
  """Run query and verify statement"""
66
  params = {"retriever": {"top_k": retriever_top_k}}
67
  return pipe.run(statement, params=params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import shutil
2
+ from typing import List
3
 
4
+ from haystack import Document
5
  from haystack.document_stores import FAISSDocumentStore
6
+ from haystack.nodes import EmbeddingRetriever, PromptNode
7
  from haystack.pipelines import Pipeline
8
  import streamlit as st
9
 
 
14
  RETRIEVER_MODEL,
15
  RETRIEVER_MODEL_FORMAT,
16
  NLI_MODEL,
17
+ PROMPT_MODEL,
18
  )
19
 
20
 
 
56
  pipe = Pipeline()
57
  pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
58
  pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"])
 
59
 
60
+ prompt_node = PromptNode(model_name_or_path=PROMPT_MODEL, max_length=150)
61
 
62
+ return pipe, prompt_node
63
+
64
+
65
+ pipe, prompt_node = start_haystack()
66
 
67
  # the pipeline is not included as parameter of the following function,
68
  # because it is difficult to cache
69
  @st.cache(allow_output_mutation=True)
70
+ def check_statement(statement: str, retriever_top_k: int = 5):
71
  """Run query and verify statement"""
72
  params = {"retriever": {"top_k": retriever_top_k}}
73
  return pipe.run(statement, params=params)
74
+
75
+
76
+ @st.cache(
77
+ hash_funcs={"tokenizers.Tokenizer": lambda _: None}, allow_output_mutation=True
78
+ )
79
+ def explain_using_llm(
80
+ statement: str, documents: List[Document], entailment_or_contradiction: str
81
+ ) -> str:
82
+ """Explain entailment/contradiction, by prompting a LLM"""
83
+ premise = " \n".join([doc.content.replace("\n", ". ") for doc in documents])
84
+ if entailment_or_contradiction == "entailment":
85
+ verb = "entails"
86
+ elif entailment_or_contradiction == "contradiction":
87
+ verb = "contradicts"
88
+
89
+ prompt = f"Premise: {premise}; Hypothesis: {statement}; Please explain in detail why the Premise {verb} the Hypothesis. Step by step Explanation:"
90
+
91
+ print(prompt)
92
+ return prompt_node(prompt)[0]
app_utils/config.py CHANGED
@@ -14,3 +14,12 @@ try:
14
  except:
15
  NLI_MODEL = "valhalla/distilbart-mnli-12-1"
16
  print(f"Used NLI model: {NLI_MODEL}")
 
 
 
 
 
 
 
 
 
 
14
  except:
15
  NLI_MODEL = "valhalla/distilbart-mnli-12-1"
16
  print(f"Used NLI model: {NLI_MODEL}")
17
+
18
+
19
+ # In HF Space, we use google/flan-t5-large
20
+ # for local testing, a smaller model is better
21
+ try:
22
+ PROMPT_MODEL = st.secrets["PROMPT_MODEL"]
23
+ except:
24
+ PROMPT_MODEL = "google/flan-t5-small"
25
+ print(f"Used Prompt model: {PROMPT_MODEL}")