chomayouni commited on
Commit
c5e8d64
1 Parent(s): 6455751

Updated with initial logic

Browse files
Files changed (1) hide show
  1. sgg_app.py +65 -59
sgg_app.py CHANGED
@@ -1,65 +1,61 @@
1
  import gradio as gr
 
 
 
2
 
3
- # def game(difficulty, generate_song, artist_choice, submit_answer):
4
- # if generate_song:
5
- # # Generate the song and the options based on the difficulty
6
- # # In the actual implementation, you would generate the song text and the options based on the difficulty
7
- # if difficulty == "Demo":
8
- # song_text = "Generated song text for Demo"
9
- # options = ["Artist 1", "Artist 2", "Artist 3", "Artist 4"]
10
- # elif difficulty == "Medium":
11
- # song_text = "Generated song text for Medium"
12
- # options = ["Artist 5", "Artist 6", "Artist 7", "Artist 8"]
13
- # else: # Hard
14
- # song_text = "Generated song text for Hard"
15
- # options = ["Artist 9", "Artist 10", "Artist 11", "Artist 12"]
16
- # return {"Generated Song": song_text, "Options": options}
17
- # elif submit_answer:
18
- # # Check the selected artist and return whether it's correct
19
- # correct_answer = "Artist 1" # Placeholder
20
- # return {"Correct Answer": correct_answer == artist_choice}
21
 
22
- # game_interface = gr.Interface(
23
- # fn=game,
24
- # inputs=[
25
- # gr.Radio(["Demo", "Medium", "Hard"], label="Difficulty",
26
- # info="The higher the difficulty makes it so that the options for the artists are more similar to one another?"),
27
- # gr.Button("Generate Song"),
28
- # gr.Radio(["A", "B", "C", "D"], label="Multi-Choice Options",
29
- # info="Select the artist that you suspect is the correct artist for the song."),
30
- # gr.Button("Submit Answer"),
31
- # ],
32
-
33
- # outputs=[
34
- # gr.Textbox(label="Generated Song"),
35
- # gr.Textbox(label="Options"),
36
- # gr.Textbox(label="Correct Answer"),
37
- # ],
38
 
39
- # title="Song Generator Guessing Game",
40
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # game_interface.launch()
 
 
 
 
 
 
43
 
44
- def generate_song(state, difficulty, generate_song):
45
- if generate_song:
46
- if not difficulty:
47
- song_text = "Please select a difficulty level before generating a song."
48
- return state, song_text, ""
49
- # Generate the song and the options based on the difficulty
50
- if difficulty == "Demo":
51
- song_text = "Generated song text for Demo"
52
- options = ["Artist 1", "Artist 2", "Artist 3", "Artist 4"]
53
- elif difficulty == "Medium":
54
- song_text = "Generated song text for Medium"
55
- options = ["Artist 5", "Artist 6", "Artist 7", "Artist 8"]
56
- else: # Hard
57
- song_text = "Generated song text for Hard"
58
- options = ["Artist 9", "Artist 10", "Artist 11", "Artist 12"]
59
- state['options'] = options
60
- state['timer_finished'] = False
61
- timer_script = "<div id='progress-bar' style='width: 100%; background-color: #f3f3f3; border: 1px solid #bbb;'><div id='progress' style='height: 20px; width: 0%; background-color: #007bff;'></div></div><script>function startTimer() {var time = 30; var timer = setInterval(function() {time--; document.getElementById('progress').style.width = (time / 30 * 100) + '%'; if (time <= 0) {clearInterval(timer);}}, 1000);}</script>"
62
- return state, song_text, ', '.join(options), timer_script
 
 
 
 
 
63
 
64
  def submit_answer(state, artist_choice, submit_answer):
65
  if submit_answer:
@@ -69,10 +65,20 @@ def submit_answer(state, artist_choice, submit_answer):
69
  # Check the selected artist and return whether it's correct
70
  correct_answer = state['options'][0] # Placeholder
71
  return {"Correct Answer": correct_answer == artist_choice}
 
 
 
 
 
 
 
 
 
 
72
 
73
- with gr.Blocks(title="Song Genorator Guessing Game") as game_interface:
74
  state = gr.State({'options': []})
75
- difficulty = gr.Radio(["Demo", "Medium", "Hard"], label="Difficulty")
76
  generate_song_button = gr.Button("Generate Song")
77
  artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options")
78
  artist_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.")
@@ -83,7 +89,7 @@ with gr.Blocks(title="Song Genorator Guessing Game") as game_interface:
83
 
84
  generate_song_button.click(
85
  generate_song,
86
- [state, difficulty, generate_song_button],
87
  [state, generated_song, artist_choice_display, timer]
88
  )
89
  submit_answer_button.click(
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from transformers import TrainingArguments, Trainer
5
 
6
+ def generate_song(state, language_model, generate_song):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Logic that takes put the selected language model and generates a song
11
+ if generate_song:
12
+ if not language_model:
13
+ song_text = "Please select a language model before generating a song."
14
+ return state, song_text, "", ""
15
+ # Generate the song and the options based on the language_model
16
+ if language_model == "Custom Gpt2":
17
+ model_name = "SpartanCinder/GPT2-pretrained-lyric-generation"
18
+ elif language_model == "Gpt2-Medium":
19
+ model_name = "gpt2-medium"
20
+ elif language_model == "facebook/bart-base":
21
+ model_name = "facebook/bart-base"
22
+ elif language_model == "Gpt-Neo":
23
+ model_name = "EleutherAI/gpt-neo-1.3B"
24
+ else: # Customized Models
25
+ model_name = "customized-models"
26
 
27
+ #tokenzer and text generation logic
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ model = AutoModelForCausalLM.from_pretrained(model_name)
30
+ input_text = pick_artist()
31
+ max_length = 128
32
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
33
+ input_ids = input_ids.to(device)
34
 
35
+ if language_model != "customized-models":
36
+ ### Using Beam search to generate text###
37
+ # encoded data
38
+ output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
39
+ # Decode output
40
+ print(tokenizer.decode(output[0], skip_special_tokens=True))
41
+ # But this output is repeating, so I need ot adjust this so that it is not repeating.
42
+ else:
43
+ ### Nucleas Sampling to generate text###
44
+ # Set the do_sample parameter to True because we are using nucleus sampling is a probabilistic sampling method
45
+ # top_p is the probability threshold for nucleus sampling
46
+ # So, we set top_p to 0.9, which means that the model will sample from the top 90% of the probability distribution
47
+ # This will help to generate more diverse text that is less repetitive
48
+ encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.9, )
49
+
50
+ song_text = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
51
+
52
+ # Generate the multiple-choice options
53
+ options = ["Artist 1", "Artist 2", "Artist 3", "Artist 4"]
54
+
55
+ state['options'] = options
56
+ state['timer_finished'] = False
57
+ timer_script = "<div id='progress-bar' style='width: 100%; background-color: #f3f3f3; border: 1px solid #bbb;'><div id='progress' style='height: 20px; width: 0%; background-color: #007bff;'></div></div><script>function startTimer() {var time = 30; var timer = setInterval(function() {time--; document.getElementById('progress').style.width = (time / 30 * 100) + '%'; if (time <= 0) {clearInterval(timer);}}, 1000);}</script>"
58
+ return state, song_text, ', '.join(options), timer_script
59
 
60
  def submit_answer(state, artist_choice, submit_answer):
61
  if submit_answer:
 
65
  # Check the selected artist and return whether it's correct
66
  correct_answer = state['options'][0] # Placeholder
67
  return {"Correct Answer": correct_answer == artist_choice}
68
+
69
+ def pick_artist():
70
+
71
+ return "A song in the style of Taylor Swift:"
72
+
73
+ def generate_artist_options(correct_artist):
74
+ # Generate 3 incorrect options
75
+ options = ["Artist 1", "Artist 2", "Artist 3", "Artist 4"]
76
+ options.remove(correct_artist)
77
+ return [correct_artist] + options
78
 
79
+ with gr.Blocks(title="Song Generator Guessing Game") as game_interface:
80
  state = gr.State({'options': []})
81
+ language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Difficulty")
82
  generate_song_button = gr.Button("Generate Song")
83
  artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options")
84
  artist_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.")
 
89
 
90
  generate_song_button.click(
91
  generate_song,
92
+ [state, language_model, generate_song_button],
93
  [state, generated_song, artist_choice_display, timer]
94
  )
95
  submit_answer_button.click(