jytole commited on
Commit
4468558
1 Parent(s): 43c55ef

Re-included multiple candidates to improve quality

Browse files
Files changed (1) hide show
  1. app.py +37 -9
app.py CHANGED
@@ -5,20 +5,27 @@ from diffusers import AudioLDMPipeline
5
 
6
  from transformers import AutoProcessor, ClapModel
7
 
8
- # replace with cuda code from AudioLDM's original app.py if using GPU
9
- device = "cpu"
10
- torch_dtype = torch.float32
 
 
 
 
 
11
 
12
  # load AudioLDM Diffuser Pipeline
13
  pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm-m-full", torch_dtype=torch_dtype).to(device)
14
  pipe.unet = torch.compile(pipe.unet)
15
 
16
- # omit CLAP model because we'll only generate one waveform, no scoring
 
 
17
 
18
  generator = torch.Generator(device)
19
 
20
- # modified from audioldm app.py to omit n_candidates
21
- def text2audio(text, negative_prompt, duration, guidance_scale, random_seed):
22
  if text is None:
23
  raise gr.Error("Please provide a text input.")
24
 
@@ -27,14 +34,27 @@ def text2audio(text, negative_prompt, duration, guidance_scale, random_seed):
27
  audio_length_in_s=duration,
28
  guidance_scale=guidance_scale,
29
  negative_prompt=negative_prompt,
30
- num_waveforms_per_prompt=1,
31
  generator=generator.manual_seed(int(random_seed)),
32
  )["audios"]
33
 
34
- waveform = waveforms[0]
 
 
 
35
 
36
  return gr.make_waveform((16000, waveform), bg_image="bg.png")
37
 
 
 
 
 
 
 
 
 
 
 
38
  # duplicate CSS config
39
 
40
  css = """
@@ -171,13 +191,21 @@ with iface:
171
  label="Guidance scale",
172
  info="Large => better quality and relevancy to text; Small => better diversity",
173
  )
 
 
 
 
 
 
 
 
174
 
175
  outputs = gr.Video(label="Output", elem_id="output-video")
176
  btn = gr.Button("Submit").style(full_width=True)
177
 
178
  btn.click(
179
  text2audio,
180
- inputs=[textbox, negative_textbox, duration, guidance_scale, seed],
181
  outputs=[outputs],
182
  )
183
 
 
5
 
6
  from transformers import AutoProcessor, ClapModel
7
 
8
+ # cuda code from AudioLDM's original app.py if using GPU
9
+ # allows support for CPU
10
+ if torch.cuda.is_available():
11
+ device = "cuda"
12
+ torch_dtype = torch.float16
13
+ else:
14
+ device = "cpu"
15
+ torch_dtype = torch.float32
16
 
17
  # load AudioLDM Diffuser Pipeline
18
  pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm-m-full", torch_dtype=torch_dtype).to(device)
19
  pipe.unet = torch.compile(pipe.unet)
20
 
21
+ # include CLAP model because it improves quality
22
+ clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device)
23
+ processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full")
24
 
25
  generator = torch.Generator(device)
26
 
27
+ # from audioldm app.py
28
+ def text2audio(text, negative_prompt, duration, guidance_scale, random_seed, n_candidates):
29
  if text is None:
30
  raise gr.Error("Please provide a text input.")
31
 
 
34
  audio_length_in_s=duration,
35
  guidance_scale=guidance_scale,
36
  negative_prompt=negative_prompt,
37
+ num_waveforms_per_prompt=n_candidates if n_candidates else 1,
38
  generator=generator.manual_seed(int(random_seed)),
39
  )["audios"]
40
 
41
+ if waveforms.shape[0] > 1:
42
+ waveform = score_waveforms(text, waveforms)
43
+ else:
44
+ waveform = waveforms[0]
45
 
46
  return gr.make_waveform((16000, waveform), bg_image="bg.png")
47
 
48
+ def score_waveforms(text, waveforms):
49
+ inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True)
50
+ inputs = {key: inputs[key].to(device) for key in inputs}
51
+ with torch.no_grad():
52
+ logits_per_text = clap_model(**inputs).logits_per_text # this is the audio-text similarity score
53
+ probs = logits_per_text.softmax(dim=-1) # we can take the softmax to get the label probabilities
54
+ most_probable = torch.argmax(probs) # and now select the most likely audio waveform
55
+ waveform = waveforms[most_probable]
56
+ return waveform
57
+
58
  # duplicate CSS config
59
 
60
  css = """
 
191
  label="Guidance scale",
192
  info="Large => better quality and relevancy to text; Small => better diversity",
193
  )
194
+ n_candidates = gr.Slider(
195
+ 1,
196
+ 3,
197
+ value=3,
198
+ step=1,
199
+ label="Number waveforms to generate",
200
+ info="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
201
+ )
202
 
203
  outputs = gr.Video(label="Output", elem_id="output-video")
204
  btn = gr.Button("Submit").style(full_width=True)
205
 
206
  btn.click(
207
  text2audio,
208
+ inputs=[textbox, negative_textbox, duration, guidance_scale, seed, n_candidates],
209
  outputs=[outputs],
210
  )
211