ansfarooq7 commited on
Commit
d6ce809
1 Parent(s): 0bb105b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -24,7 +24,7 @@ with open("wordFrequency.txt", 'r') as f:
24
  line = f.readline()
25
 
26
  def filter_rhymes(word):
27
- filter_list = ['an', 'to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'are', 'or', 'its', 'it''s']
28
  if word in filter_list:
29
  return False
30
  else:
@@ -33,7 +33,7 @@ def filter_rhymes(word):
33
  def remove_punctuation(text):
34
  text = re.sub(r'[^\w\s]', '', text)
35
  text = text.replace("\n", " ")
36
- return text
37
 
38
  def get_rhymes(inp, level):
39
  entries = nltk.corpus.cmudict.entries()
@@ -53,7 +53,6 @@ def get_inputs_length(input):
53
  return len(input_ids)
54
 
55
  def get_prediction(sent):
56
-
57
  token_ids = roberta_tokenizer.encode(sent, return_tensors='pt')
58
  masked_position = (token_ids.squeeze() == roberta_tokenizer.mask_token_id).nonzero()
59
  masked_pos = [mask.item() for mask in masked_position ]
@@ -90,7 +89,7 @@ def get_line(prompt, inputs_len):
90
  clean_up_tokenization_spaces=True,
91
  return_full_text=False
92
  )
93
- return remove_punctuation(output[0]['generated_text']).strip()
94
 
95
  def get_rhyming_line(prompt, rhyming_word, inputs_len):
96
  output = gpt2_pipeline(
@@ -100,7 +99,7 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
100
  clean_up_tokenization_spaces=True,
101
  return_full_text=False
102
  )
103
- gpt2_sentence = remove_punctuation(output[0]['generated_text']).strip()
104
  while len(gpt2_sentence) == 0:
105
  output = gpt2_pipeline(
106
  prompt + ".",
@@ -109,7 +108,7 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
109
  clean_up_tokenization_spaces=True,
110
  return_full_text=False
111
  )
112
- gpt2_sentence = remove_punctuation(output[0]['generated_text']).strip()
113
 
114
  print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
115
  sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
@@ -130,12 +129,12 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
130
  def gpt2_summary(topic):
131
  output = gpt2_pipeline(
132
  f"Here is some information about {topic}.",
133
- min_length=200,
134
  max_length=300,
135
  clean_up_tokenization_spaces=True,
136
  return_full_text=False
137
  )
138
- return remove_punctuation(output[0]['generated_text']).strip()
139
 
140
  def generate(topic, wiki=True):
141
  if wiki:
@@ -144,7 +143,6 @@ def generate(topic, wiki=True):
144
  print(f"Wikipedia search results for {topic} are: {topic_search}")
145
  topic_summary = remove_punctuation(wikipedia.summary(topic_search[0], auto_suggest=False))
146
  except wikipedia.DisambiguationError as e:
147
- print("===================== DISAMBIGUATION ERROR =====================")
148
  print(f"Wikipedia returned a disambiguation error for {topic}. Selecting the first option {e.options[0]} instead.")
149
  page = e.options[0]
150
  topic_summary = remove_punctuation(wikipedia.summary(page, auto_suggest=False))
 
24
  line = f.readline()
25
 
26
  def filter_rhymes(word):
27
+ filter_list = ['an', 'to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'are', 'or', 'its', 'it''s']
28
  if word in filter_list:
29
  return False
30
  else:
 
33
  def remove_punctuation(text):
34
  text = re.sub(r'[^\w\s]', '', text)
35
  text = text.replace("\n", " ")
36
+ return text.strip()
37
 
38
  def get_rhymes(inp, level):
39
  entries = nltk.corpus.cmudict.entries()
 
53
  return len(input_ids)
54
 
55
  def get_prediction(sent):
 
56
  token_ids = roberta_tokenizer.encode(sent, return_tensors='pt')
57
  masked_position = (token_ids.squeeze() == roberta_tokenizer.mask_token_id).nonzero()
58
  masked_pos = [mask.item() for mask in masked_position ]
 
89
  clean_up_tokenization_spaces=True,
90
  return_full_text=False
91
  )
92
+ return remove_punctuation(output[0]['generated_text'])
93
 
94
  def get_rhyming_line(prompt, rhyming_word, inputs_len):
95
  output = gpt2_pipeline(
 
99
  clean_up_tokenization_spaces=True,
100
  return_full_text=False
101
  )
102
+ gpt2_sentence = remove_punctuation(output[0]['generated_text'])
103
  while len(gpt2_sentence) == 0:
104
  output = gpt2_pipeline(
105
  prompt + ".",
 
108
  clean_up_tokenization_spaces=True,
109
  return_full_text=False
110
  )
111
+ gpt2_sentence = remove_punctuation(output[0]['generated_text'])
112
 
113
  print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
114
  sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
 
129
  def gpt2_summary(topic):
130
  output = gpt2_pipeline(
131
  f"Here is some information about {topic}.",
132
+ min_length=100,
133
  max_length=300,
134
  clean_up_tokenization_spaces=True,
135
  return_full_text=False
136
  )
137
+ return remove_punctuation(output[0]['generated_text'])
138
 
139
  def generate(topic, wiki=True):
140
  if wiki:
 
143
  print(f"Wikipedia search results for {topic} are: {topic_search}")
144
  topic_summary = remove_punctuation(wikipedia.summary(topic_search[0], auto_suggest=False))
145
  except wikipedia.DisambiguationError as e:
 
146
  print(f"Wikipedia returned a disambiguation error for {topic}. Selecting the first option {e.options[0]} instead.")
147
  page = e.options[0]
148
  topic_summary = remove_punctuation(wikipedia.summary(page, auto_suggest=False))