anakin87 commited on
Commit
4c2a969
1 Parent(s): d6bdb02

various improvements

Browse files
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: purple
5
  colorTo: blue
6
  sdk: streamlit
7
  sdk_version: 1.10.0
8
- app_file: rock_fact_checker.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
 
5
  colorTo: blue
6
  sdk: streamlit
7
  sdk_version: 1.10.0
8
+ app_file: Rock_fact_checker.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
Rock_fact_checker.py CHANGED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ import time
4
+ import streamlit as st
5
+ import logging
6
+ from json import JSONDecodeError
7
+ # from markdown import markdown
8
+ # from annotated_text import annotation
9
+ # from urllib.parse import unquote
10
+ import random
11
+
12
+ from app_utils.backend_utils import load_questions, query
13
+ from app_utils.frontend_utils import set_state_if_absent, reset_results
14
+ from app_utils.config import RETRIEVER_TOP_K
15
+
16
+
17
+ def main():
18
+
19
+
20
+ questions = load_questions()
21
+
22
+ # Persistent state
23
+ set_state_if_absent('question', "Elvis Presley is alive")
24
+ set_state_if_absent('answer', '')
25
+ set_state_if_absent('results', None)
26
+ set_state_if_absent('raw_json', None)
27
+ set_state_if_absent('random_question_requested', False)
28
+
29
+
30
+ ## MAIN CONTAINER
31
+ st.write("# Fact checking 🎸 Rocks!")
32
+ st.write()
33
+ st.markdown("""
34
+ ##### Enter a factual statement about [Rock music](https://en.wikipedia.org/wiki/List_of_mainstream_rock_performers) and let the AI check it out for you...
35
+ """)
36
+ # Search bar
37
+ question = st.text_input("", value=st.session_state.question,
38
+ max_chars=100, on_change=reset_results)
39
+ col1, col2 = st.columns(2)
40
+ col1.markdown(
41
+ "<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)
42
+ col2.markdown(
43
+ "<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)
44
+ # Run button
45
+ run_pressed = col1.button("Run")
46
+ # Random question button
47
+ if col2.button("Random question"):
48
+ reset_results()
49
+ question = random.choice(questions)
50
+ # Avoid picking the same question twice (the change is not visible on the UI)
51
+ while question == st.session_state.question:
52
+ question = random.choice(questions)
53
+ st.session_state.question = question
54
+ st.session_state.random_question_requested = True
55
+ # Re-runs the script setting the random question as the textbox value
56
+ # Unfortunately necessary as the Random Question button is _below_ the textbox
57
+ # raise st.script_runner.RerunException(
58
+ # st.script_request_queue.RerunData(None))
59
+ else:
60
+ st.session_state.random_question_requested = False
61
+ run_query = (run_pressed or question != st.session_state.question) \
62
+ and not st.session_state.random_question_requested
63
+
64
+ # Get results for query
65
+ if run_query and question:
66
+ time_start = time.time()
67
+ reset_results()
68
+ st.session_state.question = question
69
+ with st.spinner("🧠 &nbsp;&nbsp; Performing neural search on documents..."):
70
+ try:
71
+ st.session_state.results = query(
72
+ question, RETRIEVER_TOP_K)
73
+ time_end = time.time()
74
+ print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
75
+ print(f'elapsed time: {time_end - time_start}')
76
+ except JSONDecodeError as je:
77
+ st.error(
78
+ "👓 &nbsp;&nbsp; An error occurred reading the results. Is the document store working?")
79
+ return
80
+ except Exception as e:
81
+ logging.exception(e)
82
+ st.error("🐞 &nbsp;&nbsp; An error occurred during the request.")
83
+ return
84
+
85
+ # # Display results
86
+ # if st.session_state.results:
87
+ # st.write("## Results:")
88
+ # alert_irrelevance = True
89
+ # if len(st.session_state.results['answers']) == 0:
90
+ # st.info("""🤔 &nbsp;&nbsp; Haystack is unsure whether any of
91
+ # the documents contain an answer to your question. Try to reformulate it!""")
92
+
93
+ # for result in st.session_state.results['answers']:
94
+ # result = result.to_dict()
95
+ # if result["answer"]:
96
+ # if alert_irrelevance and result['score'] < LOW_RELEVANCE_THRESHOLD:
97
+ # alert_irrelevance = False
98
+ # st.write("""
99
+ # <h4 style='color: darkred'>Attention, the
100
+ # following answers have low relevance:</h4>""",
101
+ # unsafe_allow_html=True)
102
+
103
+ # answer, context = result["answer"], result["context"]
104
+ # start_idx = context.find(answer)
105
+ # end_idx = start_idx + len(answer)
106
+ # # Hack due to this bug: https://github.com/streamlit/streamlit/issues/3190
107
+ # st.write(markdown("- ..."+context[:start_idx] +
108
+ # str(annotation(answer, "ANSWER", "#3e1c21", "white")) +
109
+ # context[end_idx:]+"..."), unsafe_allow_html=True)
110
+ # source = ""
111
+ # name = unquote(result['meta']['name']).replace('_', ' ')
112
+ # url = result['meta']['url']
113
+ # source = f"[{name}]({url})"
114
+ # st.markdown(
115
+ # f"**Score:** {result['score']:.2f} - **Source:** {source}")
116
+
117
+ main()
app_utils/backend_utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from haystack.document_stores import FAISSDocumentStore
3
+ from haystack.nodes import EmbeddingRetriever
4
+ from haystack.pipelines import Pipeline
5
+
6
+ import streamlit as st
7
+
8
+ from app_utils.entailment_checker import EntailmentChecker
9
+
10
+ from app_utils.config import STATEMENTS_PATH, INDEX_DIR, RETRIEVER_MODEL, RETRIEVER_MODEL_FORMAT, NLI_MODEL
11
+
12
+ # cached to make index and models load only at start
13
+ @st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True)
14
+ def start_haystack():
15
+ """
16
+ load document store, retriever, reader and create pipeline
17
+ """
18
+ shutil.copy(f'{INDEX_DIR}/faiss_document_store.db', '.')
19
+ document_store = FAISSDocumentStore(
20
+ faiss_index_path=f'{INDEX_DIR}/my_faiss_index.faiss',
21
+ faiss_config_path=f'{INDEX_DIR}/my_faiss_index.json')
22
+ print(f'Index size: {document_store.get_document_count()}')
23
+
24
+ retriever = EmbeddingRetriever(
25
+ document_store=document_store,
26
+ embedding_model=RETRIEVER_MODEL,
27
+ model_format=RETRIEVER_MODEL_FORMAT
28
+ )
29
+
30
+ entailment_checker = EntailmentChecker(model_name_or_path=NLI_MODEL,
31
+ use_gpu=False)
32
+
33
+
34
+ pipe = Pipeline()
35
+ pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
36
+ pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"])
37
+ return pipe
38
+
39
+ pipe = start_haystack()
40
+ # the pipeline is not included as parameter of the following function,
41
+ # because it is difficult to cache
42
+ @st.cache(persist=True, allow_output_mutation=True)
43
+ def query(question: str, retriever_top_k: int = 5):
44
+ """Run query and get answers"""
45
+ params = {"retriever": {"top_k": retriever_top_k}}
46
+ results = pipe.run(question, params=params)
47
+ print(results)
48
+ return results
49
+
50
+ @st.cache()
51
+ def load_questions():
52
+ """Load statements from file"""
53
+ with open(STATEMENTS_PATH) as fin:
54
+ questions = [line.strip() for line in fin.readlines()
55
+ if not line.startswith('#')]
56
+ return questions
57
+
58
+
app_utils/config.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ INDEX_DIR = 'data/index'
3
+ STATEMENTS_PATH = 'data/statements.txt'
4
+
5
+ RETRIEVER_MODEL = "sentence-transformers/msmarco-distilbert-base-tas-b"
6
+ RETRIEVER_MODEL_FORMAT = "sentence_transformers"
7
+ RETRIEVER_TOP_K = 5
8
+ NLI_MODEL = "valhalla/distilbart-mnli-12-1"
app_utils/frontend_utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def set_state_if_absent(key, value):
5
+ if key not in st.session_state:
6
+ st.session_state[key] = value
7
+
8
+ # Small callback to reset the interface in case the text of the question changes
9
+ def reset_results(*args):
10
+ st.session_state.answer = None
11
+ st.session_state.results = None
12
+ st.session_state.raw_json = None
13
+
14
+
15
+
data/statements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Kurt Cobain died in 1994
2
+ Kurt Cobain died in 2008
3
+ Green Day are a heavy metal band
4
+ Green Day are a punk rock band
5
+ The Beatles' first album was released in 1985
pages/Info.py CHANGED
@@ -1,3 +1,3 @@
1
  import streamlit as st
2
 
3
- st.title("Test")
 
1
  import streamlit as st
2
 
3
+