Spaces:
Running
Running
ansfarooq7
commited on
Commit
•
2acd461
1
Parent(s):
01025ef
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
!pip install transformers
|
2 |
+
!pip install pronouncing
|
3 |
+
!pip install wikipedia
|
4 |
+
!pip install syllables
|
5 |
+
!pip install gradio
|
6 |
+
from transformers import RobertaTokenizer, RobertaForMaskedLM, pipeline, GPT2TokenizerFast
|
7 |
+
import torch
|
8 |
+
import pronouncing
|
9 |
+
import wikipedia
|
10 |
+
import re
|
11 |
+
import random
|
12 |
+
import nltk
|
13 |
+
import syllables
|
14 |
+
import gradio as gr
|
15 |
+
nltk.download('cmudict')
|
16 |
+
|
17 |
+
frequent_words = set()
|
18 |
+
|
19 |
+
def set_seed(seed: int):
|
20 |
+
"""
|
21 |
+
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
|
22 |
+
installed).
|
23 |
+
|
24 |
+
Args:
|
25 |
+
seed (:obj:`int`): The seed to set.
|
26 |
+
"""
|
27 |
+
#random.seed(seed)
|
28 |
+
#np.random.seed(seed)
|
29 |
+
#if is_torch_available():
|
30 |
+
torch.manual_seed(seed)
|
31 |
+
torch.cuda.manual_seed_all(seed)
|
32 |
+
# ^^ safe to call this function even if cuda is not available
|
33 |
+
#if is_tf_available():
|
34 |
+
#tf.random.set_seed(seed)
|
35 |
+
|
36 |
+
with open("wordFrequency.txt", 'r') as f:
|
37 |
+
line = f.readline()
|
38 |
+
while line != '': # The EOF char is an empty string
|
39 |
+
frequent_words.add(line.strip())
|
40 |
+
line = f.readline()
|
41 |
+
|
42 |
+
def filter_rhymes(word):
|
43 |
+
filter_list = ['to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'aitch', 'angst', 'arugula', 'beige', 'blitzed', 'boing', 'bombed', 'cairn', 'chaos', 'chocolate', 'circle', 'circus', 'cleansed', 'coif', 'cusp', 'doth', 'else', 'eth', 'fiends', 'film', 'flange', 'fourths', 'grilse', 'gulf', 'kiln', 'loge', 'midst', 'month', 'music', 'neutron', 'ninja', 'oblige', 'oink', 'opus', 'orange', 'pint', 'plagued', 'plankton', 'plinth', 'poem', 'poet', 'purple', 'quaich', 'rhythm', 'rouged', 'silver', 'siren', 'soldier', 'sylph', 'thesp', 'toilet', 'torsk', 'tufts', 'waltzed', 'wasp', 'wharves', 'width', 'woman', 'yttrium']
|
44 |
+
if word in filter_list:
|
45 |
+
return False
|
46 |
+
else:
|
47 |
+
return True
|
48 |
+
|
49 |
+
def remove_punctuation(text):
|
50 |
+
text = re.sub(r'[^\w\s]', '', text)
|
51 |
+
return text
|
52 |
+
|
53 |
+
def get_rhymes(inp, level):
|
54 |
+
entries = nltk.corpus.cmudict.entries()
|
55 |
+
syllables = [(word, syl) for word, syl in entries if word == inp]
|
56 |
+
rhymes = []
|
57 |
+
filtered_rhymes = set()
|
58 |
+
for (word, syllable) in syllables:
|
59 |
+
rhymes += [word for word, pron in entries if pron[-level:] == syllable[-level:]]
|
60 |
+
|
61 |
+
for word in rhymes:
|
62 |
+
if (word in frequent_words) and (word != inp):
|
63 |
+
filtered_rhymes.add(word)
|
64 |
+
return filtered_rhymes
|
65 |
+
|
66 |
+
def get_inputs_length(input):
|
67 |
+
gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
68 |
+
input_ids = gpt2_tokenizer(input)['input_ids']
|
69 |
+
return len(input_ids)
|
70 |
+
|
71 |
+
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
72 |
+
model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
73 |
+
text_generation = pipeline("text-generation")
|
74 |
+
set_seed(0)
|
75 |
+
|
76 |
+
def get_prediction(sent):
|
77 |
+
|
78 |
+
token_ids = tokenizer.encode(sent, return_tensors='pt')
|
79 |
+
masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero()
|
80 |
+
masked_pos = [mask.item() for mask in masked_position ]
|
81 |
+
|
82 |
+
with torch.no_grad():
|
83 |
+
output = model(token_ids)
|
84 |
+
|
85 |
+
last_hidden_state = output[0].squeeze()
|
86 |
+
|
87 |
+
list_of_list =[]
|
88 |
+
for index,mask_index in enumerate(masked_pos):
|
89 |
+
words = []
|
90 |
+
mask_hidden_state = last_hidden_state[mask_index]
|
91 |
+
idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
|
92 |
+
for i in idx:
|
93 |
+
word = tokenizer.decode(i.item()).strip()
|
94 |
+
if (remove_punctuation(word) != "") and (word != '</s>'):
|
95 |
+
words.append(word)
|
96 |
+
#words = [tokenizer.decode(i.item()).strip() for i in idx]
|
97 |
+
list_of_list.append(words)
|
98 |
+
print(f"Mask {index+1} Guesses: {words}")
|
99 |
+
|
100 |
+
best_guess = ""
|
101 |
+
for j in list_of_list:
|
102 |
+
best_guess = best_guess+" "+j[0]
|
103 |
+
|
104 |
+
return best_guess
|
105 |
+
|
106 |
+
def get_line(topic_summary, starting_words, inputs_len):
|
107 |
+
starting_word = random.choice(starting_words)
|
108 |
+
line = starting_word + text_generation(topic_summary + " " + starting_word, max_length=inputs_len + 6, do_sample=True, return_full_text=False)[0]['generated_text']
|
109 |
+
return line
|
110 |
+
|
111 |
+
def get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len):
|
112 |
+
#gpt2_sentence = text_generation(topic_summary + " " + starting_words[i][j], max_length=no_of_words + 4, do_sample=False)[0]
|
113 |
+
starting_word = random.choice(starting_words)
|
114 |
+
print(f"\nGetting rhyming line with starting word '{starting_word}' and rhyming word '{rhyming_word}'")
|
115 |
+
gpt2_sentence = text_generation(topic_summary + " " + starting_word, max_length=inputs_len + 2, do_sample=True, return_full_text=False)[0]
|
116 |
+
#sentence = gpt2_sentence['generated_text'] + " ___ ___ ___ " + rhyming_words[i][j]
|
117 |
+
sentence = starting_word + gpt2_sentence['generated_text'] + " ___ ___ ___ " + rhyming_word
|
118 |
+
print(f"Original Sentence: {sentence}")
|
119 |
+
if sentence[-1] != ".":
|
120 |
+
sentence = sentence.replace("___","<mask>") + "."
|
121 |
+
else:
|
122 |
+
sentence = sentence.replace("___","<mask>")
|
123 |
+
print(f"Original Sentence replaced with mask: {sentence}")
|
124 |
+
print("\n")
|
125 |
+
|
126 |
+
predicted_blanks = get_prediction(sentence)
|
127 |
+
print(f"\nBest guess for fill in the blanks: {predicted_blanks}")
|
128 |
+
return starting_word + gpt2_sentence['generated_text'] + predicted_blanks + " " + rhyming_word
|
129 |
+
|
130 |
+
from transformers import pipeline
|
131 |
+
|
132 |
+
def generate(topic):
|
133 |
+
text_generation = pipeline("text-generation")
|
134 |
+
|
135 |
+
limericks = []
|
136 |
+
|
137 |
+
#topic = input("Please enter a topic: ")
|
138 |
+
topic_summary = remove_punctuation(wikipedia.summary(topic))
|
139 |
+
# if len(topic_summary) > 2000:
|
140 |
+
# topic_summary = topic_summary[:2000]
|
141 |
+
word_list = topic_summary.split()
|
142 |
+
topic_summary_len = len(topic_summary)
|
143 |
+
no_of_words = len(word_list)
|
144 |
+
inputs_len = get_inputs_length(topic_summary)
|
145 |
+
print(f"Topic Summary: {topic_summary}")
|
146 |
+
print(f"Topic Summary Length: {topic_summary_len}")
|
147 |
+
print(f"No of Words in Summary: {no_of_words}")
|
148 |
+
print(f"Length of Input IDs: {inputs_len}")
|
149 |
+
|
150 |
+
starting_words = ["That", "Had", "Not", "But", "With", "I", "Because", "There", "Who", "She", "He", "To", "Whose", "In", "And", "When", "Or", "So", "The", "Of", "Every", "Whom"]
|
151 |
+
|
152 |
+
# starting_words = [["That", "Had", "Not", "But", "That"],
|
153 |
+
# ["There", "Who", "She", "Tormenting", "Til"],
|
154 |
+
# ["Relentless", "This", "First", "and", "then"],
|
155 |
+
# ["There", "Who", "That", "To", "She"],
|
156 |
+
# ["There", "Who", "Two", "Four", "Have"]]
|
157 |
+
|
158 |
+
# rhyming_words = [["told", "bold", "woodchuck", "truck", "road"],
|
159 |
+
# ["Nice", "grease", "house", "spouse", "peace"],
|
160 |
+
# ["deadlines", "lines", "edits", "credits", "wine"],
|
161 |
+
# ["Lynn", "thin", "essayed", "lemonade", "in"],
|
162 |
+
# ["beard", "feared", "hen", "wren", "beard"]]
|
163 |
+
|
164 |
+
for i in range(5):
|
165 |
+
print(f"\nGenerating limerick {i+1}")
|
166 |
+
rhyming_words_125 = []
|
167 |
+
while len(rhyming_words_125) < 3 or valid_rhyme == False:
|
168 |
+
first_line = get_line(topic_summary, starting_words, inputs_len)
|
169 |
+
#rhyming_words = pronouncing.rhymes(first_line.split()[-1])
|
170 |
+
end_word = remove_punctuation(first_line.split()[-1])
|
171 |
+
valid_rhyme = filter_rhymes(end_word)
|
172 |
+
if valid_rhyme:
|
173 |
+
print(f"\nFirst Line: {first_line}")
|
174 |
+
rhyming_words_125 = list(get_rhymes(end_word, 3))
|
175 |
+
print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
|
176 |
+
limerick = first_line + "\n"
|
177 |
+
|
178 |
+
rhyming_word = rhyming_words_125[0]
|
179 |
+
second_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len)
|
180 |
+
limerick += second_line + "\n"
|
181 |
+
|
182 |
+
rhyming_words_34 = []
|
183 |
+
while len(rhyming_words_34) < 2 or valid_rhyme == False:
|
184 |
+
third_line = get_line(topic_summary, starting_words, inputs_len)
|
185 |
+
print(f"\nThird Line: {third_line}")
|
186 |
+
#rhyming_words = pronouncing.rhymes(first_line.split()[-1])
|
187 |
+
end_word = remove_punctuation(third_line.split()[-1])
|
188 |
+
valid_rhyme = filter_rhymes(end_word)
|
189 |
+
print(f"Does '{end_word}'' have valid rhymes: {valid_rhyme}")
|
190 |
+
rhyming_words_34 = list(get_rhymes(end_word, 3))
|
191 |
+
print(f"Rhyming words for '{end_word}' are {rhyming_words_34}")
|
192 |
+
if valid_rhyme and len(rhyming_words_34) > 1:
|
193 |
+
limerick += third_line + "\n"
|
194 |
+
|
195 |
+
rhyming_word = rhyming_words_34[0]
|
196 |
+
fourth_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len)
|
197 |
+
limerick += fourth_line + "\n"
|
198 |
+
|
199 |
+
rhyming_word = rhyming_words_125[1]
|
200 |
+
fifth_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len)
|
201 |
+
limerick += fifth_line + "\n"
|
202 |
+
|
203 |
+
limericks.append(limerick)
|
204 |
+
|
205 |
+
print("\n")
|
206 |
+
output = f"Generated {len(limericks)} limericks: \n"
|
207 |
+
|
208 |
+
print(f"Generated {len(limericks)} limericks: \n")
|
209 |
+
for limerick in limericks:
|
210 |
+
print(limerick)
|
211 |
+
output += limerick
|
212 |
+
|
213 |
+
return output
|
214 |
+
|
215 |
+
interface = gr.Interface(fn=generate, inputs="text", outputs="text")
|
216 |
+
interface.launch(debug=True)
|