Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import wikipediaapi | |
| from Article import Article | |
| from QueryProcessor import QueryProcessor | |
| from QuestionAnswer import QuestionAnswer | |
| from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
| model = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2') | |
| tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2') | |
| st.write(""" | |
| # Wiki Chat | |
| """) | |
| placeholder = st.empty() | |
| wiki_wiki = wikipediaapi.Wikipedia('en') | |
| if "found_article" not in st.session_state: | |
| st.session_state.page = 0 | |
| st.session_state.found_article = False | |
| st.session_state.article = '' | |
| st.session_state.conversation = [] | |
| st.session_state.article_data = {} | |
| def get_article(): | |
| article_name = placeholder.text_input('Enter the name of a Wikipedia article', '') | |
| if article_name: | |
| page = wiki_wiki.page(article_name) | |
| if page.exists(): | |
| st.session_state.found_article = True | |
| st.session_state.article = article_name | |
| article = Article(article_name=article_name) | |
| st.session_state.article_data = article.get_article_data() | |
| ask_questions() | |
| else: | |
| st.write(f'Sorry, could not find Wikipedia article: {article_name}') | |
| def ask_questions(): | |
| question = placeholder.text_input(f"Ask questions about {st.session_state.article}", '') | |
| st.header("Questions and Answers:") | |
| if question: | |
| query_processor = QueryProcessor( | |
| question=question, | |
| section_texts=st.session_state.article_data['article_data'], | |
| N=st.session_state.article_data['num_docs'], | |
| avg_doc_len=st.session_state.article_data['avg_doc_len'] | |
| ) | |
| context = query_processor.get_context() | |
| data = { | |
| 'question': question, | |
| 'context': context | |
| } | |
| qa = QuestionAnswer(data, model, tokenizer, 'cpu') | |
| results = qa.get_results() | |
| answer = '' | |
| for r in results: | |
| answer += r['text']+", " | |
| answer = answer[:len(answer)-2] | |
| st.session_state.conversation.append({'question' : question, 'answer': answer}) | |
| st.session_state.conversation.reverse() | |
| # print(results) | |
| if len(st.session_state.conversation) > 0: | |
| for data in st.session_state.conversation: | |
| st.text("Question: " + data['question'] + "\n" + "Answer: " + data['answer'] ) | |
| if st.session_state.found_article == False: | |
| get_article() | |
| else: | |
| ask_questions() |