fffiloni commited on
Commit
314205d
·
verified ·
1 Parent(s): c4c3648

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -46
app.py CHANGED
@@ -5,23 +5,19 @@ import wavio
5
  import numpy as np
6
  from tqdm import tqdm
7
  from huggingface_hub import snapshot_download
8
-
9
  from audioldm.audio.stft import TacotronSTFT
10
  from audioldm.variational_autoencoder import AutoencoderKL
11
-
12
  from transformers import AutoTokenizer, T5ForConditionalGeneration
13
  from modelling_deberta_v2 import DebertaV2ForTokenClassificationRegression
14
-
15
  import sys
16
- sys.path.insert(0, "diffusers/src")
17
 
 
18
  from diffusers import DDPMScheduler
19
  from models import MusicAudioDiffusion
20
-
21
  from gradio import Markdown
22
-
23
  import spaces
24
 
 
25
  # Automatic device detection
26
  if torch.cuda.is_available():
27
  device_type = "cuda"
@@ -29,7 +25,8 @@ if torch.cuda.is_available():
29
  else:
30
  device_type = "cpu"
31
  device_selection = "cpu"
32
-
 
33
  class MusicFeaturePredictor:
34
  def __init__(self, path, device=device_selection, cache_dir=None, local_files_only=False):
35
  self.beats_tokenizer = AutoTokenizer.from_pretrained(
@@ -69,7 +66,11 @@ class MusicFeaturePredictor:
69
 
70
  def generate_beats(self, prompt):
71
  tokenized = self.beats_tokenizer(
72
- prompt, max_length=512, padding=True, truncation=True, return_tensors="pt"
 
 
 
 
73
  )
74
  tokenized = {k: v.to(self.beats_model.device) for k, v in tokenized.items()}
75
 
@@ -79,6 +80,7 @@ class MusicFeaturePredictor:
79
  max_beat = (
80
  1 + torch.argmax(out["logits"][:, 0, :], -1).detach().cpu().numpy()
81
  ).tolist()[0]
 
82
  intervals = (
83
  out["values"][:, :, 0]
84
  .detach()
@@ -88,14 +90,15 @@ class MusicFeaturePredictor:
88
  .round(4)
89
  .tolist()
90
  )
91
-
92
  intervals = np.cumsum(intervals)
 
93
  predicted_beats_times = []
94
  for t in intervals:
95
  if t < 10:
96
  predicted_beats_times.append(round(t, 2))
97
  else:
98
  break
 
99
  predicted_beats_times = list(np.array(predicted_beats_times)[:50])
100
 
101
  if len(predicted_beats_times) == 0:
@@ -111,7 +114,7 @@ class MusicFeaturePredictor:
111
  def generate(self, prompt):
112
  max_beat, predicted_beats_times, predicted_beats = self.generate_beats(prompt)
113
 
114
- chords_prompt = "Caption: {} \\n Timestamps: {} \\n Max Beat: {}".format(
115
  prompt,
116
  " , ".join([str(round(t, 2)) for t in predicted_beats_times]),
117
  max_beat,
@@ -162,7 +165,10 @@ class Mustango:
162
  path = snapshot_download(repo_id=name, cache_dir=cache_dir)
163
 
164
  self.music_model = MusicFeaturePredictor(
165
- path, device, cache_dir=cache_dir, local_files_only=local_files_only
 
 
 
166
  )
167
 
168
  vae_config = json.load(open(f"{path}/configs/vae_config.json"))
@@ -170,7 +176,6 @@ class Mustango:
170
  main_config = json.load(open(f"{path}/configs/main_config.json"))
171
 
172
  ALT_SD21 = "sd2-community/stable-diffusion-2-1" # <-- mets TON repo alternatif exact
173
-
174
  if main_config.get("scheduler_name") == "stabilityai/stable-diffusion-2-1":
175
  main_config["scheduler_name"] = ALT_SD21
176
 
@@ -181,16 +186,20 @@ class Mustango:
181
  main_config["scheduler_name"],
182
  unet_model_config_path=f"{path}/configs/music_diffusion_model_config.json",
183
  ).to(device)
 
184
  # self.model.device = device
185
 
186
  vae_weights = torch.load(
187
- f"{path}/vae/pytorch_model_vae.bin", map_location=device
 
188
  )
189
  stft_weights = torch.load(
190
- f"{path}/stft/pytorch_model_stft.bin", map_location=device
 
191
  )
192
  main_weights = torch.load(
193
- f"{path}/ldm/pytorch_model_ldm.bin", map_location=device
 
194
  )
195
 
196
  self.vae.load_state_dict(vae_weights)
@@ -204,14 +213,15 @@ class Mustango:
204
  self.model.eval()
205
 
206
  self.scheduler = DDPMScheduler.from_pretrained(
207
- main_config["scheduler_name"], subfolder="scheduler"
 
208
  )
209
 
210
  def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
211
  """Genrate music for a single prompt string."""
212
-
213
  with torch.no_grad():
214
  beats, chords, chords_times = self.music_model.generate(prompt)
 
215
  latents = self.model.inference(
216
  [prompt],
217
  beats,
@@ -223,10 +233,11 @@ class Mustango:
223
  samples,
224
  disable_progress,
225
  )
 
226
  mel = self.vae.decode_first_stage(latents)
227
  wave = self.vae.decode_to_waveform(mel)
228
 
229
- return wave[0]
230
 
231
 
232
  # Initialize Mustango
@@ -236,36 +247,52 @@ mustango.stft.to(device_type)
236
  mustango.model.to(device_type)
237
  mustango.music_model.beats_model.to(device_type)
238
  mustango.music_model.chords_model.to(device_type)
239
-
240
  # if torch.cuda.is_available():
241
  # mustango = Mustango(device=device_selection)
242
  # else:
243
  # mustango = Mustango(device="CPU")
244
 
245
  # mustango = Mustango(device=device_selection)
246
-
247
  mustango.model.device = device_selection
248
 
249
-
250
  # output_wave = mustango.generate("This techno song features a synth lead playing the main melody.", 5, 3, disable_progress=False)
 
 
251
  @spaces.GPU(duration=120)
252
- def gradio_generate(prompt, steps, guidance):
253
- output_wave = mustango.generate(prompt, steps, guidance)
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  # output_filename = f"{prompt.replace(' ', '_')}_{steps}_{guidance}"[:250] + ".wav"
255
  output_filename = "temp.wav"
256
  wavio.write(output_filename, output_wave, rate=16000, sampwidth=2)
257
-
258
  return output_filename
259
 
260
 
261
- title="Mustango: Toward Controllable Text-to-Music Generation"
 
262
  description_text = """
263
- <p><a href="https://huggingface.co/spaces/declare-lab/mustango/blob/main/app.py?duplicate=true"> <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> For faster inference without waiting in queue, you may duplicate the space and upgrade to a GPU in the settings. <br/><br/>
 
264
  Generate music using Mustango by providing a text prompt.
265
- <br/><br/> This is the demo for Mustango for controllable text to music generation: <a href="https://arxiv.org/abs/2311.08355">Read our paper.</a>
266
- <p/>
267
  """
268
- #description_text = ""
 
 
269
  # Gradio input and output components
270
  input_text = gr.Textbox(lines=2, label="Prompt")
271
  output_audio = gr.Audio(label="Generated Music", type="filepath")
@@ -273,37 +300,69 @@ denoising_steps = gr.Slider(minimum=100, maximum=200, value=100, step=1, label="
273
  guidance_scale = gr.Slider(minimum=1, maximum=10, value=3, step=0.1, label="Guidance Scale", interactive=True)
274
 
275
  # CSS styling for the Duplicate button
276
- css = '''
277
  #duplicate-button {
278
- margin: auto;
279
- color: white;
280
- background: #1565c0;
281
- border-radius: 100vh;
282
  }
283
- '''
284
 
285
  # Gradio interface
286
  gr_interface = gr.Interface(
287
- fn=gradio_generate,
288
  inputs=[input_text, denoising_steps, guidance_scale],
289
  outputs=[output_audio],
290
  description=description_text,
291
  examples=[
292
- ["This techno song features a synth lead playing the main melody. This is accompanied by programmed percussion playing a simple kick focused beat. The hi-hat is accented in an open position on the 3-and count of every bar. The synth plays the bass part with a voicing that sounds like a cello. This techno song can be played in a club. The chord sequence is Gm, A7, Eb, Bb, C, F, Gm. The beat counts to 2. The tempo of this song is 128.0 beats per minute. The key of this song is G minor.", 200, 3],
293
- ["This is a new age piece. There is a flute playing the main melody with a lot of staccato notes. The rhythmic background consists of a medium tempo electronic drum beat with percussive elements all over the spectrum. There is a playful atmosphere to the piece. This piece can be used in the soundtrack of a children's TV show or an advertisement jingle.", 200, 3],
294
- ["The song is an instrumental. The song is in medium tempo with a classical guitar playing a lilting melody in accompaniment style. The song is emotional and romantic. The song is a romantic instrumental song. The chord sequence is Gm, F6, Ebm. The time signature is 4/4. This song is in Adagio. The key of this song is G minor.", 200, 3],
295
- ["This folk song features a female voice singing the main melody. This is accompanied by a tabla playing the percussion. A guitar strums chords. For most parts of the song, only one chord is played. At the last bar, a different chord is played. This song has minimal instruments. This song has a story-telling mood. This song can be played in a village scene in an Indian movie. The chord sequence is Bbm, Ab. The beat is 3. The tempo of this song is Allegro. The key of this song is Bb minor.", 200, 3],
296
- ["This is a live performance of a classical music piece. There is an orchestra performing the piece with a violin lead playing the main melody. The atmosphere is sentimental and heart-touching. This piece could be playing in the background at a classy restaurant. The chord progression in this song is Am7, Gm, Dm, A7, Dm. The beat is 3. This song is in Largo. The key of this song is D minor.", 200, 3],
297
- ["This is a techno piece with drums and beats and a leading melody. A synth plays chords. The music kicks off with a powerful and relentless drumbeat. Over the pounding beats, a leading melody emerges. In the middle of the song, a flock of seagulls flies over the venue and make loud bird sounds. It has strong danceability and can be played in a club. The tempo is 120 bpm. The chords played by the synth are Am, Cm, Dm, Gm.", 200, 3],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  ],
299
  cache_examples=False,
300
  )
301
 
302
- with gr.Blocks(css=css) as demo:
303
- title=gr.HTML(f"<h1><center>{title}</center></h1>")
 
 
304
  dupe = gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
305
  gr_interface.render()
306
-
307
 
308
  # Launch Gradio app.
309
- demo.queue().launch(ssr_mode=False)
 
5
  import numpy as np
6
  from tqdm import tqdm
7
  from huggingface_hub import snapshot_download
 
8
  from audioldm.audio.stft import TacotronSTFT
9
  from audioldm.variational_autoencoder import AutoencoderKL
 
10
  from transformers import AutoTokenizer, T5ForConditionalGeneration
11
  from modelling_deberta_v2 import DebertaV2ForTokenClassificationRegression
 
12
  import sys
 
13
 
14
+ sys.path.insert(0, "diffusers/src")
15
  from diffusers import DDPMScheduler
16
  from models import MusicAudioDiffusion
 
17
  from gradio import Markdown
 
18
  import spaces
19
 
20
+
21
  # Automatic device detection
22
  if torch.cuda.is_available():
23
  device_type = "cuda"
 
25
  else:
26
  device_type = "cpu"
27
  device_selection = "cpu"
28
+
29
+
30
  class MusicFeaturePredictor:
31
  def __init__(self, path, device=device_selection, cache_dir=None, local_files_only=False):
32
  self.beats_tokenizer = AutoTokenizer.from_pretrained(
 
66
 
67
  def generate_beats(self, prompt):
68
  tokenized = self.beats_tokenizer(
69
+ prompt,
70
+ max_length=512,
71
+ padding=True,
72
+ truncation=True,
73
+ return_tensors="pt",
74
  )
75
  tokenized = {k: v.to(self.beats_model.device) for k, v in tokenized.items()}
76
 
 
80
  max_beat = (
81
  1 + torch.argmax(out["logits"][:, 0, :], -1).detach().cpu().numpy()
82
  ).tolist()[0]
83
+
84
  intervals = (
85
  out["values"][:, :, 0]
86
  .detach()
 
90
  .round(4)
91
  .tolist()
92
  )
 
93
  intervals = np.cumsum(intervals)
94
+
95
  predicted_beats_times = []
96
  for t in intervals:
97
  if t < 10:
98
  predicted_beats_times.append(round(t, 2))
99
  else:
100
  break
101
+
102
  predicted_beats_times = list(np.array(predicted_beats_times)[:50])
103
 
104
  if len(predicted_beats_times) == 0:
 
114
  def generate(self, prompt):
115
  max_beat, predicted_beats_times, predicted_beats = self.generate_beats(prompt)
116
 
117
+ chords_prompt = "Caption: {} \n Timestamps: {} \n Max Beat: {}".format(
118
  prompt,
119
  " , ".join([str(round(t, 2)) for t in predicted_beats_times]),
120
  max_beat,
 
165
  path = snapshot_download(repo_id=name, cache_dir=cache_dir)
166
 
167
  self.music_model = MusicFeaturePredictor(
168
+ path,
169
+ device,
170
+ cache_dir=cache_dir,
171
+ local_files_only=local_files_only,
172
  )
173
 
174
  vae_config = json.load(open(f"{path}/configs/vae_config.json"))
 
176
  main_config = json.load(open(f"{path}/configs/main_config.json"))
177
 
178
  ALT_SD21 = "sd2-community/stable-diffusion-2-1" # <-- mets TON repo alternatif exact
 
179
  if main_config.get("scheduler_name") == "stabilityai/stable-diffusion-2-1":
180
  main_config["scheduler_name"] = ALT_SD21
181
 
 
186
  main_config["scheduler_name"],
187
  unet_model_config_path=f"{path}/configs/music_diffusion_model_config.json",
188
  ).to(device)
189
+
190
  # self.model.device = device
191
 
192
  vae_weights = torch.load(
193
+ f"{path}/vae/pytorch_model_vae.bin",
194
+ map_location=device,
195
  )
196
  stft_weights = torch.load(
197
+ f"{path}/stft/pytorch_model_stft.bin",
198
+ map_location=device,
199
  )
200
  main_weights = torch.load(
201
+ f"{path}/ldm/pytorch_model_ldm.bin",
202
+ map_location=device,
203
  )
204
 
205
  self.vae.load_state_dict(vae_weights)
 
213
  self.model.eval()
214
 
215
  self.scheduler = DDPMScheduler.from_pretrained(
216
+ main_config["scheduler_name"],
217
+ subfolder="scheduler",
218
  )
219
 
220
  def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
221
  """Genrate music for a single prompt string."""
 
222
  with torch.no_grad():
223
  beats, chords, chords_times = self.music_model.generate(prompt)
224
+
225
  latents = self.model.inference(
226
  [prompt],
227
  beats,
 
233
  samples,
234
  disable_progress,
235
  )
236
+
237
  mel = self.vae.decode_first_stage(latents)
238
  wave = self.vae.decode_to_waveform(mel)
239
 
240
+ return wave[0]
241
 
242
 
243
  # Initialize Mustango
 
247
  mustango.model.to(device_type)
248
  mustango.music_model.beats_model.to(device_type)
249
  mustango.music_model.chords_model.to(device_type)
250
+
251
  # if torch.cuda.is_available():
252
  # mustango = Mustango(device=device_selection)
253
  # else:
254
  # mustango = Mustango(device="CPU")
255
 
256
  # mustango = Mustango(device=device_selection)
 
257
  mustango.model.device = device_selection
258
 
 
259
  # output_wave = mustango.generate("This techno song features a synth lead playing the main melody.", 5, 3, disable_progress=False)
260
+
261
+
262
  @spaces.GPU(duration=120)
263
+ def generate_music_from_prompt(music_prompt: str, denoising_steps: int, guidance_scale: float) -> str:
264
+ """
265
+ Generate a short music audio file from a text prompt using Mustango.
266
+
267
+ Use this tool when a user wants to create controllable text-to-music audio from a detailed music description.
268
+
269
+ Args:
270
+ music_prompt (str): Text description of the desired music, including instruments, mood, tempo, key, chords, or scene.
271
+ denoising_steps (int): Number of diffusion denoising steps to use for generation.
272
+ guidance_scale (float): Classifier-free guidance scale controlling prompt adherence.
273
+
274
+ Returns:
275
+ str: Filepath to the generated WAV audio file.
276
+ """
277
+ output_wave = mustango.generate(music_prompt, denoising_steps, guidance_scale)
278
  # output_filename = f"{prompt.replace(' ', '_')}_{steps}_{guidance}"[:250] + ".wav"
279
  output_filename = "temp.wav"
280
  wavio.write(output_filename, output_wave, rate=16000, sampwidth=2)
 
281
  return output_filename
282
 
283
 
284
+ title = "Mustango: Toward Controllable Text-to-Music Generation"
285
+
286
  description_text = """
287
+ For faster inference without waiting in queue, you may duplicate the space and upgrade to a GPU in the settings.
288
+
289
  Generate music using Mustango by providing a text prompt.
290
+
291
+ This is the demo for Mustango for controllable text to music generation: [Read our paper.](https://arxiv.org/abs/2311.08355)
292
  """
293
+
294
+ # description_text = ""
295
+
296
  # Gradio input and output components
297
  input_text = gr.Textbox(lines=2, label="Prompt")
298
  output_audio = gr.Audio(label="Generated Music", type="filepath")
 
300
  guidance_scale = gr.Slider(minimum=1, maximum=10, value=3, step=0.1, label="Guidance Scale", interactive=True)
301
 
302
  # CSS styling for the Duplicate button
303
+ css = """
304
  #duplicate-button {
305
+ margin: auto;
306
+ color: white;
307
+ background: #1565c0;
308
+ border-radius: 100vh;
309
  }
310
+ """
311
 
312
  # Gradio interface
313
  gr_interface = gr.Interface(
314
+ fn=generate_music_from_prompt,
315
  inputs=[input_text, denoising_steps, guidance_scale],
316
  outputs=[output_audio],
317
  description=description_text,
318
  examples=[
319
+ [
320
+ """This techno song features a synth lead playing the main melody.
321
+ This is accompanied by programmed percussion playing a simple kick focused beat. The hi-hat is accented in an open position on the 3-and count of every bar. The synth plays the bass part with a voicing that sounds like a cello. This techno song can be played in a club. The chord sequence is Gm, A7, Eb, Bb, C, F, Gm. The beat counts to 2. The tempo of this song is 128.0 beats per minute. The key of this song is G minor.""",
322
+ 200,
323
+ 3,
324
+ ],
325
+ [
326
+ """This is a new age piece.
327
+ There is a flute playing the main melody with a lot of staccato notes. The rhythmic background consists of a medium tempo electronic drum beat with percussive elements all over the spectrum. There is a playful atmosphere to the piece. This piece can be used in the soundtrack of a children's TV show or an advertisement jingle.""",
328
+ 200,
329
+ 3,
330
+ ],
331
+ [
332
+ """The song is an instrumental. The song is in medium tempo with a classical guitar playing a lilting melody in accompaniment style. The song is emotional and romantic.
333
+ The song is a romantic instrumental song. The chord sequence is Gm, F6, Ebm. The time signature is 4/4. This song is in Adagio. The key of this song is G minor.""",
334
+ 200,
335
+ 3,
336
+ ],
337
+ [
338
+ """This folk song features a female voice singing the main melody. This is accompanied by a tabla playing the percussion. A guitar strums chords. For most parts of the song, only one chord is played. At the last bar, a different chord is played. This song has minimal instruments. This song has a story-telling mood.
339
+ This song can be played in a village scene in an Indian movie. The chord sequence is Bbm, Ab. The beat is 3. The tempo of this song is Allegro. The key of this song is Bb minor.""",
340
+ 200,
341
+ 3,
342
+ ],
343
+ [
344
+ """This is a live performance of a classical music piece. There is an orchestra performing the piece with a violin lead playing the main melody. The atmosphere is sentimental and heart-touching. This piece could be playing in the background at a classy restaurant. The chord progression in this song is Am7, Gm, Dm, A7, Dm.
345
+ The beat is 3. This song is in Largo. The key of this song is D minor.""",
346
+ 200,
347
+ 3,
348
+ ],
349
+ [
350
+ """This is a techno piece with drums and beats and a leading melody. A synth plays chords. The music kicks off with a powerful and relentless drumbeat. Over the pounding beats, a leading melody emerges. In the middle of the song, a flock of seagulls flies over the venue and make loud bird sounds. It has strong danceability and can be played in a club. The tempo is 120 bpm.
351
+ The chords played by the synth are Am, Cm, Dm, Gm.""",
352
+ 200,
353
+ 3,
354
+ ],
355
  ],
356
  cache_examples=False,
357
  )
358
 
359
+ with gr.Blocks() as demo:
360
+ title = gr.HTML(f"""
361
+ # {title}
362
+ """)
363
  dupe = gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
364
  gr_interface.render()
365
+
366
 
367
  # Launch Gradio app.
368
+ demo.queue().launch(css=css, ssr_mode=False, mcp_server=True)