from transformers import RobertaTokenizer, RobertaForMaskedLM, pipeline, GPT2TokenizerFast import torch import pronouncing import wikipedia import re import random import nltk import syllables import gradio as gr nltk.download('cmudict') frequent_words = set() def set_seed(seed: int): """ Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if installed). Args: seed (:obj:`int`): The seed to set. """ #random.seed(seed) #np.random.seed(seed) #if is_torch_available(): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # ^^ safe to call this function even if cuda is not available #if is_tf_available(): #tf.random.set_seed(seed) with open("wordFrequency.txt", 'r') as f: line = f.readline() while line != '': # The EOF char is an empty string frequent_words.add(line.strip()) line = f.readline() def filter_rhymes(word): filter_list = ['to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'aitch', 'angst', 'arugula', 'beige', 'blitzed', 'boing', 'bombed', 'cairn', 'chaos', 'chocolate', 'circle', 'circus', 'cleansed', 'coif', 'cusp', 'doth', 'else', 'eth', 'fiends', 'film', 'flange', 'fourths', 'grilse', 'gulf', 'kiln', 'loge', 'midst', 'month', 'music', 'neutron', 'ninja', 'oblige', 'oink', 'opus', 'orange', 'pint', 'plagued', 'plankton', 'plinth', 'poem', 'poet', 'purple', 'quaich', 'rhythm', 'rouged', 'silver', 'siren', 'soldier', 'sylph', 'thesp', 'toilet', 'torsk', 'tufts', 'waltzed', 'wasp', 'wharves', 'width', 'woman', 'yttrium'] if word in filter_list: return False else: return True def remove_punctuation(text): text = re.sub(r'[^\w\s]', '', text) return text def get_rhymes(inp, level): entries = nltk.corpus.cmudict.entries() syllables = [(word, syl) for word, syl in entries if word == inp] rhymes = [] filtered_rhymes = set() for (word, syllable) in syllables: rhymes += [word for word, pron in entries if pron[-level:] == syllable[-level:]] for word in rhymes: if (word in frequent_words) and (word != inp): filtered_rhymes.add(word) return filtered_rhymes def get_inputs_length(input): gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") input_ids = gpt2_tokenizer(input)['input_ids'] return len(input_ids) tokenizer = RobertaTokenizer.from_pretrained('roberta-base') model = RobertaForMaskedLM.from_pretrained('roberta-base') text_generation = pipeline("text-generation") set_seed(0) def get_prediction(sent): token_ids = tokenizer.encode(sent, return_tensors='pt') masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero() masked_pos = [mask.item() for mask in masked_position ] with torch.no_grad(): output = model(token_ids) last_hidden_state = output[0].squeeze() list_of_list =[] for index,mask_index in enumerate(masked_pos): words = [] mask_hidden_state = last_hidden_state[mask_index] idx = torch.topk(mask_hidden_state, k=5, dim=0)[1] for i in idx: word = tokenizer.decode(i.item()).strip() if (remove_punctuation(word) != "") and (word != ''): words.append(word) #words = [tokenizer.decode(i.item()).strip() for i in idx] list_of_list.append(words) print(f"Mask {index+1} Guesses: {words}") best_guess = "" for j in list_of_list: best_guess = best_guess+" "+j[0] return best_guess def get_line(topic_summary, starting_words, inputs_len): starting_word = random.choice(starting_words) line = starting_word + text_generation(topic_summary + " " + starting_word, max_length=inputs_len + 6, do_sample=True, return_full_text=False)[0]['generated_text'] return line def get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len): #gpt2_sentence = text_generation(topic_summary + " " + starting_words[i][j], max_length=no_of_words + 4, do_sample=False)[0] starting_word = random.choice(starting_words) print(f"\nGetting rhyming line with starting word '{starting_word}' and rhyming word '{rhyming_word}'") gpt2_sentence = text_generation(topic_summary + " " + starting_word, max_length=inputs_len + 2, do_sample=True, return_full_text=False)[0] #sentence = gpt2_sentence['generated_text'] + " ___ ___ ___ " + rhyming_words[i][j] sentence = starting_word + gpt2_sentence['generated_text'] + " ___ ___ ___ " + rhyming_word print(f"Original Sentence: {sentence}") if sentence[-1] != ".": sentence = sentence.replace("___","") + "." else: sentence = sentence.replace("___","") print(f"Original Sentence replaced with mask: {sentence}") print("\n") predicted_blanks = get_prediction(sentence) print(f"\nBest guess for fill in the blanks: {predicted_blanks}") return starting_word + gpt2_sentence['generated_text'] + predicted_blanks + " " + rhyming_word from transformers import pipeline def generate(topic): text_generation = pipeline("text-generation") limericks = [] #topic = input("Please enter a topic: ") topic_summary = remove_punctuation(wikipedia.summary(topic)) # if len(topic_summary) > 2000: # topic_summary = topic_summary[:2000] word_list = topic_summary.split() topic_summary_len = len(topic_summary) no_of_words = len(word_list) inputs_len = get_inputs_length(topic_summary) print(f"Topic Summary: {topic_summary}") print(f"Topic Summary Length: {topic_summary_len}") print(f"No of Words in Summary: {no_of_words}") print(f"Length of Input IDs: {inputs_len}") starting_words = ["That", "Had", "Not", "But", "With", "I", "Because", "There", "Who", "She", "He", "To", "Whose", "In", "And", "When", "Or", "So", "The", "Of", "Every", "Whom"] # starting_words = [["That", "Had", "Not", "But", "That"], # ["There", "Who", "She", "Tormenting", "Til"], # ["Relentless", "This", "First", "and", "then"], # ["There", "Who", "That", "To", "She"], # ["There", "Who", "Two", "Four", "Have"]] # rhyming_words = [["told", "bold", "woodchuck", "truck", "road"], # ["Nice", "grease", "house", "spouse", "peace"], # ["deadlines", "lines", "edits", "credits", "wine"], # ["Lynn", "thin", "essayed", "lemonade", "in"], # ["beard", "feared", "hen", "wren", "beard"]] for i in range(5): print(f"\nGenerating limerick {i+1}") rhyming_words_125 = [] while len(rhyming_words_125) < 3 or valid_rhyme == False: first_line = get_line(topic_summary, starting_words, inputs_len) #rhyming_words = pronouncing.rhymes(first_line.split()[-1]) end_word = remove_punctuation(first_line.split()[-1]) valid_rhyme = filter_rhymes(end_word) if valid_rhyme: print(f"\nFirst Line: {first_line}") rhyming_words_125 = list(get_rhymes(end_word, 3)) print(f"Rhyming words for '{end_word}' are {rhyming_words_125}") limerick = first_line + "\n" rhyming_word = rhyming_words_125[0] second_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len) limerick += second_line + "\n" rhyming_words_34 = [] while len(rhyming_words_34) < 2 or valid_rhyme == False: third_line = get_line(topic_summary, starting_words, inputs_len) print(f"\nThird Line: {third_line}") #rhyming_words = pronouncing.rhymes(first_line.split()[-1]) end_word = remove_punctuation(third_line.split()[-1]) valid_rhyme = filter_rhymes(end_word) print(f"Does '{end_word}'' have valid rhymes: {valid_rhyme}") rhyming_words_34 = list(get_rhymes(end_word, 3)) print(f"Rhyming words for '{end_word}' are {rhyming_words_34}") if valid_rhyme and len(rhyming_words_34) > 1: limerick += third_line + "\n" rhyming_word = rhyming_words_34[0] fourth_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len) limerick += fourth_line + "\n" rhyming_word = rhyming_words_125[1] fifth_line = get_rhyming_line(topic_summary, starting_words, rhyming_word, inputs_len) limerick += fifth_line + "\n" limericks.append(limerick) print("\n") output = f"Generated {len(limericks)} limericks: \n" print(f"Generated {len(limericks)} limericks: \n") for limerick in limericks: print(limerick) output += limerick return output interface = gr.Interface(fn=generate, inputs="text", outputs="text") interface.launch(debug=True)