|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from transformers import TrainingArguments, Trainer |
|
from datasets import load_dataset |
|
import random |
|
|
|
|
|
dataset = load_dataset("SpartanCinder/song-lyrics-artist-classifier") |
|
|
|
|
|
|
|
|
|
|
|
def generate_song(state, language_model, generate_song): |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
if generate_song: |
|
if not language_model: |
|
song_text = "Please select a language model before generating a song." |
|
return state, song_text, "", "" |
|
|
|
if language_model == "Custom Gpt2": |
|
model_name = "SpartanCinder/GPT2-finetuned-lyric-generation" |
|
elif language_model == "Gpt2-Medium": |
|
model_name = "gpt2-medium" |
|
elif language_model == "facebook/bart-base": |
|
model_name = "facebook/bart-base" |
|
elif language_model == "Gpt-Neo": |
|
model_name = "EleutherAI/gpt-neo-1.3B" |
|
else: |
|
model_name = "customized-models" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
correct_choice = pick_artist(dataset) |
|
input_text = f"Write a song in the style of {correct_choice}:" |
|
|
|
|
|
max_length = 128 |
|
input_ids = tokenizer.encode(input_text, return_tensors="pt") |
|
input_ids = input_ids.to(device) |
|
|
|
if language_model != "customized-models" or "Custom Gpt2": |
|
|
|
|
|
encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) |
|
|
|
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True)) |
|
|
|
elif language_model == "Custom Gpt2": |
|
|
|
|
|
|
|
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.95, ) |
|
|
|
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True)) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.9, ) |
|
|
|
|
|
output = tokenizer.decode(encoded_output[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
|
song_text = output |
|
|
|
|
|
options = generate_artist_options(dataset, correct_choice) |
|
state['options'] = options |
|
|
|
|
|
multiple_choice_check = generate_multiple_choice_check(options, correct_choice) |
|
state['multiple_choice_check'] = multiple_choice_check |
|
state['correct_choice'] = correct_choice |
|
|
|
return state, song_text, ', '.join(options) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_submit_answer(state, user_choice): |
|
|
|
choice_to_index = {'A': 0, 'B': 1, 'C': 2, 'D': 3} |
|
index = choice_to_index[user_choice] |
|
|
|
|
|
user_artist = state['options'][index] |
|
correct_artist = state['correct_choice'] |
|
|
|
|
|
if user_artist == correct_artist: |
|
return {"CORRECT": f"You guessed the right artist: {correct_artist}"} |
|
else: |
|
return {"INCORRECT": f"You selected {user_choice}, but the correct answer is {correct_artist}"} |
|
|
|
def pick_artist(dataset): |
|
|
|
|
|
artist_choice = list(set(dataset['train']['Artist'])) |
|
artist_choice = random.choice(artist_choice) |
|
return artist_choice |
|
|
|
|
|
|
|
|
|
|
|
def generate_artist_options(dataset, correct_artist): |
|
|
|
all_artists = list(set(dataset['train']['Artist'])) |
|
if correct_artist in all_artists: |
|
all_artists.remove(correct_artist) |
|
options = random.sample(all_artists, 3) + [correct_artist] |
|
random.shuffle(options) |
|
return options |
|
|
|
def generate_multiple_choice_check(options, correct_choice): |
|
return {option: option == correct_choice for option in options} |
|
|
|
def check_correct_choice(user_choice, correct_choice): |
|
if user_choice == correct_choice: |
|
return True |
|
return user_choice == correct_choice |
|
|
|
with gr.Blocks(title="Song Generator Guessing Game") as game_interface: |
|
state = gr.State({'options': []}) |
|
language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Model Selection", info="Select the language model to generate the song.") |
|
generate_song_button = gr.Button("Generate Song") |
|
generated_song = gr.Textbox(label="Generated Song") |
|
artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options") |
|
user_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.") |
|
|
|
submit_answer_button = gr.Button("Submit Answer") |
|
correct_answer = gr.Textbox(label="Results") |
|
|
|
generate_song_button.click( |
|
generate_song, |
|
[state, language_model, generate_song_button], |
|
[state, generated_song, artist_choice_display,] |
|
) |
|
submit_answer_button.click( |
|
on_submit_answer, |
|
[state, user_choice,], |
|
[correct_answer] |
|
) |
|
|
|
game_interface.launch() |