Christian Koch commited on
Commit
9ed5930
1 Parent(s): cc3c391

paraphrase

Browse files
Files changed (2) hide show
  1. app.py +17 -9
  2. paraphrase.py +45 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  from transformers import pipeline, PegasusForConditionalGeneration, PegasusTokenizer
3
  from fill_in_summary import FillInSummary
 
4
 
5
  def paraphrase(text):
6
  return text
@@ -16,36 +17,43 @@ if select == "Summarization":
16
  # left_column.selectbox('Type', ['Question Generator', 'Paraphrasing'])
17
  #st.selectbox('Model', ['T5', 'GPT Neo-X'])
18
 
19
- input = st.text_area("Input Text")
20
 
21
  submitted = st.form_submit_button("Generate")
22
 
23
  if submitted:
24
- st.write(FillInSummary().summarize(input))
 
 
25
 
26
 
27
  if select == "Fill in the blank":
28
  with st.form("summarization"):
29
- input = st.text_area("Input Text")
30
 
31
  submitted = st.form_submit_button("Generate")
32
 
33
  if submitted:
34
- fill = FillInSummary()
35
- summarized = fill.summarize(input)
36
- st.write(fill.blank_ne_out(summarized))
 
 
37
 
38
 
39
  if select == "Paraphrasing":
40
  with st.form("paraphrasing"):
41
- st.selectbox('Model', ['T5', 'GPT Neo-X'])
42
 
43
- input = st.text_area("Input Text")
44
 
45
  submitted = st.form_submit_button("Generate")
46
 
47
  if submitted:
48
- st.write(paraphrase(input))
 
 
 
49
 
50
 
51
 
1
  import streamlit as st
2
  from transformers import pipeline, PegasusForConditionalGeneration, PegasusTokenizer
3
  from fill_in_summary import FillInSummary
4
+ from paraphrase import PegasusParaphraser
5
 
6
  def paraphrase(text):
7
  return text
17
  # left_column.selectbox('Type', ['Question Generator', 'Paraphrasing'])
18
  #st.selectbox('Model', ['T5', 'GPT Neo-X'])
19
 
20
+ text_input = st.text_area("Input Text")
21
 
22
  submitted = st.form_submit_button("Generate")
23
 
24
  if submitted:
25
+ with st.spinner('Wait for it...'):
26
+ result = FillInSummary().summarize(text_input)
27
+ st.write(text_input)
28
 
29
 
30
  if select == "Fill in the blank":
31
  with st.form("summarization"):
32
+ text_input = st.text_area("Input Text")
33
 
34
  submitted = st.form_submit_button("Generate")
35
 
36
  if submitted:
37
+ with st.spinner('Wait for it...'):
38
+ fill = FillInSummary()
39
+ result = fill.summarize(text_input)
40
+ result = fill.blank_ne_out(result)
41
+ st.write(result)
42
 
43
 
44
  if select == "Paraphrasing":
45
  with st.form("paraphrasing"):
46
+ # st.selectbox('Model', ['T5', 'GPT Neo-X'])
47
 
48
+ text_input = st.text_area("Input Text")
49
 
50
  submitted = st.form_submit_button("Generate")
51
 
52
  if submitted:
53
+ with st.spinner('Wait for it...'):
54
+ paraphrase_model = PegasusParaphraser()
55
+ result = paraphrase_model.paraphrase(text_input)
56
+ st.write(result)
57
 
58
 
59
 
paraphrase.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
2
+
3
+ class PegasusParaphraser:
4
+ """ Pegasus Model for Paraphrase"""
5
+
6
+ def __init__(self, num_return_sequences=3, num_beams=10, max_length=60,temperature=1.5, device="cpu"):
7
+ self.model_name = "tuner007/pegasus_paraphrase"
8
+ self.device = device
9
+ self.model = self.load_model()
10
+ self.tokenizer = PegasusTokenizer.from_pretrained(self.model_name)
11
+ self.num_return_sequences = num_return_sequences
12
+ self.num_beams = num_beams
13
+ self.max_length=max_length
14
+ self.temperature=temperature
15
+
16
+
17
+ def load_model(self):
18
+ model = PegasusForConditionalGeneration.from_pretrained(self.model_name).to(self.device)
19
+ return model
20
+
21
+
22
+ def paraphrase(self,input_text ):
23
+
24
+ batch = self.tokenizer(
25
+ [input_text],
26
+ truncation=True,
27
+ padding="longest",
28
+ max_length=self.max_length,
29
+ return_tensors="pt",
30
+ ).to(self.device)
31
+ translated = self.model.generate(
32
+ **batch,
33
+ max_length=self.max_length,
34
+ num_beams=self.num_beams,
35
+ num_return_sequences=self.num_return_sequences,
36
+ temperature=self.temperature
37
+ )
38
+ tgt_text = self.tokenizer.batch_decode(translated, skip_special_tokens=True)
39
+ return tgt_text
40
+
41
+
42
+
43
+
44
+
45
+