Spaces:
Sleeping
Sleeping
ansfarooq7
commited on
Commit
•
d6ce809
1
Parent(s):
0bb105b
Update app.py
Browse files
app.py
CHANGED
@@ -24,7 +24,7 @@ with open("wordFrequency.txt", 'r') as f:
|
|
24 |
line = f.readline()
|
25 |
|
26 |
def filter_rhymes(word):
|
27 |
-
filter_list = ['an', 'to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'are', 'or', 'its', 'it''s']
|
28 |
if word in filter_list:
|
29 |
return False
|
30 |
else:
|
@@ -33,7 +33,7 @@ def filter_rhymes(word):
|
|
33 |
def remove_punctuation(text):
|
34 |
text = re.sub(r'[^\w\s]', '', text)
|
35 |
text = text.replace("\n", " ")
|
36 |
-
return text
|
37 |
|
38 |
def get_rhymes(inp, level):
|
39 |
entries = nltk.corpus.cmudict.entries()
|
@@ -53,7 +53,6 @@ def get_inputs_length(input):
|
|
53 |
return len(input_ids)
|
54 |
|
55 |
def get_prediction(sent):
|
56 |
-
|
57 |
token_ids = roberta_tokenizer.encode(sent, return_tensors='pt')
|
58 |
masked_position = (token_ids.squeeze() == roberta_tokenizer.mask_token_id).nonzero()
|
59 |
masked_pos = [mask.item() for mask in masked_position ]
|
@@ -90,7 +89,7 @@ def get_line(prompt, inputs_len):
|
|
90 |
clean_up_tokenization_spaces=True,
|
91 |
return_full_text=False
|
92 |
)
|
93 |
-
return remove_punctuation(output[0]['generated_text'])
|
94 |
|
95 |
def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
96 |
output = gpt2_pipeline(
|
@@ -100,7 +99,7 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
|
100 |
clean_up_tokenization_spaces=True,
|
101 |
return_full_text=False
|
102 |
)
|
103 |
-
gpt2_sentence = remove_punctuation(output[0]['generated_text'])
|
104 |
while len(gpt2_sentence) == 0:
|
105 |
output = gpt2_pipeline(
|
106 |
prompt + ".",
|
@@ -109,7 +108,7 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
|
109 |
clean_up_tokenization_spaces=True,
|
110 |
return_full_text=False
|
111 |
)
|
112 |
-
gpt2_sentence = remove_punctuation(output[0]['generated_text'])
|
113 |
|
114 |
print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
|
115 |
sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
|
@@ -130,12 +129,12 @@ def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
|
130 |
def gpt2_summary(topic):
|
131 |
output = gpt2_pipeline(
|
132 |
f"Here is some information about {topic}.",
|
133 |
-
min_length=
|
134 |
max_length=300,
|
135 |
clean_up_tokenization_spaces=True,
|
136 |
return_full_text=False
|
137 |
)
|
138 |
-
return remove_punctuation(output[0]['generated_text'])
|
139 |
|
140 |
def generate(topic, wiki=True):
|
141 |
if wiki:
|
@@ -144,7 +143,6 @@ def generate(topic, wiki=True):
|
|
144 |
print(f"Wikipedia search results for {topic} are: {topic_search}")
|
145 |
topic_summary = remove_punctuation(wikipedia.summary(topic_search[0], auto_suggest=False))
|
146 |
except wikipedia.DisambiguationError as e:
|
147 |
-
print("===================== DISAMBIGUATION ERROR =====================")
|
148 |
print(f"Wikipedia returned a disambiguation error for {topic}. Selecting the first option {e.options[0]} instead.")
|
149 |
page = e.options[0]
|
150 |
topic_summary = remove_punctuation(wikipedia.summary(page, auto_suggest=False))
|
|
|
24 |
line = f.readline()
|
25 |
|
26 |
def filter_rhymes(word):
|
27 |
+
filter_list = ['an', 'to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'are', 'or', 'its', 'it''s']
|
28 |
if word in filter_list:
|
29 |
return False
|
30 |
else:
|
|
|
33 |
def remove_punctuation(text):
|
34 |
text = re.sub(r'[^\w\s]', '', text)
|
35 |
text = text.replace("\n", " ")
|
36 |
+
return text.strip()
|
37 |
|
38 |
def get_rhymes(inp, level):
|
39 |
entries = nltk.corpus.cmudict.entries()
|
|
|
53 |
return len(input_ids)
|
54 |
|
55 |
def get_prediction(sent):
|
|
|
56 |
token_ids = roberta_tokenizer.encode(sent, return_tensors='pt')
|
57 |
masked_position = (token_ids.squeeze() == roberta_tokenizer.mask_token_id).nonzero()
|
58 |
masked_pos = [mask.item() for mask in masked_position ]
|
|
|
89 |
clean_up_tokenization_spaces=True,
|
90 |
return_full_text=False
|
91 |
)
|
92 |
+
return remove_punctuation(output[0]['generated_text'])
|
93 |
|
94 |
def get_rhyming_line(prompt, rhyming_word, inputs_len):
|
95 |
output = gpt2_pipeline(
|
|
|
99 |
clean_up_tokenization_spaces=True,
|
100 |
return_full_text=False
|
101 |
)
|
102 |
+
gpt2_sentence = remove_punctuation(output[0]['generated_text'])
|
103 |
while len(gpt2_sentence) == 0:
|
104 |
output = gpt2_pipeline(
|
105 |
prompt + ".",
|
|
|
108 |
clean_up_tokenization_spaces=True,
|
109 |
return_full_text=False
|
110 |
)
|
111 |
+
gpt2_sentence = remove_punctuation(output[0]['generated_text'])
|
112 |
|
113 |
print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
|
114 |
sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
|
|
|
129 |
def gpt2_summary(topic):
|
130 |
output = gpt2_pipeline(
|
131 |
f"Here is some information about {topic}.",
|
132 |
+
min_length=100,
|
133 |
max_length=300,
|
134 |
clean_up_tokenization_spaces=True,
|
135 |
return_full_text=False
|
136 |
)
|
137 |
+
return remove_punctuation(output[0]['generated_text'])
|
138 |
|
139 |
def generate(topic, wiki=True):
|
140 |
if wiki:
|
|
|
143 |
print(f"Wikipedia search results for {topic} are: {topic_search}")
|
144 |
topic_summary = remove_punctuation(wikipedia.summary(topic_search[0], auto_suggest=False))
|
145 |
except wikipedia.DisambiguationError as e:
|
|
|
146 |
print(f"Wikipedia returned a disambiguation error for {topic}. Selecting the first option {e.options[0]} instead.")
|
147 |
page = e.options[0]
|
148 |
topic_summary = remove_punctuation(wikipedia.summary(page, auto_suggest=False))
|