hftesting / app.py
jytole's picture
Fixed ref to "text" variable
4de77e2
raw
history blame
2.22 kB
import gradio as gr
from diffusers import AudioLDMPipeline, DPMSolverMultistepScheduler
from transformers import AutoProcessor, ClapModel
import torch
# import scipy
device="cpu"
repo_id = "cvssp/audioldm-s-full-v2"
pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float32)
#pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device)
processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full")
generator = torch.Generator(device)
def texttoaudio(prompt, neg_prompt, seed, inf_steps, guidance_scale, n_candidates):
if prompt is None:
raise gr.Error("Please provide a text input.")
waveforms = pipe(
prompt,
negative_prompt=neg_prompt,
num_inference_steps=int(inf_steps),
guidance_scale=guidance_scale,
audio_length_in_s=5.0,
generator=generator.manual_seed(int(seed)),
num_waveforms_per_prompt=int(n_candidates) if n_candidates else 1,
)["audios"]
# save the audio sample as a .wav file
# scipy.io.wavfile.write("output.wav", rate=16000, data=audio)
if waveforms.shape[0] > 1:
waveform = score_waveforms(prompt, waveforms)
else:
waveform = waveforms[0]
return (16000, waveform)
def score_waveforms(text, waveforms):
inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True)
inputs = {key: inputs[key].to(device) for key in inputs}
with torch.no_grad():
logits_per_text = clap_model(**inputs).logits_per_text # this is the audio-text similarity score
probs = logits_per_text.softmax(dim=-1) # we can take the softmax to get the label probabilities
most_probable = torch.argmax(probs) # and now select the most likely audio waveform
waveform = waveforms[most_probable]
return waveform
iface = gr.Interface(fn=texttoaudio, title="AudioLDM Testing Playground", inputs=["text", "text", "number", "number", "number", "number"], outputs="audio")
iface.launch()