wiki-chat / app.py
Pennywise881's picture
Update app.py
e67127b
import streamlit as st
import wikipediaapi
from Article import Article
from QueryProcessor import QueryProcessor
from QuestionAnswer import QuestionAnswer
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
model = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
st.write("""
# Wiki Chat
""")
placeholder = st.empty()
wiki_wiki = wikipediaapi.Wikipedia('en')
if "found_article" not in st.session_state:
st.session_state.page = 0
st.session_state.found_article = False
st.session_state.article = ''
st.session_state.conversation = []
st.session_state.article_data = {}
def get_article():
article_name = placeholder.text_input('Enter the name of a Wikipedia article', '')
if article_name:
page = wiki_wiki.page(article_name)
if page.exists():
st.session_state.found_article = True
st.session_state.article = article_name
article = Article(article_name=article_name)
st.session_state.article_data = article.get_article_data()
ask_questions()
else:
st.write(f'Sorry, could not find Wikipedia article: {article_name}')
def ask_questions():
question = placeholder.text_input(f"Ask questions about {st.session_state.article}", '')
st.header("Questions and Answers:")
if question:
query_processor = QueryProcessor(
question=question,
section_texts=st.session_state.article_data['article_data'],
N=st.session_state.article_data['num_docs'],
avg_doc_len=st.session_state.article_data['avg_doc_len']
)
context = query_processor.get_context()
data = {
'question': question,
'context': context
}
qa = QuestionAnswer(data, model, tokenizer, 'cpu')
results = qa.get_results()
answer = ''
for r in results:
answer += r['text']+", "
answer = answer[:len(answer)-2]
st.session_state.conversation.append({'question' : question, 'answer': answer})
st.session_state.conversation.reverse()
# print(results)
if len(st.session_state.conversation) > 0:
for data in st.session_state.conversation:
st.text("Question: " + data['question'] + "\n" + "Answer: " + data['answer'] )
if st.session_state.found_article == False:
get_article()
else:
ask_questions()