autosumm / app.py
mhsvieira's picture
Add timer
a9e7556
raw
history blame
No virus
3.3 kB
import streamlit as st
from extractor import extract, FewDocumentsError
from summarizer import summarize
from translation import translate
from utils.timing import Timer
import cProfile
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from os import environ
@st.cache(allow_output_mutation=True)
def init():
# Dowload required NLTK resources
from nltk import download
download('punkt')
download('stopwords')
device = "cuda" if torch.cuda.is_available() else "cpu"
# Model for semantic searches
search_model = SentenceTransformer('msmarco-distilbert-base-v4', device=device)
# Model for abstraction
summ_model = AutoModelForSeq2SeqLM.from_pretrained('t5-base')
tokenizer = AutoTokenizer.from_pretrained('t5-base')
return search_model, summ_model, tokenizer
def main():
search_model, summ_model, tokenizer = init()
Timer.reset()
st.title("AutoSumm")
st.subheader("Lucas Antunes & Matheus Vieira")
portuguese = st.checkbox('Traduzir para o portugu锚s.')
if portuguese:
environ['PORTUGUESE'] = 'true' # work around (gambiarra)
st.subheader("Digite o t贸pico sobre o qual voc锚 deseja gerar um resumo")
query_pt = st.text_input('Digite o t贸pico') #text is stored in this variable
button = st.button('Gerar resumo')
else:
environ['PORTUGUESE'] = 'false' # work around (gambiarra)
st.subheader("Type the desired topic to generate the summary")
query = st.text_input('Type your topic') #text is stored in this variable
button = st.button('Generate summary')
result = st.empty()
if 'few_documents' not in st.session_state:
st.session_state['few_documents'] = False
few_documents = False
else:
few_documents = st.session_state['few_documents']
if button:
query = translate(query_pt, 'pt', 'en') if portuguese else query
try:
text = extract(query, search_model=search_model)
except FewDocumentsError as e:
few_documents = True
st.session_state['few_documents'] = True
st.session_state['documents'] = e.documents
st.session_state['msg'] = e.msg
else:
summary = summarize(text, summ_model, tokenizer)
if portuguese:
result.markdown(f'Seu resumo para "{query_pt}":\n\n> {translate(summary, "en", "pt")}')
else:
result.markdown(f'Your summary for "{query}":\n\n> {summary}')
Timer.show_total()
if few_documents:
st.warning(st.session_state['msg'])
if st.button('Prosseguir'):
text = extract(query, search_model=search_model, extracted_documents=st.session_state['documents'])
summary = summarize(text, summ_model, tokenizer)
if portuguese:
result.markdown(f'Seu resumo para "{query_pt}":\n\n> {translate(summary, "en", "pt")}')
else:
result.markdown(f'Your summary for "{query}":\n\n> {summary}')
st.session_state['few_documents'] = False
few_documents = False
if __name__ == '__main__':
cProfile.run('main()', 'stats.txt')