ansfarooq7 commited on
Commit
b326a58
1 Parent(s): 4140a06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -57
app.py CHANGED
@@ -1,41 +1,19 @@
1
- from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer, GPTNeoForCausalLM
2
  import torch
3
  import wikipedia
4
  import re
5
  import random
6
  import nltk
7
- import syllables
8
  from aitextgen import aitextgen
9
  nltk.download('cmudict')
10
 
11
- masked_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
12
- masked_model = RobertaForMaskedLM.from_pretrained('roberta-base')
13
 
14
- causal_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
15
- gptneo_tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
16
- gptneo_model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
17
-
18
- # Without any parameters, aitextgen() will download, cache, and load the 124M GPT-2 "small" model
19
- gpt2 = aitextgen()
20
 
21
  frequent_words = set()
22
-
23
- def set_seed(seed: int):
24
- """
25
- Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
26
- installed).
27
-
28
- Args:
29
- seed (:obj:`int`): The seed to set.
30
- """
31
- #random.seed(seed)
32
- #np.random.seed(seed)
33
- #if is_torch_available():
34
- torch.manual_seed(seed)
35
- torch.cuda.manual_seed_all(seed)
36
- # ^^ safe to call this function even if cuda is not available
37
- #if is_tf_available():
38
- #tf.random.set_seed(seed)
39
 
40
  with open("wordFrequency.txt", 'r') as f:
41
  line = f.readline()
@@ -69,19 +47,17 @@ def get_rhymes(inp, level):
69
  return filtered_rhymes
70
 
71
  def get_inputs_length(input):
72
- input_ids = causal_tokenizer(input)['input_ids']
73
  return len(input_ids)
74
-
75
- set_seed(0)
76
 
77
  def get_prediction(sent):
78
 
79
- token_ids = masked_tokenizer.encode(sent, return_tensors='pt')
80
- masked_position = (token_ids.squeeze() == masked_tokenizer.mask_token_id).nonzero()
81
  masked_pos = [mask.item() for mask in masked_position ]
82
 
83
  with torch.no_grad():
84
- output = masked_model(token_ids)
85
 
86
  last_hidden_state = output[0].squeeze()
87
 
@@ -92,10 +68,9 @@ def get_prediction(sent):
92
  mask_hidden_state = last_hidden_state[mask_index]
93
  idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
94
  for i in idx:
95
- word = masked_tokenizer.decode(i.item()).strip()
96
  if (remove_punctuation(word) != "") and (word != '</s>'):
97
  words.append(word)
98
- #words = [masked_tokenizer.decode(i.item()).strip() for i in idx]
99
  list_of_list.append(words)
100
  print(f"Mask {index+1} Guesses: {words}")
101
 
@@ -104,13 +79,13 @@ def get_prediction(sent):
104
  best_guess = best_guess+" "+j[0]
105
 
106
  return best_guess
107
-
108
  def get_line(prompt, inputs_len):
109
- line = gpt2.generate_one(prompt=prompt + ".", max_length=inputs_len + 7)[len(prompt)+2:]
110
  return line
111
 
112
  def get_rhyming_line(prompt, rhyming_word, inputs_len):
113
- gpt2_sentence = gpt2.generate_one(prompt=prompt + ".", max_length=inputs_len + 4)[len(prompt)+2:]
114
  gpt2_sentence = gpt2_sentence.replace("\n", "")
115
  print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
116
  sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
@@ -128,17 +103,15 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
128
  print(f"Final Sentence: {final_sentence}")
129
  return final_sentence
130
 
131
- def gptneo_summary(topic):
132
- input_ids = gptneo_tokenizer(f"Here is some information about {topic}", return_tensors="pt").input_ids
133
- gen_tokens = gptneo_model.generate(input_ids, do_sample=True, temperature=0.9, max_length=200)
134
- generated_text = gptneo_tokenizer.decode(gen_tokens[0])
135
- return generated_text
136
-
137
  def generate(topic, wiki=True):
138
  if wiki:
139
  topic_summary = remove_punctuation(wikipedia.summary(topic))
140
  else:
141
- topic_summary = remove_punctuation(gptneo_summary(topic))
 
142
  word_list = topic_summary.split()
143
  topic_summary_len = len(topic_summary)
144
  no_of_words = len(word_list)
@@ -160,7 +133,8 @@ def generate(topic, wiki=True):
160
  print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
161
  limerick = first_line + "\n"
162
 
163
- rhyming_word = rhyming_words_125[0]
 
164
  prompt = topic_summary + " " + first_line
165
  inputs_len = get_inputs_length(prompt)
166
  print(f"Prompt: {prompt}")
@@ -186,7 +160,8 @@ def generate(topic, wiki=True):
186
  if valid_rhyme and len(rhyming_words_34) > 1:
187
  limerick += third_line + "\n"
188
 
189
- rhyming_word = rhyming_words_34[0]
 
190
  prompt = prompt + " " + third_line
191
  inputs_len = get_inputs_length(prompt)
192
  print(f"Prompt: {prompt}")
@@ -195,7 +170,8 @@ def generate(topic, wiki=True):
195
  print(f"\nFourth Line: {fourth_line}")
196
  limerick += fourth_line + "\n"
197
 
198
- rhyming_word = rhyming_words_125[1]
 
199
  prompt = prompt + " " + fourth_line
200
  inputs_len = get_inputs_length(prompt)
201
  print(f"Prompt: {prompt}")
@@ -210,20 +186,23 @@ def generate(topic, wiki=True):
210
  return limerick
211
 
212
  def compare_summaries(topic):
213
- wiki_limerick = generate(topic, wiki=True)
214
- gptneo_limerick = generate(topic, wiki=False)
 
 
 
 
 
215
 
216
- output = f"Limerick with Wikipedia summary of topic as prompt: \n"
217
- output += wiki_limerick + "\n"
218
- output += f"Limerick with GPT Neo summary of topic as prompt: \n"
219
- output += gptneo_limerick
220
 
221
- return output
222
-
223
  import gradio as gr
224
 
225
  interface = gr.Interface(
226
  fn=compare_summaries,
227
  inputs="text",
228
- outputs="text")
 
 
 
229
  interface.launch(debug=True)
 
1
+ from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer
2
  import torch
3
  import wikipedia
4
  import re
5
  import random
6
  import nltk
 
7
  from aitextgen import aitextgen
8
  nltk.download('cmudict')
9
 
10
+ roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
11
+ roberta_model = RobertaForMaskedLM.from_pretrained('roberta-base')
12
 
13
+ gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
14
+ gpt2_model = aitextgen(tf_gpt2="355M")
 
 
 
 
15
 
16
  frequent_words = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  with open("wordFrequency.txt", 'r') as f:
19
  line = f.readline()
 
47
  return filtered_rhymes
48
 
49
  def get_inputs_length(input):
50
+ input_ids = gpt2_tokenizer(input)['input_ids']
51
  return len(input_ids)
 
 
52
 
53
  def get_prediction(sent):
54
 
55
+ token_ids = roberta_tokenizer.encode(sent, return_tensors='pt')
56
+ masked_position = (token_ids.squeeze() == roberta_tokenizer.mask_token_id).nonzero()
57
  masked_pos = [mask.item() for mask in masked_position ]
58
 
59
  with torch.no_grad():
60
+ output = roberta_model(token_ids)
61
 
62
  last_hidden_state = output[0].squeeze()
63
 
 
68
  mask_hidden_state = last_hidden_state[mask_index]
69
  idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
70
  for i in idx:
71
+ word = roberta_tokenizer.decode(i.item()).strip()
72
  if (remove_punctuation(word) != "") and (word != '</s>'):
73
  words.append(word)
 
74
  list_of_list.append(words)
75
  print(f"Mask {index+1} Guesses: {words}")
76
 
 
79
  best_guess = best_guess+" "+j[0]
80
 
81
  return best_guess
82
+
83
  def get_line(prompt, inputs_len):
84
+ line = gpt2_model.generate_one(prompt=prompt + ".", max_length=inputs_len + 7, min_length=4)[len(prompt)+2:]
85
  return line
86
 
87
  def get_rhyming_line(prompt, rhyming_word, inputs_len):
88
+ gpt2_sentence = gpt2_model.generate_one(prompt=prompt + ".", max_length=inputs_len + 4, min_length=2)[len(prompt)+2:]
89
  gpt2_sentence = gpt2_sentence.replace("\n", "")
90
  print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
91
  sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
 
103
  print(f"Final Sentence: {final_sentence}")
104
  return final_sentence
105
 
106
+ def gpt2_summary(topic):
107
+ return gpt2_model.generate_one(prompt=f"Here is some information about {topic}", top_k=100, top_p=0.95)
108
+
 
 
 
109
  def generate(topic, wiki=True):
110
  if wiki:
111
  topic_summary = remove_punctuation(wikipedia.summary(topic))
112
  else:
113
+ topic_summary = remove_punctuation(gpt2_summary(topic))
114
+
115
  word_list = topic_summary.split()
116
  topic_summary_len = len(topic_summary)
117
  no_of_words = len(word_list)
 
133
  print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
134
  limerick = first_line + "\n"
135
 
136
+ rhyming_word = random.choice(rhyming_words_125)
137
+ rhyming_words_125.remove(rhyming_word)
138
  prompt = topic_summary + " " + first_line
139
  inputs_len = get_inputs_length(prompt)
140
  print(f"Prompt: {prompt}")
 
160
  if valid_rhyme and len(rhyming_words_34) > 1:
161
  limerick += third_line + "\n"
162
 
163
+ rhyming_word = random.choice(rhyming_words_34)
164
+ rhyming_words_34.remove(rhyming_word)
165
  prompt = prompt + " " + third_line
166
  inputs_len = get_inputs_length(prompt)
167
  print(f"Prompt: {prompt}")
 
170
  print(f"\nFourth Line: {fourth_line}")
171
  limerick += fourth_line + "\n"
172
 
173
+ rhyming_word = random.choice(rhyming_words_125)
174
+ rhyming_words_125.remove(rhyming_word)
175
  prompt = prompt + " " + fourth_line
176
  inputs_len = get_inputs_length(prompt)
177
  print(f"Prompt: {prompt}")
 
186
  return limerick
187
 
188
  def compare_summaries(topic):
189
+ wiki_limerick = generate(topic)
190
+ gpt2_limerick = generate(topic, wiki=False)
191
+
192
+ output1 = f"Limerick with Wikipedia summary of topic as prompt: \n"
193
+ output1 += wiki_limerick + "\n"
194
+ output2 = f"Limerick with GPT-2 summary of topic as prompt: \n"
195
+ output2 += gpt2_limerick
196
 
197
+ return output1, output2
 
 
 
198
 
 
 
199
  import gradio as gr
200
 
201
  interface = gr.Interface(
202
  fn=compare_summaries,
203
  inputs="text",
204
+ outputs=["text", "text"],
205
+ title="Text-generation with rhyme and rhythm",
206
+ layout="horizontal",
207
+ theme="peach")
208
  interface.launch(debug=True)