Spaces:
Sleeping
Sleeping
ansfarooq7
commited on
Commit
•
4677c24
1
Parent(s):
d1f0e51
Update app.py
Browse files
app.py
CHANGED
@@ -1,18 +1,26 @@
|
|
1 |
-
from transformers import RobertaTokenizer, RobertaForMaskedLM,
|
2 |
import torch
|
3 |
import wikipedia
|
4 |
import re
|
5 |
import random
|
6 |
import nltk
|
7 |
import syllables
|
|
|
8 |
nltk.download('cmudict')
|
9 |
|
|
|
|
|
10 |
masked_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
11 |
masked_model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
12 |
|
13 |
causal_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
frequent_words = set()
|
18 |
|
@@ -48,6 +56,7 @@ def filter_rhymes(word):
|
|
48 |
|
49 |
def remove_punctuation(text):
|
50 |
text = re.sub(r'[^\w\s]', '', text)
|
|
|
51 |
return text
|
52 |
|
53 |
def get_rhymes(inp, level):
|
@@ -100,19 +109,13 @@ def get_prediction(sent):
|
|
100 |
|
101 |
return best_guess
|
102 |
|
103 |
-
text_generation = pipeline("text-generation", model=causal_model, tokenizer=causal_tokenizer)
|
104 |
-
from aitextgen import aitextgen
|
105 |
-
|
106 |
-
# Without any parameters, aitextgen() will download, cache, and load the 124M GPT-2 "small" model
|
107 |
-
ai = aitextgen()
|
108 |
-
|
109 |
-
|
110 |
def get_line(prompt, inputs_len):
|
111 |
-
line =
|
112 |
return line
|
113 |
|
114 |
def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
115 |
-
gpt2_sentence =
|
|
|
116 |
print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
|
117 |
sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
|
118 |
print(f"Original Sentence: {sentence}")
|
@@ -128,12 +131,20 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
|
128 |
final_sentence = gpt2_sentence + predicted_blanks + " " + rhyming_word
|
129 |
print(f"Final Sentence: {final_sentence}")
|
130 |
return final_sentence
|
131 |
-
|
132 |
-
def generate(topic):
|
133 |
|
134 |
-
|
|
|
|
|
135 |
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
word_list = topic_summary.split()
|
138 |
topic_summary_len = len(topic_summary)
|
139 |
no_of_words = len(word_list)
|
@@ -143,77 +154,82 @@ def generate(topic):
|
|
143 |
print(f"No of Words in Summary: {no_of_words}")
|
144 |
print(f"Length of Input IDs: {inputs_len}")
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
print(f"\nFifth Line: {fifth_line}")
|
202 |
-
limerick += fifth_line + "\n"
|
203 |
-
|
204 |
-
limericks.append(limerick)
|
205 |
|
206 |
print("\n")
|
207 |
-
|
|
|
|
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
|
|
|
|
|
|
|
|
213 |
|
214 |
return output
|
215 |
-
|
216 |
import gradio as gr
|
217 |
|
218 |
-
interface = gr.Interface(
|
|
|
|
|
|
|
219 |
interface.launch(debug=True)
|
|
|
1 |
+
from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer, AutoTokenizer, GPTJForCausalLM
|
2 |
import torch
|
3 |
import wikipedia
|
4 |
import re
|
5 |
import random
|
6 |
import nltk
|
7 |
import syllables
|
8 |
+
from aitextgen import aitextgen
|
9 |
nltk.download('cmudict')
|
10 |
|
11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
|
13 |
masked_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
14 |
masked_model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
15 |
|
16 |
causal_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
17 |
+
|
18 |
+
# Without any parameters, aitextgen() will download, cache, and load the 124M GPT-2 "small" model
|
19 |
+
gpt2 = aitextgen()
|
20 |
+
|
21 |
+
gptj_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
|
22 |
+
gptj_model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", low_cpu_mem_usage=True)
|
23 |
+
gptj_model.to(device)
|
24 |
|
25 |
frequent_words = set()
|
26 |
|
|
|
56 |
|
57 |
def remove_punctuation(text):
|
58 |
text = re.sub(r'[^\w\s]', '', text)
|
59 |
+
text = text.replace("\n", " ")
|
60 |
return text
|
61 |
|
62 |
def get_rhymes(inp, level):
|
|
|
109 |
|
110 |
return best_guess
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
def get_line(prompt, inputs_len):
|
113 |
+
line = gpt2.generate_one(prompt=prompt + ".", max_length=inputs_len + 7)[len(prompt)+2:]
|
114 |
return line
|
115 |
|
116 |
def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
117 |
+
gpt2_sentence = gpt2.generate_one(prompt=prompt + ".", max_length=inputs_len + 4)[len(prompt)+2:]
|
118 |
+
gpt2_sentence = gpt2_sentence.replace("\n", "")
|
119 |
print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
|
120 |
sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
|
121 |
print(f"Original Sentence: {sentence}")
|
|
|
131 |
final_sentence = gpt2_sentence + predicted_blanks + " " + rhyming_word
|
132 |
print(f"Final Sentence: {final_sentence}")
|
133 |
return final_sentence
|
|
|
|
|
134 |
|
135 |
+
def gptj_summary(topic):
|
136 |
+
prompt = f"Here is some information about {topic}"
|
137 |
+
input_ids = gptj_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
138 |
|
139 |
+
generated_ids = gptj_model.generate(input_ids, do_sample=True, temperature=0.9, max_length=200)
|
140 |
+
generated_text = gptj_tokenizer.decode(generated_ids[0])
|
141 |
+
return generated_text
|
142 |
+
|
143 |
+
def generate(topic, wiki=True):
|
144 |
+
if wiki:
|
145 |
+
topic_summary = remove_punctuation(wikipedia.summary(topic))
|
146 |
+
else:
|
147 |
+
topic_summary = remove_punctuation(gptj_summary(topic))
|
148 |
word_list = topic_summary.split()
|
149 |
topic_summary_len = len(topic_summary)
|
150 |
no_of_words = len(word_list)
|
|
|
154 |
print(f"No of Words in Summary: {no_of_words}")
|
155 |
print(f"Length of Input IDs: {inputs_len}")
|
156 |
|
157 |
+
rhyming_words_125 = []
|
158 |
+
while len(rhyming_words_125) < 3 or valid_rhyme == False or len(first_line) == 0:
|
159 |
+
first_line = get_line(topic_summary, inputs_len)
|
160 |
+
if first_line:
|
161 |
+
end_word = remove_punctuation(first_line.split()[-1])
|
162 |
+
valid_rhyme = filter_rhymes(end_word)
|
163 |
+
if valid_rhyme:
|
164 |
+
print(f"\nFirst Line: {first_line}")
|
165 |
+
rhyming_words_125 = list(get_rhymes(end_word, 3))
|
166 |
+
print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
|
167 |
+
limerick = first_line + "\n"
|
168 |
+
|
169 |
+
rhyming_word = rhyming_words_125[0]
|
170 |
+
prompt = topic_summary + " " + first_line
|
171 |
+
inputs_len = get_inputs_length(prompt)
|
172 |
+
print(f"Prompt: {prompt}")
|
173 |
+
print(f"Length of prompt: {inputs_len}")
|
174 |
+
second_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
|
175 |
+
print(f"\nSecond Line: {second_line}")
|
176 |
+
limerick += second_line + "\n"
|
177 |
+
|
178 |
+
rhyming_words_34 = []
|
179 |
+
prompt = prompt + " " + second_line
|
180 |
+
inputs_len = get_inputs_length(prompt)
|
181 |
+
print(f"Prompt: {prompt}")
|
182 |
+
print(f"Length of prompt: {inputs_len}")
|
183 |
+
while len(rhyming_words_34) < 2 or valid_rhyme == False or len(third_line) == 0:
|
184 |
+
third_line = get_line(prompt, inputs_len)
|
185 |
+
if third_line:
|
186 |
+
print(f"\nThird Line: {third_line}")
|
187 |
+
end_word = remove_punctuation(third_line.split()[-1])
|
188 |
+
valid_rhyme = filter_rhymes(end_word)
|
189 |
+
print(f"Does '{end_word}' have valid rhymes: {valid_rhyme}")
|
190 |
+
rhyming_words_34 = list(get_rhymes(end_word, 3))
|
191 |
+
print(f"Rhyming words for '{end_word}' are {rhyming_words_34}")
|
192 |
+
if valid_rhyme and len(rhyming_words_34) > 1:
|
193 |
+
limerick += third_line + "\n"
|
194 |
+
|
195 |
+
rhyming_word = rhyming_words_34[0]
|
196 |
+
prompt = prompt + " " + third_line
|
197 |
+
inputs_len = get_inputs_length(prompt)
|
198 |
+
print(f"Prompt: {prompt}")
|
199 |
+
print(f"Length of prompt: {inputs_len}")
|
200 |
+
fourth_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
|
201 |
+
print(f"\nFourth Line: {fourth_line}")
|
202 |
+
limerick += fourth_line + "\n"
|
203 |
+
|
204 |
+
rhyming_word = rhyming_words_125[1]
|
205 |
+
prompt = prompt + " " + fourth_line
|
206 |
+
inputs_len = get_inputs_length(prompt)
|
207 |
+
print(f"Prompt: {prompt}")
|
208 |
+
print(f"Length of prompt: {inputs_len}")
|
209 |
+
fifth_line = get_rhyming_line(prompt, rhyming_word, inputs_len)
|
210 |
+
print(f"\nFifth Line: {fifth_line}")
|
211 |
+
limerick += fifth_line + "\n"
|
|
|
|
|
|
|
|
|
212 |
|
213 |
print("\n")
|
214 |
+
print(limerick)
|
215 |
+
|
216 |
+
return limerick
|
217 |
|
218 |
+
def compare_summaries(topic):
|
219 |
+
wiki_limerick = generate(topic, wiki=True)
|
220 |
+
gptj_limerick = generate(topic, wiki=False)
|
221 |
+
|
222 |
+
output = f"Limerick with Wikipedia summary of topic as prompt: \n"
|
223 |
+
output += wiki_limerick + "\n"
|
224 |
+
output += f"Limerick with GPT-J summary of topic as prompt: \n"
|
225 |
+
output += gptj_limerick
|
226 |
|
227 |
return output
|
228 |
+
|
229 |
import gradio as gr
|
230 |
|
231 |
+
interface = gr.Interface(
|
232 |
+
fn=compare_summaries,
|
233 |
+
inputs="text",
|
234 |
+
outputs="text")
|
235 |
interface.launch(debug=True)
|