mhsvieira commited on
Commit
78a71e8
1 Parent(s): 49e74b8

Pre-load models

Browse files
app.py CHANGED
@@ -3,15 +3,30 @@ from extractor import extract, FewDocumentsError
3
  from summarizer import summarize
4
  import time
5
  import cProfile
 
 
 
6
 
7
- # Dowload required NLTK resources
8
- from nltk import download
9
- download('punkt')
10
- download('stopwords')
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # TODO: translation
13
 
14
  def main():
 
15
 
16
  st.title("Trabalho de Formatura - Construindo textos para a internet")
17
  st.subheader("Lucas Antunes e Matheus Vieira")
@@ -31,7 +46,7 @@ def main():
31
  start_time = time.time()
32
  try:
33
  with st.spinner('Extraindo textos relevantes...'):
34
- text = extract(query)
35
  except FewDocumentsError as e:
36
  few_documents = True
37
  st.session_state['few_documents'] = True
@@ -41,7 +56,7 @@ def main():
41
 
42
  st.info(f'(Extraction) Elapsed time: {time.time() - start_time:.2f}s')
43
  with st.spinner('Gerando resumo...'):
44
- summary = summarize(text)
45
  st.info(f'(Total) Elapsed time: {time.time() - start_time:.2f}s')
46
 
47
  st.markdown(f'Seu resumo para "{query}":\n\n> {summary}')
@@ -52,10 +67,10 @@ def main():
52
  if st.button('Prosseguir'):
53
  start_time = time.time()
54
  with st.spinner('Extraindo textos relevantes...'):
55
- text = extract(query, extracted_documents=st.session_state['documents'])
56
  st.info(f'(Extraction) Elapsed time: {time.time() - start_time:.2f}s')
57
  with st.spinner('Gerando resumo...'):
58
- summary = summarize(text)
59
  st.info(f'(Total) Elapsed time: {time.time() - start_time:.2f}s')
60
 
61
  st.markdown(f'Seu resumo para "{query}":\n\n> {summary}')
 
3
  from summarizer import summarize
4
  import time
5
  import cProfile
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
+ import torch
9
 
10
+ @st.cache(allow_output_mutation=True)
11
+ def init():
12
+ # Dowload required NLTK resources
13
+ from nltk import download
14
+ download('punkt')
15
+ download('stopwords')
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ # Model for semantic searches
19
+ search_model = SentenceTransformer('msmarco-distilbert-base-v4', device=device)
20
+ # Model for abstraction
21
+ summ_model = AutoModelForSeq2SeqLM.from_pretrained('t5-base')
22
+ tokenizer = AutoTokenizer.from_pretrained('t5-base')
23
+
24
+ return search_model, summ_model, tokenizer
25
 
26
  # TODO: translation
27
 
28
  def main():
29
+ search_model, summ_model, tokenizer = init()
30
 
31
  st.title("Trabalho de Formatura - Construindo textos para a internet")
32
  st.subheader("Lucas Antunes e Matheus Vieira")
 
46
  start_time = time.time()
47
  try:
48
  with st.spinner('Extraindo textos relevantes...'):
49
+ text = extract(query, search_model=search_model)
50
  except FewDocumentsError as e:
51
  few_documents = True
52
  st.session_state['few_documents'] = True
 
56
 
57
  st.info(f'(Extraction) Elapsed time: {time.time() - start_time:.2f}s')
58
  with st.spinner('Gerando resumo...'):
59
+ summary = summarize(text, summ_model, tokenizer)
60
  st.info(f'(Total) Elapsed time: {time.time() - start_time:.2f}s')
61
 
62
  st.markdown(f'Seu resumo para "{query}":\n\n> {summary}')
 
67
  if st.button('Prosseguir'):
68
  start_time = time.time()
69
  with st.spinner('Extraindo textos relevantes...'):
70
+ text = extract(query, search_model=search_model, extracted_documents=st.session_state['documents'])
71
  st.info(f'(Extraction) Elapsed time: {time.time() - start_time:.2f}s')
72
  with st.spinner('Gerando resumo...'):
73
+ summary = summarize(text, summ_model, tokenizer)
74
  st.info(f'(Total) Elapsed time: {time.time() - start_time:.2f}s')
75
 
76
  st.markdown(f'Seu resumo para "{query}":\n\n> {summary}')
extractor/_utils.py CHANGED
@@ -4,8 +4,6 @@ import streamlit as st
4
  # import inflect
5
  import torch
6
 
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
  # p = inflect.engine()
10
 
11
  class FewDocumentsError(Exception):
@@ -90,8 +88,8 @@ def paragraph_extraction(documents, min_paragraph_size):
90
  return paragraphs
91
 
92
  def semantic_search(model, query, files, number_of_similar_files):
93
- encoded_query = model.encode(query, device=device)
94
- encoded_files = model.encode(files, device=device)
95
 
96
  model_index = nmslib.init(method='hnsw', space='angulardist')
97
  model_index.addDataPointBatch(encoded_files)
 
4
  # import inflect
5
  import torch
6
 
 
 
7
  # p = inflect.engine()
8
 
9
  class FewDocumentsError(Exception):
 
88
  return paragraphs
89
 
90
  def semantic_search(model, query, files, number_of_similar_files):
91
+ encoded_query = model.encode(query)
92
+ encoded_files = model.encode(files)
93
 
94
  model_index = nmslib.init(method='hnsw', space='angulardist')
95
  model_index.addDataPointBatch(encoded_files)
extractor/extract.py CHANGED
@@ -1,4 +1,3 @@
1
- from sentence_transformers import SentenceTransformer
2
  from ._utils import FewDocumentsError
3
  from ._utils import document_extraction, paragraph_extraction, semantic_search
4
  from corpora import gen_corpus
@@ -6,9 +5,7 @@ from nltk.corpus import stopwords
6
  from nltk.tokenize import word_tokenize
7
  import string
8
 
9
- from ._utils import device
10
-
11
- def extract(query: str, n: int=3, extracted_documents: list=None) -> str:
12
  """Extract n paragraphs from the corpus using the given query.
13
 
14
  Parameters:
@@ -38,8 +35,6 @@ def extract(query: str, n: int=3, extracted_documents: list=None) -> str:
38
  )
39
 
40
  # First semantc search (over documents)
41
- # Model for semantic searches
42
- search_model = SentenceTransformer('msmarco-distilbert-base-v4', device=device)
43
  selected_documents, documents_distances = semantic_search(
44
  model=search_model,
45
  query=query,
 
 
1
  from ._utils import FewDocumentsError
2
  from ._utils import document_extraction, paragraph_extraction, semantic_search
3
  from corpora import gen_corpus
 
5
  from nltk.tokenize import word_tokenize
6
  import string
7
 
8
+ def extract(query: str, search_model, n: int=3, extracted_documents: list=None) -> str:
 
 
9
  """Extract n paragraphs from the corpus using the given query.
10
 
11
  Parameters:
 
35
  )
36
 
37
  # First semantc search (over documents)
 
 
38
  selected_documents, documents_distances = semantic_search(
39
  model=search_model,
40
  query=query,
summarizer/summarize.py CHANGED
@@ -1,14 +1,9 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
 
3
- def summarize(text: str) -> str:
4
  """
5
  Generate a summary based from the given text
6
  """
7
 
8
- # Model for abstraction
9
- model = AutoModelForSeq2SeqLM.from_pretrained('t5-base')
10
- tokenizer = AutoTokenizer.from_pretrained('t5-base')
11
-
12
  input_tokens = tokenizer.encode(
13
  f'summarize: {text}',
14
  return_tensors='pt',
 
 
1
 
2
+ def summarize(text: str, model, tokenizer) -> str:
3
  """
4
  Generate a summary based from the given text
5
  """
6
 
 
 
 
 
7
  input_tokens = tokenizer.encode(
8
  f'summarize: {text}',
9
  return_tensors='pt',