ansfarooq7 commited on
Commit
8856002
1 Parent(s): 3f464e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer, GPTJForCausalLM
2
  import torch
3
  import wikipedia
4
  import re
@@ -8,16 +8,12 @@ import syllables
8
  from aitextgen import aitextgen
9
  nltk.download('cmudict')
10
 
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
-
13
  masked_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
14
  masked_model = RobertaForMaskedLM.from_pretrained('roberta-base')
15
 
16
  causal_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
17
- gptj_tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-j-6B")
18
- gptj_model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", low_cpu_mem_usage=True)
19
- gptj_model.to(device)
20
- #gpt_neo = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=causal_tokenizer.eos_token_id)
21
 
22
  # Without any parameters, aitextgen() will download, cache, and load the 124M GPT-2 "small" model
23
  gpt2 = aitextgen()
@@ -132,17 +128,17 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
132
  print(f"Final Sentence: {final_sentence}")
133
  return final_sentence
134
 
135
- def gptj_summary(topic):
136
- input_ids = gptj_tokenizer(f"Here is some information about {topic}", return_tensors="pt").input_ids.to(device)
137
- generated_ids = gptj_model.generate(input_ids, do_sample=True, temperature=0.9, max_length=200)
138
- generated_text = gptj_tokenizer.decode(generated_ids[0])
139
  return generated_text
140
 
141
  def generate(topic, wiki=True):
142
  if wiki:
143
  topic_summary = remove_punctuation(wikipedia.summary(topic))
144
  else:
145
- topic_summary = remove_punctuation(gptj_summary(topic))
146
  word_list = topic_summary.split()
147
  topic_summary_len = len(topic_summary)
148
  no_of_words = len(word_list)
@@ -215,12 +211,12 @@ def generate(topic, wiki=True):
215
 
216
  def compare_summaries(topic):
217
  wiki_limerick = generate(topic, wiki=True)
218
- gptj_limerick = generate(topic, wiki=False)
219
 
220
  output = f"Limerick with Wikipedia summary of topic as prompt: \n"
221
  output += wiki_limerick + "\n"
222
- output += f"Limerick with GPT-J summary of topic as prompt: \n"
223
- output += gptj_limerick
224
 
225
  return output
226
 
 
1
+ from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer, GPTNeoForCausalLM
2
  import torch
3
  import wikipedia
4
  import re
 
8
  from aitextgen import aitextgen
9
  nltk.download('cmudict')
10
 
 
 
11
  masked_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
12
  masked_model = RobertaForMaskedLM.from_pretrained('roberta-base')
13
 
14
  causal_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
15
+ gptneo_tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
16
+ gptneo_model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
 
 
17
 
18
  # Without any parameters, aitextgen() will download, cache, and load the 124M GPT-2 "small" model
19
  gpt2 = aitextgen()
 
128
  print(f"Final Sentence: {final_sentence}")
129
  return final_sentence
130
 
131
+ def gptneo_summary(topic):
132
+ input_ids = gptneo_tokenizer(f"Here is some information about {topic}", return_tensors="pt").input_ids.to(device)
133
+ generated_ids = gptneo_model.generate(input_ids, do_sample=True, temperature=0.9, max_length=200)
134
+ generated_text = gptneo_tokenizer.decode(generated_ids[0])
135
  return generated_text
136
 
137
  def generate(topic, wiki=True):
138
  if wiki:
139
  topic_summary = remove_punctuation(wikipedia.summary(topic))
140
  else:
141
+ topic_summary = remove_punctuation(gptneo_summary(topic))
142
  word_list = topic_summary.split()
143
  topic_summary_len = len(topic_summary)
144
  no_of_words = len(word_list)
 
211
 
212
  def compare_summaries(topic):
213
  wiki_limerick = generate(topic, wiki=True)
214
+ gptneo_limerick = generate(topic, wiki=False)
215
 
216
  output = f"Limerick with Wikipedia summary of topic as prompt: \n"
217
  output += wiki_limerick + "\n"
218
+ output += f"Limerick with GPT Neo summary of topic as prompt: \n"
219
+ output += gptneo_limerick
220
 
221
  return output
222