File size: 3,295 Bytes
e539b70
 
 
c16fec3
a9e7556
bfbd0a1
78a71e8
 
 
a9e7556
e539b70
78a71e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f37c0ba
bfbd0a1
78a71e8
a9e7556
e539b70
c16fec3
 
 
 
e539b70
c16fec3
a9e7556
c16fec3
 
 
 
a9e7556
c16fec3
 
 
e539b70
a9e7556
 
bfbd0a1
e539b70
 
bfbd0a1
 
 
c16fec3
 
bfbd0a1
a9e7556
bfbd0a1
 
 
 
 
 
 
a9e7556
bfbd0a1
c16fec3
a9e7556
c16fec3
a9e7556
 
 
bfbd0a1
 
 
 
 
a9e7556
 
bfbd0a1
c16fec3
a9e7556
c16fec3
a9e7556
bfbd0a1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import streamlit as st
from extractor import extract, FewDocumentsError
from summarizer import summarize
from translation import translate
from utils.timing import Timer
import cProfile
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from os import environ

@st.cache(allow_output_mutation=True)
def init():
    # Dowload required NLTK resources
    from nltk import download
    download('punkt')
    download('stopwords')

    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Model for semantic searches
    search_model = SentenceTransformer('msmarco-distilbert-base-v4', device=device)
    # Model for abstraction
    summ_model = AutoModelForSeq2SeqLM.from_pretrained('t5-base')
    tokenizer = AutoTokenizer.from_pretrained('t5-base')

    return search_model, summ_model, tokenizer

def main():
    search_model, summ_model, tokenizer = init()
    Timer.reset()

    st.title("AutoSumm")
    st.subheader("Lucas Antunes & Matheus Vieira")

    portuguese = st.checkbox('Traduzir para o português.')

    if portuguese:
        environ['PORTUGUESE'] = 'true' # work around (gambiarra)
        st.subheader("Digite o tópico sobre o qual você deseja gerar um resumo")
        query_pt = st.text_input('Digite o tópico') #text is stored in this variable
        button = st.button('Gerar resumo')
    else:
        environ['PORTUGUESE'] = 'false' # work around (gambiarra)
        st.subheader("Type the desired topic to generate the summary")
        query = st.text_input('Type your topic') #text is stored in this variable
        button = st.button('Generate summary')

    result = st.empty()

    if 'few_documents' not in st.session_state:
        st.session_state['few_documents'] = False
        few_documents = False
    else:
        few_documents = st.session_state['few_documents']

    if button:
        query = translate(query_pt, 'pt', 'en') if portuguese else query
        try:
            text = extract(query, search_model=search_model)
        except FewDocumentsError as e:
            few_documents = True
            st.session_state['few_documents'] = True
            st.session_state['documents'] = e.documents
            st.session_state['msg'] = e.msg
        else:

            summary = summarize(text, summ_model, tokenizer)

            if portuguese:
                result.markdown(f'Seu resumo para "{query_pt}":\n\n> {translate(summary, "en", "pt")}')
            else:
                result.markdown(f'Your summary for "{query}":\n\n> {summary}')

            Timer.show_total()


    if few_documents:
        st.warning(st.session_state['msg'])
        if st.button('Prosseguir'):
            text = extract(query, search_model=search_model, extracted_documents=st.session_state['documents'])
            summary = summarize(text, summ_model, tokenizer)

            if portuguese:
                result.markdown(f'Seu resumo para "{query_pt}":\n\n> {translate(summary, "en", "pt")}')
            else:
                result.markdown(f'Your summary for "{query}":\n\n> {summary}')

            st.session_state['few_documents'] = False
            few_documents = False
            
if __name__ == '__main__':
    cProfile.run('main()', 'stats.txt')