ansfarooq7 commited on
Commit
1c3c84c
1 Parent(s): 84d0230

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -78
app.py CHANGED
@@ -1,13 +1,19 @@
1
- from transformers import RobertaTokenizer, RobertaForMaskedLM, pipeline, GPT2TokenizerFast
2
  import torch
3
  import wikipedia
4
  import re
5
  import random
6
  import nltk
7
  import syllables
8
- import gradio as gr
9
  nltk.download('cmudict')
10
 
 
 
 
 
 
 
 
11
  frequent_words = set()
12
 
13
  def set_seed(seed: int):
@@ -58,36 +64,33 @@ def get_rhymes(inp, level):
58
  return filtered_rhymes
59
 
60
  def get_inputs_length(input):
61
- gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
62
- input_ids = gpt2_tokenizer(input)['input_ids']
63
  return len(input_ids)
64
-
65
- tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
66
- model = RobertaForMaskedLM.from_pretrained('roberta-base')
67
- text_generation = pipeline("text-generation")
68
  set_seed(0)
69
 
70
  def get_prediction(sent):
71
 
72
- token_ids = tokenizer.encode(sent, return_tensors='pt')
73
- masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero()
74
  masked_pos = [mask.item() for mask in masked_position ]
75
 
76
  with torch.no_grad():
77
- output = model(token_ids)
78
 
79
  last_hidden_state = output[0].squeeze()
80
 
81
  list_of_list =[]
82
  for index,mask_index in enumerate(masked_pos):
83
  words = []
84
- mask_hidden_state = last_hidden_state[mask_index]
85
- idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
86
- for i in idx:
87
- word = tokenizer.decode(i.item()).strip()
88
- if (remove_punctuation(word) != "") and (word != '</s>'):
89
- words.append(word)
90
- #words = [tokenizer.decode(i.item()).strip() for i in idx]
 
91
  list_of_list.append(words)
92
  print(f"Mask {index+1} Guesses: {words}")
93
 
@@ -97,18 +100,21 @@ def get_prediction(sent):
97
 
98
  return best_guess
99
 
100
- def get_line(topic_summary, starting_words, inputs_len):
101
- starting_word = random.choice(starting_words)
102
- line = starting_word + text_generation(topic_summary + " " + starting_word, max_length=inputs_len + 6, do_sample=True, return_full_text=False)[0]['generated_text']
 
 
 
 
 
 
103
  return line
104
 
105
- def get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len):
106
- #gpt2_sentence = text_generation(topic_summary + " " + starting_words[i][j], max_length=no_of_words + 4, do_sample=False)[0]
107
- starting_word = random.choice(starting_words)
108
- print(f"\nGetting rhyming line with starting word '{starting_word}' and rhyming word '{rhyming_word}'")
109
- gpt2_sentence = text_generation(topic_summary + " " + starting_word, max_length=inputs_len + 2, do_sample=True, return_full_text=False)[0]
110
- #sentence = gpt2_sentence['generated_text'] + " ___ ___ ___ " + rhyming_words[i][j]
111
- sentence = starting_word + gpt2_sentence['generated_text'] + " ___ ___ ___ " + rhyming_word
112
  print(f"Original Sentence: {sentence}")
113
  if sentence[-1] != ".":
114
  sentence = sentence.replace("___","<mask>") + "."
@@ -119,20 +125,15 @@ def get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len):
119
 
120
  predicted_blanks = get_prediction(sentence)
121
  print(f"\nBest guess for fill in the blanks: {predicted_blanks}")
122
- return starting_word + gpt2_sentence['generated_text'] + predicted_blanks + " " + rhyming_word
 
 
123
 
124
- from transformers import pipeline
125
-
126
  def generate(topic):
127
- text_generation = pipeline("text-generation")
128
 
129
  limericks = []
130
 
131
- #topic = input("Please enter a topic: ")
132
- ## topic_summary = remove_punctuation(wikipedia.summary(topic))
133
- topic_summary = topic
134
- # if len(topic_summary) > 2000:
135
- # topic_summary = topic_summary[:2000]
136
  word_list = topic_summary.split()
137
  topic_summary_len = len(topic_summary)
138
  no_of_words = len(word_list)
@@ -140,59 +141,64 @@ def generate(topic):
140
  print(f"Topic Summary: {topic_summary}")
141
  print(f"Topic Summary Length: {topic_summary_len}")
142
  print(f"No of Words in Summary: {no_of_words}")
143
- print(f"Length of Input IDs: {inputs_len}")
144
-
145
- starting_words = ["That", "Had", "Not", "But", "With", "I", "Because", "There", "Who", "She", "He", "To", "Whose", "In", "And", "When", "Or", "So", "The", "Of", "Every", "Whom"]
146
-
147
- # starting_words = [["That", "Had", "Not", "But", "That"],
148
- # ["There", "Who", "She", "Tormenting", "Til"],
149
- # ["Relentless", "This", "First", "and", "then"],
150
- # ["There", "Who", "That", "To", "She"],
151
- # ["There", "Who", "Two", "Four", "Have"]]
152
 
153
- # rhyming_words = [["told", "bold", "woodchuck", "truck", "road"],
154
- # ["Nice", "grease", "house", "spouse", "peace"],
155
- # ["deadlines", "lines", "edits", "credits", "wine"],
156
- # ["Lynn", "thin", "essayed", "lemonade", "in"],
157
- # ["beard", "feared", "hen", "wren", "beard"]]
158
-
159
- for i in range(5):
160
  print(f"\nGenerating limerick {i+1}")
161
  rhyming_words_125 = []
162
- while len(rhyming_words_125) < 3 or valid_rhyme == False:
163
- first_line = get_line(topic_summary, starting_words, inputs_len)
164
- #rhyming_words = pronouncing.rhymes(first_line.split()[-1])
165
- end_word = remove_punctuation(first_line.split()[-1])
166
- valid_rhyme = filter_rhymes(end_word)
167
- if valid_rhyme:
168
- print(f"\nFirst Line: {first_line}")
169
- rhyming_words_125 = list(get_rhymes(end_word, 3))
170
- print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
171
- limerick = first_line + "\n"
172
 
173
  rhyming_word = rhyming_words_125[0]
174
- second_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len)
 
 
 
 
 
175
  limerick += second_line + "\n"
176
 
177
  rhyming_words_34 = []
178
- while len(rhyming_words_34) < 2 or valid_rhyme == False:
179
- third_line = get_line(topic_summary, starting_words, inputs_len)
180
- print(f"\nThird Line: {third_line}")
181
- #rhyming_words = pronouncing.rhymes(first_line.split()[-1])
182
- end_word = remove_punctuation(third_line.split()[-1])
183
- valid_rhyme = filter_rhymes(end_word)
184
- print(f"Does '{end_word}'' have valid rhymes: {valid_rhyme}")
185
- rhyming_words_34 = list(get_rhymes(end_word, 3))
186
- print(f"Rhyming words for '{end_word}' are {rhyming_words_34}")
187
- if valid_rhyme and len(rhyming_words_34) > 1:
188
- limerick += third_line + "\n"
 
 
 
 
189
 
190
  rhyming_word = rhyming_words_34[0]
191
- fourth_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len)
 
 
 
 
 
192
  limerick += fourth_line + "\n"
193
 
194
  rhyming_word = rhyming_words_125[1]
195
- fifth_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len)
 
 
 
 
 
196
  limerick += fifth_line + "\n"
197
 
198
  limericks.append(limerick)
@@ -203,9 +209,11 @@ def generate(topic):
203
  print(f"Generated {len(limericks)} limericks: \n")
204
  for limerick in limericks:
205
  print(limerick)
206
- output += limerick
207
 
208
  return output
 
 
209
 
210
  interface = gr.Interface(fn=generate, inputs="text", outputs="text")
211
  interface.launch(debug=True)
 
1
+ from transformers import RobertaTokenizer, RobertaForMaskedLM, pipeline, GPT2Tokenizer, GPT2LMHeadModel
2
  import torch
3
  import wikipedia
4
  import re
5
  import random
6
  import nltk
7
  import syllables
 
8
  nltk.download('cmudict')
9
 
10
+ masked_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
11
+ masked_model = RobertaForMaskedLM.from_pretrained('roberta-base')
12
+
13
+ causal_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
14
+ # add the EOS token as PAD token to avoid warnings
15
+ causal_model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=causal_tokenizer.eos_token_id)
16
+
17
  frequent_words = set()
18
 
19
  def set_seed(seed: int):
 
64
  return filtered_rhymes
65
 
66
  def get_inputs_length(input):
67
+ input_ids = causal_tokenizer(input)['input_ids']
 
68
  return len(input_ids)
69
+
 
 
 
70
  set_seed(0)
71
 
72
  def get_prediction(sent):
73
 
74
+ token_ids = masked_tokenizer.encode(sent, return_tensors='pt')
75
+ masked_position = (token_ids.squeeze() == masked_tokenizer.mask_token_id).nonzero()
76
  masked_pos = [mask.item() for mask in masked_position ]
77
 
78
  with torch.no_grad():
79
+ output = masked_model(token_ids)
80
 
81
  last_hidden_state = output[0].squeeze()
82
 
83
  list_of_list =[]
84
  for index,mask_index in enumerate(masked_pos):
85
  words = []
86
+ while not words:
87
+ mask_hidden_state = last_hidden_state[mask_index]
88
+ idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
89
+ for i in idx:
90
+ word = masked_tokenizer.decode(i.item()).strip()
91
+ if (remove_punctuation(word) != "") and (word != '</s>'):
92
+ words.append(word)
93
+ #words = [masked_tokenizer.decode(i.item()).strip() for i in idx]
94
  list_of_list.append(words)
95
  print(f"Mask {index+1} Guesses: {words}")
96
 
 
100
 
101
  return best_guess
102
 
103
+ text_generation = pipeline("text-generation", model=causal_model, tokenizer=causal_tokenizer)
104
+ from aitextgen import aitextgen
105
+
106
+ # Without any parameters, aitextgen() will download, cache, and load the 124M GPT-2 "small" model
107
+ ai = aitextgen()
108
+
109
+
110
+ def get_line(prompt, inputs_len):
111
+ line = ai.generate_one(prompt=prompt + ".", max_length=inputs_len + 7)[len(prompt)+2:]
112
  return line
113
 
114
+ def get_rhyming_line(prompt, rhyming_word, inputs_len):
115
+ gpt2_sentence = ai.generate_one(prompt=prompt + ".", max_length=inputs_len + 4)[len(prompt)+2:]
116
+ print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
117
+ sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
 
 
 
118
  print(f"Original Sentence: {sentence}")
119
  if sentence[-1] != ".":
120
  sentence = sentence.replace("___","<mask>") + "."
 
125
 
126
  predicted_blanks = get_prediction(sentence)
127
  print(f"\nBest guess for fill in the blanks: {predicted_blanks}")
128
+ final_sentence = gpt2_sentence + predicted_blanks + " " + rhyming_word
129
+ print(f"Final Sentence: {final_sentence}")
130
+ return final_sentence
131
 
 
 
132
  def generate(topic):
 
133
 
134
  limericks = []
135
 
136
+ topic_summary = remove_punctuation(wikipedia.summary(topic))
 
 
 
 
137
  word_list = topic_summary.split()
138
  topic_summary_len = len(topic_summary)
139
  no_of_words = len(word_list)
 
141
  print(f"Topic Summary: {topic_summary}")
142
  print(f"Topic Summary Length: {topic_summary_len}")
143
  print(f"No of Words in Summary: {no_of_words}")
144
+ print(f"Length of Input IDs: {inputs_len}")
 
 
 
 
 
 
 
 
145
 
146
+ for i in range(1):
 
 
 
 
 
 
147
  print(f"\nGenerating limerick {i+1}")
148
  rhyming_words_125 = []
149
+ while len(rhyming_words_125) < 3 or valid_rhyme == False or len(first_line) == 0:
150
+ first_line = get_line(topic_summary, inputs_len)
151
+ if first_line:
152
+ end_word = remove_punctuation(first_line.split()[-1])
153
+ valid_rhyme = filter_rhymes(end_word)
154
+ if valid_rhyme:
155
+ print(f"\nFirst Line: {first_line}")
156
+ rhyming_words_125 = list(get_rhymes(end_word, 3))
157
+ print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
158
+ limerick = first_line + "\n"
159
 
160
  rhyming_word = rhyming_words_125[0]
161
+ prompt = topic_summary + " " + first_line
162
+ inputs_len = get_inputs_length(prompt)
163
+ print(f"Prompt: {prompt}")
164
+ print(f"Length of prompt: {inputs_len}")
165
+ second_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
166
+ print(f"\nSecond Line: {second_line}")
167
  limerick += second_line + "\n"
168
 
169
  rhyming_words_34 = []
170
+ prompt = prompt + " " + second_line
171
+ inputs_len = get_inputs_length(prompt)
172
+ print(f"Prompt: {prompt}")
173
+ print(f"Length of prompt: {inputs_len}")
174
+ while len(rhyming_words_34) < 2 or valid_rhyme == False or len(third_line) == 0:
175
+ third_line = get_line(prompt, inputs_len)
176
+ if third_line:
177
+ print(f"\nThird Line: {third_line}")
178
+ end_word = remove_punctuation(third_line.split()[-1])
179
+ valid_rhyme = filter_rhymes(end_word)
180
+ print(f"Does '{end_word}'' have valid rhymes: {valid_rhyme}")
181
+ rhyming_words_34 = list(get_rhymes(end_word, 3))
182
+ print(f"Rhyming words for '{end_word}' are {rhyming_words_34}")
183
+ if valid_rhyme and len(rhyming_words_34) > 1:
184
+ limerick += third_line + "\n"
185
 
186
  rhyming_word = rhyming_words_34[0]
187
+ prompt = prompt + " " + third_line
188
+ inputs_len = get_inputs_length(prompt)
189
+ print(f"Prompt: {prompt}")
190
+ print(f"Length of prompt: {inputs_len}")
191
+ fourth_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
192
+ print(f"\nFourth Line: {fourth_line}")
193
  limerick += fourth_line + "\n"
194
 
195
  rhyming_word = rhyming_words_125[1]
196
+ prompt = prompt + " " + fourth_line
197
+ inputs_len = get_inputs_length(prompt)
198
+ print(f"Prompt: {prompt}")
199
+ print(f"Length of prompt: {inputs_len}")
200
+ fifth_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
201
+ print(f"\nFifth Line: {fifth_line}")
202
  limerick += fifth_line + "\n"
203
 
204
  limericks.append(limerick)
 
209
  print(f"Generated {len(limericks)} limericks: \n")
210
  for limerick in limericks:
211
  print(limerick)
212
+ output += "\n" + limerick
213
 
214
  return output
215
+
216
+ import gradio as gr
217
 
218
  interface = gr.Interface(fn=generate, inputs="text", outputs="text")
219
  interface.launch(debug=True)