Spaces:
Sleeping
Sleeping
ansfarooq7
commited on
Commit
•
f5fde7c
1
Parent(s):
a494241
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer, GPT2LMHeadModel, pipeline
|
2 |
import torch
|
3 |
import wikipedia
|
@@ -7,22 +8,26 @@ import nltk
|
|
7 |
import gradio as gr
|
8 |
nltk.download('cmudict')
|
9 |
|
|
|
10 |
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
11 |
roberta_model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
12 |
|
|
|
13 |
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
14 |
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=gpt2_tokenizer.eos_token_id)
|
15 |
|
|
|
16 |
gpt2_pipeline = pipeline('text-generation', model=gpt2_model, tokenizer=gpt2_tokenizer)
|
17 |
|
|
|
18 |
frequent_words = set()
|
19 |
-
|
20 |
with open("wordFrequency.txt", 'r') as f:
|
21 |
line = f.readline()
|
22 |
while line != '': # The EOF char is an empty string
|
23 |
frequent_words.add(line.strip())
|
24 |
line = f.readline()
|
25 |
|
|
|
26 |
def filter_rhymes(word):
|
27 |
filter_list = ['an', 'to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'are', 'or', 'its', 'it''s']
|
28 |
if word in filter_list:
|
@@ -30,11 +35,16 @@ def filter_rhymes(word):
|
|
30 |
else:
|
31 |
return True
|
32 |
|
|
|
33 |
def remove_punctuation(text):
|
34 |
text = re.sub(r'[^\w\s]', '', text)
|
35 |
text = text.replace("\n", " ")
|
36 |
return text.strip()
|
37 |
|
|
|
|
|
|
|
|
|
38 |
def get_rhymes(inp, level):
|
39 |
entries = nltk.corpus.cmudict.entries()
|
40 |
syllables = [(word, syl) for word, syl in entries if word == inp]
|
@@ -48,10 +58,18 @@ def get_rhymes(inp, level):
|
|
48 |
filtered_rhymes.add(word)
|
49 |
return filtered_rhymes
|
50 |
|
|
|
|
|
51 |
def get_inputs_length(input):
|
52 |
input_ids = gpt2_tokenizer(input)['input_ids']
|
53 |
return len(input_ids)
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def get_prediction(sent):
|
56 |
token_ids = roberta_tokenizer.encode(sent, return_tensors='pt')
|
57 |
masked_position = (token_ids.squeeze() == roberta_tokenizer.mask_token_id).nonzero()
|
@@ -68,6 +86,8 @@ def get_prediction(sent):
|
|
68 |
while not words:
|
69 |
mask_hidden_state = last_hidden_state[mask_index]
|
70 |
idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
|
|
|
|
|
71 |
for i in idx:
|
72 |
word = roberta_tokenizer.decode(i.item()).strip()
|
73 |
if (remove_punctuation(word) != "") and (word != '</s>'):
|
@@ -80,7 +100,9 @@ def get_prediction(sent):
|
|
80 |
best_guess = best_guess+" "+j[0]
|
81 |
|
82 |
return best_guess
|
83 |
-
|
|
|
|
|
84 |
def get_line(prompt, inputs_len):
|
85 |
output = gpt2_pipeline(
|
86 |
prompt + ".",
|
@@ -91,6 +113,9 @@ def get_line(prompt, inputs_len):
|
|
91 |
)
|
92 |
return remove_punctuation(output[0]['generated_text'])
|
93 |
|
|
|
|
|
|
|
94 |
def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
95 |
output = gpt2_pipeline(
|
96 |
prompt + ".",
|
@@ -126,6 +151,8 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
|
126 |
print(f"Final Sentence: {final_sentence}")
|
127 |
return final_sentence
|
128 |
|
|
|
|
|
129 |
def gpt2_summary(topic):
|
130 |
output = gpt2_pipeline(
|
131 |
f"Here is some information about {topic}.",
|
@@ -135,8 +162,11 @@ def gpt2_summary(topic):
|
|
135 |
return_full_text=False
|
136 |
)
|
137 |
return remove_punctuation(output[0]['generated_text'])
|
138 |
-
|
|
|
139 |
def generate(topic, wiki=True):
|
|
|
|
|
140 |
if wiki:
|
141 |
try:
|
142 |
topic_search = wikipedia.search(topic, results=3)
|
@@ -148,9 +178,12 @@ def generate(topic, wiki=True):
|
|
148 |
topic_summary = remove_punctuation(wikipedia.summary(page, auto_suggest=False))
|
149 |
except:
|
150 |
return(f"Method A struggled to find information about {topic}, please try a different topic!")
|
|
|
|
|
151 |
else:
|
152 |
topic_summary = remove_punctuation(gpt2_summary(topic))
|
153 |
|
|
|
154 |
word_list = topic_summary.split()
|
155 |
topic_summary_len = len(topic_summary)
|
156 |
no_of_words = len(word_list)
|
@@ -160,6 +193,7 @@ def generate(topic, wiki=True):
|
|
160 |
print(f"No of Words in Summary: {no_of_words}")
|
161 |
print(f"Length of Input IDs: {inputs_len}")
|
162 |
|
|
|
163 |
rhyming_words_125 = []
|
164 |
while len(rhyming_words_125) < 3 or valid_rhyme == False or len(first_line) == 0:
|
165 |
first_line = get_line(topic_summary, inputs_len)
|
@@ -172,6 +206,7 @@ def generate(topic, wiki=True):
|
|
172 |
print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
|
173 |
limerick = first_line + "\n"
|
174 |
|
|
|
175 |
rhyming_word = random.choice(rhyming_words_125)
|
176 |
rhyming_words_125.remove(rhyming_word)
|
177 |
prompt = topic_summary + " " + first_line
|
@@ -182,6 +217,7 @@ def generate(topic, wiki=True):
|
|
182 |
print(f"\nSecond Line: {second_line}")
|
183 |
limerick += second_line + "\n"
|
184 |
|
|
|
185 |
rhyming_words_34 = []
|
186 |
prompt = prompt + " " + second_line
|
187 |
inputs_len = get_inputs_length(prompt)
|
@@ -199,6 +235,7 @@ def generate(topic, wiki=True):
|
|
199 |
if valid_rhyme and len(rhyming_words_34) > 1:
|
200 |
limerick += third_line + "\n"
|
201 |
|
|
|
202 |
rhyming_word = random.choice(rhyming_words_34)
|
203 |
rhyming_words_34.remove(rhyming_word)
|
204 |
prompt = prompt + " " + third_line
|
@@ -209,6 +246,7 @@ def generate(topic, wiki=True):
|
|
209 |
print(f"\nFourth Line: {fourth_line}")
|
210 |
limerick += fourth_line + "\n"
|
211 |
|
|
|
212 |
rhyming_word = random.choice(rhyming_words_125)
|
213 |
rhyming_words_125.remove(rhyming_word)
|
214 |
prompt = prompt + " " + fourth_line
|
@@ -223,7 +261,8 @@ def generate(topic, wiki=True):
|
|
223 |
print(limerick)
|
224 |
|
225 |
return limerick
|
226 |
-
|
|
|
227 |
def compare_summaries(topic):
|
228 |
wiki_limerick = generate(topic)
|
229 |
gpt2_limerick = generate(topic, wiki=False)
|
@@ -233,8 +272,11 @@ def compare_summaries(topic):
|
|
233 |
print(output1 + "\n" + output2)
|
234 |
|
235 |
return output1, output2
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
238 |
article = '<center><big><strong>Limerick Generation</strong></big></center>'\
|
239 |
'<center><strong>By Ans Farooq</strong></center>'\
|
240 |
'<center><small>Level 4 Individual Project</small></center>'\
|
@@ -265,4 +307,4 @@ interface = gr.Interface(
|
|
265 |
theme="peach",
|
266 |
description=description,
|
267 |
article=article)
|
268 |
-
interface.launch(debug=
|
|
|
1 |
+
# Import required packages
|
2 |
from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer, GPT2LMHeadModel, pipeline
|
3 |
import torch
|
4 |
import wikipedia
|
|
|
8 |
import gradio as gr
|
9 |
nltk.download('cmudict')
|
10 |
|
11 |
+
# Use the RoBERTa model from HuggingFace for masked language modelling
|
12 |
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
13 |
roberta_model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
14 |
|
15 |
+
# Use the GPT-2 from HuggingFace for causal language modelling
|
16 |
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
17 |
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=gpt2_tokenizer.eos_token_id)
|
18 |
|
19 |
+
# Initialise a text generation pipeline using HuggingFace Transformers and the pre-trained GPT-2 model
|
20 |
gpt2_pipeline = pipeline('text-generation', model=gpt2_model, tokenizer=gpt2_tokenizer)
|
21 |
|
22 |
+
# Hold all the words in wordFrequency.txt in a Python set
|
23 |
frequent_words = set()
|
|
|
24 |
with open("wordFrequency.txt", 'r') as f:
|
25 |
line = f.readline()
|
26 |
while line != '': # The EOF char is an empty string
|
27 |
frequent_words.add(line.strip())
|
28 |
line = f.readline()
|
29 |
|
30 |
+
# Used alongside the word frequency list to filter out problematic words for rhyming
|
31 |
def filter_rhymes(word):
|
32 |
filter_list = ['an', 'to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'are', 'or', 'its', 'it''s']
|
33 |
if word in filter_list:
|
|
|
35 |
else:
|
36 |
return True
|
37 |
|
38 |
+
# Used to remove any punctuation and new line characters from generated text
|
39 |
def remove_punctuation(text):
|
40 |
text = re.sub(r'[^\w\s]', '', text)
|
41 |
text = text.replace("\n", " ")
|
42 |
return text.strip()
|
43 |
|
44 |
+
# Used to find rhymes to a given word using NLTK
|
45 |
+
# where inp is a word and level means how good the rhyme should be.
|
46 |
+
# Adapted from the following Stack Overflow answer:
|
47 |
+
# https://stackoverflow.com/a/25714769/18559178
|
48 |
def get_rhymes(inp, level):
|
49 |
entries = nltk.corpus.cmudict.entries()
|
50 |
syllables = [(word, syl) for word, syl in entries if word == inp]
|
|
|
58 |
filtered_rhymes.add(word)
|
59 |
return filtered_rhymes
|
60 |
|
61 |
+
# Used to get the length of the topic summary, to then determine max length for
|
62 |
+
# the text generation pipeline
|
63 |
def get_inputs_length(input):
|
64 |
input_ids = gpt2_tokenizer(input)['input_ids']
|
65 |
return len(input_ids)
|
66 |
+
|
67 |
+
# Sized Fill-in-the-blank or Multi Mask filling with RoBERTa and Huggingface Transformers
|
68 |
+
# Used to fill in the blank words between the starting words of each line
|
69 |
+
# generated by GPT-2 and the end rhyming word
|
70 |
+
# Code adapted from the following Medium article:
|
71 |
+
# https://ramsrigoutham.medium.com/sized-fill-in-the-blank-or-multi-mask-filling-with-roberta-and-huggingface-transformers-58eb9e7fb0c
|
72 |
+
|
73 |
def get_prediction(sent):
|
74 |
token_ids = roberta_tokenizer.encode(sent, return_tensors='pt')
|
75 |
masked_position = (token_ids.squeeze() == roberta_tokenizer.mask_token_id).nonzero()
|
|
|
86 |
while not words:
|
87 |
mask_hidden_state = last_hidden_state[mask_index]
|
88 |
idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
|
89 |
+
|
90 |
+
# Discard predicted word if it is blank or end token
|
91 |
for i in idx:
|
92 |
word = roberta_tokenizer.decode(i.item()).strip()
|
93 |
if (remove_punctuation(word) != "") and (word != '</s>'):
|
|
|
100 |
best_guess = best_guess+" "+j[0]
|
101 |
|
102 |
return best_guess
|
103 |
+
|
104 |
+
# Used to generate the 1st and 3rd lines of the limerick
|
105 |
+
# these are full lines, without RoBERTa being used
|
106 |
def get_line(prompt, inputs_len):
|
107 |
output = gpt2_pipeline(
|
108 |
prompt + ".",
|
|
|
113 |
)
|
114 |
return remove_punctuation(output[0]['generated_text'])
|
115 |
|
116 |
+
# Used to generate the 2nd, 4th and 5th lines
|
117 |
+
# GPT-2 is used to generate starting few words of the lines
|
118 |
+
# RoBERTa is then used to fill in the rest of the words until the end rhyme word
|
119 |
def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
120 |
output = gpt2_pipeline(
|
121 |
prompt + ".",
|
|
|
151 |
print(f"Final Sentence: {final_sentence}")
|
152 |
return final_sentence
|
153 |
|
154 |
+
# Used for the second method, Method B, of limerick generation
|
155 |
+
# Uses GPT-2 to get information about the user's given topic
|
156 |
def gpt2_summary(topic):
|
157 |
output = gpt2_pipeline(
|
158 |
f"Here is some information about {topic}.",
|
|
|
162 |
return_full_text=False
|
163 |
)
|
164 |
return remove_punctuation(output[0]['generated_text'])
|
165 |
+
|
166 |
+
# Main logic for limerick generation is contained here
|
167 |
def generate(topic, wiki=True):
|
168 |
+
|
169 |
+
# Search for the topic on Wikipedia and get an informational summary
|
170 |
if wiki:
|
171 |
try:
|
172 |
topic_search = wikipedia.search(topic, results=3)
|
|
|
178 |
topic_summary = remove_punctuation(wikipedia.summary(page, auto_suggest=False))
|
179 |
except:
|
180 |
return(f"Method A struggled to find information about {topic}, please try a different topic!")
|
181 |
+
|
182 |
+
# Use GPT-2 to get info about the topic if the wiki parameter is false
|
183 |
else:
|
184 |
topic_summary = remove_punctuation(gpt2_summary(topic))
|
185 |
|
186 |
+
# Log info about the topic summary data
|
187 |
word_list = topic_summary.split()
|
188 |
topic_summary_len = len(topic_summary)
|
189 |
no_of_words = len(word_list)
|
|
|
193 |
print(f"No of Words in Summary: {no_of_words}")
|
194 |
print(f"Length of Input IDs: {inputs_len}")
|
195 |
|
196 |
+
# Generate the first line of the limerick
|
197 |
rhyming_words_125 = []
|
198 |
while len(rhyming_words_125) < 3 or valid_rhyme == False or len(first_line) == 0:
|
199 |
first_line = get_line(topic_summary, inputs_len)
|
|
|
206 |
print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
|
207 |
limerick = first_line + "\n"
|
208 |
|
209 |
+
# Generate the second line of the limerick
|
210 |
rhyming_word = random.choice(rhyming_words_125)
|
211 |
rhyming_words_125.remove(rhyming_word)
|
212 |
prompt = topic_summary + " " + first_line
|
|
|
217 |
print(f"\nSecond Line: {second_line}")
|
218 |
limerick += second_line + "\n"
|
219 |
|
220 |
+
# Generate the third line of the limerick
|
221 |
rhyming_words_34 = []
|
222 |
prompt = prompt + " " + second_line
|
223 |
inputs_len = get_inputs_length(prompt)
|
|
|
235 |
if valid_rhyme and len(rhyming_words_34) > 1:
|
236 |
limerick += third_line + "\n"
|
237 |
|
238 |
+
# Generate the fourth line of the limerick
|
239 |
rhyming_word = random.choice(rhyming_words_34)
|
240 |
rhyming_words_34.remove(rhyming_word)
|
241 |
prompt = prompt + " " + third_line
|
|
|
246 |
print(f"\nFourth Line: {fourth_line}")
|
247 |
limerick += fourth_line + "\n"
|
248 |
|
249 |
+
# Generate the fifth line of the limerick
|
250 |
rhyming_word = random.choice(rhyming_words_125)
|
251 |
rhyming_words_125.remove(rhyming_word)
|
252 |
prompt = prompt + " " + fourth_line
|
|
|
261 |
print(limerick)
|
262 |
|
263 |
return limerick
|
264 |
+
|
265 |
+
# Helper function to generate two limericks via both methods to then compare
|
266 |
def compare_summaries(topic):
|
267 |
wiki_limerick = generate(topic)
|
268 |
gpt2_limerick = generate(topic, wiki=False)
|
|
|
272 |
print(output1 + "\n" + output2)
|
273 |
|
274 |
return output1, output2
|
275 |
+
|
276 |
+
# Use Gradio to create an interface, which can be hosted on HuggingFace spaces
|
277 |
+
# https://huggingface.co/spaces/ansfarooq7/l4-project
|
278 |
+
|
279 |
+
description = "Generates limericks (five-line poems with a rhyme scheme of AABBA) using two different methods, please be patient as it can take up to a minute to generate both limericks."
|
280 |
article = '<center><big><strong>Limerick Generation</strong></big></center>'\
|
281 |
'<center><strong>By Ans Farooq</strong></center>'\
|
282 |
'<center><small>Level 4 Individual Project</small></center>'\
|
|
|
307 |
theme="peach",
|
308 |
description=description,
|
309 |
article=article)
|
310 |
+
interface.launch(debug=False)
|