import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import TrainingArguments, Trainer from datasets import load_dataset import random # Load the dataset dataset = load_dataset("SpartanCinder/song-lyrics-artist-classifier") # print(dataset.column_names) # print(dataset['train']['Artist']) # artist_list = list(set(dataset['train']['Artist'])) # print(artist_list) def generate_song(state, language_model, generate_song): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Logic that takes put the selected language model and generates a song if generate_song: if not language_model: song_text = "Please select a language model before generating a song." return state, song_text, "", "" # Generate the song and the options based on the language_model if language_model == "Custom Gpt2": model_name = "SpartanCinder/GPT2-finetuned-lyric-generation" elif language_model == "Gpt2-Medium": model_name = "gpt2-medium" elif language_model == "facebook/bart-base": model_name = "facebook/bart-base" elif language_model == "Gpt-Neo": model_name = "EleutherAI/gpt-neo-1.3B" else: # Customized Models model_name = "customized-models" #tokenzer and text generation logic tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) #Call for a random artist from the dataset correct_choice = pick_artist(dataset) input_text = f"Write a song in the style of {correct_choice}:" # Tuninng settings max_length = 128 input_ids = tokenizer.encode(input_text, return_tensors="pt") input_ids = input_ids.to(device) if language_model != "customized-models" or "Custom Gpt2": ### Using Beam search to generate text### # encoded data encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text # Decode output print(tokenizer.decode(encoded_output[0], skip_special_tokens=True)) # But this output is repeating, so I need ot adjust this so that it is not repeating. elif language_model == "Custom Gpt2": # tokenizer = AutoTokenizer.from_pretrained("SpartanCinder/GPT2-pretrained-lyric-generation") # model = AutoModelForCausalLM.from_pretrained("SpartanCinder/GPT2-pretrained-lyric-generation") # encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.95, ) # Decode output print(tokenizer.decode(encoded_output[0], skip_special_tokens=True)) else: ### Nucleas Sampling to generate text### # Set the do_sample parameter to True because we are using nucleus sampling is a probabilistic sampling method # top_p is the probability threshold for nucleus sampling # So, we set top_p to 0.9, which means that the model will sample from the top 90% of the probability distribution # This will help to generate more diverse text that is less repetitive encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.9, ) # Decode output output = tokenizer.decode(encoded_output[0], skip_special_tokens=True) # Remove the first line of the output if it contains newline characters # if '\n' in output: # output = '\n'.join(output.split('\n')[1:]) # formatted_output = output.split('\n')[0] # might have to remove this line song_text = output # Generate the multiple-choice options options = generate_artist_options(dataset, correct_choice) state['options'] = options # Generate the multiple-choice check multiple_choice_check = generate_multiple_choice_check(options, correct_choice) state['multiple_choice_check'] = multiple_choice_check state['correct_choice'] = correct_choice return state, song_text, ', '.join(options) #Check the selected artist and return whether it's correct # def on_submit_answer(state, correct_choice, user_choice, submit_answer): # if submit_answer: # if not user_choice: # return {"Error": "Please select an artist before submitting an answer."} # # Check if 'correct_choice' is in the state keys # if 'correct_choice' in state: # correct_answer = state['correct_choice'] # if correct_answer == user_choice: # return {"Result": f"You guessed the right artist: {correct_choice}"} # else: # return {"Result": f"You selected {user_choice}, but the correct answer is {correct_choice}"} # else: # print("The 'correct_choice' key does not exist in the state.") # return None def on_submit_answer(state, user_choice): # Map the user's choice (A, B, C, or D) to an index choice_to_index = {'A': 0, 'B': 1, 'C': 2, 'D': 3} index = choice_to_index[user_choice] # Retrieve the user's choice and the correct choice from the state user_artist = state['options'][index] correct_artist = state['correct_choice'] # Compare the user's choice with the correct choice if user_artist == correct_artist: return {"CORRECT": f"You guessed the right artist: {correct_artist}"} else: return {"INCORRECT": f"You selected {user_choice}, but the correct answer is {correct_artist}"} def pick_artist(dataset): # Check if 'Artist' is in the dataset columns artist_choice = list(set(dataset['train']['Artist'])) artist_choice = random.choice(artist_choice) return artist_choice # print("The 'Artist' column does not exist in the dataset.") # artist_choice = "Green Day" # return artist_choice def generate_artist_options(dataset, correct_artist): # Generate 3 incorrect options all_artists = list(set(dataset['train']['Artist'])) if correct_artist in all_artists: all_artists.remove(correct_artist) options = random.sample(all_artists, 3) + [correct_artist] random.shuffle(options) return options def generate_multiple_choice_check(options, correct_choice): return {option: option == correct_choice for option in options} def check_correct_choice(user_choice, correct_choice): if user_choice == correct_choice: return True return user_choice == correct_choice with gr.Blocks(title="Song Generator Guessing Game") as game_interface: state = gr.State({'options': []}) language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Model Selection", info="Select the language model to generate the song.") generate_song_button = gr.Button("Generate Song") generated_song = gr.Textbox(label="Generated Song") artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options") user_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.") # timer = gr.HTML("