ansfarooq7 commited on
Commit
f5fde7c
1 Parent(s): a494241

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -8
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer, GPT2LMHeadModel, pipeline
2
  import torch
3
  import wikipedia
@@ -7,22 +8,26 @@ import nltk
7
  import gradio as gr
8
  nltk.download('cmudict')
9
 
 
10
  roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
11
  roberta_model = RobertaForMaskedLM.from_pretrained('roberta-base')
12
 
 
13
  gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
14
  gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=gpt2_tokenizer.eos_token_id)
15
 
 
16
  gpt2_pipeline = pipeline('text-generation', model=gpt2_model, tokenizer=gpt2_tokenizer)
17
 
 
18
  frequent_words = set()
19
-
20
  with open("wordFrequency.txt", 'r') as f:
21
  line = f.readline()
22
  while line != '': # The EOF char is an empty string
23
  frequent_words.add(line.strip())
24
  line = f.readline()
25
 
 
26
  def filter_rhymes(word):
27
  filter_list = ['an', 'to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'are', 'or', 'its', 'it''s']
28
  if word in filter_list:
@@ -30,11 +35,16 @@ def filter_rhymes(word):
30
  else:
31
  return True
32
 
 
33
  def remove_punctuation(text):
34
  text = re.sub(r'[^\w\s]', '', text)
35
  text = text.replace("\n", " ")
36
  return text.strip()
37
 
 
 
 
 
38
  def get_rhymes(inp, level):
39
  entries = nltk.corpus.cmudict.entries()
40
  syllables = [(word, syl) for word, syl in entries if word == inp]
@@ -48,10 +58,18 @@ def get_rhymes(inp, level):
48
  filtered_rhymes.add(word)
49
  return filtered_rhymes
50
 
 
 
51
  def get_inputs_length(input):
52
  input_ids = gpt2_tokenizer(input)['input_ids']
53
  return len(input_ids)
54
-
 
 
 
 
 
 
55
  def get_prediction(sent):
56
  token_ids = roberta_tokenizer.encode(sent, return_tensors='pt')
57
  masked_position = (token_ids.squeeze() == roberta_tokenizer.mask_token_id).nonzero()
@@ -68,6 +86,8 @@ def get_prediction(sent):
68
  while not words:
69
  mask_hidden_state = last_hidden_state[mask_index]
70
  idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
 
 
71
  for i in idx:
72
  word = roberta_tokenizer.decode(i.item()).strip()
73
  if (remove_punctuation(word) != "") and (word != '</s>'):
@@ -80,7 +100,9 @@ def get_prediction(sent):
80
  best_guess = best_guess+" "+j[0]
81
 
82
  return best_guess
83
-
 
 
84
  def get_line(prompt, inputs_len):
85
  output = gpt2_pipeline(
86
  prompt + ".",
@@ -91,6 +113,9 @@ def get_line(prompt, inputs_len):
91
  )
92
  return remove_punctuation(output[0]['generated_text'])
93
 
 
 
 
94
  def get_rhyming_line(prompt, rhyming_word, inputs_len):
95
  output = gpt2_pipeline(
96
  prompt + ".",
@@ -126,6 +151,8 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
126
  print(f"Final Sentence: {final_sentence}")
127
  return final_sentence
128
 
 
 
129
  def gpt2_summary(topic):
130
  output = gpt2_pipeline(
131
  f"Here is some information about {topic}.",
@@ -135,8 +162,11 @@ def gpt2_summary(topic):
135
  return_full_text=False
136
  )
137
  return remove_punctuation(output[0]['generated_text'])
138
-
 
139
  def generate(topic, wiki=True):
 
 
140
  if wiki:
141
  try:
142
  topic_search = wikipedia.search(topic, results=3)
@@ -148,9 +178,12 @@ def generate(topic, wiki=True):
148
  topic_summary = remove_punctuation(wikipedia.summary(page, auto_suggest=False))
149
  except:
150
  return(f"Method A struggled to find information about {topic}, please try a different topic!")
 
 
151
  else:
152
  topic_summary = remove_punctuation(gpt2_summary(topic))
153
 
 
154
  word_list = topic_summary.split()
155
  topic_summary_len = len(topic_summary)
156
  no_of_words = len(word_list)
@@ -160,6 +193,7 @@ def generate(topic, wiki=True):
160
  print(f"No of Words in Summary: {no_of_words}")
161
  print(f"Length of Input IDs: {inputs_len}")
162
 
 
163
  rhyming_words_125 = []
164
  while len(rhyming_words_125) < 3 or valid_rhyme == False or len(first_line) == 0:
165
  first_line = get_line(topic_summary, inputs_len)
@@ -172,6 +206,7 @@ def generate(topic, wiki=True):
172
  print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
173
  limerick = first_line + "\n"
174
 
 
175
  rhyming_word = random.choice(rhyming_words_125)
176
  rhyming_words_125.remove(rhyming_word)
177
  prompt = topic_summary + " " + first_line
@@ -182,6 +217,7 @@ def generate(topic, wiki=True):
182
  print(f"\nSecond Line: {second_line}")
183
  limerick += second_line + "\n"
184
 
 
185
  rhyming_words_34 = []
186
  prompt = prompt + " " + second_line
187
  inputs_len = get_inputs_length(prompt)
@@ -199,6 +235,7 @@ def generate(topic, wiki=True):
199
  if valid_rhyme and len(rhyming_words_34) > 1:
200
  limerick += third_line + "\n"
201
 
 
202
  rhyming_word = random.choice(rhyming_words_34)
203
  rhyming_words_34.remove(rhyming_word)
204
  prompt = prompt + " " + third_line
@@ -209,6 +246,7 @@ def generate(topic, wiki=True):
209
  print(f"\nFourth Line: {fourth_line}")
210
  limerick += fourth_line + "\n"
211
 
 
212
  rhyming_word = random.choice(rhyming_words_125)
213
  rhyming_words_125.remove(rhyming_word)
214
  prompt = prompt + " " + fourth_line
@@ -223,7 +261,8 @@ def generate(topic, wiki=True):
223
  print(limerick)
224
 
225
  return limerick
226
-
 
227
  def compare_summaries(topic):
228
  wiki_limerick = generate(topic)
229
  gpt2_limerick = generate(topic, wiki=False)
@@ -233,8 +272,11 @@ def compare_summaries(topic):
233
  print(output1 + "\n" + output2)
234
 
235
  return output1, output2
236
-
237
- description = "Generates limericks (five-line poems with a rhyme scheme of AABBA) using two different methods, please be patient as it can take up to a minute to generate both limericks. You may have to generate multiple times or try different topics in order to produce something of good quality."
 
 
 
238
  article = '<center><big><strong>Limerick Generation</strong></big></center>'\
239
  '<center><strong>By Ans Farooq</strong></center>'\
240
  '<center><small>Level 4 Individual Project</small></center>'\
@@ -265,4 +307,4 @@ interface = gr.Interface(
265
  theme="peach",
266
  description=description,
267
  article=article)
268
- interface.launch(debug=True)
 
1
+ # Import required packages
2
  from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer, GPT2LMHeadModel, pipeline
3
  import torch
4
  import wikipedia
 
8
  import gradio as gr
9
  nltk.download('cmudict')
10
 
11
+ # Use the RoBERTa model from HuggingFace for masked language modelling
12
  roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
13
  roberta_model = RobertaForMaskedLM.from_pretrained('roberta-base')
14
 
15
+ # Use the GPT-2 from HuggingFace for causal language modelling
16
  gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
17
  gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=gpt2_tokenizer.eos_token_id)
18
 
19
+ # Initialise a text generation pipeline using HuggingFace Transformers and the pre-trained GPT-2 model
20
  gpt2_pipeline = pipeline('text-generation', model=gpt2_model, tokenizer=gpt2_tokenizer)
21
 
22
+ # Hold all the words in wordFrequency.txt in a Python set
23
  frequent_words = set()
 
24
  with open("wordFrequency.txt", 'r') as f:
25
  line = f.readline()
26
  while line != '': # The EOF char is an empty string
27
  frequent_words.add(line.strip())
28
  line = f.readline()
29
 
30
+ # Used alongside the word frequency list to filter out problematic words for rhyming
31
  def filter_rhymes(word):
32
  filter_list = ['an', 'to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'are', 'or', 'its', 'it''s']
33
  if word in filter_list:
 
35
  else:
36
  return True
37
 
38
+ # Used to remove any punctuation and new line characters from generated text
39
  def remove_punctuation(text):
40
  text = re.sub(r'[^\w\s]', '', text)
41
  text = text.replace("\n", " ")
42
  return text.strip()
43
 
44
+ # Used to find rhymes to a given word using NLTK
45
+ # where inp is a word and level means how good the rhyme should be.
46
+ # Adapted from the following Stack Overflow answer:
47
+ # https://stackoverflow.com/a/25714769/18559178
48
  def get_rhymes(inp, level):
49
  entries = nltk.corpus.cmudict.entries()
50
  syllables = [(word, syl) for word, syl in entries if word == inp]
 
58
  filtered_rhymes.add(word)
59
  return filtered_rhymes
60
 
61
+ # Used to get the length of the topic summary, to then determine max length for
62
+ # the text generation pipeline
63
  def get_inputs_length(input):
64
  input_ids = gpt2_tokenizer(input)['input_ids']
65
  return len(input_ids)
66
+
67
+ # Sized Fill-in-the-blank or Multi Mask filling with RoBERTa and Huggingface Transformers
68
+ # Used to fill in the blank words between the starting words of each line
69
+ # generated by GPT-2 and the end rhyming word
70
+ # Code adapted from the following Medium article:
71
+ # https://ramsrigoutham.medium.com/sized-fill-in-the-blank-or-multi-mask-filling-with-roberta-and-huggingface-transformers-58eb9e7fb0c
72
+
73
  def get_prediction(sent):
74
  token_ids = roberta_tokenizer.encode(sent, return_tensors='pt')
75
  masked_position = (token_ids.squeeze() == roberta_tokenizer.mask_token_id).nonzero()
 
86
  while not words:
87
  mask_hidden_state = last_hidden_state[mask_index]
88
  idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
89
+
90
+ # Discard predicted word if it is blank or end token
91
  for i in idx:
92
  word = roberta_tokenizer.decode(i.item()).strip()
93
  if (remove_punctuation(word) != "") and (word != '</s>'):
 
100
  best_guess = best_guess+" "+j[0]
101
 
102
  return best_guess
103
+
104
+ # Used to generate the 1st and 3rd lines of the limerick
105
+ # these are full lines, without RoBERTa being used
106
  def get_line(prompt, inputs_len):
107
  output = gpt2_pipeline(
108
  prompt + ".",
 
113
  )
114
  return remove_punctuation(output[0]['generated_text'])
115
 
116
+ # Used to generate the 2nd, 4th and 5th lines
117
+ # GPT-2 is used to generate starting few words of the lines
118
+ # RoBERTa is then used to fill in the rest of the words until the end rhyme word
119
  def get_rhyming_line(prompt, rhyming_word, inputs_len):
120
  output = gpt2_pipeline(
121
  prompt + ".",
 
151
  print(f"Final Sentence: {final_sentence}")
152
  return final_sentence
153
 
154
+ # Used for the second method, Method B, of limerick generation
155
+ # Uses GPT-2 to get information about the user's given topic
156
  def gpt2_summary(topic):
157
  output = gpt2_pipeline(
158
  f"Here is some information about {topic}.",
 
162
  return_full_text=False
163
  )
164
  return remove_punctuation(output[0]['generated_text'])
165
+
166
+ # Main logic for limerick generation is contained here
167
  def generate(topic, wiki=True):
168
+
169
+ # Search for the topic on Wikipedia and get an informational summary
170
  if wiki:
171
  try:
172
  topic_search = wikipedia.search(topic, results=3)
 
178
  topic_summary = remove_punctuation(wikipedia.summary(page, auto_suggest=False))
179
  except:
180
  return(f"Method A struggled to find information about {topic}, please try a different topic!")
181
+
182
+ # Use GPT-2 to get info about the topic if the wiki parameter is false
183
  else:
184
  topic_summary = remove_punctuation(gpt2_summary(topic))
185
 
186
+ # Log info about the topic summary data
187
  word_list = topic_summary.split()
188
  topic_summary_len = len(topic_summary)
189
  no_of_words = len(word_list)
 
193
  print(f"No of Words in Summary: {no_of_words}")
194
  print(f"Length of Input IDs: {inputs_len}")
195
 
196
+ # Generate the first line of the limerick
197
  rhyming_words_125 = []
198
  while len(rhyming_words_125) < 3 or valid_rhyme == False or len(first_line) == 0:
199
  first_line = get_line(topic_summary, inputs_len)
 
206
  print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
207
  limerick = first_line + "\n"
208
 
209
+ # Generate the second line of the limerick
210
  rhyming_word = random.choice(rhyming_words_125)
211
  rhyming_words_125.remove(rhyming_word)
212
  prompt = topic_summary + " " + first_line
 
217
  print(f"\nSecond Line: {second_line}")
218
  limerick += second_line + "\n"
219
 
220
+ # Generate the third line of the limerick
221
  rhyming_words_34 = []
222
  prompt = prompt + " " + second_line
223
  inputs_len = get_inputs_length(prompt)
 
235
  if valid_rhyme and len(rhyming_words_34) > 1:
236
  limerick += third_line + "\n"
237
 
238
+ # Generate the fourth line of the limerick
239
  rhyming_word = random.choice(rhyming_words_34)
240
  rhyming_words_34.remove(rhyming_word)
241
  prompt = prompt + " " + third_line
 
246
  print(f"\nFourth Line: {fourth_line}")
247
  limerick += fourth_line + "\n"
248
 
249
+ # Generate the fifth line of the limerick
250
  rhyming_word = random.choice(rhyming_words_125)
251
  rhyming_words_125.remove(rhyming_word)
252
  prompt = prompt + " " + fourth_line
 
261
  print(limerick)
262
 
263
  return limerick
264
+
265
+ # Helper function to generate two limericks via both methods to then compare
266
  def compare_summaries(topic):
267
  wiki_limerick = generate(topic)
268
  gpt2_limerick = generate(topic, wiki=False)
 
272
  print(output1 + "\n" + output2)
273
 
274
  return output1, output2
275
+
276
+ # Use Gradio to create an interface, which can be hosted on HuggingFace spaces
277
+ # https://huggingface.co/spaces/ansfarooq7/l4-project
278
+
279
+ description = "Generates limericks (five-line poems with a rhyme scheme of AABBA) using two different methods, please be patient as it can take up to a minute to generate both limericks."
280
  article = '<center><big><strong>Limerick Generation</strong></big></center>'\
281
  '<center><strong>By Ans Farooq</strong></center>'\
282
  '<center><small>Level 4 Individual Project</small></center>'\
 
307
  theme="peach",
308
  description=description,
309
  article=article)
310
+ interface.launch(debug=False)