ansfarooq7 commited on
Commit
4677c24
1 Parent(s): d1f0e51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -82
app.py CHANGED
@@ -1,18 +1,26 @@
1
- from transformers import RobertaTokenizer, RobertaForMaskedLM, pipeline, GPT2Tokenizer, GPT2LMHeadModel
2
  import torch
3
  import wikipedia
4
  import re
5
  import random
6
  import nltk
7
  import syllables
 
8
  nltk.download('cmudict')
9
 
 
 
10
  masked_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
11
  masked_model = RobertaForMaskedLM.from_pretrained('roberta-base')
12
 
13
  causal_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
14
- # add the EOS token as PAD token to avoid warnings
15
- causal_model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=causal_tokenizer.eos_token_id)
 
 
 
 
 
16
 
17
  frequent_words = set()
18
 
@@ -48,6 +56,7 @@ def filter_rhymes(word):
48
 
49
  def remove_punctuation(text):
50
  text = re.sub(r'[^\w\s]', '', text)
 
51
  return text
52
 
53
  def get_rhymes(inp, level):
@@ -100,19 +109,13 @@ def get_prediction(sent):
100
 
101
  return best_guess
102
 
103
- text_generation = pipeline("text-generation", model=causal_model, tokenizer=causal_tokenizer)
104
- from aitextgen import aitextgen
105
-
106
- # Without any parameters, aitextgen() will download, cache, and load the 124M GPT-2 "small" model
107
- ai = aitextgen()
108
-
109
-
110
  def get_line(prompt, inputs_len):
111
- line = ai.generate_one(prompt=prompt + ".", max_length=inputs_len + 7)[len(prompt)+2:]
112
  return line
113
 
114
  def get_rhyming_line(prompt, rhyming_word, inputs_len):
115
- gpt2_sentence = ai.generate_one(prompt=prompt + ".", max_length=inputs_len + 4)[len(prompt)+2:]
 
116
  print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
117
  sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
118
  print(f"Original Sentence: {sentence}")
@@ -128,12 +131,20 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
128
  final_sentence = gpt2_sentence + predicted_blanks + " " + rhyming_word
129
  print(f"Final Sentence: {final_sentence}")
130
  return final_sentence
131
-
132
- def generate(topic):
133
 
134
- limericks = []
 
 
135
 
136
- topic_summary = remove_punctuation(wikipedia.summary(topic))
 
 
 
 
 
 
 
 
137
  word_list = topic_summary.split()
138
  topic_summary_len = len(topic_summary)
139
  no_of_words = len(word_list)
@@ -143,77 +154,82 @@ def generate(topic):
143
  print(f"No of Words in Summary: {no_of_words}")
144
  print(f"Length of Input IDs: {inputs_len}")
145
 
146
- for i in range(1):
147
- print(f"\nGenerating limerick {i+1}")
148
- rhyming_words_125 = []
149
- while len(rhyming_words_125) < 3 or valid_rhyme == False or len(first_line) == 0:
150
- first_line = get_line(topic_summary, inputs_len)
151
- if first_line:
152
- end_word = remove_punctuation(first_line.split()[-1])
153
- valid_rhyme = filter_rhymes(end_word)
154
- if valid_rhyme:
155
- print(f"\nFirst Line: {first_line}")
156
- rhyming_words_125 = list(get_rhymes(end_word, 3))
157
- print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
158
- limerick = first_line + "\n"
159
-
160
- rhyming_word = rhyming_words_125[0]
161
- prompt = topic_summary + " " + first_line
162
- inputs_len = get_inputs_length(prompt)
163
- print(f"Prompt: {prompt}")
164
- print(f"Length of prompt: {inputs_len}")
165
- second_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
166
- print(f"\nSecond Line: {second_line}")
167
- limerick += second_line + "\n"
168
-
169
- rhyming_words_34 = []
170
- prompt = prompt + " " + second_line
171
- inputs_len = get_inputs_length(prompt)
172
- print(f"Prompt: {prompt}")
173
- print(f"Length of prompt: {inputs_len}")
174
- while len(rhyming_words_34) < 2 or valid_rhyme == False or len(third_line) == 0:
175
- third_line = get_line(prompt, inputs_len)
176
- if third_line:
177
- print(f"\nThird Line: {third_line}")
178
- end_word = remove_punctuation(third_line.split()[-1])
179
- valid_rhyme = filter_rhymes(end_word)
180
- print(f"Does '{end_word}'' have valid rhymes: {valid_rhyme}")
181
- rhyming_words_34 = list(get_rhymes(end_word, 3))
182
- print(f"Rhyming words for '{end_word}' are {rhyming_words_34}")
183
- if valid_rhyme and len(rhyming_words_34) > 1:
184
- limerick += third_line + "\n"
185
-
186
- rhyming_word = rhyming_words_34[0]
187
- prompt = prompt + " " + third_line
188
- inputs_len = get_inputs_length(prompt)
189
- print(f"Prompt: {prompt}")
190
- print(f"Length of prompt: {inputs_len}")
191
- fourth_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
192
- print(f"\nFourth Line: {fourth_line}")
193
- limerick += fourth_line + "\n"
194
-
195
- rhyming_word = rhyming_words_125[1]
196
- prompt = prompt + " " + fourth_line
197
- inputs_len = get_inputs_length(prompt)
198
- print(f"Prompt: {prompt}")
199
- print(f"Length of prompt: {inputs_len}")
200
- fifth_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
201
- print(f"\nFifth Line: {fifth_line}")
202
- limerick += fifth_line + "\n"
203
-
204
- limericks.append(limerick)
205
 
206
  print("\n")
207
- output = f"Generated {len(limericks)} limericks: \n"
 
 
208
 
209
- print(f"Generated {len(limericks)} limericks: \n")
210
- for limerick in limericks:
211
- print(limerick)
212
- output += "\n" + limerick
 
 
 
 
213
 
214
  return output
215
-
216
  import gradio as gr
217
 
218
- interface = gr.Interface(fn=generate, inputs="text", outputs="text")
 
 
 
219
  interface.launch(debug=True)
 
1
+ from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer, AutoTokenizer, GPTJForCausalLM
2
  import torch
3
  import wikipedia
4
  import re
5
  import random
6
  import nltk
7
  import syllables
8
+ from aitextgen import aitextgen
9
  nltk.download('cmudict')
10
 
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
  masked_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
14
  masked_model = RobertaForMaskedLM.from_pretrained('roberta-base')
15
 
16
  causal_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
17
+
18
+ # Without any parameters, aitextgen() will download, cache, and load the 124M GPT-2 "small" model
19
+ gpt2 = aitextgen()
20
+
21
+ gptj_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
22
+ gptj_model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", low_cpu_mem_usage=True)
23
+ gptj_model.to(device)
24
 
25
  frequent_words = set()
26
 
 
56
 
57
  def remove_punctuation(text):
58
  text = re.sub(r'[^\w\s]', '', text)
59
+ text = text.replace("\n", " ")
60
  return text
61
 
62
  def get_rhymes(inp, level):
 
109
 
110
  return best_guess
111
 
 
 
 
 
 
 
 
112
  def get_line(prompt, inputs_len):
113
+ line = gpt2.generate_one(prompt=prompt + ".", max_length=inputs_len + 7)[len(prompt)+2:]
114
  return line
115
 
116
  def get_rhyming_line(prompt, rhyming_word, inputs_len):
117
+ gpt2_sentence = gpt2.generate_one(prompt=prompt + ".", max_length=inputs_len + 4)[len(prompt)+2:]
118
+ gpt2_sentence = gpt2_sentence.replace("\n", "")
119
  print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
120
  sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
121
  print(f"Original Sentence: {sentence}")
 
131
  final_sentence = gpt2_sentence + predicted_blanks + " " + rhyming_word
132
  print(f"Final Sentence: {final_sentence}")
133
  return final_sentence
 
 
134
 
135
+ def gptj_summary(topic):
136
+ prompt = f"Here is some information about {topic}"
137
+ input_ids = gptj_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
138
 
139
+ generated_ids = gptj_model.generate(input_ids, do_sample=True, temperature=0.9, max_length=200)
140
+ generated_text = gptj_tokenizer.decode(generated_ids[0])
141
+ return generated_text
142
+
143
+ def generate(topic, wiki=True):
144
+ if wiki:
145
+ topic_summary = remove_punctuation(wikipedia.summary(topic))
146
+ else:
147
+ topic_summary = remove_punctuation(gptj_summary(topic))
148
  word_list = topic_summary.split()
149
  topic_summary_len = len(topic_summary)
150
  no_of_words = len(word_list)
 
154
  print(f"No of Words in Summary: {no_of_words}")
155
  print(f"Length of Input IDs: {inputs_len}")
156
 
157
+ rhyming_words_125 = []
158
+ while len(rhyming_words_125) < 3 or valid_rhyme == False or len(first_line) == 0:
159
+ first_line = get_line(topic_summary, inputs_len)
160
+ if first_line:
161
+ end_word = remove_punctuation(first_line.split()[-1])
162
+ valid_rhyme = filter_rhymes(end_word)
163
+ if valid_rhyme:
164
+ print(f"\nFirst Line: {first_line}")
165
+ rhyming_words_125 = list(get_rhymes(end_word, 3))
166
+ print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
167
+ limerick = first_line + "\n"
168
+
169
+ rhyming_word = rhyming_words_125[0]
170
+ prompt = topic_summary + " " + first_line
171
+ inputs_len = get_inputs_length(prompt)
172
+ print(f"Prompt: {prompt}")
173
+ print(f"Length of prompt: {inputs_len}")
174
+ second_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
175
+ print(f"\nSecond Line: {second_line}")
176
+ limerick += second_line + "\n"
177
+
178
+ rhyming_words_34 = []
179
+ prompt = prompt + " " + second_line
180
+ inputs_len = get_inputs_length(prompt)
181
+ print(f"Prompt: {prompt}")
182
+ print(f"Length of prompt: {inputs_len}")
183
+ while len(rhyming_words_34) < 2 or valid_rhyme == False or len(third_line) == 0:
184
+ third_line = get_line(prompt, inputs_len)
185
+ if third_line:
186
+ print(f"\nThird Line: {third_line}")
187
+ end_word = remove_punctuation(third_line.split()[-1])
188
+ valid_rhyme = filter_rhymes(end_word)
189
+ print(f"Does '{end_word}' have valid rhymes: {valid_rhyme}")
190
+ rhyming_words_34 = list(get_rhymes(end_word, 3))
191
+ print(f"Rhyming words for '{end_word}' are {rhyming_words_34}")
192
+ if valid_rhyme and len(rhyming_words_34) > 1:
193
+ limerick += third_line + "\n"
194
+
195
+ rhyming_word = rhyming_words_34[0]
196
+ prompt = prompt + " " + third_line
197
+ inputs_len = get_inputs_length(prompt)
198
+ print(f"Prompt: {prompt}")
199
+ print(f"Length of prompt: {inputs_len}")
200
+ fourth_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
201
+ print(f"\nFourth Line: {fourth_line}")
202
+ limerick += fourth_line + "\n"
203
+
204
+ rhyming_word = rhyming_words_125[1]
205
+ prompt = prompt + " " + fourth_line
206
+ inputs_len = get_inputs_length(prompt)
207
+ print(f"Prompt: {prompt}")
208
+ print(f"Length of prompt: {inputs_len}")
209
+ fifth_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
210
+ print(f"\nFifth Line: {fifth_line}")
211
+ limerick += fifth_line + "\n"
 
 
 
 
212
 
213
  print("\n")
214
+ print(limerick)
215
+
216
+ return limerick
217
 
218
+ def compare_summaries(topic):
219
+ wiki_limerick = generate(topic, wiki=True)
220
+ gptj_limerick = generate(topic, wiki=False)
221
+
222
+ output = f"Limerick with Wikipedia summary of topic as prompt: \n"
223
+ output += wiki_limerick + "\n"
224
+ output += f"Limerick with GPT-J summary of topic as prompt: \n"
225
+ output += gptj_limerick
226
 
227
  return output
228
+
229
  import gradio as gr
230
 
231
+ interface = gr.Interface(
232
+ fn=compare_summaries,
233
+ inputs="text",
234
+ outputs="text")
235
  interface.launch(debug=True)