import streamlit as st import wikipediaapi from Article import Article from QueryProcessor import QueryProcessor from QuestionAnswer import QuestionAnswer from transformers import AutoTokenizer, AutoModelForQuestionAnswering model = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2') tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2') st.write(""" # Wiki Chat """) placeholder = st.empty() wiki_wiki = wikipediaapi.Wikipedia('en') if "found_article" not in st.session_state: st.session_state.page = 0 st.session_state.found_article = False st.session_state.article = '' st.session_state.conversation = [] st.session_state.article_data = {} def get_article(): article_name = placeholder.text_input('Enter the name of a Wikipedia article', '') if article_name: page = wiki_wiki.page(article_name) if page.exists(): st.session_state.found_article = True st.session_state.article = article_name article = Article(article_name=article_name) st.session_state.article_data = article.get_article_data() ask_questions() else: st.write(f'Sorry, could not find Wikipedia article: {article_name}') def ask_questions(): question = placeholder.text_input(f"Ask questions about {st.session_state.article}", '') st.header("Questions and Answers:") if question: query_processor = QueryProcessor( question=question, section_texts=st.session_state.article_data['article_data'], N=st.session_state.article_data['num_docs'], avg_doc_len=st.session_state.article_data['avg_doc_len'] ) context = query_processor.get_context() data = { 'question': question, 'context': context } qa = QuestionAnswer(data, model, tokenizer, 'cpu') results = qa.get_results() answer = '' for r in results: answer += r['text']+", " answer = answer[:len(answer)-2] st.session_state.conversation.append({'question' : question, 'answer': answer}) st.session_state.conversation.reverse() # print(results) if len(st.session_state.conversation) > 0: for data in st.session_state.conversation: st.text("Question: " + data['question'] + "\n" + "Answer: " + data['answer'] ) if st.session_state.found_article == False: get_article() else: ask_questions()