Spaces:
Runtime error
Runtime error
import streamlit as st | |
import wikipediaapi | |
from Article import Article | |
from QueryProcessor import QueryProcessor | |
from QuestionAnswer import QuestionAnswer | |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
model = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2') | |
tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2') | |
st.write(""" | |
# Wiki Chat | |
""") | |
placeholder = st.empty() | |
wiki_wiki = wikipediaapi.Wikipedia('en') | |
if "found_article" not in st.session_state: | |
st.session_state.page = 0 | |
st.session_state.found_article = False | |
st.session_state.article = '' | |
st.session_state.conversation = [] | |
st.session_state.article_data = {} | |
def get_article(): | |
article_name = placeholder.text_input('Enter the name of a Wikipedia article', '') | |
if article_name: | |
page = wiki_wiki.page(article_name) | |
if page.exists(): | |
st.session_state.found_article = True | |
st.session_state.article = article_name | |
article = Article(article_name=article_name) | |
st.session_state.article_data = article.get_article_data() | |
ask_questions() | |
else: | |
st.write(f'Sorry, could not find Wikipedia article: {article_name}') | |
def ask_questions(): | |
question = placeholder.text_input(f"Ask questions about {st.session_state.article}", '') | |
st.header("Questions and Answers:") | |
if question: | |
query_processor = QueryProcessor( | |
question=question, | |
section_texts=st.session_state.article_data['article_data'], | |
N=st.session_state.article_data['num_docs'], | |
avg_doc_len=st.session_state.article_data['avg_doc_len'] | |
) | |
context = query_processor.get_context() | |
data = { | |
'question': question, | |
'context': context | |
} | |
qa = QuestionAnswer(data, model, tokenizer, 'cpu') | |
results = qa.get_results() | |
answer = '' | |
for r in results: | |
answer += r['text']+", " | |
answer = answer[:len(answer)-2] | |
st.session_state.conversation.append({'question' : question, 'answer': answer}) | |
st.session_state.conversation.reverse() | |
# print(results) | |
if len(st.session_state.conversation) > 0: | |
for data in st.session_state.conversation: | |
st.text("Question: " + data['question'] + "\n" + "Answer: " + data['answer'] ) | |
if st.session_state.found_article == False: | |
get_article() | |
else: | |
ask_questions() |