chomayouni
commited on
Commit
•
1719da7
1
Parent(s):
394bbaa
The v1.2 commit
Browse files- sgg_app.py +5 -34
sgg_app.py
CHANGED
@@ -7,10 +7,6 @@ import random
|
|
7 |
|
8 |
# Load the dataset
|
9 |
dataset = load_dataset("SpartanCinder/song-lyrics-artist-classifier")
|
10 |
-
# print(dataset.column_names)
|
11 |
-
# print(dataset['train']['Artist'])
|
12 |
-
# artist_list = list(set(dataset['train']['Artist']))
|
13 |
-
# print(artist_list)
|
14 |
|
15 |
def generate_song(state, language_model, generate_song):
|
16 |
|
@@ -54,8 +50,6 @@ def generate_song(state, language_model, generate_song):
|
|
54 |
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
|
55 |
# But this output is repeating, so I need ot adjust this so that it is not repeating.
|
56 |
elif language_model == "Custom Gpt2":
|
57 |
-
# tokenizer = AutoTokenizer.from_pretrained("SpartanCinder/GPT2-pretrained-lyric-generation")
|
58 |
-
# model = AutoModelForCausalLM.from_pretrained("SpartanCinder/GPT2-pretrained-lyric-generation")
|
59 |
# encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
|
60 |
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.95, )
|
61 |
# Decode output
|
@@ -71,9 +65,8 @@ def generate_song(state, language_model, generate_song):
|
|
71 |
# Decode output
|
72 |
output = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
|
73 |
# Remove the first line of the output if it contains newline characters
|
74 |
-
|
75 |
-
|
76 |
-
# formatted_output = output.split('\n')[0] # might have to remove this line
|
77 |
song_text = output
|
78 |
|
79 |
# Generate the multiple-choice options
|
@@ -88,21 +81,6 @@ def generate_song(state, language_model, generate_song):
|
|
88 |
return state, song_text, ', '.join(options)
|
89 |
|
90 |
#Check the selected artist and return whether it's correct
|
91 |
-
# def on_submit_answer(state, correct_choice, user_choice, submit_answer):
|
92 |
-
# if submit_answer:
|
93 |
-
# if not user_choice:
|
94 |
-
# return {"Error": "Please select an artist before submitting an answer."}
|
95 |
-
# # Check if 'correct_choice' is in the state keys
|
96 |
-
# if 'correct_choice' in state:
|
97 |
-
# correct_answer = state['correct_choice']
|
98 |
-
# if correct_answer == user_choice:
|
99 |
-
# return {"Result": f"You guessed the right artist: {correct_choice}"}
|
100 |
-
# else:
|
101 |
-
# return {"Result": f"You selected {user_choice}, but the correct answer is {correct_choice}"}
|
102 |
-
# else:
|
103 |
-
# print("The 'correct_choice' key does not exist in the state.")
|
104 |
-
# return None
|
105 |
-
|
106 |
def on_submit_answer(state, user_choice):
|
107 |
# Map the user's choice (A, B, C, or D) to an index
|
108 |
choice_to_index = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
|
@@ -125,10 +103,6 @@ def pick_artist(dataset):
|
|
125 |
artist_choice = random.choice(artist_choice)
|
126 |
return artist_choice
|
127 |
|
128 |
-
# print("The 'Artist' column does not exist in the dataset.")
|
129 |
-
# artist_choice = "Green Day"
|
130 |
-
# return artist_choice
|
131 |
-
|
132 |
def generate_artist_options(dataset, correct_artist):
|
133 |
# Generate 3 incorrect options
|
134 |
all_artists = list(set(dataset['train']['Artist']))
|
@@ -141,22 +115,19 @@ def generate_artist_options(dataset, correct_artist):
|
|
141 |
def generate_multiple_choice_check(options, correct_choice):
|
142 |
return {option: option == correct_choice for option in options}
|
143 |
|
144 |
-
def check_correct_choice(user_choice, correct_choice):
|
145 |
-
if user_choice == correct_choice:
|
146 |
-
return True
|
147 |
-
return user_choice == correct_choice
|
148 |
-
|
149 |
with gr.Blocks(title="Song Generator Guessing Game") as game_interface:
|
|
|
|
|
150 |
state = gr.State({'options': []})
|
151 |
language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Model Selection", info="Select the language model to generate the song.")
|
152 |
generate_song_button = gr.Button("Generate Song")
|
153 |
generated_song = gr.Textbox(label="Generated Song")
|
154 |
artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options")
|
155 |
user_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.")
|
156 |
-
# timer = gr.HTML("<div id='progress-bar' style='width: 100%; background-color: #f3f3f3; border: 1px solid #bbb;'><div id='progress' style='height: 20px; width: 0%; background-color: #007bff;'></div></div><script>function startTimer() {var time = 30; var timer = setInterval(function() {time--; document.getElementById('progress').style.width = (time / 30 * 100) + '%'; if (time <= 0) {clearInterval(timer);}}, 1000);}</script>", label="Timer")
|
157 |
submit_answer_button = gr.Button("Submit Answer")
|
158 |
correct_answer = gr.Textbox(label="Results")
|
159 |
|
|
|
160 |
generate_song_button.click(
|
161 |
generate_song,
|
162 |
[state, language_model, generate_song_button],
|
|
|
7 |
|
8 |
# Load the dataset
|
9 |
dataset = load_dataset("SpartanCinder/song-lyrics-artist-classifier")
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def generate_song(state, language_model, generate_song):
|
12 |
|
|
|
50 |
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
|
51 |
# But this output is repeating, so I need ot adjust this so that it is not repeating.
|
52 |
elif language_model == "Custom Gpt2":
|
|
|
|
|
53 |
# encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
|
54 |
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.95, )
|
55 |
# Decode output
|
|
|
65 |
# Decode output
|
66 |
output = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
|
67 |
# Remove the first line of the output if it contains newline characters
|
68 |
+
if '\n' in output:
|
69 |
+
output = '\n'.join(output.split('\n')[1:])
|
|
|
70 |
song_text = output
|
71 |
|
72 |
# Generate the multiple-choice options
|
|
|
81 |
return state, song_text, ', '.join(options)
|
82 |
|
83 |
#Check the selected artist and return whether it's correct
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
def on_submit_answer(state, user_choice):
|
85 |
# Map the user's choice (A, B, C, or D) to an index
|
86 |
choice_to_index = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
|
|
|
103 |
artist_choice = random.choice(artist_choice)
|
104 |
return artist_choice
|
105 |
|
|
|
|
|
|
|
|
|
106 |
def generate_artist_options(dataset, correct_artist):
|
107 |
# Generate 3 incorrect options
|
108 |
all_artists = list(set(dataset['train']['Artist']))
|
|
|
115 |
def generate_multiple_choice_check(options, correct_choice):
|
116 |
return {option: option == correct_choice for option in options}
|
117 |
|
|
|
|
|
|
|
|
|
|
|
118 |
with gr.Blocks(title="Song Generator Guessing Game") as game_interface:
|
119 |
+
gr.Markdown(" # Song Generator Guessing Game")
|
120 |
+
gr.Image("https://huggingface.co/spaces/SpartanCinder/NLP_Song_Generator_Guessing_Game/blob/main/RobotSinger.png")
|
121 |
state = gr.State({'options': []})
|
122 |
language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Model Selection", info="Select the language model to generate the song.")
|
123 |
generate_song_button = gr.Button("Generate Song")
|
124 |
generated_song = gr.Textbox(label="Generated Song")
|
125 |
artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options")
|
126 |
user_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.")
|
|
|
127 |
submit_answer_button = gr.Button("Submit Answer")
|
128 |
correct_answer = gr.Textbox(label="Results")
|
129 |
|
130 |
+
|
131 |
generate_song_button.click(
|
132 |
generate_song,
|
133 |
[state, language_model, generate_song_button],
|