Re-included multiple candidates to improve quality
Browse files
app.py
CHANGED
@@ -5,20 +5,27 @@ from diffusers import AudioLDMPipeline
|
|
5 |
|
6 |
from transformers import AutoProcessor, ClapModel
|
7 |
|
8 |
-
#
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# load AudioLDM Diffuser Pipeline
|
13 |
pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm-m-full", torch_dtype=torch_dtype).to(device)
|
14 |
pipe.unet = torch.compile(pipe.unet)
|
15 |
|
16 |
-
#
|
|
|
|
|
17 |
|
18 |
generator = torch.Generator(device)
|
19 |
|
20 |
-
#
|
21 |
-
def text2audio(text, negative_prompt, duration, guidance_scale, random_seed):
|
22 |
if text is None:
|
23 |
raise gr.Error("Please provide a text input.")
|
24 |
|
@@ -27,14 +34,27 @@ def text2audio(text, negative_prompt, duration, guidance_scale, random_seed):
|
|
27 |
audio_length_in_s=duration,
|
28 |
guidance_scale=guidance_scale,
|
29 |
negative_prompt=negative_prompt,
|
30 |
-
num_waveforms_per_prompt=1,
|
31 |
generator=generator.manual_seed(int(random_seed)),
|
32 |
)["audios"]
|
33 |
|
34 |
-
|
|
|
|
|
|
|
35 |
|
36 |
return gr.make_waveform((16000, waveform), bg_image="bg.png")
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
# duplicate CSS config
|
39 |
|
40 |
css = """
|
@@ -171,13 +191,21 @@ with iface:
|
|
171 |
label="Guidance scale",
|
172 |
info="Large => better quality and relevancy to text; Small => better diversity",
|
173 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
outputs = gr.Video(label="Output", elem_id="output-video")
|
176 |
btn = gr.Button("Submit").style(full_width=True)
|
177 |
|
178 |
btn.click(
|
179 |
text2audio,
|
180 |
-
inputs=[textbox, negative_textbox, duration, guidance_scale, seed],
|
181 |
outputs=[outputs],
|
182 |
)
|
183 |
|
|
|
5 |
|
6 |
from transformers import AutoProcessor, ClapModel
|
7 |
|
8 |
+
# cuda code from AudioLDM's original app.py if using GPU
|
9 |
+
# allows support for CPU
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
device = "cuda"
|
12 |
+
torch_dtype = torch.float16
|
13 |
+
else:
|
14 |
+
device = "cpu"
|
15 |
+
torch_dtype = torch.float32
|
16 |
|
17 |
# load AudioLDM Diffuser Pipeline
|
18 |
pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm-m-full", torch_dtype=torch_dtype).to(device)
|
19 |
pipe.unet = torch.compile(pipe.unet)
|
20 |
|
21 |
+
# include CLAP model because it improves quality
|
22 |
+
clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device)
|
23 |
+
processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full")
|
24 |
|
25 |
generator = torch.Generator(device)
|
26 |
|
27 |
+
# from audioldm app.py
|
28 |
+
def text2audio(text, negative_prompt, duration, guidance_scale, random_seed, n_candidates):
|
29 |
if text is None:
|
30 |
raise gr.Error("Please provide a text input.")
|
31 |
|
|
|
34 |
audio_length_in_s=duration,
|
35 |
guidance_scale=guidance_scale,
|
36 |
negative_prompt=negative_prompt,
|
37 |
+
num_waveforms_per_prompt=n_candidates if n_candidates else 1,
|
38 |
generator=generator.manual_seed(int(random_seed)),
|
39 |
)["audios"]
|
40 |
|
41 |
+
if waveforms.shape[0] > 1:
|
42 |
+
waveform = score_waveforms(text, waveforms)
|
43 |
+
else:
|
44 |
+
waveform = waveforms[0]
|
45 |
|
46 |
return gr.make_waveform((16000, waveform), bg_image="bg.png")
|
47 |
|
48 |
+
def score_waveforms(text, waveforms):
|
49 |
+
inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True)
|
50 |
+
inputs = {key: inputs[key].to(device) for key in inputs}
|
51 |
+
with torch.no_grad():
|
52 |
+
logits_per_text = clap_model(**inputs).logits_per_text # this is the audio-text similarity score
|
53 |
+
probs = logits_per_text.softmax(dim=-1) # we can take the softmax to get the label probabilities
|
54 |
+
most_probable = torch.argmax(probs) # and now select the most likely audio waveform
|
55 |
+
waveform = waveforms[most_probable]
|
56 |
+
return waveform
|
57 |
+
|
58 |
# duplicate CSS config
|
59 |
|
60 |
css = """
|
|
|
191 |
label="Guidance scale",
|
192 |
info="Large => better quality and relevancy to text; Small => better diversity",
|
193 |
)
|
194 |
+
n_candidates = gr.Slider(
|
195 |
+
1,
|
196 |
+
3,
|
197 |
+
value=3,
|
198 |
+
step=1,
|
199 |
+
label="Number waveforms to generate",
|
200 |
+
info="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
|
201 |
+
)
|
202 |
|
203 |
outputs = gr.Video(label="Output", elem_id="output-video")
|
204 |
btn = gr.Button("Submit").style(full_width=True)
|
205 |
|
206 |
btn.click(
|
207 |
text2audio,
|
208 |
+
inputs=[textbox, negative_textbox, duration, guidance_scale, seed, n_candidates],
|
209 |
outputs=[outputs],
|
210 |
)
|
211 |
|