|
|
|
|
|
import time |
|
import streamlit as st |
|
import logging |
|
from json import JSONDecodeError |
|
from markdown import markdown |
|
from annotated_text import annotation |
|
from urllib.parse import unquote |
|
import random |
|
|
|
from app_utils.backend_utils import load_questions, query |
|
from app_utils.frontend_utils import (set_state_if_absent, reset_results, |
|
SIDEBAR_STYLE, TWIN_PEAKS_IMG_SRC, LAURA_PALMER_IMG_SRC, SPOTIFY_IFRAME) |
|
from app_utils.config import RETRIEVER_TOP_K, READER_TOP_K, LOW_RELEVANCE_THRESHOLD |
|
|
|
def main(): |
|
questions = load_questions() |
|
|
|
|
|
set_state_if_absent('question', "Where is Twin Peaks?") |
|
set_state_if_absent('answer', '') |
|
set_state_if_absent('results', None) |
|
set_state_if_absent('raw_json', None) |
|
set_state_if_absent('random_question_requested', False) |
|
|
|
|
|
st.markdown(SIDEBAR_STYLE, unsafe_allow_html=True) |
|
st.sidebar.header("Who killed Laura Palmer?") |
|
st.sidebar.image(TWIN_PEAKS_IMG_SRC) |
|
st.sidebar.markdown(f""" |
|
<p align="center"><b>Twin Peaks Question Answering system</b></p> |
|
<div class="haystack-footer"> |
|
<p><a href="https://github.com/anakin87/who-killed-laura-palmer" target="_blank">GitHub</a> - |
|
Built with <a href="https://github.com/deepset-ai/haystack/" target="_blank">Haystack</a><br/> |
|
<small>Data crawled from <a href="https://twinpeaks.fandom.com/wiki/Twin_Peaks_Wiki" target="_blank"> |
|
Twin Peaks Wiki</a>.</small> |
|
</p><img src="{LAURA_PALMER_IMG_SRC}"/><br/></div> |
|
""", unsafe_allow_html=True) |
|
|
|
st.sidebar.markdown(SPOTIFY_IFRAME, unsafe_allow_html=True) |
|
|
|
|
|
st.write("# Who killed Laura Palmer?") |
|
st.write("### The first Twin Peaks Question Answering system!") |
|
st.markdown(""" |
|
Ask any question about [Twin Peaks] |
|
(https://twinpeaks.fandom.com/wiki/Twin_Peaks) |
|
and see if the AI ββcan find an answer... |
|
|
|
*Note: do not use keywords, but full-fledged questions.* |
|
""") |
|
|
|
question = st.text_input("", value=st.session_state.question, |
|
max_chars=100, on_change=reset_results) |
|
col1, col2 = st.columns(2) |
|
col1.markdown( |
|
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True) |
|
col2.markdown( |
|
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True) |
|
|
|
run_pressed = col1.button("Run") |
|
|
|
if col2.button("Random question"): |
|
reset_results() |
|
question = random.choice(questions) |
|
|
|
while question == st.session_state.question: |
|
question = random.choice(questions) |
|
st.session_state.question = question |
|
st.session_state.random_question_requested = True |
|
|
|
|
|
raise st.script_runner.RerunException( |
|
st.script_request_queue.RerunData(None)) |
|
else: |
|
st.session_state.random_question_requested = False |
|
run_query = (run_pressed or question != st.session_state.question) \ |
|
and not st.session_state.random_question_requested |
|
|
|
|
|
if run_query and question: |
|
time_start = time.time() |
|
reset_results() |
|
st.session_state.question = question |
|
with st.spinner("π§ Performing neural search on documents..."): |
|
try: |
|
st.session_state.results = query( |
|
question, RETRIEVER_TOP_K, READER_TOP_K) |
|
time_end = time.time() |
|
print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) |
|
print(f'elapsed time: {time_end - time_start}') |
|
except JSONDecodeError as je: |
|
st.error( |
|
"π An error occurred reading the results. Is the document store working?") |
|
return |
|
except Exception as e: |
|
logging.exception(e) |
|
st.error("π An error occurred during the request.") |
|
return |
|
|
|
|
|
if st.session_state.results: |
|
st.write("## Results:") |
|
alert_irrelevance = True |
|
if len(st.session_state.results['answers']) == 0: |
|
st.info("""π€ Haystack is unsure whether any of |
|
the documents contain an answer to your question. Try to reformulate it!""") |
|
|
|
for result in st.session_state.results['answers']: |
|
result = result.to_dict() |
|
if result["answer"]: |
|
if alert_irrelevance and result['score'] < LOW_RELEVANCE_THRESHOLD: |
|
alert_irrelevance = False |
|
st.write(""" |
|
<h4 style='color: darkred'>Attention, the |
|
following answers have low relevance:</h4>""", |
|
unsafe_allow_html=True) |
|
|
|
answer, context = result["answer"], result["context"] |
|
start_idx = context.find(answer) |
|
end_idx = start_idx + len(answer) |
|
|
|
st.write(markdown("- ..."+context[:start_idx] + |
|
str(annotation(answer, "ANSWER", "#3e1c21", "white")) + |
|
context[end_idx:]+"..."), unsafe_allow_html=True) |
|
source = "" |
|
name = unquote(result['meta']['name']).replace('_', ' ') |
|
url = result['meta']['url'] |
|
source = f"[{name}]({url})" |
|
st.markdown( |
|
f"**Score:** {result['score']:.2f} - **Source:** {source}") |
|
|
|
main() |
|
|