Gladiator commited on
Commit
4065f3f
1 Parent(s): ea0864a

cache load model funcs for faster load times

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -3,13 +3,12 @@ import streamlit as st
3
  from extractive_summarizer.model_processors import Summarizer
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
5
 
6
- def abstractive_summarizer(text : str):
7
- abs_model = T5ForConditionalGeneration.from_pretrained('t5-large')
8
  tokenizer = T5Tokenizer.from_pretrained('t5-large')
9
  device = torch.device('cpu')
10
  preprocess_text = text.strip().replace("\n", "")
11
  t5_prepared_text = "summarize: " + preprocess_text
12
- tokenized_text = tokenizer.encode(t5_prepared_text, return_tensors="pt").to("cpu")
13
 
14
  # summmarize
15
  summary_ids = abs_model.generate(tokenized_text,
@@ -22,6 +21,17 @@ def abstractive_summarizer(text : str):
22
 
23
  return abs_summarized_text
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  if __name__ == "__main__":
26
  # ---------------------------------
27
  # Main Application
@@ -42,12 +52,12 @@ if __name__ == "__main__":
42
  if summarize_type == "Extractive":
43
  # extractive summarizer
44
 
45
- ext_model = Summarizer()
46
  summarized_text = ext_model(inp_text, num_sentences=5)
47
 
48
  elif summarize_type == "Abstractive":
49
-
50
- summarized_text = abstractive_summarizer(inp_text)
51
 
52
  # final summarized output
53
  st.subheader("Summarized text")
 
3
  from extractive_summarizer.model_processors import Summarizer
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
5
 
6
+ def abstractive_summarizer(text : str, model):
 
7
  tokenizer = T5Tokenizer.from_pretrained('t5-large')
8
  device = torch.device('cpu')
9
  preprocess_text = text.strip().replace("\n", "")
10
  t5_prepared_text = "summarize: " + preprocess_text
11
+ tokenized_text = tokenizer.encode(t5_prepared_text, return_tensors="pt").to(device)
12
 
13
  # summmarize
14
  summary_ids = abs_model.generate(tokenized_text,
 
21
 
22
  return abs_summarized_text
23
 
24
+ @st.cache()
25
+ def load_ext_model():
26
+ model = Summarizer()
27
+ return model
28
+
29
+ @st.cache()
30
+ def load_abs_model():
31
+ model = T5ForConditionalGeneration.from_pretrained('t5-large')
32
+ return model
33
+
34
+
35
  if __name__ == "__main__":
36
  # ---------------------------------
37
  # Main Application
 
52
  if summarize_type == "Extractive":
53
  # extractive summarizer
54
 
55
+ ext_model = load_ext_model()
56
  summarized_text = ext_model(inp_text, num_sentences=5)
57
 
58
  elif summarize_type == "Abstractive":
59
+ abs_model = load_abs_model()
60
+ summarized_text = abstractive_summarizer(inp_text, model=abs_model)
61
 
62
  # final summarized output
63
  st.subheader("Summarized text")