chomayouni commited on
Commit
1719da7
1 Parent(s): 394bbaa

The v1.2 commit

Browse files
Files changed (1) hide show
  1. sgg_app.py +5 -34
sgg_app.py CHANGED
@@ -7,10 +7,6 @@ import random
7
 
8
  # Load the dataset
9
  dataset = load_dataset("SpartanCinder/song-lyrics-artist-classifier")
10
- # print(dataset.column_names)
11
- # print(dataset['train']['Artist'])
12
- # artist_list = list(set(dataset['train']['Artist']))
13
- # print(artist_list)
14
 
15
  def generate_song(state, language_model, generate_song):
16
 
@@ -54,8 +50,6 @@ def generate_song(state, language_model, generate_song):
54
  print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
55
  # But this output is repeating, so I need ot adjust this so that it is not repeating.
56
  elif language_model == "Custom Gpt2":
57
- # tokenizer = AutoTokenizer.from_pretrained("SpartanCinder/GPT2-pretrained-lyric-generation")
58
- # model = AutoModelForCausalLM.from_pretrained("SpartanCinder/GPT2-pretrained-lyric-generation")
59
  # encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
60
  encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.95, )
61
  # Decode output
@@ -71,9 +65,8 @@ def generate_song(state, language_model, generate_song):
71
  # Decode output
72
  output = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
73
  # Remove the first line of the output if it contains newline characters
74
- # if '\n' in output:
75
- # output = '\n'.join(output.split('\n')[1:])
76
- # formatted_output = output.split('\n')[0] # might have to remove this line
77
  song_text = output
78
 
79
  # Generate the multiple-choice options
@@ -88,21 +81,6 @@ def generate_song(state, language_model, generate_song):
88
  return state, song_text, ', '.join(options)
89
 
90
  #Check the selected artist and return whether it's correct
91
- # def on_submit_answer(state, correct_choice, user_choice, submit_answer):
92
- # if submit_answer:
93
- # if not user_choice:
94
- # return {"Error": "Please select an artist before submitting an answer."}
95
- # # Check if 'correct_choice' is in the state keys
96
- # if 'correct_choice' in state:
97
- # correct_answer = state['correct_choice']
98
- # if correct_answer == user_choice:
99
- # return {"Result": f"You guessed the right artist: {correct_choice}"}
100
- # else:
101
- # return {"Result": f"You selected {user_choice}, but the correct answer is {correct_choice}"}
102
- # else:
103
- # print("The 'correct_choice' key does not exist in the state.")
104
- # return None
105
-
106
  def on_submit_answer(state, user_choice):
107
  # Map the user's choice (A, B, C, or D) to an index
108
  choice_to_index = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
@@ -125,10 +103,6 @@ def pick_artist(dataset):
125
  artist_choice = random.choice(artist_choice)
126
  return artist_choice
127
 
128
- # print("The 'Artist' column does not exist in the dataset.")
129
- # artist_choice = "Green Day"
130
- # return artist_choice
131
-
132
  def generate_artist_options(dataset, correct_artist):
133
  # Generate 3 incorrect options
134
  all_artists = list(set(dataset['train']['Artist']))
@@ -141,22 +115,19 @@ def generate_artist_options(dataset, correct_artist):
141
  def generate_multiple_choice_check(options, correct_choice):
142
  return {option: option == correct_choice for option in options}
143
 
144
- def check_correct_choice(user_choice, correct_choice):
145
- if user_choice == correct_choice:
146
- return True
147
- return user_choice == correct_choice
148
-
149
  with gr.Blocks(title="Song Generator Guessing Game") as game_interface:
 
 
150
  state = gr.State({'options': []})
151
  language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Model Selection", info="Select the language model to generate the song.")
152
  generate_song_button = gr.Button("Generate Song")
153
  generated_song = gr.Textbox(label="Generated Song")
154
  artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options")
155
  user_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.")
156
- # timer = gr.HTML("<div id='progress-bar' style='width: 100%; background-color: #f3f3f3; border: 1px solid #bbb;'><div id='progress' style='height: 20px; width: 0%; background-color: #007bff;'></div></div><script>function startTimer() {var time = 30; var timer = setInterval(function() {time--; document.getElementById('progress').style.width = (time / 30 * 100) + '%'; if (time <= 0) {clearInterval(timer);}}, 1000);}</script>", label="Timer")
157
  submit_answer_button = gr.Button("Submit Answer")
158
  correct_answer = gr.Textbox(label="Results")
159
 
 
160
  generate_song_button.click(
161
  generate_song,
162
  [state, language_model, generate_song_button],
 
7
 
8
  # Load the dataset
9
  dataset = load_dataset("SpartanCinder/song-lyrics-artist-classifier")
 
 
 
 
10
 
11
  def generate_song(state, language_model, generate_song):
12
 
 
50
  print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
51
  # But this output is repeating, so I need ot adjust this so that it is not repeating.
52
  elif language_model == "Custom Gpt2":
 
 
53
  # encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
54
  encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.95, )
55
  # Decode output
 
65
  # Decode output
66
  output = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
67
  # Remove the first line of the output if it contains newline characters
68
+ if '\n' in output:
69
+ output = '\n'.join(output.split('\n')[1:])
 
70
  song_text = output
71
 
72
  # Generate the multiple-choice options
 
81
  return state, song_text, ', '.join(options)
82
 
83
  #Check the selected artist and return whether it's correct
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def on_submit_answer(state, user_choice):
85
  # Map the user's choice (A, B, C, or D) to an index
86
  choice_to_index = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
 
103
  artist_choice = random.choice(artist_choice)
104
  return artist_choice
105
 
 
 
 
 
106
  def generate_artist_options(dataset, correct_artist):
107
  # Generate 3 incorrect options
108
  all_artists = list(set(dataset['train']['Artist']))
 
115
  def generate_multiple_choice_check(options, correct_choice):
116
  return {option: option == correct_choice for option in options}
117
 
 
 
 
 
 
118
  with gr.Blocks(title="Song Generator Guessing Game") as game_interface:
119
+ gr.Markdown(" # Song Generator Guessing Game")
120
+ gr.Image("https://huggingface.co/spaces/SpartanCinder/NLP_Song_Generator_Guessing_Game/blob/main/RobotSinger.png")
121
  state = gr.State({'options': []})
122
  language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Model Selection", info="Select the language model to generate the song.")
123
  generate_song_button = gr.Button("Generate Song")
124
  generated_song = gr.Textbox(label="Generated Song")
125
  artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options")
126
  user_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.")
 
127
  submit_answer_button = gr.Button("Submit Answer")
128
  correct_answer = gr.Textbox(label="Results")
129
 
130
+
131
  generate_song_button.click(
132
  generate_song,
133
  [state, language_model, generate_song_button],