ansfarooq7 commited on
Commit
2acd461
1 Parent(s): 01025ef

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -0
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install transformers
2
+ !pip install pronouncing
3
+ !pip install wikipedia
4
+ !pip install syllables
5
+ !pip install gradio
6
+ from transformers import RobertaTokenizer, RobertaForMaskedLM, pipeline, GPT2TokenizerFast
7
+ import torch
8
+ import pronouncing
9
+ import wikipedia
10
+ import re
11
+ import random
12
+ import nltk
13
+ import syllables
14
+ import gradio as gr
15
+ nltk.download('cmudict')
16
+
17
+ frequent_words = set()
18
+
19
+ def set_seed(seed: int):
20
+ """
21
+ Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
22
+ installed).
23
+
24
+ Args:
25
+ seed (:obj:`int`): The seed to set.
26
+ """
27
+ #random.seed(seed)
28
+ #np.random.seed(seed)
29
+ #if is_torch_available():
30
+ torch.manual_seed(seed)
31
+ torch.cuda.manual_seed_all(seed)
32
+ # ^^ safe to call this function even if cuda is not available
33
+ #if is_tf_available():
34
+ #tf.random.set_seed(seed)
35
+
36
+ with open("wordFrequency.txt", 'r') as f:
37
+ line = f.readline()
38
+ while line != '': # The EOF char is an empty string
39
+ frequent_words.add(line.strip())
40
+ line = f.readline()
41
+
42
+ def filter_rhymes(word):
43
+ filter_list = ['to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'aitch', 'angst', 'arugula', 'beige', 'blitzed', 'boing', 'bombed', 'cairn', 'chaos', 'chocolate', 'circle', 'circus', 'cleansed', 'coif', 'cusp', 'doth', 'else', 'eth', 'fiends', 'film', 'flange', 'fourths', 'grilse', 'gulf', 'kiln', 'loge', 'midst', 'month', 'music', 'neutron', 'ninja', 'oblige', 'oink', 'opus', 'orange', 'pint', 'plagued', 'plankton', 'plinth', 'poem', 'poet', 'purple', 'quaich', 'rhythm', 'rouged', 'silver', 'siren', 'soldier', 'sylph', 'thesp', 'toilet', 'torsk', 'tufts', 'waltzed', 'wasp', 'wharves', 'width', 'woman', 'yttrium']
44
+ if word in filter_list:
45
+ return False
46
+ else:
47
+ return True
48
+
49
+ def remove_punctuation(text):
50
+ text = re.sub(r'[^\w\s]', '', text)
51
+ return text
52
+
53
+ def get_rhymes(inp, level):
54
+ entries = nltk.corpus.cmudict.entries()
55
+ syllables = [(word, syl) for word, syl in entries if word == inp]
56
+ rhymes = []
57
+ filtered_rhymes = set()
58
+ for (word, syllable) in syllables:
59
+ rhymes += [word for word, pron in entries if pron[-level:] == syllable[-level:]]
60
+
61
+ for word in rhymes:
62
+ if (word in frequent_words) and (word != inp):
63
+ filtered_rhymes.add(word)
64
+ return filtered_rhymes
65
+
66
+ def get_inputs_length(input):
67
+ gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
68
+ input_ids = gpt2_tokenizer(input)['input_ids']
69
+ return len(input_ids)
70
+
71
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
72
+ model = RobertaForMaskedLM.from_pretrained('roberta-base')
73
+ text_generation = pipeline("text-generation")
74
+ set_seed(0)
75
+
76
+ def get_prediction(sent):
77
+
78
+ token_ids = tokenizer.encode(sent, return_tensors='pt')
79
+ masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero()
80
+ masked_pos = [mask.item() for mask in masked_position ]
81
+
82
+ with torch.no_grad():
83
+ output = model(token_ids)
84
+
85
+ last_hidden_state = output[0].squeeze()
86
+
87
+ list_of_list =[]
88
+ for index,mask_index in enumerate(masked_pos):
89
+ words = []
90
+ mask_hidden_state = last_hidden_state[mask_index]
91
+ idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
92
+ for i in idx:
93
+ word = tokenizer.decode(i.item()).strip()
94
+ if (remove_punctuation(word) != "") and (word != '</s>'):
95
+ words.append(word)
96
+ #words = [tokenizer.decode(i.item()).strip() for i in idx]
97
+ list_of_list.append(words)
98
+ print(f"Mask {index+1} Guesses: {words}")
99
+
100
+ best_guess = ""
101
+ for j in list_of_list:
102
+ best_guess = best_guess+" "+j[0]
103
+
104
+ return best_guess
105
+
106
+ def get_line(topic_summary, starting_words, inputs_len):
107
+ starting_word = random.choice(starting_words)
108
+ line = starting_word + text_generation(topic_summary + " " + starting_word, max_length=inputs_len + 6, do_sample=True, return_full_text=False)[0]['generated_text']
109
+ return line
110
+
111
+ def get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len):
112
+ #gpt2_sentence = text_generation(topic_summary + " " + starting_words[i][j], max_length=no_of_words + 4, do_sample=False)[0]
113
+ starting_word = random.choice(starting_words)
114
+ print(f"\nGetting rhyming line with starting word '{starting_word}' and rhyming word '{rhyming_word}'")
115
+ gpt2_sentence = text_generation(topic_summary + " " + starting_word, max_length=inputs_len + 2, do_sample=True, return_full_text=False)[0]
116
+ #sentence = gpt2_sentence['generated_text'] + " ___ ___ ___ " + rhyming_words[i][j]
117
+ sentence = starting_word + gpt2_sentence['generated_text'] + " ___ ___ ___ " + rhyming_word
118
+ print(f"Original Sentence: {sentence}")
119
+ if sentence[-1] != ".":
120
+ sentence = sentence.replace("___","<mask>") + "."
121
+ else:
122
+ sentence = sentence.replace("___","<mask>")
123
+ print(f"Original Sentence replaced with mask: {sentence}")
124
+ print("\n")
125
+
126
+ predicted_blanks = get_prediction(sentence)
127
+ print(f"\nBest guess for fill in the blanks: {predicted_blanks}")
128
+ return starting_word + gpt2_sentence['generated_text'] + predicted_blanks + " " + rhyming_word
129
+
130
+ from transformers import pipeline
131
+
132
+ def generate(topic):
133
+ text_generation = pipeline("text-generation")
134
+
135
+ limericks = []
136
+
137
+ #topic = input("Please enter a topic: ")
138
+ topic_summary = remove_punctuation(wikipedia.summary(topic))
139
+ # if len(topic_summary) > 2000:
140
+ # topic_summary = topic_summary[:2000]
141
+ word_list = topic_summary.split()
142
+ topic_summary_len = len(topic_summary)
143
+ no_of_words = len(word_list)
144
+ inputs_len = get_inputs_length(topic_summary)
145
+ print(f"Topic Summary: {topic_summary}")
146
+ print(f"Topic Summary Length: {topic_summary_len}")
147
+ print(f"No of Words in Summary: {no_of_words}")
148
+ print(f"Length of Input IDs: {inputs_len}")
149
+
150
+ starting_words = ["That", "Had", "Not", "But", "With", "I", "Because", "There", "Who", "She", "He", "To", "Whose", "In", "And", "When", "Or", "So", "The", "Of", "Every", "Whom"]
151
+
152
+ # starting_words = [["That", "Had", "Not", "But", "That"],
153
+ # ["There", "Who", "She", "Tormenting", "Til"],
154
+ # ["Relentless", "This", "First", "and", "then"],
155
+ # ["There", "Who", "That", "To", "She"],
156
+ # ["There", "Who", "Two", "Four", "Have"]]
157
+
158
+ # rhyming_words = [["told", "bold", "woodchuck", "truck", "road"],
159
+ # ["Nice", "grease", "house", "spouse", "peace"],
160
+ # ["deadlines", "lines", "edits", "credits", "wine"],
161
+ # ["Lynn", "thin", "essayed", "lemonade", "in"],
162
+ # ["beard", "feared", "hen", "wren", "beard"]]
163
+
164
+ for i in range(5):
165
+ print(f"\nGenerating limerick {i+1}")
166
+ rhyming_words_125 = []
167
+ while len(rhyming_words_125) < 3 or valid_rhyme == False:
168
+ first_line = get_line(topic_summary, starting_words, inputs_len)
169
+ #rhyming_words = pronouncing.rhymes(first_line.split()[-1])
170
+ end_word = remove_punctuation(first_line.split()[-1])
171
+ valid_rhyme = filter_rhymes(end_word)
172
+ if valid_rhyme:
173
+ print(f"\nFirst Line: {first_line}")
174
+ rhyming_words_125 = list(get_rhymes(end_word, 3))
175
+ print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
176
+ limerick = first_line + "\n"
177
+
178
+ rhyming_word = rhyming_words_125[0]
179
+ second_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len)
180
+ limerick += second_line + "\n"
181
+
182
+ rhyming_words_34 = []
183
+ while len(rhyming_words_34) < 2 or valid_rhyme == False:
184
+ third_line = get_line(topic_summary, starting_words, inputs_len)
185
+ print(f"\nThird Line: {third_line}")
186
+ #rhyming_words = pronouncing.rhymes(first_line.split()[-1])
187
+ end_word = remove_punctuation(third_line.split()[-1])
188
+ valid_rhyme = filter_rhymes(end_word)
189
+ print(f"Does '{end_word}'' have valid rhymes: {valid_rhyme}")
190
+ rhyming_words_34 = list(get_rhymes(end_word, 3))
191
+ print(f"Rhyming words for '{end_word}' are {rhyming_words_34}")
192
+ if valid_rhyme and len(rhyming_words_34) > 1:
193
+ limerick += third_line + "\n"
194
+
195
+ rhyming_word = rhyming_words_34[0]
196
+ fourth_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len)
197
+ limerick += fourth_line + "\n"
198
+
199
+ rhyming_word = rhyming_words_125[1]
200
+ fifth_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len)
201
+ limerick += fifth_line + "\n"
202
+
203
+ limericks.append(limerick)
204
+
205
+ print("\n")
206
+ output = f"Generated {len(limericks)} limericks: \n"
207
+
208
+ print(f"Generated {len(limericks)} limericks: \n")
209
+ for limerick in limericks:
210
+ print(limerick)
211
+ output += limerick
212
+
213
+ return output
214
+
215
+ interface = gr.Interface(fn=generate, inputs="text", outputs="text")
216
+ interface.launch(debug=True)