SpartanCinder's picture
Rename sgg_app.py to app.py
609f19c verified
raw
history blame
8.29 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
import random
# Load the dataset
dataset = load_dataset("SpartanCinder/song-lyrics-artist-classifier")
# print(dataset.column_names)
# print(dataset['train']['Artist'])
# artist_list = list(set(dataset['train']['Artist']))
# print(artist_list)
def generate_song(state, language_model, generate_song):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Logic that takes put the selected language model and generates a song
if generate_song:
if not language_model:
song_text = "Please select a language model before generating a song."
return state, song_text, "", ""
# Generate the song and the options based on the language_model
if language_model == "Custom Gpt2":
model_name = "SpartanCinder/GPT2-finetuned-lyric-generation"
elif language_model == "Gpt2-Medium":
model_name = "gpt2-medium"
elif language_model == "facebook/bart-base":
model_name = "facebook/bart-base"
elif language_model == "Gpt-Neo":
model_name = "EleutherAI/gpt-neo-1.3B"
else: # Customized Models
model_name = "customized-models"
#tokenzer and text generation logic
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
#Call for a random artist from the dataset
correct_choice = pick_artist(dataset)
input_text = f"Write a song in the style of {correct_choice}:"
# Tuninng settings
max_length = 128
input_ids = tokenizer.encode(input_text, return_tensors="pt")
input_ids = input_ids.to(device)
if language_model != "customized-models" or "Custom Gpt2":
### Using Beam search to generate text###
# encoded data
encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
# Decode output
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
# But this output is repeating, so I need ot adjust this so that it is not repeating.
elif language_model == "Custom Gpt2":
# tokenizer = AutoTokenizer.from_pretrained("SpartanCinder/GPT2-pretrained-lyric-generation")
# model = AutoModelForCausalLM.from_pretrained("SpartanCinder/GPT2-pretrained-lyric-generation")
# encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.95, )
# Decode output
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
else:
### Nucleas Sampling to generate text###
# Set the do_sample parameter to True because we are using nucleus sampling is a probabilistic sampling method
# top_p is the probability threshold for nucleus sampling
# So, we set top_p to 0.9, which means that the model will sample from the top 90% of the probability distribution
# This will help to generate more diverse text that is less repetitive
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.9, )
# Decode output
output = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
# Remove the first line of the output if it contains newline characters
# if '\n' in output:
# output = '\n'.join(output.split('\n')[1:])
# formatted_output = output.split('\n')[0] # might have to remove this line
song_text = output
# Generate the multiple-choice options
options = generate_artist_options(dataset, correct_choice)
state['options'] = options
# Generate the multiple-choice check
multiple_choice_check = generate_multiple_choice_check(options, correct_choice)
state['multiple_choice_check'] = multiple_choice_check
state['correct_choice'] = correct_choice
return state, song_text, ', '.join(options)
#Check the selected artist and return whether it's correct
# def on_submit_answer(state, correct_choice, user_choice, submit_answer):
# if submit_answer:
# if not user_choice:
# return {"Error": "Please select an artist before submitting an answer."}
# # Check if 'correct_choice' is in the state keys
# if 'correct_choice' in state:
# correct_answer = state['correct_choice']
# if correct_answer == user_choice:
# return {"Result": f"You guessed the right artist: {correct_choice}"}
# else:
# return {"Result": f"You selected {user_choice}, but the correct answer is {correct_choice}"}
# else:
# print("The 'correct_choice' key does not exist in the state.")
# return None
def on_submit_answer(state, user_choice):
# Map the user's choice (A, B, C, or D) to an index
choice_to_index = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
index = choice_to_index[user_choice]
# Retrieve the user's choice and the correct choice from the state
user_artist = state['options'][index]
correct_artist = state['correct_choice']
# Compare the user's choice with the correct choice
if user_artist == correct_artist:
return {"CORRECT": f"You guessed the right artist: {correct_artist}"}
else:
return {"INCORRECT": f"You selected {user_choice}, but the correct answer is {correct_artist}"}
def pick_artist(dataset):
# Check if 'Artist' is in the dataset columns
artist_choice = list(set(dataset['train']['Artist']))
artist_choice = random.choice(artist_choice)
return artist_choice
# print("The 'Artist' column does not exist in the dataset.")
# artist_choice = "Green Day"
# return artist_choice
def generate_artist_options(dataset, correct_artist):
# Generate 3 incorrect options
all_artists = list(set(dataset['train']['Artist']))
if correct_artist in all_artists:
all_artists.remove(correct_artist)
options = random.sample(all_artists, 3) + [correct_artist]
random.shuffle(options)
return options
def generate_multiple_choice_check(options, correct_choice):
return {option: option == correct_choice for option in options}
def check_correct_choice(user_choice, correct_choice):
if user_choice == correct_choice:
return True
return user_choice == correct_choice
with gr.Blocks(title="Song Generator Guessing Game") as game_interface:
state = gr.State({'options': []})
language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Model Selection", info="Select the language model to generate the song.")
generate_song_button = gr.Button("Generate Song")
generated_song = gr.Textbox(label="Generated Song")
artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options")
user_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.")
# timer = gr.HTML("<div id='progress-bar' style='width: 100%; background-color: #f3f3f3; border: 1px solid #bbb;'><div id='progress' style='height: 20px; width: 0%; background-color: #007bff;'></div></div><script>function startTimer() {var time = 30; var timer = setInterval(function() {time--; document.getElementById('progress').style.width = (time / 30 * 100) + '%'; if (time <= 0) {clearInterval(timer);}}, 1000);}</script>", label="Timer")
submit_answer_button = gr.Button("Submit Answer")
correct_answer = gr.Textbox(label="Results")
generate_song_button.click(
generate_song,
[state, language_model, generate_song_button],
[state, generated_song, artist_choice_display,]
)
submit_answer_button.click(
on_submit_answer,
[state, user_choice,],
[correct_answer]
)
game_interface.launch()