Gladiator commited on
Commit
4354680
1 Parent(s): b916752

add abs preprocess func

Browse files
Files changed (2) hide show
  1. app.py +24 -16
  2. src/abstractive_summarizer.py +37 -0
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
 
2
  import validators
3
  import streamlit as st
 
4
  from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration
5
 
6
  # local modules
@@ -11,7 +13,7 @@ from src.abstractive_summarizer import abstractive_summarizer
11
  # abstractive summarizer model
12
  @st.cache()
13
  def load_abs_model():
14
- tokenizer = T5Tokenizer.from_pretrained("t5-large")
15
  model = T5ForConditionalGeneration.from_pretrained("t5-base")
16
  return tokenizer, model
17
 
@@ -24,27 +26,30 @@ if __name__ == "__main__":
24
  summarize_type = st.sidebar.selectbox(
25
  "Summarization type", options=["Extractive", "Abstractive"]
26
  )
 
27
 
28
  inp_text = st.text_input("Enter text or a url here")
29
 
30
  is_url = validators.url(inp_text)
31
  if is_url:
32
  # complete text, chunks to summarize (list of sentences for long docs)
33
- text, text_to_summarize = fetch_article_text(url=inp_text)
34
  else:
35
- text_to_summarize = clean_text(inp_text)
36
 
37
  # view summarized text (expander)
38
  with st.expander("View input text"):
39
- st.write(text_to_summarize)
40
-
 
 
41
  summarize = st.button("Summarize")
42
 
43
  # called on toggle button [summarize]
44
  if summarize:
45
  if summarize_type == "Extractive":
46
  if is_url:
47
- text_to_summarize = " ".join([txt for txt in text_to_summarize])
48
  # extractive summarizer
49
 
50
  with st.spinner(
@@ -57,16 +62,19 @@ if __name__ == "__main__":
57
  with st.spinner(
58
  text="Creating abstractive summary. This might take a few seconds ..."
59
  ):
60
- abs_tokenizer, abs_model = load_abs_model()
61
- summarized_text = abstractive_summarizer(
62
- abs_tokenizer, abs_model, text_to_summarize
63
- )
64
- elif summarize_type == "Abstractive" and is_url:
65
- abs_url_summarizer = pipeline("summarization")
66
- tmp_sum = abs_url_summarizer(
67
- text_to_summarize, max_length=120, min_length=30, do_sample=False
68
- )
69
- summarized_text = " ".join([summ["summary_text"] for summ in tmp_sum])
 
 
 
70
 
71
  # final summarized output
72
  st.subheader("Summarized text")
 
1
  import torch
2
+ import nltk
3
  import validators
4
  import streamlit as st
5
+ from nltk.tokenize import sent_tokenize
6
  from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration
7
 
8
  # local modules
 
13
  # abstractive summarizer model
14
  @st.cache()
15
  def load_abs_model():
16
+ tokenizer = T5Tokenizer.from_pretrained("t5-base")
17
  model = T5ForConditionalGeneration.from_pretrained("t5-base")
18
  return tokenizer, model
19
 
 
26
  summarize_type = st.sidebar.selectbox(
27
  "Summarization type", options=["Extractive", "Abstractive"]
28
  )
29
+ nltk.download("punkt")
30
 
31
  inp_text = st.text_input("Enter text or a url here")
32
 
33
  is_url = validators.url(inp_text)
34
  if is_url:
35
  # complete text, chunks to summarize (list of sentences for long docs)
36
+ text, clean_txt = fetch_article_text(url=inp_text)
37
  else:
38
+ clean_txt = clean_text(inp_text)
39
 
40
  # view summarized text (expander)
41
  with st.expander("View input text"):
42
+ if is_url:
43
+ st.write(clean_txt[0])
44
+ else:
45
+ st.write(clean_txt)
46
  summarize = st.button("Summarize")
47
 
48
  # called on toggle button [summarize]
49
  if summarize:
50
  if summarize_type == "Extractive":
51
  if is_url:
52
+ text_to_summarize = " ".join([txt for txt in clean_txt])
53
  # extractive summarizer
54
 
55
  with st.spinner(
 
62
  with st.spinner(
63
  text="Creating abstractive summary. This might take a few seconds ..."
64
  ):
65
+ if not is_url:
66
+ text_to_summarize = sent_tokenize(clean_txt)
67
+
68
+ # abs_tokenizer, abs_model = load_abs_model()
69
+ # summarized_text = abstractive_summarizer(
70
+ # abs_tokenizer, abs_model, text_to_summarize
71
+ # )
72
+ # elif summarize_type == "Abstractive" and is_url:
73
+ # abs_url_summarizer = pipeline("summarization")
74
+ # tmp_sum = abs_url_summarizer(
75
+ # text_to_summarize, max_length=120, min_length=30, do_sample=False
76
+ # )
77
+ # summarized_text = " ".join([summ["summary_text"] for summ in tmp_sum])
78
 
79
  # final summarized output
80
  st.subheader("Summarized text")
src/abstractive_summarizer.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from transformers import T5Tokenizer
3
 
4
 
@@ -20,3 +21,39 @@ def abstractive_summarizer(tokenizer, model, text):
20
  abs_summarized_text = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
21
 
22
  return abs_summarized_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from nltk.tokenize import sent_tokenize
3
  from transformers import T5Tokenizer
4
 
5
 
 
21
  abs_summarized_text = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
22
 
23
  return abs_summarized_text
24
+
25
+
26
+ def preprocess_text_for_abstractive_summarization(tokenizer, text):
27
+ sentences = sent_tokenize(text)
28
+
29
+ # initialize
30
+ length = 0
31
+ chunk = ""
32
+ chunks = []
33
+ count = -1
34
+ for sentence in sentences:
35
+ count += 1
36
+ combined_length = (
37
+ len(tokenizer.tokenize(sentence)) + length
38
+ ) # add the no. of sentence tokens to the length counter
39
+
40
+ if combined_length <= tokenizer.max_len_single_sentence: # if it doesn't exceed
41
+ chunk += sentence + " " # add the sentence to the chunk
42
+ length = combined_length # update the length counter
43
+
44
+ # if it is the last sentence
45
+ if count == len(sentences) - 1:
46
+ chunks.append(chunk.strip()) # save the chunk
47
+
48
+ else:
49
+ chunks.append(chunk.strip()) # save the chunk
50
+
51
+ # reset
52
+ length = 0
53
+ chunk = ""
54
+
55
+ # take care of the overflow sentence
56
+ chunk += sentence + " "
57
+ length = len(tokenizer.tokenize(sentence))
58
+
59
+ return chunks