| import gradio as gr | |
| import torch | |
| from diffusers import AudioLDMPipeline | |
| from transformers import AutoProcessor, ClapModel | |
| device = "cpu" | |
| torch_dtype = torch.float32 | |
| repo_id = "cvssp/audioldm-m-full" | |
| pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device) | |
| pipe.unet = torch.compile(pipe.unet) | |
| clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device) | |
| processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full") | |
| generator = torch.Generator(device) | |
| def text2audio(text, negative_prompt, duration, guidance_scale, random_seed, n_candidates): | |
| if text is None: | |
| raise gr.Error("θ―·ζδΎζζ¬θΎε ₯") | |
| waveforms = pipe( | |
| text, | |
| audio_length_in_s=duration, | |
| guidance_scale=guidance_scale, | |
| negative_prompt=negative_prompt, | |
| num_waveforms_per_prompt=n_candidates if n_candidates else 1, | |
| generator=generator.manual_seed(int(random_seed)), | |
| )["audios"] | |
| if waveforms.shape[0] > 1: | |
| waveform = score_waveforms(text, waveforms) | |
| else: | |
| waveform = waveforms[0] | |
| return gr.make_waveform((16000, waveform)) | |
| def score_waveforms(text, waveforms): | |
| inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True) | |
| inputs = {key: inputs[key].to(device) for key in inputs} | |
| with torch.no_grad(): | |
| logits_per_text = clap_model(**inputs).logits_per_text | |
| probs = logits_per_text.softmax(dim=-1) | |
| most_probable = torch.argmax(probs) | |
| waveform = waveforms[most_probable] | |
| return waveform | |
| iface = gr.Blocks() | |
| with iface: | |
| with gr.Group(): | |
| with gr.Box(): | |
| textbox = gr.Textbox( | |
| max_lines=1, | |
| label="θ¦ζ±", | |
| info="θ¦ζ±", | |
| elem_id="prompt-in", | |
| ) | |
| negative_textbox = gr.Textbox( | |
| max_lines=1, | |
| label="ζ΄θ―¦η»ηθ¦ζ±", | |
| info="ζ΄θ―¦η»ηθ¦ζ±", | |
| elem_id="prompt-in", | |
| ) | |
| with gr.Accordion("ε±εΌζ΄ε€ιι‘Ή", open=False): | |
| seed = gr.Number( | |
| value=45, | |
| label="η§ε", | |
| info="δΈεη§εζδΈεη»ζ,ηΈεη§εζηΈεη»ζ", | |
| ) | |
| duration = gr.Slider(2.5, 10, value=5, step=2.5, label="ζη»ζΆι΄(η§)") | |
| guidance_scale = gr.Slider( | |
| 0, | |
| 4, | |
| value=2.5, | |
| step=0.5, | |
| label="质ι", | |
| info="ε€§ζζ΄ε₯½η质ιεδΈζζ¬ηηΈε ³ζ§οΌε°ζζ΄ε₯½ηε€ζ ·ζ§", | |
| ) | |
| n_candidates = gr.Slider( | |
| 1, | |
| 3, | |
| value=3, | |
| step=1, | |
| label="ειζ°ι", | |
| info="θΏδΈͺζ°εζ§εΆειζ°ι", | |
| ) | |
| outputs = gr.Video(label="Output", elem_id="output-video") | |
| btn = gr.Button("Submit").style(full_width=True) | |
| btn.click( | |
| text2audio, | |
| inputs=[textbox, negative_textbox, duration, guidance_scale, seed, n_candidates], | |
| outputs=[outputs], | |
| ) | |
| iface.queue(max_size=10).launch(debug=True) | |