Spaces:
Sleeping
Sleeping
ansfarooq7
commited on
Commit
•
1c3c84c
1
Parent(s):
84d0230
Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,19 @@
|
|
1 |
-
from transformers import RobertaTokenizer, RobertaForMaskedLM, pipeline,
|
2 |
import torch
|
3 |
import wikipedia
|
4 |
import re
|
5 |
import random
|
6 |
import nltk
|
7 |
import syllables
|
8 |
-
import gradio as gr
|
9 |
nltk.download('cmudict')
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
frequent_words = set()
|
12 |
|
13 |
def set_seed(seed: int):
|
@@ -58,36 +64,33 @@ def get_rhymes(inp, level):
|
|
58 |
return filtered_rhymes
|
59 |
|
60 |
def get_inputs_length(input):
|
61 |
-
|
62 |
-
input_ids = gpt2_tokenizer(input)['input_ids']
|
63 |
return len(input_ids)
|
64 |
-
|
65 |
-
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
66 |
-
model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
67 |
-
text_generation = pipeline("text-generation")
|
68 |
set_seed(0)
|
69 |
|
70 |
def get_prediction(sent):
|
71 |
|
72 |
-
token_ids =
|
73 |
-
masked_position = (token_ids.squeeze() ==
|
74 |
masked_pos = [mask.item() for mask in masked_position ]
|
75 |
|
76 |
with torch.no_grad():
|
77 |
-
output =
|
78 |
|
79 |
last_hidden_state = output[0].squeeze()
|
80 |
|
81 |
list_of_list =[]
|
82 |
for index,mask_index in enumerate(masked_pos):
|
83 |
words = []
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
91 |
list_of_list.append(words)
|
92 |
print(f"Mask {index+1} Guesses: {words}")
|
93 |
|
@@ -97,18 +100,21 @@ def get_prediction(sent):
|
|
97 |
|
98 |
return best_guess
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
return line
|
104 |
|
105 |
-
def get_rhyming_line(
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
gpt2_sentence = text_generation(topic_summary + " " + starting_word, max_length=inputs_len + 2, do_sample=True, return_full_text=False)[0]
|
110 |
-
#sentence = gpt2_sentence['generated_text'] + " ___ ___ ___ " + rhyming_words[i][j]
|
111 |
-
sentence = starting_word + gpt2_sentence['generated_text'] + " ___ ___ ___ " + rhyming_word
|
112 |
print(f"Original Sentence: {sentence}")
|
113 |
if sentence[-1] != ".":
|
114 |
sentence = sentence.replace("___","<mask>") + "."
|
@@ -119,20 +125,15 @@ def get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len):
|
|
119 |
|
120 |
predicted_blanks = get_prediction(sentence)
|
121 |
print(f"\nBest guess for fill in the blanks: {predicted_blanks}")
|
122 |
-
|
|
|
|
|
123 |
|
124 |
-
from transformers import pipeline
|
125 |
-
|
126 |
def generate(topic):
|
127 |
-
text_generation = pipeline("text-generation")
|
128 |
|
129 |
limericks = []
|
130 |
|
131 |
-
|
132 |
-
## topic_summary = remove_punctuation(wikipedia.summary(topic))
|
133 |
-
topic_summary = topic
|
134 |
-
# if len(topic_summary) > 2000:
|
135 |
-
# topic_summary = topic_summary[:2000]
|
136 |
word_list = topic_summary.split()
|
137 |
topic_summary_len = len(topic_summary)
|
138 |
no_of_words = len(word_list)
|
@@ -140,59 +141,64 @@ def generate(topic):
|
|
140 |
print(f"Topic Summary: {topic_summary}")
|
141 |
print(f"Topic Summary Length: {topic_summary_len}")
|
142 |
print(f"No of Words in Summary: {no_of_words}")
|
143 |
-
print(f"Length of Input IDs: {inputs_len}")
|
144 |
-
|
145 |
-
starting_words = ["That", "Had", "Not", "But", "With", "I", "Because", "There", "Who", "She", "He", "To", "Whose", "In", "And", "When", "Or", "So", "The", "Of", "Every", "Whom"]
|
146 |
-
|
147 |
-
# starting_words = [["That", "Had", "Not", "But", "That"],
|
148 |
-
# ["There", "Who", "She", "Tormenting", "Til"],
|
149 |
-
# ["Relentless", "This", "First", "and", "then"],
|
150 |
-
# ["There", "Who", "That", "To", "She"],
|
151 |
-
# ["There", "Who", "Two", "Four", "Have"]]
|
152 |
|
153 |
-
|
154 |
-
# ["Nice", "grease", "house", "spouse", "peace"],
|
155 |
-
# ["deadlines", "lines", "edits", "credits", "wine"],
|
156 |
-
# ["Lynn", "thin", "essayed", "lemonade", "in"],
|
157 |
-
# ["beard", "feared", "hen", "wren", "beard"]]
|
158 |
-
|
159 |
-
for i in range(5):
|
160 |
print(f"\nGenerating limerick {i+1}")
|
161 |
rhyming_words_125 = []
|
162 |
-
while len(rhyming_words_125) < 3 or valid_rhyme == False:
|
163 |
-
first_line = get_line(topic_summary,
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
|
173 |
rhyming_word = rhyming_words_125[0]
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
175 |
limerick += second_line + "\n"
|
176 |
|
177 |
rhyming_words_34 = []
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
189 |
|
190 |
rhyming_word = rhyming_words_34[0]
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
192 |
limerick += fourth_line + "\n"
|
193 |
|
194 |
rhyming_word = rhyming_words_125[1]
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
196 |
limerick += fifth_line + "\n"
|
197 |
|
198 |
limericks.append(limerick)
|
@@ -203,9 +209,11 @@ def generate(topic):
|
|
203 |
print(f"Generated {len(limericks)} limericks: \n")
|
204 |
for limerick in limericks:
|
205 |
print(limerick)
|
206 |
-
output += limerick
|
207 |
|
208 |
return output
|
|
|
|
|
209 |
|
210 |
interface = gr.Interface(fn=generate, inputs="text", outputs="text")
|
211 |
interface.launch(debug=True)
|
|
|
1 |
+
from transformers import RobertaTokenizer, RobertaForMaskedLM, pipeline, GPT2Tokenizer, GPT2LMHeadModel
|
2 |
import torch
|
3 |
import wikipedia
|
4 |
import re
|
5 |
import random
|
6 |
import nltk
|
7 |
import syllables
|
|
|
8 |
nltk.download('cmudict')
|
9 |
|
10 |
+
masked_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
11 |
+
masked_model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
12 |
+
|
13 |
+
causal_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
14 |
+
# add the EOS token as PAD token to avoid warnings
|
15 |
+
causal_model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=causal_tokenizer.eos_token_id)
|
16 |
+
|
17 |
frequent_words = set()
|
18 |
|
19 |
def set_seed(seed: int):
|
|
|
64 |
return filtered_rhymes
|
65 |
|
66 |
def get_inputs_length(input):
|
67 |
+
input_ids = causal_tokenizer(input)['input_ids']
|
|
|
68 |
return len(input_ids)
|
69 |
+
|
|
|
|
|
|
|
70 |
set_seed(0)
|
71 |
|
72 |
def get_prediction(sent):
|
73 |
|
74 |
+
token_ids = masked_tokenizer.encode(sent, return_tensors='pt')
|
75 |
+
masked_position = (token_ids.squeeze() == masked_tokenizer.mask_token_id).nonzero()
|
76 |
masked_pos = [mask.item() for mask in masked_position ]
|
77 |
|
78 |
with torch.no_grad():
|
79 |
+
output = masked_model(token_ids)
|
80 |
|
81 |
last_hidden_state = output[0].squeeze()
|
82 |
|
83 |
list_of_list =[]
|
84 |
for index,mask_index in enumerate(masked_pos):
|
85 |
words = []
|
86 |
+
while not words:
|
87 |
+
mask_hidden_state = last_hidden_state[mask_index]
|
88 |
+
idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
|
89 |
+
for i in idx:
|
90 |
+
word = masked_tokenizer.decode(i.item()).strip()
|
91 |
+
if (remove_punctuation(word) != "") and (word != '</s>'):
|
92 |
+
words.append(word)
|
93 |
+
#words = [masked_tokenizer.decode(i.item()).strip() for i in idx]
|
94 |
list_of_list.append(words)
|
95 |
print(f"Mask {index+1} Guesses: {words}")
|
96 |
|
|
|
100 |
|
101 |
return best_guess
|
102 |
|
103 |
+
text_generation = pipeline("text-generation", model=causal_model, tokenizer=causal_tokenizer)
|
104 |
+
from aitextgen import aitextgen
|
105 |
+
|
106 |
+
# Without any parameters, aitextgen() will download, cache, and load the 124M GPT-2 "small" model
|
107 |
+
ai = aitextgen()
|
108 |
+
|
109 |
+
|
110 |
+
def get_line(prompt, inputs_len):
|
111 |
+
line = ai.generate_one(prompt=prompt + ".", max_length=inputs_len + 7)[len(prompt)+2:]
|
112 |
return line
|
113 |
|
114 |
+
def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
115 |
+
gpt2_sentence = ai.generate_one(prompt=prompt + ".", max_length=inputs_len + 4)[len(prompt)+2:]
|
116 |
+
print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
|
117 |
+
sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
|
|
|
|
|
|
|
118 |
print(f"Original Sentence: {sentence}")
|
119 |
if sentence[-1] != ".":
|
120 |
sentence = sentence.replace("___","<mask>") + "."
|
|
|
125 |
|
126 |
predicted_blanks = get_prediction(sentence)
|
127 |
print(f"\nBest guess for fill in the blanks: {predicted_blanks}")
|
128 |
+
final_sentence = gpt2_sentence + predicted_blanks + " " + rhyming_word
|
129 |
+
print(f"Final Sentence: {final_sentence}")
|
130 |
+
return final_sentence
|
131 |
|
|
|
|
|
132 |
def generate(topic):
|
|
|
133 |
|
134 |
limericks = []
|
135 |
|
136 |
+
topic_summary = remove_punctuation(wikipedia.summary(topic))
|
|
|
|
|
|
|
|
|
137 |
word_list = topic_summary.split()
|
138 |
topic_summary_len = len(topic_summary)
|
139 |
no_of_words = len(word_list)
|
|
|
141 |
print(f"Topic Summary: {topic_summary}")
|
142 |
print(f"Topic Summary Length: {topic_summary_len}")
|
143 |
print(f"No of Words in Summary: {no_of_words}")
|
144 |
+
print(f"Length of Input IDs: {inputs_len}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
+
for i in range(1):
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
print(f"\nGenerating limerick {i+1}")
|
148 |
rhyming_words_125 = []
|
149 |
+
while len(rhyming_words_125) < 3 or valid_rhyme == False or len(first_line) == 0:
|
150 |
+
first_line = get_line(topic_summary, inputs_len)
|
151 |
+
if first_line:
|
152 |
+
end_word = remove_punctuation(first_line.split()[-1])
|
153 |
+
valid_rhyme = filter_rhymes(end_word)
|
154 |
+
if valid_rhyme:
|
155 |
+
print(f"\nFirst Line: {first_line}")
|
156 |
+
rhyming_words_125 = list(get_rhymes(end_word, 3))
|
157 |
+
print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
|
158 |
+
limerick = first_line + "\n"
|
159 |
|
160 |
rhyming_word = rhyming_words_125[0]
|
161 |
+
prompt = topic_summary + " " + first_line
|
162 |
+
inputs_len = get_inputs_length(prompt)
|
163 |
+
print(f"Prompt: {prompt}")
|
164 |
+
print(f"Length of prompt: {inputs_len}")
|
165 |
+
second_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
|
166 |
+
print(f"\nSecond Line: {second_line}")
|
167 |
limerick += second_line + "\n"
|
168 |
|
169 |
rhyming_words_34 = []
|
170 |
+
prompt = prompt + " " + second_line
|
171 |
+
inputs_len = get_inputs_length(prompt)
|
172 |
+
print(f"Prompt: {prompt}")
|
173 |
+
print(f"Length of prompt: {inputs_len}")
|
174 |
+
while len(rhyming_words_34) < 2 or valid_rhyme == False or len(third_line) == 0:
|
175 |
+
third_line = get_line(prompt, inputs_len)
|
176 |
+
if third_line:
|
177 |
+
print(f"\nThird Line: {third_line}")
|
178 |
+
end_word = remove_punctuation(third_line.split()[-1])
|
179 |
+
valid_rhyme = filter_rhymes(end_word)
|
180 |
+
print(f"Does '{end_word}'' have valid rhymes: {valid_rhyme}")
|
181 |
+
rhyming_words_34 = list(get_rhymes(end_word, 3))
|
182 |
+
print(f"Rhyming words for '{end_word}' are {rhyming_words_34}")
|
183 |
+
if valid_rhyme and len(rhyming_words_34) > 1:
|
184 |
+
limerick += third_line + "\n"
|
185 |
|
186 |
rhyming_word = rhyming_words_34[0]
|
187 |
+
prompt = prompt + " " + third_line
|
188 |
+
inputs_len = get_inputs_length(prompt)
|
189 |
+
print(f"Prompt: {prompt}")
|
190 |
+
print(f"Length of prompt: {inputs_len}")
|
191 |
+
fourth_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
|
192 |
+
print(f"\nFourth Line: {fourth_line}")
|
193 |
limerick += fourth_line + "\n"
|
194 |
|
195 |
rhyming_word = rhyming_words_125[1]
|
196 |
+
prompt = prompt + " " + fourth_line
|
197 |
+
inputs_len = get_inputs_length(prompt)
|
198 |
+
print(f"Prompt: {prompt}")
|
199 |
+
print(f"Length of prompt: {inputs_len}")
|
200 |
+
fifth_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
|
201 |
+
print(f"\nFifth Line: {fifth_line}")
|
202 |
limerick += fifth_line + "\n"
|
203 |
|
204 |
limericks.append(limerick)
|
|
|
209 |
print(f"Generated {len(limericks)} limericks: \n")
|
210 |
for limerick in limericks:
|
211 |
print(limerick)
|
212 |
+
output += "\n" + limerick
|
213 |
|
214 |
return output
|
215 |
+
|
216 |
+
import gradio as gr
|
217 |
|
218 |
interface = gr.Interface(fn=generate, inputs="text", outputs="text")
|
219 |
interface.launch(debug=True)
|