File size: 3,056 Bytes
de5d292
 
109a51e
de5d292
6f2256e
6583354
6f2256e
 
 
 
 
de5d292
6f2256e
080ccca
03f073b
de5d292
6f2256e
3d99e17
63163f7
109a51e
63163f7
de5d292
 
6f2256e
2da57d4
 
6f2256e
03f073b
 
0f5dc43
 
03f073b
6f2256e
080ccca
 
 
 
6f2256e
080ccca
 
 
0f5dc43
 
 
 
 
 
787cf57
0b4cb0b
109a51e
6f2256e
 
2da57d4
0f5dc43
 
 
 
 
de5d292
 
03f073b
de5d292
 
 
 
6b58923
de5d292
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import streamlit as st
from transformers import pipeline
from transformers import AutoModelForQuestionAnswering, AutoTokenizer

# set page title
st.set_page_config(page_title="Automated Question Answering System")
# heading 
st.markdown("<h2 style='text-align: center; color:grey;'>Question Answering on Academic Essays</h2>", unsafe_allow_html=True)
st.markdown("<h3 style='text-align: left; color:#F63366; font-size:18px;'><b>What is extractive question answering about?<b></h3>", unsafe_allow_html=True)
st.write("Extractive question answering is a Natural Language Processing task where text is provided for a model so that the model can refer to it and make predictions about where the answer to a question is.")
# st.markdown('___')

# store the model in cache resources to enhance efficiency
# ref: https://docs.streamlit.io/library/advanced-features/caching 
@st.cache_resource(show_spinner=True)
def question_model():
    # call my model for question answering
    model_name = "kxx-kkk/FYP_deberta-v3-base-squad2_mrqa"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForQuestionAnswering.from_pretrained(model_name)
    question_answerer = pipeline("question-answering", model=model, tokenizer=tokenizer)
    return question_answerer

# choose the source with different tabs
tab1, tab2 = st.tabs(["Input text", "Upload File"])

# if type the text as input
with tab1:    
    sample_question = "What is NLP?"
    with open("sample.txt", "r") as text_file:
        sample_text = text_file.read()

    example = st.button("Try example")

    context = st.text_area("Enter the essay below:", key="context", height=330)
    question = st.text_input(label="Enter the question: ", key="question")

    if example:
        st.session_state.context = sample_text
        st.session_state.question = sample_question

    button = st.button("Get answer")
    if button:
        with st.spinner(text="Loading question model..."):
            question_answerer = question_model()
        with st.spinner(text="Getting answer..."):
            answer = question_answerer(context=context, question=question)
            answer = answer["answer"]
            container = st.container(border=True)
            container.write("<h5><b>Answer:</b></h5>" + answer, unsafe_allow_html=True)

# if upload file as input  
with tab2:
    uploaded_file = st.file_uploader("Choose a .txt file to upload", type=["txt"])
    if uploaded_file is not None:
        raw_text = str(uploaded_file.read(),"utf-8")
        context = st.text_area("", value=raw_text, height=330)
        question = st.text_input(label="Enter your question", value=sample_question)
        button = st.button("Get answer")
        if button:
            with st.spinner(text="Loading question model..."):
                question_answerer = question_model()
            with st.spinner(text="Getting answer..."):
                answer = question_answerer(context=context, question=question)
                answer = answer["answer"]
                st.success(answer)