import gradio as gr import torch from diffusers import AudioLDMPipeline from transformers import AutoProcessor, ClapModel device = "mps" torch_dtype = torch.float32 repo_id = "cvssp/audioldm-m-full" pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device) pipe.unet = torch.compile(pipe.unet) clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device) processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full") generator = torch.Generator(device) def text2audio(text, negative_prompt, duration, guidance_scale, random_seed, n_candidates): if text is None: raise gr.Error("请提供文本输入") waveforms = pipe( text, audio_length_in_s=duration, guidance_scale=guidance_scale, negative_prompt=negative_prompt, num_waveforms_per_prompt=n_candidates if n_candidates else 1, generator=generator.manual_seed(int(random_seed)), )["audios"] if waveforms.shape[0] > 1: waveform = score_waveforms(text, waveforms) else: waveform = waveforms[0] return gr.make_waveform((16000, waveform)) def score_waveforms(text, waveforms): inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True) inputs = {key: inputs[key].to(device) for key in inputs} with torch.no_grad(): logits_per_text = clap_model(**inputs).logits_per_text probs = logits_per_text.softmax(dim=-1) most_probable = torch.argmax(probs) waveform = waveforms[most_probable] return waveform iface = gr.Blocks() with iface: with gr.Group(): with gr.Box(): textbox = gr.Textbox( max_lines=1, label="要求", info="要求", elem_id="prompt-in", ) negative_textbox = gr.Textbox( max_lines=1, label="更详细的要求", info="更详细的要求", elem_id="prompt-in", ) with gr.Accordion("展开更多选项", open=False): seed = gr.Number( value=45, label="种子", info="不同种子有不同结果,相同种子有相同结果", ) duration = gr.Slider(2.5, 10, value=5, step=2.5, label="持续时间(秒)") guidance_scale = gr.Slider( 0, 4, value=2.5, step=0.5, label="质量", info="大有更好的质量和与文本的相关性;小有更好的多样性", ) n_candidates = gr.Slider( 1, 3, value=3, step=1, label="候选数量", info="这个数字控制候选数量", ) outputs = gr.Video(label="Output", elem_id="output-video") btn = gr.Button("Submit").style(full_width=True) btn.click( text2audio, inputs=[textbox, negative_textbox, duration, guidance_scale, seed, n_candidates], outputs=[outputs], ) iface.queue(max_size=10).launch(debug=True)