ansfarooq7 commited on
Commit
651a92a
1 Parent(s): e37de85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -13
app.py CHANGED
@@ -1,18 +1,19 @@
1
- from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer
2
  import torch
3
  import wikipedia
4
  import re
5
  import random
6
  import nltk
7
  import gradio as gr
8
- from aitextgen import aitextgen
9
  nltk.download('cmudict')
10
 
11
  roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
12
  roberta_model = RobertaForMaskedLM.from_pretrained('roberta-base')
13
 
14
  gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
15
- gpt2_model = aitextgen()
 
 
16
 
17
  frequent_words = set()
18
 
@@ -23,7 +24,7 @@ with open("wordFrequency.txt", 'r') as f:
23
  line = f.readline()
24
 
25
  def filter_rhymes(word):
26
- filter_list = ['to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'aitch', 'angst', 'arugula', 'beige', 'blitzed', 'boing', 'bombed', 'cairn', 'chaos', 'chocolate', 'circle', 'circus', 'cleansed', 'coif', 'cusp', 'doth', 'else', 'eth', 'fiends', 'film', 'flange', 'fourths', 'grilse', 'gulf', 'kiln', 'loge', 'midst', 'month', 'music', 'neutron', 'ninja', 'oblige', 'oink', 'opus', 'orange', 'pint', 'plagued', 'plankton', 'plinth', 'poem', 'poet', 'purple', 'quaich', 'rhythm', 'rouged', 'silver', 'siren', 'soldier', 'sylph', 'thesp', 'toilet', 'torsk', 'tufts', 'waltzed', 'wasp', 'wharves', 'width', 'woman', 'yttrium']
27
  if word in filter_list:
28
  return False
29
  else:
@@ -80,15 +81,35 @@ def get_prediction(sent):
80
  best_guess = best_guess+" "+j[0]
81
 
82
  return best_guess
83
-
84
  def get_line(prompt, inputs_len):
85
- line = gpt2_model.generate_one(prompt=prompt + ".", max_length=inputs_len + 7, min_length=4)[len(prompt)+2:]
86
- return line
 
 
 
 
 
 
87
 
88
  def get_rhyming_line(prompt, rhyming_word, inputs_len):
89
- gpt2_sentence = gpt2_model.generate_one(prompt=prompt + ".", max_length=inputs_len + 4, min_length=2)[len(prompt)+2:]
 
 
 
 
 
 
 
90
  while len(gpt2_sentence) == 0:
91
- gpt2_sentence = gpt2_model.generate_one(prompt=prompt + ".", max_length=inputs_len + 4, min_length=2)[len(prompt)+2:]
 
 
 
 
 
 
 
92
 
93
  gpt2_sentence = gpt2_sentence.replace("\n", "")
94
  print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
@@ -108,16 +129,26 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
108
  return final_sentence
109
 
110
  def gpt2_summary(topic):
111
- return gpt2_model.generate_one(prompt=f"Here is some information about {topic}.", top_k=50, top_p=0.95, min_length=200)
112
-
 
 
 
 
 
 
 
113
  def generate(topic, wiki=True):
114
  if wiki:
115
  try:
116
  topic_search = wikipedia.search(topic, results=3)
 
117
  topic_summary = remove_punctuation(wikipedia.summary(topic_search[0], auto_suggest=False))
118
  except wikipedia.DisambiguationError as e:
 
 
119
  page = e.options[0]
120
- topic_summary = remove_punctuation(wikipedia.summary(page))
121
  except:
122
  return(f"Method A struggled to find information about {topic}, please try a different topic!")
123
  else:
@@ -205,7 +236,7 @@ def compare_summaries(topic):
205
  print(output1 + "\n" + output2)
206
 
207
  return output1, output2
208
-
209
  description = "Generates limericks (five-line poems with a rhyme scheme of AABBA) using two different methods, please be patient as it can take up to a minute to generate both limericks."
210
  article = '<center><big><strong>Limerick Generation</strong></big></center>'\
211
  '<center><strong>By Ans Farooq</strong></center>'\
 
1
+ from transformers import RobertaTokenizer, RobertaForMaskedLM, GPT2Tokenizer, GPT2LMHeadModel, pipeline
2
  import torch
3
  import wikipedia
4
  import re
5
  import random
6
  import nltk
7
  import gradio as gr
 
8
  nltk.download('cmudict')
9
 
10
  roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
11
  roberta_model = RobertaForMaskedLM.from_pretrained('roberta-base')
12
 
13
  gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
14
+ gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=gpt2_tokenizer.eos_token_id)
15
+
16
+ gpt2_pipeline = pipeline('text-generation', model=gpt2_model, tokenizer=gpt2_tokenizer)
17
 
18
  frequent_words = set()
19
 
 
24
  line = f.readline()
25
 
26
  def filter_rhymes(word):
27
+ filter_list = ['an', 'to', 'on', 'has', 'but', 'the', 'in', 'and', 'a']
28
  if word in filter_list:
29
  return False
30
  else:
 
81
  best_guess = best_guess+" "+j[0]
82
 
83
  return best_guess
84
+
85
  def get_line(prompt, inputs_len):
86
+ output = gpt2_pipeline(
87
+ prompt + ".",
88
+ min_length=4,
89
+ max_length=inputs_len + 7,
90
+ clean_up_tokenization_spaces=True,
91
+ return_full_text=False
92
+ )
93
+ return remove_punctuation(output[0]['generated_text'])
94
 
95
  def get_rhyming_line(prompt, rhyming_word, inputs_len):
96
+ output = gpt2_pipeline(
97
+ prompt + ".",
98
+ min_length=4,
99
+ max_length=inputs_len + 3,
100
+ clean_up_tokenization_spaces=True,
101
+ return_full_text=False
102
+ )
103
+ gpt2_sentence = remove_punctuation(output[0]['generated_text'])
104
  while len(gpt2_sentence) == 0:
105
+ output = gpt2_pipeline(
106
+ prompt + ".",
107
+ min_length=4,
108
+ max_length=inputs_len + 3,
109
+ clean_up_tokenization_spaces=True,
110
+ return_full_text=False
111
+ )
112
+ gpt2_sentence = remove_punctuation(output[0]['generated_text'])
113
 
114
  gpt2_sentence = gpt2_sentence.replace("\n", "")
115
  print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
 
129
  return final_sentence
130
 
131
  def gpt2_summary(topic):
132
+ output = gpt2_pipeline(
133
+ f"Here is some information about {topic}.",
134
+ min_length=200,
135
+ max_length=300,
136
+ clean_up_tokenization_spaces=True,
137
+ return_full_text=False
138
+ )
139
+ return remove_punctuation(output[0]['generated_text'])
140
+
141
  def generate(topic, wiki=True):
142
  if wiki:
143
  try:
144
  topic_search = wikipedia.search(topic, results=3)
145
+ print(f"Wikipedia search results for {topic} are: {topic_search}")
146
  topic_summary = remove_punctuation(wikipedia.summary(topic_search[0], auto_suggest=False))
147
  except wikipedia.DisambiguationError as e:
148
+ print("===================== DISAMBIGUATION ERROR =====================")
149
+ print(f"Wikipedia returned a disambiguation error for {topic}. Selecting the first option {e.options[0]} instead.")
150
  page = e.options[0]
151
+ topic_summary = remove_punctuation(wikipedia.summary(page, auto_suggest=False))
152
  except:
153
  return(f"Method A struggled to find information about {topic}, please try a different topic!")
154
  else:
 
236
  print(output1 + "\n" + output2)
237
 
238
  return output1, output2
239
+
240
  description = "Generates limericks (five-line poems with a rhyme scheme of AABBA) using two different methods, please be patient as it can take up to a minute to generate both limericks."
241
  article = '<center><big><strong>Limerick Generation</strong></big></center>'\
242
  '<center><strong>By Ans Farooq</strong></center>'\