File size: 3,908 Bytes
4de8fd3
57998d7
a10ed5c
57998d7
 
 
5b7126d
 
 
5467249
4de8fd3
4ef8a52
c0a1eea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
087827a
5b7126d
 
 
 
c0a1eea
5b7126d
 
 
 
 
 
86668bc
 
4ef8a52
5467249
 
 
 
 
 
 
 
 
 
5b7126d
4ef8a52
087827a
 
5b7126d
 
c0a1eea
2ad73ca
5b7126d
 
 
2ad73ca
5b7126d
 
 
 
 
 
 
 
 
 
 
 
 
b6ac152
087827a
c0a1eea
 
 
 
 
 
 
5b7126d
 
 
 
 
 
 
 
 
 
 
5467249
5b7126d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
from haystack.utils import fetch_archive_from_http, clean_wiki_text, convert_files_to_docs
from haystack.schema import Answer
from haystack.document_stores import InMemoryDocumentStore
from haystack.pipelines import ExtractiveQAPipeline
from haystack.nodes import FARMReader, TfidfRetriever
import logging
from markdown import markdown
from annotated_text import annotation
from PIL import Image

#Haystack Components
@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},allow_output_mutation=True)
def start_haystack():
    document_store = InMemoryDocumentStore()
    load_and_write_data(document_store)
    retriever = TfidfRetriever(document_store=document_store)
    reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", use_gpu=True)
    pipeline = ExtractiveQAPipeline(reader, retriever)
    return pipeline

def load_and_write_data(document_store):
    doc_dir = './article_txt_got'
    docs = convert_files_to_docs(dir_path=doc_dir, clean_func=clean_wiki_text, split_paragraphs=True)

    document_store.write_documents(docs)

pipeline = start_haystack()

def set_state_if_absent(key, value):
    if key not in st.session_state:
        st.session_state[key] = value

set_state_if_absent("question", "Who is Arya's father?")
set_state_if_absent("results", None)


def reset_results(*args):
    st.session_state.results = None

#Streamlit App

st.title('Game of Thrones QA with Haystack')

image = Image.open('got-haystack.png')
st.image(image)

st.markdown( """
This QA demo uses a [Haystack Extractive QA Pipleine](https://haystack.deepset.ai/components/ready-made-pipelines#extractiveqapipeline) with 
an [InMemoryDocumentStore](https://haystack.deepset.ai/components/document-store) which contains documents about Game of Thrones πŸ‘‘
Go ahead and ask questions about the marvelous kingdom!
""", unsafe_allow_html=True)

question = st.text_input("", value=st.session_state.question, max_chars=100, on_change=reset_results)

def ask_question(question):
    prediction = pipeline.run(query=question, params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}})
    results = []
    for answer in prediction["answers"]:
        answer = answer.to_dict()
        if answer["answer"]:
            results.append(
                {
                    "context": "..." + answer["context"] + "...",
                    "answer": answer["answer"],
                    "relevance": round(answer["score"] * 100, 2),
                    "offset_start_in_doc": answer["offsets_in_document"][0]["start"],
                }
            )
        else:
            results.append(
                {
                    "context": None,
                    "answer": None,
                    "relevance": round(answer["score"] * 100, 2),
                }
            )
    return results

if question:
    with st.spinner("πŸ‘‘    Performing neural search on royal scripts..."):
        try:
            msg = 'Asked ' + question
            logging.info(msg)
            st.session_state.results = ask_question(question)    
        except Exception as e:
            logging.exception(e)
    


if st.session_state.results:
    st.write('## Top Results')
    for count, result in enumerate(st.session_state.results):
        if result["answer"]:
            answer, context = result["answer"], result["context"]
            start_idx = context.find(answer)
            end_idx = start_idx + len(answer)
            st.write(
                markdown(context[:start_idx] + str(annotation(answer, "ANSWER", "#964448")) + context[end_idx:]),
                unsafe_allow_html=True,
            )
            st.markdown(f"**Relevance:** {result['relevance']}")
        else:
            st.info(
                "πŸ€”    Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
            )