chomayouni
The v1.3 commit
5d59d15
raw
history blame contribute delete
No virus
8.73 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
import random
# Load the dataset
dataset = load_dataset("SpartanCinder/song-lyrics-artist-classifier")
def generate_song(state, language_model, generate_song):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Logic that takes put the selected language model and generates a song
if generate_song:
if not language_model:
song_text = "Please select a language model before generating a song."
return state, song_text, "", ""
# Generate the song and the options based on the language_model
if language_model == "Custom Gpt2":
model_name = "SpartanCinder/GPT2-finetuned-lyric-generation"
elif language_model == "Gpt2-Medium":
model_name = "gpt2-medium"
elif language_model == "facebook/bart-base":
model_name = "facebook/bart-base"
elif language_model == "Gpt-Neo":
model_name = "EleutherAI/gpt-neo-1.3B"
else: # Customized Models
model_name = "customized-models"
#tokenzer and text generation logic
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
#Call for a random artist from the dataset
correct_choice = pick_artist(dataset)
input_text = f"A Song in the style of {correct_choice}:"
# Tuninng settings
max_length = 128
input_ids = tokenizer.encode(input_text, return_tensors="pt")
input_ids = input_ids.to(device)
if language_model != "customized-models" or "Custom Gpt2":
### Using Beam search to generate text###
# encoded data
encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
# Decode output
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
# But this output is repeating, so I need ot adjust this so that it is not repeating.
elif language_model == "Custom Gpt2":
# encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.95, )
# Decode output
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
else:
### Nucleas Sampling to generate text###
# Set the do_sample parameter to True because we are using nucleus sampling is a probabilistic sampling method
# top_p is the probability threshold for nucleus sampling
# So, we set top_p to 0.9, which means that the model will sample from the top 90% of the probability distribution
# This will help to generate more diverse text that is less repetitive
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.9, )
# Decode output
output = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
# Remove the first line of the output if it contains newline characters
if '\n' in output:
output = '\n'.join(output.split('\n')[1:])
song_text = output
# Generate the multiple-choice options
options = generate_artist_options(dataset, correct_choice)
state['options'] = options
# Generate the multiple-choice check
multiple_choice_check = generate_multiple_choice_check(options, correct_choice)
state['multiple_choice_check'] = multiple_choice_check
state['correct_choice'] = correct_choice
return state, song_text, ', '.join(options)
#Check the selected artist and return whether it's correct
def on_submit_answer(state, user_choice):
# Map the user's choice (A, B, C, or D) to an index
choice_to_index = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
index = choice_to_index[user_choice]
# Retrieve the user's choice and the correct choice from the state
user_artist = state['options'][index]
correct_artist = state['correct_choice']
# Compare the user's choice with the correct choice
if user_artist == correct_artist:
return {"CORRECT": f"You guessed the right artist: {correct_artist}"}
else:
return {"INCORRECT": f"You selected {user_choice}, but the correct answer is {correct_artist}"}
def pick_artist(dataset):
# Check if 'Artist' is in the dataset columns
artist_choice = list(set(dataset['train']['Artist']))
artist_choice = random.choice(artist_choice)
return artist_choice
def generate_artist_options(dataset, correct_artist):
# Generate 3 incorrect options
all_artists = list(set(dataset['train']['Artist']))
if correct_artist in all_artists:
all_artists.remove(correct_artist)
options = random.sample(all_artists, 3) + [correct_artist]
random.shuffle(options)
return options
def generate_multiple_choice_check(options, correct_choice):
return {option: option == correct_choice for option in options}
with gr.Blocks(title="Song Generator Guessing Game") as game_interface:
gr.Markdown(" # Song Generator Guessing Game")
# gr.Markdown("![Image](https://huggingface.co/spaces/SpartanCinder/NLP_Song_Generator_Guessing_Game/raw/main/RobotSinger.png)")
# gr.HTML("<img src='/NLP_Song_Generator_Guessing_Game/RobotSinger.png'")
gr.Markdown("""
## Instructions
1. Select a language model from the dropdown.
2. Click the 'Generate Song' button to generate a song.
3. Guess the artist of the generated song by selecting an option from the radio buttons.
4. Click the 'Submit Answer' button to submit your guess.
""")
state = gr.State({'options': []})
language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Model Selection", info="Select the language model to generate the song.")
generate_song_button = gr.Button("Generate Song")
generated_song = gr.Textbox(label="Generated Song")
artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options")
user_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.")
submit_answer_button = gr.Button("Submit Answer")
correct_answer = gr.Textbox(label="Results")
gr.Markdown("""
## Developer Notes:
- The 'Custom Gpt2' model is a custom fine-tuned GPT-2 model for generating song lyrics.
-- It was trained using a custom dataset of song lyrics from various artists that I created using the Genius API.
-- It uses beam search to generate text, and is much faster compared to the other moodels
-- However, it still has trouble producing prompts
-- I found that artists like "Adele" and "Taylor Swift" are more likely to have coherent lyrics
- The 'Gpt2-Medium' model is the GPT-2 medium model from the Hugging Face model hub.
-- It uses beam search to generate text, and is slower compared to the custom GPT-2 model
-- Without tuning, it is more likely to produce a general response to the prompt
-- Oddly enough, had a tendency to produce lyrics that were more coherent than the full GPT-2 model
- The 'facebook/bart-base' model is the BART base model from the Hugging Face model hub.
-- The model only workd 20% of the time
- The 'Gpt-Neo' model is the GPT-Neo 1.3B model from the EleutherAI model hub.
-- It performs well, but is slower compared to the GPT-2 models
- The 'Customized Models' option is a placeholder for any other custom models that you may have.
#### Known Issues:
- The 'facebook/bart-base' model has a tendency to produce empty responses.
-- This is likely due to the model's architecture and the way it processes the input data.
- Ocasionaly, the Custom Gpt2 model will produce a result that is just numbers. This has only happened once or twice and both times where when is was generating a song for the Weekend.
""")
generate_song_button.click(
generate_song,
[state, language_model, generate_song_button],
[state, generated_song, artist_choice_display,]
)
submit_answer_button.click(
on_submit_answer,
[state, user_choice,],
[correct_answer]
)
game_interface.launch()