Gladiator commited on
Commit
4b21134
1 Parent(s): 85ebc15

modularized code

Browse files
app.py CHANGED
@@ -1,35 +1,19 @@
1
  import torch
2
  import streamlit as st
3
- from extractive_summarizer.model_processors import Summarizer
4
- from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
5
-
6
- def abstractive_summarizer(text : str, model):
7
- tokenizer = T5Tokenizer.from_pretrained('t5-large')
8
- device = torch.device('cpu')
9
- preprocess_text = text.strip().replace("\n", "")
10
- t5_prepared_text = "summarize: " + preprocess_text
11
- tokenized_text = tokenizer.encode(t5_prepared_text, return_tensors="pt").to(device)
12
 
13
- # summmarize
14
- summary_ids = abs_model.generate(tokenized_text,
15
- num_beams=4,
16
- no_repeat_ngram_size=2,
17
- min_length=30,
18
- max_length=100,
19
- early_stopping=True)
20
- abs_summarized_text = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
21
-
22
- return abs_summarized_text
23
 
24
- # @st.cache()
25
- # def load_ext_model():
26
- # model = Summarizer()
27
- # return model
28
 
 
29
  @st.cache()
30
  def load_abs_model():
31
- model = T5ForConditionalGeneration.from_pretrained('t5-base')
32
- return model
 
33
 
34
 
35
  if __name__ == "__main__":
@@ -37,10 +21,14 @@ if __name__ == "__main__":
37
  # Main Application
38
  # ---------------------------------
39
  st.title("Text Summarizer 📝")
40
- summarize_type = st.sidebar.selectbox("Summarization type", options=["Extractive", "Abstractive"])
 
 
41
 
42
  inp_text = st.text_input("Enter the text here")
43
 
 
 
44
  # view summarized text (expander)
45
  with st.expander("View input text"):
46
  st.write(inp_text)
@@ -51,16 +39,22 @@ if __name__ == "__main__":
51
  if summarize:
52
  if summarize_type == "Extractive":
53
  # extractive summarizer
54
-
55
- with st.spinner(text="Creating extractive summary. This might take a few seconds ..."):
 
 
56
  ext_model = Summarizer()
57
  summarized_text = ext_model(inp_text, num_sentences=5)
58
-
59
- elif summarize_type == "Abstractive":
60
- with st.spinner(text="Creating abstractive summary. This might take a few seconds ..."):
61
- abs_model = load_abs_model()
62
- summarized_text = abstractive_summarizer(inp_text, model=abs_model)
63
 
64
- # final summarized output
 
 
 
 
 
 
 
 
 
65
  st.subheader("Summarized text")
66
  st.info(summarized_text)
 
1
  import torch
2
  import streamlit as st
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
 
 
 
 
 
 
 
 
4
 
5
+ # local modules
6
+ from extractive_summarizer.model_processors import Summarizer
7
+ from src.utils import clean_text
8
+ from src.abstractive_summarizer import abstractive_summarizer
 
 
 
 
 
 
9
 
 
 
 
 
10
 
11
+ # abstractive summarizer model
12
  @st.cache()
13
  def load_abs_model():
14
+ tokenizer = T5Tokenizer.from_pretrained("t5-large")
15
+ model = T5ForConditionalGeneration.from_pretrained("t5-base")
16
+ return tokenizer, model
17
 
18
 
19
  if __name__ == "__main__":
 
21
  # Main Application
22
  # ---------------------------------
23
  st.title("Text Summarizer 📝")
24
+ summarize_type = st.sidebar.selectbox(
25
+ "Summarization type", options=["Extractive", "Abstractive"]
26
+ )
27
 
28
  inp_text = st.text_input("Enter the text here")
29
 
30
+ inp_text = clean_text(inp_text)
31
+
32
  # view summarized text (expander)
33
  with st.expander("View input text"):
34
  st.write(inp_text)
 
39
  if summarize:
40
  if summarize_type == "Extractive":
41
  # extractive summarizer
42
+
43
+ with st.spinner(
44
+ text="Creating extractive summary. This might take a few seconds ..."
45
+ ):
46
  ext_model = Summarizer()
47
  summarized_text = ext_model(inp_text, num_sentences=5)
 
 
 
 
 
48
 
49
+ elif summarize_type == "Abstractive":
50
+ with st.spinner(
51
+ text="Creating abstractive summary. This might take a few seconds ..."
52
+ ):
53
+ abs_tokenizer, abs_model = load_abs_model()
54
+ summarized_text = abstractive_summarizer(
55
+ abs_tokenizer, abs_model, inp_text
56
+ )
57
+
58
+ # final summarized output
59
  st.subheader("Summarized text")
60
  st.info(summarized_text)
src/abstractive_summarizer.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5Tokenizer
3
+
4
+
5
+ def abstractive_summarizer(tokenizer, model, text):
6
+ device = torch.device("cpu")
7
+ preprocess_text = text.strip().replace("\n", "")
8
+ t5_prepared_text = "summarize: " + preprocess_text
9
+ tokenized_text = tokenizer.encode(t5_prepared_text, return_tensors="pt").to(device)
10
+
11
+ # summmarize
12
+ summary_ids = model.generate(
13
+ tokenized_text,
14
+ num_beams=4,
15
+ no_repeat_ngram_size=2,
16
+ min_length=30,
17
+ max_length=100,
18
+ early_stopping=True,
19
+ )
20
+ abs_summarized_text = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
21
+
22
+ return abs_summarized_text
src/vanilla_summarizer.py DELETED
File without changes