Spaces:
Sleeping
Sleeping
ansfarooq7
commited on
Commit
•
b326a58
1
Parent(s):
4140a06
Update app.py
Browse files
app.py
CHANGED
@@ -1,41 +1,19 @@
|
|
1 |
-
from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer
|
2 |
import torch
|
3 |
import wikipedia
|
4 |
import re
|
5 |
import random
|
6 |
import nltk
|
7 |
-
import syllables
|
8 |
from aitextgen import aitextgen
|
9 |
nltk.download('cmudict')
|
10 |
|
11 |
-
|
12 |
-
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
gptneo_model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
17 |
-
|
18 |
-
# Without any parameters, aitextgen() will download, cache, and load the 124M GPT-2 "small" model
|
19 |
-
gpt2 = aitextgen()
|
20 |
|
21 |
frequent_words = set()
|
22 |
-
|
23 |
-
def set_seed(seed: int):
|
24 |
-
"""
|
25 |
-
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
|
26 |
-
installed).
|
27 |
-
|
28 |
-
Args:
|
29 |
-
seed (:obj:`int`): The seed to set.
|
30 |
-
"""
|
31 |
-
#random.seed(seed)
|
32 |
-
#np.random.seed(seed)
|
33 |
-
#if is_torch_available():
|
34 |
-
torch.manual_seed(seed)
|
35 |
-
torch.cuda.manual_seed_all(seed)
|
36 |
-
# ^^ safe to call this function even if cuda is not available
|
37 |
-
#if is_tf_available():
|
38 |
-
#tf.random.set_seed(seed)
|
39 |
|
40 |
with open("wordFrequency.txt", 'r') as f:
|
41 |
line = f.readline()
|
@@ -69,19 +47,17 @@ def get_rhymes(inp, level):
|
|
69 |
return filtered_rhymes
|
70 |
|
71 |
def get_inputs_length(input):
|
72 |
-
input_ids =
|
73 |
return len(input_ids)
|
74 |
-
|
75 |
-
set_seed(0)
|
76 |
|
77 |
def get_prediction(sent):
|
78 |
|
79 |
-
token_ids =
|
80 |
-
masked_position = (token_ids.squeeze() ==
|
81 |
masked_pos = [mask.item() for mask in masked_position ]
|
82 |
|
83 |
with torch.no_grad():
|
84 |
-
output =
|
85 |
|
86 |
last_hidden_state = output[0].squeeze()
|
87 |
|
@@ -92,10 +68,9 @@ def get_prediction(sent):
|
|
92 |
mask_hidden_state = last_hidden_state[mask_index]
|
93 |
idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
|
94 |
for i in idx:
|
95 |
-
word =
|
96 |
if (remove_punctuation(word) != "") and (word != '</s>'):
|
97 |
words.append(word)
|
98 |
-
#words = [masked_tokenizer.decode(i.item()).strip() for i in idx]
|
99 |
list_of_list.append(words)
|
100 |
print(f"Mask {index+1} Guesses: {words}")
|
101 |
|
@@ -104,13 +79,13 @@ def get_prediction(sent):
|
|
104 |
best_guess = best_guess+" "+j[0]
|
105 |
|
106 |
return best_guess
|
107 |
-
|
108 |
def get_line(prompt, inputs_len):
|
109 |
-
line =
|
110 |
return line
|
111 |
|
112 |
def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
113 |
-
gpt2_sentence =
|
114 |
gpt2_sentence = gpt2_sentence.replace("\n", "")
|
115 |
print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
|
116 |
sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
|
@@ -128,17 +103,15 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
|
128 |
print(f"Final Sentence: {final_sentence}")
|
129 |
return final_sentence
|
130 |
|
131 |
-
def
|
132 |
-
|
133 |
-
|
134 |
-
generated_text = gptneo_tokenizer.decode(gen_tokens[0])
|
135 |
-
return generated_text
|
136 |
-
|
137 |
def generate(topic, wiki=True):
|
138 |
if wiki:
|
139 |
topic_summary = remove_punctuation(wikipedia.summary(topic))
|
140 |
else:
|
141 |
-
topic_summary = remove_punctuation(
|
|
|
142 |
word_list = topic_summary.split()
|
143 |
topic_summary_len = len(topic_summary)
|
144 |
no_of_words = len(word_list)
|
@@ -160,7 +133,8 @@ def generate(topic, wiki=True):
|
|
160 |
print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
|
161 |
limerick = first_line + "\n"
|
162 |
|
163 |
-
rhyming_word = rhyming_words_125
|
|
|
164 |
prompt = topic_summary + " " + first_line
|
165 |
inputs_len = get_inputs_length(prompt)
|
166 |
print(f"Prompt: {prompt}")
|
@@ -186,7 +160,8 @@ def generate(topic, wiki=True):
|
|
186 |
if valid_rhyme and len(rhyming_words_34) > 1:
|
187 |
limerick += third_line + "\n"
|
188 |
|
189 |
-
rhyming_word = rhyming_words_34
|
|
|
190 |
prompt = prompt + " " + third_line
|
191 |
inputs_len = get_inputs_length(prompt)
|
192 |
print(f"Prompt: {prompt}")
|
@@ -195,7 +170,8 @@ def generate(topic, wiki=True):
|
|
195 |
print(f"\nFourth Line: {fourth_line}")
|
196 |
limerick += fourth_line + "\n"
|
197 |
|
198 |
-
rhyming_word = rhyming_words_125
|
|
|
199 |
prompt = prompt + " " + fourth_line
|
200 |
inputs_len = get_inputs_length(prompt)
|
201 |
print(f"Prompt: {prompt}")
|
@@ -210,20 +186,23 @@ def generate(topic, wiki=True):
|
|
210 |
return limerick
|
211 |
|
212 |
def compare_summaries(topic):
|
213 |
-
wiki_limerick = generate(topic
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
-
|
217 |
-
output += wiki_limerick + "\n"
|
218 |
-
output += f"Limerick with GPT Neo summary of topic as prompt: \n"
|
219 |
-
output += gptneo_limerick
|
220 |
|
221 |
-
return output
|
222 |
-
|
223 |
import gradio as gr
|
224 |
|
225 |
interface = gr.Interface(
|
226 |
fn=compare_summaries,
|
227 |
inputs="text",
|
228 |
-
outputs="text"
|
|
|
|
|
|
|
229 |
interface.launch(debug=True)
|
|
|
1 |
+
from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer
|
2 |
import torch
|
3 |
import wikipedia
|
4 |
import re
|
5 |
import random
|
6 |
import nltk
|
|
|
7 |
from aitextgen import aitextgen
|
8 |
nltk.download('cmudict')
|
9 |
|
10 |
+
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
11 |
+
roberta_model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
12 |
|
13 |
+
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
14 |
+
gpt2_model = aitextgen(tf_gpt2="355M")
|
|
|
|
|
|
|
|
|
15 |
|
16 |
frequent_words = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
with open("wordFrequency.txt", 'r') as f:
|
19 |
line = f.readline()
|
|
|
47 |
return filtered_rhymes
|
48 |
|
49 |
def get_inputs_length(input):
|
50 |
+
input_ids = gpt2_tokenizer(input)['input_ids']
|
51 |
return len(input_ids)
|
|
|
|
|
52 |
|
53 |
def get_prediction(sent):
|
54 |
|
55 |
+
token_ids = roberta_tokenizer.encode(sent, return_tensors='pt')
|
56 |
+
masked_position = (token_ids.squeeze() == roberta_tokenizer.mask_token_id).nonzero()
|
57 |
masked_pos = [mask.item() for mask in masked_position ]
|
58 |
|
59 |
with torch.no_grad():
|
60 |
+
output = roberta_model(token_ids)
|
61 |
|
62 |
last_hidden_state = output[0].squeeze()
|
63 |
|
|
|
68 |
mask_hidden_state = last_hidden_state[mask_index]
|
69 |
idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
|
70 |
for i in idx:
|
71 |
+
word = roberta_tokenizer.decode(i.item()).strip()
|
72 |
if (remove_punctuation(word) != "") and (word != '</s>'):
|
73 |
words.append(word)
|
|
|
74 |
list_of_list.append(words)
|
75 |
print(f"Mask {index+1} Guesses: {words}")
|
76 |
|
|
|
79 |
best_guess = best_guess+" "+j[0]
|
80 |
|
81 |
return best_guess
|
82 |
+
|
83 |
def get_line(prompt, inputs_len):
|
84 |
+
line = gpt2_model.generate_one(prompt=prompt + ".", max_length=inputs_len + 7, min_length=4)[len(prompt)+2:]
|
85 |
return line
|
86 |
|
87 |
def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
88 |
+
gpt2_sentence = gpt2_model.generate_one(prompt=prompt + ".", max_length=inputs_len + 4, min_length=2)[len(prompt)+2:]
|
89 |
gpt2_sentence = gpt2_sentence.replace("\n", "")
|
90 |
print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
|
91 |
sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
|
|
|
103 |
print(f"Final Sentence: {final_sentence}")
|
104 |
return final_sentence
|
105 |
|
106 |
+
def gpt2_summary(topic):
|
107 |
+
return gpt2_model.generate_one(prompt=f"Here is some information about {topic}", top_k=100, top_p=0.95)
|
108 |
+
|
|
|
|
|
|
|
109 |
def generate(topic, wiki=True):
|
110 |
if wiki:
|
111 |
topic_summary = remove_punctuation(wikipedia.summary(topic))
|
112 |
else:
|
113 |
+
topic_summary = remove_punctuation(gpt2_summary(topic))
|
114 |
+
|
115 |
word_list = topic_summary.split()
|
116 |
topic_summary_len = len(topic_summary)
|
117 |
no_of_words = len(word_list)
|
|
|
133 |
print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
|
134 |
limerick = first_line + "\n"
|
135 |
|
136 |
+
rhyming_word = random.choice(rhyming_words_125)
|
137 |
+
rhyming_words_125.remove(rhyming_word)
|
138 |
prompt = topic_summary + " " + first_line
|
139 |
inputs_len = get_inputs_length(prompt)
|
140 |
print(f"Prompt: {prompt}")
|
|
|
160 |
if valid_rhyme and len(rhyming_words_34) > 1:
|
161 |
limerick += third_line + "\n"
|
162 |
|
163 |
+
rhyming_word = random.choice(rhyming_words_34)
|
164 |
+
rhyming_words_34.remove(rhyming_word)
|
165 |
prompt = prompt + " " + third_line
|
166 |
inputs_len = get_inputs_length(prompt)
|
167 |
print(f"Prompt: {prompt}")
|
|
|
170 |
print(f"\nFourth Line: {fourth_line}")
|
171 |
limerick += fourth_line + "\n"
|
172 |
|
173 |
+
rhyming_word = random.choice(rhyming_words_125)
|
174 |
+
rhyming_words_125.remove(rhyming_word)
|
175 |
prompt = prompt + " " + fourth_line
|
176 |
inputs_len = get_inputs_length(prompt)
|
177 |
print(f"Prompt: {prompt}")
|
|
|
186 |
return limerick
|
187 |
|
188 |
def compare_summaries(topic):
|
189 |
+
wiki_limerick = generate(topic)
|
190 |
+
gpt2_limerick = generate(topic, wiki=False)
|
191 |
+
|
192 |
+
output1 = f"Limerick with Wikipedia summary of topic as prompt: \n"
|
193 |
+
output1 += wiki_limerick + "\n"
|
194 |
+
output2 = f"Limerick with GPT-2 summary of topic as prompt: \n"
|
195 |
+
output2 += gpt2_limerick
|
196 |
|
197 |
+
return output1, output2
|
|
|
|
|
|
|
198 |
|
|
|
|
|
199 |
import gradio as gr
|
200 |
|
201 |
interface = gr.Interface(
|
202 |
fn=compare_summaries,
|
203 |
inputs="text",
|
204 |
+
outputs=["text", "text"],
|
205 |
+
title="Text-generation with rhyme and rhythm",
|
206 |
+
layout="horizontal",
|
207 |
+
theme="peach")
|
208 |
interface.launch(debug=True)
|