import streamlit as st from extractor import extract, FewDocumentsError from summarizer import summarize from translation import translate from utils.timing import Timer import cProfile from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch from os import environ @st.cache(allow_output_mutation=True) def init(): # Dowload required NLTK resources from nltk import download download('punkt') download('stopwords') device = "cuda" if torch.cuda.is_available() else "cpu" # Model for semantic searches search_model = SentenceTransformer('msmarco-distilbert-base-v4', device=device) # Model for abstraction summ_model = AutoModelForSeq2SeqLM.from_pretrained('t5-base') tokenizer = AutoTokenizer.from_pretrained('t5-base') return search_model, summ_model, tokenizer def main(): search_model, summ_model, tokenizer = init() Timer.reset() st.title("AutoSumm") st.subheader("Lucas Antunes & Matheus Vieira") portuguese = st.checkbox('Traduzir para o português.') if portuguese: environ['PORTUGUESE'] = 'true' # work around (gambiarra) st.subheader("Digite o tópico sobre o qual você deseja gerar um resumo") query_pt = st.text_input('Digite o tópico') #text is stored in this variable button = st.button('Gerar resumo') else: environ['PORTUGUESE'] = 'false' # work around (gambiarra) st.subheader("Type the desired topic to generate the summary") query = st.text_input('Type your topic') #text is stored in this variable button = st.button('Generate summary') result = st.empty() if 'few_documents' not in st.session_state: st.session_state['few_documents'] = False few_documents = False else: few_documents = st.session_state['few_documents'] if button: query = translate(query_pt, 'pt', 'en') if portuguese else query try: text = extract(query, search_model=search_model) except FewDocumentsError as e: few_documents = True st.session_state['few_documents'] = True st.session_state['documents'] = e.documents st.session_state['msg'] = e.msg else: summary = summarize(text, summ_model, tokenizer) if portuguese: result.markdown(f'Seu resumo para "{query_pt}":\n\n> {translate(summary, "en", "pt")}') else: result.markdown(f'Your summary for "{query}":\n\n> {summary}') Timer.show_total() if few_documents: st.warning(st.session_state['msg']) if st.button('Prosseguir'): text = extract(query, search_model=search_model, extracted_documents=st.session_state['documents']) summary = summarize(text, summ_model, tokenizer) if portuguese: result.markdown(f'Seu resumo para "{query_pt}":\n\n> {translate(summary, "en", "pt")}') else: result.markdown(f'Your summary for "{query}":\n\n> {summary}') st.session_state['few_documents'] = False few_documents = False if __name__ == '__main__': cProfile.run('main()', 'stats.txt')