Spaces:
Sleeping
Sleeping
from transformers import RobertaTokenizer, RobertaForMaskedLM, pipeline, GPT2TokenizerFast | |
import torch | |
import pronouncing | |
import wikipedia | |
import re | |
import random | |
import nltk | |
import syllables | |
import gradio as gr | |
nltk.download('cmudict') | |
frequent_words = set() | |
def set_seed(seed: int): | |
""" | |
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if | |
installed). | |
Args: | |
seed (:obj:`int`): The seed to set. | |
""" | |
#random.seed(seed) | |
#np.random.seed(seed) | |
#if is_torch_available(): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
# ^^ safe to call this function even if cuda is not available | |
#if is_tf_available(): | |
#tf.random.set_seed(seed) | |
with open("wordFrequency.txt", 'r') as f: | |
line = f.readline() | |
while line != '': # The EOF char is an empty string | |
frequent_words.add(line.strip()) | |
line = f.readline() | |
def filter_rhymes(word): | |
filter_list = ['to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'aitch', 'angst', 'arugula', 'beige', 'blitzed', 'boing', 'bombed', 'cairn', 'chaos', 'chocolate', 'circle', 'circus', 'cleansed', 'coif', 'cusp', 'doth', 'else', 'eth', 'fiends', 'film', 'flange', 'fourths', 'grilse', 'gulf', 'kiln', 'loge', 'midst', 'month', 'music', 'neutron', 'ninja', 'oblige', 'oink', 'opus', 'orange', 'pint', 'plagued', 'plankton', 'plinth', 'poem', 'poet', 'purple', 'quaich', 'rhythm', 'rouged', 'silver', 'siren', 'soldier', 'sylph', 'thesp', 'toilet', 'torsk', 'tufts', 'waltzed', 'wasp', 'wharves', 'width', 'woman', 'yttrium'] | |
if word in filter_list: | |
return False | |
else: | |
return True | |
def remove_punctuation(text): | |
text = re.sub(r'[^\w\s]', '', text) | |
return text | |
def get_rhymes(inp, level): | |
entries = nltk.corpus.cmudict.entries() | |
syllables = [(word, syl) for word, syl in entries if word == inp] | |
rhymes = [] | |
filtered_rhymes = set() | |
for (word, syllable) in syllables: | |
rhymes += [word for word, pron in entries if pron[-level:] == syllable[-level:]] | |
for word in rhymes: | |
if (word in frequent_words) and (word != inp): | |
filtered_rhymes.add(word) | |
return filtered_rhymes | |
def get_inputs_length(input): | |
gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
input_ids = gpt2_tokenizer(input)['input_ids'] | |
return len(input_ids) | |
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') | |
model = RobertaForMaskedLM.from_pretrained('roberta-base') | |
text_generation = pipeline("text-generation") | |
set_seed(0) | |
def get_prediction(sent): | |
token_ids = tokenizer.encode(sent, return_tensors='pt') | |
masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero() | |
masked_pos = [mask.item() for mask in masked_position ] | |
with torch.no_grad(): | |
output = model(token_ids) | |
last_hidden_state = output[0].squeeze() | |
list_of_list =[] | |
for index,mask_index in enumerate(masked_pos): | |
words = [] | |
mask_hidden_state = last_hidden_state[mask_index] | |
idx = torch.topk(mask_hidden_state, k=5, dim=0)[1] | |
for i in idx: | |
word = tokenizer.decode(i.item()).strip() | |
if (remove_punctuation(word) != "") and (word != '</s>'): | |
words.append(word) | |
#words = [tokenizer.decode(i.item()).strip() for i in idx] | |
list_of_list.append(words) | |
print(f"Mask {index+1} Guesses: {words}") | |
best_guess = "" | |
for j in list_of_list: | |
best_guess = best_guess+" "+j[0] | |
return best_guess | |
def get_line(topic_summary, starting_words, inputs_len): | |
starting_word = random.choice(starting_words) | |
line = starting_word + text_generation(topic_summary + " " + starting_word, max_length=inputs_len + 6, do_sample=True, return_full_text=False)[0]['generated_text'] | |
return line | |
def get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len): | |
#gpt2_sentence = text_generation(topic_summary + " " + starting_words[i][j], max_length=no_of_words + 4, do_sample=False)[0] | |
starting_word = random.choice(starting_words) | |
print(f"\nGetting rhyming line with starting word '{starting_word}' and rhyming word '{rhyming_word}'") | |
gpt2_sentence = text_generation(topic_summary + " " + starting_word, max_length=inputs_len + 2, do_sample=True, return_full_text=False)[0] | |
#sentence = gpt2_sentence['generated_text'] + " ___ ___ ___ " + rhyming_words[i][j] | |
sentence = starting_word + gpt2_sentence['generated_text'] + " ___ ___ ___ " + rhyming_word | |
print(f"Original Sentence: {sentence}") | |
if sentence[-1] != ".": | |
sentence = sentence.replace("___","<mask>") + "." | |
else: | |
sentence = sentence.replace("___","<mask>") | |
print(f"Original Sentence replaced with mask: {sentence}") | |
print("\n") | |
predicted_blanks = get_prediction(sentence) | |
print(f"\nBest guess for fill in the blanks: {predicted_blanks}") | |
return starting_word + gpt2_sentence['generated_text'] + predicted_blanks + " " + rhyming_word | |
from transformers import pipeline | |
def generate(topic): | |
text_generation = pipeline("text-generation") | |
limericks = [] | |
#topic = input("Please enter a topic: ") | |
topic_summary = remove_punctuation(wikipedia.summary(topic)) | |
# if len(topic_summary) > 2000: | |
# topic_summary = topic_summary[:2000] | |
word_list = topic_summary.split() | |
topic_summary_len = len(topic_summary) | |
no_of_words = len(word_list) | |
inputs_len = get_inputs_length(topic_summary) | |
print(f"Topic Summary: {topic_summary}") | |
print(f"Topic Summary Length: {topic_summary_len}") | |
print(f"No of Words in Summary: {no_of_words}") | |
print(f"Length of Input IDs: {inputs_len}") | |
starting_words = ["That", "Had", "Not", "But", "With", "I", "Because", "There", "Who", "She", "He", "To", "Whose", "In", "And", "When", "Or", "So", "The", "Of", "Every", "Whom"] | |
# starting_words = [["That", "Had", "Not", "But", "That"], | |
# ["There", "Who", "She", "Tormenting", "Til"], | |
# ["Relentless", "This", "First", "and", "then"], | |
# ["There", "Who", "That", "To", "She"], | |
# ["There", "Who", "Two", "Four", "Have"]] | |
# rhyming_words = [["told", "bold", "woodchuck", "truck", "road"], | |
# ["Nice", "grease", "house", "spouse", "peace"], | |
# ["deadlines", "lines", "edits", "credits", "wine"], | |
# ["Lynn", "thin", "essayed", "lemonade", "in"], | |
# ["beard", "feared", "hen", "wren", "beard"]] | |
for i in range(5): | |
print(f"\nGenerating limerick {i+1}") | |
rhyming_words_125 = [] | |
while len(rhyming_words_125) < 3 or valid_rhyme == False: | |
first_line = get_line(topic_summary, starting_words, inputs_len) | |
#rhyming_words = pronouncing.rhymes(first_line.split()[-1]) | |
end_word = remove_punctuation(first_line.split()[-1]) | |
valid_rhyme = filter_rhymes(end_word) | |
if valid_rhyme: | |
print(f"\nFirst Line: {first_line}") | |
rhyming_words_125 = list(get_rhymes(end_word, 3)) | |
print(f"Rhyming words for '{end_word}' are {rhyming_words_125}") | |
limerick = first_line + "\n" | |
rhyming_word = rhyming_words_125[0] | |
second_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len) | |
limerick += second_line + "\n" | |
rhyming_words_34 = [] | |
while len(rhyming_words_34) < 2 or valid_rhyme == False: | |
third_line = get_line(topic_summary, starting_words, inputs_len) | |
print(f"\nThird Line: {third_line}") | |
#rhyming_words = pronouncing.rhymes(first_line.split()[-1]) | |
end_word = remove_punctuation(third_line.split()[-1]) | |
valid_rhyme = filter_rhymes(end_word) | |
print(f"Does '{end_word}'' have valid rhymes: {valid_rhyme}") | |
rhyming_words_34 = list(get_rhymes(end_word, 3)) | |
print(f"Rhyming words for '{end_word}' are {rhyming_words_34}") | |
if valid_rhyme and len(rhyming_words_34) > 1: | |
limerick += third_line + "\n" | |
rhyming_word = rhyming_words_34[0] | |
fourth_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len) | |
limerick += fourth_line + "\n" | |
rhyming_word = rhyming_words_125[1] | |
fifth_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len) | |
limerick += fifth_line + "\n" | |
limericks.append(limerick) | |
print("\n") | |
output = f"Generated {len(limericks)} limericks: \n" | |
print(f"Generated {len(limericks)} limericks: \n") | |
for limerick in limericks: | |
print(limerick) | |
output += limerick | |
return output | |
interface = gr.Interface(fn=generate, inputs="text", outputs="text") | |
interface.launch(debug=True) |